안녕하세요. 지난 포스팅의 [Transformer] PvT v2: Improved baselines with pyramid vision transformer (Springer CBM2022)에서는 기존의 PVT v1 구조에서 Linear SRA, Overlapping Patch Embedding, Convolutional FFN이 추가된 PVT v2를 소개하였습니다. 이러한 구조로 인해 더욱 효율적인 모델이 만들어졌으며 inductive bias를 주입할 수 있어 positional embedding의 필요성을 낮추게 되었습니다. 오늘은 이전에 소개시켜드렸던 DeiT와 유사하게 Knowledge Distillation을 기반으로 학습하는 모델이지만 더 빠르게 그리고 더 강력한 모델을 만드는 학습 프레임워크와 이를 위한 더 가벼운 Transformer인 TinyViT에 대해서 말씀드리겠습니다.
Background
기본적으로 저희가 보았던 Transformer 기반 모델들 (ViT, DeiT, SwinT, CvT, ...)은 모두 너무 큰 모델 사이즈와 Pretraining 시 요구되는 대규모 데이터셋 (ImageNet-21K, JFT-300M)으로 인해 모바일 또는 IoT 계열에서 적합하지 않는다는 문제점이 있습니다. 이러한 문제점은 CNN에서 제안되었던 MobileNet과 마찬가지로 Lightweight Transformer를 만드는 것에 목표를 두게 되었습니다. 기존 모델들보다도 작은 스케일이지만 그럼에도 불구하고 downstream task에 대한 전이능력을 향상시키는 것도 목표로 두게 되죠.
이러한 목표들로 인해 본 논문에서는 두 가지 근본적인 질문을 던집니다.
- How to effectively transfer the knowledge of existing large-scale transformers to small ones, as well as unleash the power of large-scale data to elevate the representing of small models?
- Is there any possible way for small model to absorb knowledge from massive data and further unveil their capacities?
두 질문들의 핵심은 기존의 대규모 스케일을 가지는 Transformer의 지식을 소규모 스케일의 Transformer에 이전함으로써 대규모 데이터를 직접적으로 학습하지 않고 소규모 모델의 표현력을 높히는 방법에 대한 질문입니다. 본 논문에서는 이러한 문제를 해결하기 위해 빠르고 다양하게 활용가능한 새로운 Knowledge Distillation 전략을 소개합니다. 그리고 이를 통해 제안된 새로운 소규모 Transformer 모델인 TinyViT는 ImageNet-1K에서 기존 DeiT와 Swin-T 보다도 높은 성능을 달성할 뿐만 아니라 Fine-grained Image Classification 및 Object Detection과 같은 downstream task에서도 높은 성능을 달성하게 되었습니다.
Tiny Vision Transformer (TinyViT)
1) Overall Framework
그림 2는 본 논문에서 제안하는 "Fast and Scalable Knowledge Distillation"에 대한 전체적인 프레임워크를 보여주고 있습니다. 기본적인 컨셉부터 설명드리자면 일반적으로 Knowledge Distillation에서는 Teacher & Student 모델로 구성되어 Teacher 모델의 출력 결과가 Student 모델의 레이블이 되는 방식으로 선생님의 지식을 학생이 흡수하는 형태로 수행됩니다. 이를 위해서는 Teacher 모델은 Student 모델보다 더 큰 규모의 모델로 구성되어야합니다. 예를 들어, DeiT에서는 학습 시 ViT에 CNN이 가진 Inductive Bias를 주입해주기 위해 Teacher 모델로 RegNetY라는 모델을 사용하였습니다.
하지만 이러한 구성에서 큰 문제점은 학습 시 Teacher & Student 모델로부터 출력값을 모두 얻어야하기 때문에 Forwarding 과정과 Backwarding 과정을 두 번 수행해야합니다. 이는 생각보다 큰 GPU 메모리 소모가 발생하게 되고 상대적으로 배치 사이즈를 적게 잡고 학습할 수 밖에 없기 때문에 전체적인 학습 속도가 떨어질 수 밖에 없습니다.
그렇다면 단순하게 Teacher 모델의 Forwarding 및 Backwarding 과정을 없애버리면 안될까요? 미리 영상-출력 결과에 대한 쌍으로 데이터베이스에 저장해놓은 뒤 나중에 Student 모델을 학습할 때 특정 영상이 입력되었을 때 이 데이터베이스에서 꺼내어 쓰면 될 것 입니다! 해당 방식이 본 논문에서 제시하는 가장 중요한 핵심입니다.
여기서 궁금한 점이 생길 수도 있을 것 입니다. "왜 Knowledge Distillation을 수행해야하지? 그냥 Student Model에 직접적으로 Pretraining을 하면 안되나?". 이 역시 본 논문의 관찰을 통해 소규모 모델에서는 ImageNet-21K와 같은 대규모 데이터셋에 직접적으로 학습하게 되면 오히려 성능이 떨어지는 현상을 관찰했다고 합니다. 이유는 이후에 더 자세하게 설명드리도록 하겠습니다.
2) Fast Pretraining Distillation
일반적으로 딥 러닝 모델을 학습할 때는 데이터 증강 (Data Augmentation)도 함께 들어가며 Knowledge Distillation 상황에서는 RandAugment, CutMix와 같은 더 강력한 데이터 증강이 적용되기도 합니다. 이러한 데이터 증강들을 $\mathcal{A}$라고 정의하고 Teacher 모델 $T$가 주어졌을 때 $\hat{\mathbf{y}} = T(\mathcal{A}(x))$를 Teacher 모델의 출력 결과라고 가정하겠습니다. 여기서 중요한 점은 데이터 증강을 수행하면서 Teacher 모델을 학습하게 되면 같은 영상이라고 하더라도 내재된 무작위성에 의해 서로 다른 영상을 만들게 됩니다. 따라서, 어떤 데이터 증강 $\mathcal{A}$이 적용되었을 때 그에 대한 출력 결과 $\hat{\mathbf{y}}$를 쌍 $(\mathcal{A}, \hat{\mathbf{y}})$를 함께 저장해야합니다.
Teacher 모델이 학습이 다 끝났다고 가정하면 다음 단계는 Student 모델 $S$을 학습해야합니다. 이를 위해 $(\mathcal{A}, \hat{\mathbf{y}})$이 저장된 데이터베이스로부터 예측 결과를 복원하여 다음과 같이 학습을 진행합니다.
$$\mathcal{L} = \text{CE} (\hat{\mathbf{y}}, S(\mathcal{A} (x)))$$
여기서 재밌는 점은 학습할 때 입력 영상 $x$에 대응되는 ground truth $\mathbf{y}$를 사용하지 않고 오직 Teacher 모델의 출력 결과 $\hat{\mathbf{y}}$만을 사용하여 학습한다는 점입니다. 이러한 학습 framework는 나중에 레이블이 없는 상황에서도 유연하게 학습을 수행할 수 있는 방식을 제공하게 됩니다. 또한, ImageNet-21K의 레이블 자체가 서로 상호배타적으로 구성되어 있지 않습니다. 예를 들어, "의자", "가구"와 같은 형태가 있죠. 이는 레이블 간 상관성이 존재하기 때문에 Teacher 모델로 부터 이러한 지식을 이해하는 형태로 생각해볼 수 있습니다.
하지만 Teacher 모델의 출력 결과를 저장하기 위해서는 생각보다 큰 비용이 소모됩니다. 일단, ImageNet-21K의 레이블 개수는 약 2만 2천개의 레이블이 존재합니다. 만약, Hard Label이 아닌 Soft Label로 저장한다고 가정했을 때 한 장의 영상은 22000의 길이를 가지는 벡터로 각 클래스별 확률로 약 32bit를 가지기 때문에 대략 88킬로바이트의 저장용량을 차지하겠죠. 그런데 ImageNet-21K는 총 14,197,122개의 영상으로 구성되었기 때문에 총 1.24TB가 요구됩니다. 생각보다 엄청난 크기의 저장용량이 필요합니다. 본 논문에서는 이러한 문제를 해결하기 위해 Teacher 모델 기준으로 Top-K 개의 Soft Label을 선택하여 저장합니다. 이를 통해 저장용량을 크게 절약할 수 있게 됩니다.
또한, 데이터 증강의 무작위성으로 인해 서로 다른 영상들이 만들어지기 때문에 파라미터 $\mathbf{d}$도 함께 저장해주어야합니다. 이를 위해 인코더 $ \epsilon ( \cdot )$을 이용하여 $d_{0} = \epsilon (\mathbf{d})$ 가볍게 변환해주어 저장해줍니다. 나중에 Student 모델을 학습할 때는 디코더 $ \epsilon^{-1} (\cdot) $을 이용하여 $\mathbf{d} = \epsilon^{-1} (d_{0}) $로 복원한 뒤 사용하면 됩니다. 본 논문에서는 PCG라는 방식을 디코더와 인코더에 동일하게 사용했다고 하네요.
3) Model Architecture
본 논문에서 학습할 때 사용한 Student 모델은 기본적으로 Hierarchical Vision Transformer 구조를 가지고 있습니다. 그리고 Swin Transformer와 LeViT와 동일하게 4개의 스테이지로 구성되어 점점 특징 맵의 해상도가 감소하게 됩니다. Pach Embedding에서는 $3 \times 3$ 크기의 커널 크기, stride 2, padding 1로 구성된 2개의 합성곱 계층을 활용합니다. 또한, 모든 스테이지에서 Transformer를 사용하게 되면 Resolution의 제곱만큼 복잡도가 증가하는 Self-Attention의 특성상 비효율적입니다. 이러한 문제를 해결하기 위해 본 논문에서는 Stage 1에서는 MBConv를 사용하고 나머지 Stage들에서는 Transformer Block을 이용합니다. 그리고 MLP 내에서는 $3 \times 3$ 크기의 Depthwise Separable Convolution을 활용하여 내부적으로 지역적 특징을 잡을 수 있도록 도와줍니다. 그리고 다른 Transformer 구조와 마찬가지로 residual connection도 함께 추가하여 학습의 안정성도 추가해줍니다.
TinyViT의 핵심 파라미터들은 위와 같습니다. 다만 전체적인 모델의 복잡도는 $\gamma_{D_{1-4}}$가 통제하게 됩니다. 나머지 파라미터들은 스케일에 관계없이 모두 동일하게 셋팅하여 다음과 같습니다.
- $\{ \gamma_{N_{1}}, \gamma_{N_{2}} , \gamma_{N_{3}} , \gamma_{N_{4}} \} = \{ 2, 2, 6, 2 \}$
- $\{ \gamma_{W_{2}}, \gamma_{W_{3}} , \gamma_{W_{4}} \} = \{ 7, 14, 7 \}$
- $\{ \gamma_{R}, \gamma_{M} , \gamma_{E} \} = \{ 4, 4, 32 \}$
그리고 TinyViT의 스케일에 따른 $\gamma_{D_{1-4}}$는 다음과 같이 정해집니다.
- TinyViT-21M: $\{ \gamma_{D_{1}}, \gamma_{D_{2}} , \gamma_{D_{3}} , \gamma_{D_{4}} \} = \{ 96, 192, 384, 576 \}$
- TinyViT-11M: $\{ \gamma_{D_{1}}, \gamma_{D_{2}} , \gamma_{D_{3}} , \gamma_{D_{4}} \} = \{ 64, 128, 256, 448 \}$
- TinyViT-5M: $\{ \gamma_{D_{1}}, \gamma_{D_{2}} , \gamma_{D_{3}} , \gamma_{D_{4}} \} = \{ 64, 128, 160, 320 \}$
Analysis and Discussion
1) What are the underlying factors limiting small models to fit large data?
그림 1에서 보았다싶이 단순히 작은 모델에 바로 ImageNet-21K를 학습하게 되면 downstream task에서 성능이 충분히 좋아지지 않았습니다. 그 이유는 ImageNet-21K의 특성 상 틀린 레이블이 생각보다 많이 존재하고 유사한 영상임에도 불구하고 서로 다른 형태의 레이블링이 되어있기 때문이라고 볼 수 있습니다. 이를 확인하기 위해 본 논문에서는 실제로 ImageNet-21K에 데이터 정제를 수행한 Cleaned ImageNet-21K를 이용해서 학습하게 되면 Swin Transformer와 TinyViT에서 Original ViT에 비해 더 높은 성능이 나오는 것으로 확인할 수 있습니다. 하지만 이러한 과정없이 Knowledge Distillation을 수행하게 되면 더 높은 성능을 얻을 수 있는 데, 그 이유는 Teacher 모델의 Soft Label을 활용하여 학습함으로써 어느정도 정제된 레이블을 받아 Student 모델이 학습할 수 있게 되는 것으로 해석할 수 있죠.
2) Why can distillation improve the performance of small models on large datasets?
또한, ImageNet-21K의 레이블 간 상호관계를 Knowledge Distillation없이 학습하는 이해하지 못한다는 문제점이 있습니다. 이는 그림 3 (b)에서 볼 수 있죠. 하지만, Knowledge Distillation을 수행하게 되면 Teacher 모델과 마찬가지로 어느정도 Domain Knowledge를 Student 모델이 이해한 것을 볼 수있습니다. (그림 3 (a)와 (c))
Experiment Results
1) ImageNet-1K
2) Few-Shot Image Classification
3) Object Detection
Ablation Study
1) Pretraining Strategy
2) Hyperparameter Test
3) Different Teacher Models for Pretraining Distillation