TinyViT

Introduction

ViT는 많은 발전이 있었지만 edge device에 적용하기에는 모델이 너무 컸다. 하지만 모두들 알다싶이 작은 모델 representation이 작다. 따라서 small dataset에 적합할지 몰라도 large dataset에서는 빠르게 saturation이 되면서 underfitting이 발생하게 되어 잘 사용하지 못한다. 저자는 이에 대하여 고민을 했고 knowledge distillation을 사용해야한다는 결론에 도달하게 된다. 이에따라 small model이 downstream task에 transfer가 잘 되는 것을 확인했다. 하지만 기존의 knowledge distillation은 train시 teacher network를 메모리에 올리기 때문에 teacher network가 gpu memory를 다수 잡아먹게 되어 batch size 조절이 힘들다. 또한 학습에 필요한 soft-label을 그때그때 제작하기 때문에 학습속도도 느려진다. 따라서 저자는 이를 해결하기 위해 data augmentation과 soft-label을 먼저 저장하여 student model이 학습 시 이를 사용하는 방법을 사용했다. 결론적으로 TinyViT는 21M의 파라미터로 ImageNet에서 84.8% top-1 accuracy를 달성했으며 88M으로 85.8%를 달성한 Swin-B보다 4.2배 적은 파라미터이다. 이미지 크기를 키웠을 때 SOTA를 찍었으며 COCO object detection도 Swin-T 우위에 있다는 것을 확인했다.

TinyViT

Fast Pretraining Distillation

위 사진에서 볼 수 있듯 small model을 바로 큰 데이터셋에 학습하면 성능이 낮아진다. 따라서 knowledge distillation을 이용하려고 하는데 저자는 downstream task를 위해 finetuning-distillation이 아니라 pretraining distillation에 주목을 한다. 이에 따라 ImageNet-1K에는 distillation없이 finetuning을 이용하여 학습을 진행한다. 하지만 pretraining distillation은 large teacher model이 많은 데이터셋을 inference 해야 하기 때문에 비용이 많이 들고 비효율적이다. 따라서 Fig.2에서 나오듯 사전에 teacher model에 사용되는 augmentation과 그에 맞는 label을 만들어 스토리지에 저장하게 된다. 이를 통해 student model이 학습 시 teacher model의 forward computation 없이 knoweldge distillation을 할 수 있다. 수학적으로 표현하자면, input image를 \(x\), RandAugment나 CutMix와 같은 strong augmentation을 \(\mathcal{A}\)라 하자. Teacher 학습시 augmentation \(\mathcal{A}\), techer prediction \(\hat{y}=T(\mathcal{A}(s))\)의 pair인 \((\mathcal{A},\hat{y})\)를 저장한다. 이 때 \(T(\cdot )\)은 teacher model을 의미한다. 그 후 student \(S(\cdot)\)과 cross entropy \(CE(\cdot)\)에 대해 loss 연산을 진행한다.

\[\mathcal{L}=CE(\hat{y},S(\mathcal{A}(x)))\]

Sparse Soft Label

Teacher model의 output을 그대로 저장하는 것은 storage를 많이 사용한다. 따라서 저자는 \(\hat{y}\)의 top-K value과 그들의 indices \(\{\mathcal{I}(k)\}_{k=1}^K\)만을 저장하고 나머지는 label smoothing을 이용하여 reconstruct한다. 학습 시 ground truth인 hard label을 사용하지 않고 pseudo label인 soft label만을 사용하여 학습한다.

\[\hat{y}_c = \begin{cases} \hat{y}_{\mathcal{I}(k)} & \text{if} \ c=\mathcal{I}(k)\\ \frac{1-\sum_{k=1}^K\hat{y}_{\mathcal{I}(k)}}{C-K} & \text{otherwise} \end{cases}\]

