안녕하세요. 지난 포스팅의 [Transformer] Transformer in Transformer (NIPS2021)에서는 큰 패치로 나눈 뒤 그 패치들을 다시 나누어 서브 패치 간의 관계성을 학습하는 TNT에 대해서 알아보았습니다. 오늘은 JFT-300M 데이터셋과 같은 대규모 데이터셋에 사전학습의 필요성을 줄이기 위한 시도 중 하나인 Compact Transformer에 대해서 알아보겠습니다.
Background
Convolution Neural Network (CNN) 이후로 최근 다양한 Transformer 기반의 모델들이 각광받고 있습니다. 특히, Vision Transformer (ViT)의 등장으로 Computer Vision 분야에서 엄청난 관심을 이끌게 되었죠. 하지만, 합성곱 연산의 local feature extraction과 같은 inductive bias가 존재하지 않아 JFT-300M 데이터셋과 같은 대규모 데이터셋에 사전학습을 해야만 하는 문제점이 있었죠. 또한, JFT-300M은 구글에서만 사용하는 비공개 데이터셋이기 때문에 이러한 문제점은 더욱 강조됩니다.
이러한 문제점을 해결하기 위해 본 논문에서는 CIFAR 또는 ImageNet-1K에 직접 학습했을 때도 성능을 보장할 수 있는 모델을 설계하였습니다. 이를 Compact Transformer라고 부르며 접근법에 따라 아래의 그림과 같 Compact Vision Transformer (CVT)와 Compact Convolution Transformer (CCT)로 나뉘게 됩니다.
Vision Transformer Lite version
위 그림은 ViT, CVT, 그리고 CCT 간의 차이점을 보여주고 있습니다. 여기서 가장 큰 차이점은 Class Token이 빠지고 Sequence Pooling (SeqPool)이라는 계층이 추가된 것을 볼 수 있죠. 또한, CVT와 CCT 사이의 차이점은 patch extraction 및 embedding 과정을 convolution을 이용해서 할 지에 대한 것임도 확인할 수 있습니다.
1) Transformer-based Backbone
기본적으로 본 논문에서 사용하는 Transformer는 original ViT를 기반으로 설계되었습니다. 따라서 흔히 저희가 알고 있는 Multi-head Self-attention, Multi-layer Perceptron, Layer Normalization, GeLU 등과 같은 구조는 동일하게 사용됩니다. 또한, positional encoding 역시 적용될 수 있으며 본 논문에서는 이를 optional로 두고 있습니다. 사용하지 않을 수도 있고, 사용하게 된다면 learnable하게 하거나 sinusoidal 하게 할 수도 있죠. 두 방법 모두 성능 상으로는 효과가 있었다고 합니다.
하지만, 이를 CIFAR와 같은 저해상도의 소규모 데이터셋에 활용하기 위해서는 ViT의 configuration을 수정하여 더 작은 규모의 모델을 만들필요가 있습니다.
위 표는 Vision Transformer 논문의 모델 별 configuration입니다. 본 논문에서는 더 작은 모델인 ViT-Lite를 설계하기 위해 ViT-Base를 기반으로 수정합니다. 이를 위해 ViT-Lite-12/16은 12개의 Transformer encoder 구조를 사용하며 $16 \times 16$의 패치로 나눈다는 것으로 가정하겠습니다.
위 표는 이를 기반으로 구성한 configuration으로 ViT-Base 보다 transformer encoder, head, ratio, hidden dimension을 절반 가량 줄인 것을 볼 수 있습니다.
2) SeqPool
다음은 본 논문의 핵심 중 하나인 SeqPool입니다. 기본적으로 Transformer 마지막 encoder에서는 입력 영상의 여러 부분에 걸쳐 관련된 정보가 종합되어 있습니다. 본 논문에서는 이러한 정보를 보존하면 성능이 향상될 수 있을 것이라고 생각하여 class token을 없애고 최종 단계에서 출력 시퀀스에서의 self-attention을 수행합니다. 이 과정이 SeqPool이 되는 것 이죠. 여기서 class token이 없어짐으로써 계산량이 약간 감소한다는 장점이 있습니다. 핵심은 어떤 변환 함수 $T: \mathbb{R}^{b \times n \times d} \rightarrow \mathbb{R}^{b \times d}$를 학습하는 것이 목표입니다. 이를 위해 다음과 같은 과정을 거치게 됩니다.
STEP1. $\mathbf{x}_{L} = f(\mathbf{x}_{0}) \in \mathbb{R}^{b \times n \times d}$를 $L$개의 transformer encoder를 통과한 출력 시퀀스라고 하자. 여기서 $b$는 배치 사이즈, $n$은 출력 시퀀스의 길이, $d$는 hidden dimension이다.
STEP2. $\mathbf{x}_{L}$을 선형함수 $g(\mathbf{x}_{L}) \in \mathbb{R}^{d \times 1}$에 입력한 두 softmax function을 적용한다.
$$\mathbf{z} = \text{softmax} (g(\mathbf{x}_{L})^{T}) \times \mathbf{x}_{L} \in \mathbb{R}^{1 \times 1 \times d}$$
STEP3. 출력 시퀀스 $\mathbf{z}$를 classifier에 입력한다.
3) Convolutional Tokenizer
여기서부터는 CCT를 위한 설명입니다. 이전 포스팅에서도 설명드렸다 싶이 Transformer에서는 CNN이 가지고 있는 inductive bias가 부족하기 때문에 이를 주입해주는 것도 한가지 연구방향입니다. 본 논문의 CCT에서는 Patch Embedding 시 단순한 convolution layer를 이용하여 추출하였습니다. 이로 인해 indunctive bias가 포함되었기 때문에 positional encoding은 선택적으로 적용하면 되겠네요.
$$\mathbf{x}_{0} = \text{MaxPool} (\text{ReLU} (\text{Conv2D} (\mathbf{x})))$$
Experiment Results