이 때 \(\hat{y}_c\)는recovered teacher logit for student model distillation이라고 하고, \(\hat{y}=[\hat{y}_1, ... ,\hat{y}_c, ... ,\hat{y}C]\)이다. 만약 \(K << C\)라면 메모리 감소가 크다.

Data augmentation encoding

Data augmentation 정보 또한 그대로 저장하면 storage를 많이 사용한다. (rotation degree, crop coordinate 등) 따라서 set of data augmenatation parameter를 \(\mathbf{d}\)라고 하고, encoder \(\mathcal{E}(\cdot)\) 를 사용하여 \(d_0=\mathcal{E}(\mathbf{d})\)로 변환하여 저장한다. 그 후 training process에는 \(\mathcal{E}^{-1}(\cdot)\)을 decoder로사용하여 augmetation을 진행한다.

Model Architecture

저자는 기본적으로 ViT를 기반으로 모델을 설계했다. 또한 Swin Transformer나 LeViT와 같이 hierarchical한 구조를 채택했다. Patch embedding block은 kernel size 3, stride 2, padding 1인 두 개의 convolution을 사용했다. 하지만 처음부터 끝까지 transformer block을 사용하는 것은 연산적으로 무리가 된다. 따라서 MobileNetV2에서 사용하는 MBConv를 사용하여 stage 1과 downsampling을 구성했다. 또한 MBConv에 residual conntection 또한 적용했다. 마지막 3개 stage 에서는 transformer block을 사용했다. 모든 layer와 block에서 activation function은 요즘 성능이 좋은 GELU를 이용했고, normalization은 conolution layer에는 batch norm, linear layer에는 layer norm을 사용했다.

Contraction factors

Modern model와 같이 factor를 사용하여 모델의 크기를 조절한다.

  • \(\gamma_{D_{1-4}}\): embeded demension in four stages
  • \(\gamma_{N_{1-4}}\): number of block in four stages
  • \(\gamma_{W_{2-4}}\): window size of last three stages
  • \(\gamma_{R}\): channel expansion ratio of MBConv
  • \(\gamma_{M}\): channel expansion MLP of transformer blocks
  • \(\gamma_{E}\): the dimension of each head in multi-head attention

모든 모델이 \(\gamma_{D_{1-4}}\)을 제외하고 같은 factor를 갖는다.

\[\{\gamma_{N_1}, \gamma_{N_1}, \gamma_{N_1}, \gamma_{N_1}\} = \{2, 2, 6, 2\}\] \[\{\gamma_{W_1}, \gamma_{W_1}, \gamma_{W_1}\} = \{7, 14, 7\}\] \[\{\gamma_{R}, \gamma_{M}, \gamma_{E}\} = \{4, 4, 32\}\]

Embeded dimensiondm \(\{\gamma_{D_1}, \gamma_{D_2}, \gamma_{D_3}, \gamma_{D_4}\}\)는 TinyViT-21M: {96, 192, 384, 576},TinyViT-11M: {64, 128, 256, 448},TinyViT-5M: {64, 128, 160, 320} 으로 구성했다.

Analysis and Discussion

작은 모델은 Image21K와 같은 large scale dataset에서는 underfitting이 발생하여 학습이 잘되지 않는다. 그렇다면 다음 두 가지의 의문이 들 수 있다.

1. 어떠한 요소가 small model의 학습을 방해하는가?

이를 확인하려면 IN-21K의 문제점을 알아야한다. Label의 오류와 비슷한 이미지가 다른 label인 경우가 존재한다. 이와 같은 데이터가 전체의 10%로 hard sample이라고 불린다. Large model은 이러한 hard sample을 학습할 수 있자만 small model은 hard sample을 예측하기 어렵고 이는 small model을 large model 보다 낮은 성능으로 이끈다. 이를 증명하기 위해 강력한 모델인 Florence를 이용하여 data마다 florence prediction의 top 5 중에 label이 존재하는지 찾아봤다. 그 결과 ImageNet-21K의 14%에 해당하는 hard sample을 골라냈다. 또한 Florence를 이용하여 TinyViT-21M/Swin-T에 knowledge distllation을 적용했다. 그 결과 small model을 곧바로 ImageNet-21K에 학습하는 것은 성능에서 제약이 존재했고, ImageNet-21K에서 hard sample을 제거하면 1%p 정도 성능향상이 있었다. 놀라운 점은 knowledge distillation이 hard sample의 defact을 줄여줬다.

2. 왜 distillation이 small model의 large data 학습에 도움이 되는가?

Student model이 teacher model의 지식을 바로 학습할 수 있기 때문이다. Gound truth는 각 물체간 상간관계를 보여주지 못한다. 하지만 teacher model의 inference 값은 그것을 알 수 있다.

위의 그림을 보았을 때도 distillation 없이 TinyViT를 학습했을 때 비슷한 물체간 상관관계가 크지 않았으나 distillation 진행시 상관관계가 크게 나와 teacher model을 제대로 따라할 수 있었다.

Experiment

Impact of pretraining distillation on existing small ViTs

앞서 설명을 했 듯 ImageNet-21K로 distillation을 진행한 후 down stream task를 진행했다. 이 학습방법론이 효과가 있을지 다른 가벼운 ViT를 사용하여 실험한 결과 해당 방법론은 효과가 있었다.

Impact of pretraining data scale

Distillation을 pretranining할 때 사용한다. 이 때 사용한 데이터에 따라 성능차이를 확인해보니 데이터가 많아질수록 성능이 좋아졌다.

Impact of the number of saved logits

Soft-label의 용량을 줄이기 위해서 top-K만 저장한다고 했다. K가 늘어날 수록 메모리 사용량은 늘어나나 정확도는 비슷하다. 따라서 이를 밸런스 있게 가져가기 위해서 ImageNet-1K에서는 \(K=10\), ImageNet-21K에서는 \(K=100\)으로 설정했다.

Impact of teacher models

Teacher Model을 어떤 것으로 설정하는지에 따라서도 성능차이가 났다. 더 강력하고 좋은 teacher model을 사용하면 성능이 좋아지나 그 만큼 학습시간이 늘어나는 것을 볼 수 있었다.

Result on ImageNet

위 표에서 볼 수 있듯 TinyViT는 경량화 ViT 중에서 좋은 성능을 내었다.

Transfer Learning Results

Self-supervised learning으로 MOCO를 사용하여 Linear probe를 하거나 few shot learning을 할 때도 성능이 좋았다.

또한 MS COCO를 사용하여 object detection에 사용할 때도 성능이 좋았다.

Comment

사실 TinyViT 자체는 knowedge distillation 내용이라 새로운 것이 별로 없다. 결과들을 미리 저장하는 것 역시 가끔씩 쓰는 트릭이라 크게 중요하지 않다. 해당 논문은 그런 것 보다 small ViT가 ImageNet에서 hard example에 취약하여 이를 골라내는 과정이 중요한 것 같다. 또한 저자는 말하지 않았지만 (몰랐을 수도 있지만) sparse soft label 자체가 teacher model의 noise를 제거하는 역할을 하여 small ViT에 더 좋은 성능을 가져와 주었을 것이다. (비교적 어려운 이미지들은 confidence가 낮은 이미지들의 probability가 noise로 작용하기도 한다.) 끝마치면서 드는 생각은 ImageNet-21K로 pretrain하고 이를 ImageNet-1K로 finetuning하는 것이 workshop에서 맞는 방법이지 않나 싶으면서도 ViT를 생각해보면 그럴수도 있다는 생각이 든다.




Enjoy Reading This Article?

Here are some more articles you might like to read next:

  • Meta Pseudo Labels
  • Proper Reuse of Image Classification Features Improves Object Detection
  • Integral Neural Network
  • DINE: Domain Adaptation from Single and Multiple Black-box Predictors
  • EdgeViT