안녕하세요. 지난 포스팅의 [Transformer] P2T: Pyramid Pooling Transformer for Scene Understanding (IEEE TPAMI2022)에서는 기존의 Pyramid Vision Transformer와 Multi-Scale ViT에서 다루지 않은 Pyramid Pooling을 통한 연산량 감소 및 강력한 특징 표현을 얻을 수 있는 P2T에 대한 설명을 드렸습니다. 그런데, 실제 P2T 구현에서는 이상하게 positional embedding이 없고 MobileNetV1에서 제안된 Depth-wise Separable Convolution을 사용하는 것을 볼 수 있었습니다. 저는 이 부분에 대해 궁금증이 생겨 찾아보니 관련 논문 중 Convolutions to Vision Transformer (CvT)라는 논문을 찾을 수 있었습니다. 오늘은 해당 논문에 대한 소개와 함께 어째서 이후 논문들이 positional embedding이 필요없어지고 Depth-wise Separable Convolution이 사용되기 시작했는 지 알아보도록 하겠습니다.
Background
Transformer 기반 모델의 근본적인 문제는 크게 두 가지로 정리할 수 있습니다. 1) 높은 연산량으로 인한 고해상도 영상을 사용해야하는 Dense Prediction에서는 사용할 수 없음. 2) CNN에서 sliding window를 통해 얻을 수 있는 Spatial Correlation을 얻을 수 없음. 이러한 문제로 인해 작은 규모의 데이터셋 (ImageNet-1K)에서 ViT 기반의 모델들을 학습했을 때 성능이 유사한 규모의 CNN 기반 모델 (ResNet)에 비해 많이 떨어지게 되는 것이죠.
따라서, 본 논문은 Transformer가 가지고 있는 Dynamic Attention, Global Context Fusion, Generalizability 능력과 CNN이 가지고 있는 Local Receptive Field, Shared Weights, Spatial Subsampling과 같은 특징을 결합하고자 Convolutions to Vision Transformer (CvT)를 제안합니다. 근본적으로 positional embedding을 사용해야하는 이유는 패치 간 위치 정보를 배양하기 위함이지만 CNN 구조를 적용함으로써 그 정보는 암시적으로 적용될 수 있겠죠. 이는 positional embedding이 이제부터 필요없어지는 계기가 될뿐만 아니라 다양한 크기의 해상도를 가지는 영상에도 학습할 수 있는 계기를 마련합니다. 본 논문의 핵심 기여도를 정리하면 다음과 같습니다.
- 기존 Vision Transformer 구조에 처음으로 Convolution 연산을 결합한 모델인 Convolution to Vision Transformer (CvT)를 제안
- CvT가 CNN과 ViT의 두 가지 장점을 모두 함유하게 됨
- CNN을 결합하면서 Positional Embedding을 암시적으로 적용할 수 있게 됨으로써 초기 positional embedding이 필요없어지게 되고 이는 다른 vision task에서 다양한 해상도를 가지는 영상을 사용할 수 있게 되는 계기 마련
- 소규모 모델인 ImageNet-1K에서 학습했을 때 가벼우며 효율적으로 state-of-the-art 성능을 달성
Convolutional Vision Transformer (CvT)
1). Overall Architecture
그림 2는 CvT는 전체적인 구조를 보여주고 있습니다. 본 논문에서 제안하는 모델의 핵심은 2가지 CNN 기반 모듈로써 patch extraction을 convolution을 이용해 수행하는 Convolutional Token Embedding 그리고 Transformer Block에 Depth-wise Separable Convolution을 도입한 Convolutional Projection 입니다.
2). Convolutional Token Embedding
그림 2의 가장 왼쪽의 input image $x_{0}$에서 색이 들어간 사각형이 대각선 방향으로 움직이고 있는 듯한 그림을 볼 수 있습니다. 본 모듈의 목표는 기존 ViT에서 가지고 있지 않은 계층적 구조 (Hierarchical Structure)를 주입하여 low-level edge부터 high-level semantic primitive까지의 지역적 공간 특징을 모델링하는 것 입니다. 이는 입력 영상에 convolution 연산을 적용하는 단계로써 다음과 같은 단계로 수행됩니다.
STEP1. 2D 영상 또는 2D reshaped token map $x_{i - 1} \in \mathbb{R}^{H_{i - 1} \times W_{i - 1} \times C_{i - 1}}$을 이전 Transformer Block로부터 입력
STEP2. 함수 $f(\cdot)$을 이용해 다음 token map 생성 (이때, $f(\cdot)$은 $s \times s$ 크기의 커널과 $s - o$ 크기의 stride, $p$의 패딩을 가지는 2D convolution operation으로 정의)
$$f(x_{i - 1}) \in \mathbb{R}^{H_{i} \times W_{i} \times C_{i}}$$
STEP3. 새로운 token map $f(x_{i - 1})$를 flatten하여 $H_{i}W_{i} \times C_{i}$의 shape을 가지도록 변경
STEP4. $i$번째 Convolutional Transformer Block에 입력되기 전 Layer Normalization 적용
이때, 새롭게 얻는 token map의 크기는 다음과 같습니다.
$$H_{i} = \lfloor \frac{H_{i - 1} + 2p - s}{s - 0} + 1 \rfloor, W_{i} = \lfloor \frac{W_{i - 1} + 2p - s}{s - o} + 1 \rfloor$$
실제로 현재 나오는 다양한 Transformer 기반 모델들에서는 위와 같이 patch embedding을 convolution operation으로 바꾸어 수행하는 편 입니다. 이와 같이 patch 단위로 분해된 token map이라고 하더라도 CNN을 통해 Spatial Correlation을 주입해줄 수 있기 때문에 positional embedding이 필요없어지는 한가지 이유가 될 수 있겠습니다. 뿐만 아니라 각 토큰은 보다 복잡한 visual pattern을 인식할 수 있기 때문에 이러한 구조는 CNN과 어느정도 유사성을 띠게 설계되었다고도 볼 수 있습니다.
3). Convolutional Projection for Attention
그림 3은 개인적으로 CvT의 제일 핵심이라고 생각하는 부분 중에 하나 입니다. 최근 등장하는 다양한 Transformer 기반 모델들 모두 Depth-wise Separable Convolution을 사용하는 데 본 논문은 그 이유를 뒷받침해주기 때문이죠. 이 구조를 통해 CvT는 CNN이 가지고 있던 local spatial context를 Transformer에 부여하고 $K$와 $V$의 sub-sampling을 통한 효율성도 함께 챙길 수 있게 됩니다.
사실 연산하는 과정도 굉장히 단순합니다. 기본적으로 입력되는 $x_{i}$는 현재 1D로 flatten된 patch 형태이기 때문에 이를 2D로 바꾸는 Reshape2D를 먼저 거칩니다. 그래야 2D Convolution을 적용할 수 있겠죠. 다음으로 Depth-wise Convolution + Batch Normalization + Point-wise Convolution으로 구성된 Conv2d를 적용합니다. 여기서 $s$는 Depth-wise Convolution에서 사용하는 커널의 크기를 의미합니다. 그 다음 Flatten을 거치면서 최종적으로 $Q, K, V$를 얻을 수 있습니다. 이는 이후 Multi-Head Self-Attention을 통과하여 흔히 저희가 알고 있는 Self-Attention을 수행하게 됩니다.
기본적으로 해당 모듈은 Depth-wise Separable Convolution을 사용하기 때문에 연산량 자체가 기존 Convolution 연산보다는 적습니다. 이에 대한 자세한 설명은 이전에 설명드린 MobileNetV1을 참고해주세요. 또한, 일반적으로 MHSA에서 연산량이 많이 들어가는 데 이는 P2T에서도 했듯이 $K$와 $V$의 크기를 줄여주면 됩니다. 본 논문에서는 이를 커널 크기 $s$로 조절할 수 있습니다 (그림 3 (b)과 그림 3 (c)). 본 논문에서는 default 값으로 $s = 2$를 사용했다고 하네요. 이 단계에서도 Convolution 연산을 통해 위치 정보를 embedding할 수 있기 때문에 positional embedding이 필요없어지는 두번째 이유가 되기도 합니다.
4). Model Variants
본 모델도 다른 Transformer 모델과 마찬가지로 파라미터의 개수에 따른 모델을 여러 제시합니다. 모델의 규모는 주로 각 stage의 블록 개수와 hidden feature의 개수에 따라 결정됩니다. 주의하셔야할 점은 기존 Transformer 모델과는 다르게 3개의 Stage로 구성되어 있다는 점 입니다. 그리고 $X$를 CvT 내에 존재하는 Transformer 블록의 개수라고 할 때 CvT-$X$라고 적을 수 있습니다. 즉, CvT-13와 CvT-21은 각각 13개, 21개의 Transformer 블록을 가지는 모델입니다.
여기서 마지막으로 굉장히 큰 규모의 모델인 CvT-W24가 존재합니다. 이 모델은 기존 두 모델인 CvT-13와 CvT-21보다 더 "Wide 하다"의 의미로 쓰였습니다. 처음 hidden feature의 개수부터 192개로 CvT-13와 CvT-21에 비해 3배 더 많은 것을 볼 수 있죠. 이로 인해 총 276.7M개의 파라미터를 가지는 대규모 모델을 만들 수 있습니다.
표 1은 이전에 제안된 Transformer 기반 모델들과의 각 요소별 비교를 하고 있습니다.
Experiment Results
1). ImageNet Classification
- Dataset
- ImageNet-1K: 1.28 million training images & 50K validation images with 1,000 classes
- Data Augmentation (Same as DeiT & PVT): random cropping, random horizontal flipping, label-smooting regularization, MixUp, CutMix, Random Erasing
- Optimization: AdamW
- momentum: -
- weight decay: 0.05 for CvT-13 & 0.1 for CvT-21 and CvT-W24
- learing rate: 0.02 with cosine learning rate decay scheduler
- batch size: 2,048
- epochs: 300
- GPU Type 언급 X
- 그림 1 (a)는 ViT, BiT 그리고 본 논문에서 제안하는 CvT 사이의 비교를 수행하고 있습니다. BiT에 비해 파라미터는 훨씬 적지만 성능은 훨씬 높네요. 그림 1 (b)에서는 주로 Transformer 기반 모델들과 비교하고 있습니다. 다른 Transformer 모델들과도 비교했을 때 속도와 성능을 모두 압도하고 있습니다.
- 표 1은 ImageNet-1K에 바로 학습한 모델들과 ImageNet-22K (ImageNet-21K와 동일)에 pre-trained된 모델을 fine-tuning했을 때 실험결과를 보여주고 있습니다. 결론적으로 모든 모델들보다 훨씬 높은 성능을 가지고 있습니다.
- ImageNet-Real과 ImageNet-V2는 모두 기존 ImageNet-1K와 다른 새로운 평가 데이터셋입니다. 해당 데이터셋에서도 모두 좋은 결과를 보여주고 있네요.
2). Downstream Task Transfer
- 표 4는 CvT의 다른 데이터셋으로의 Transferability를 측정하고 있습니다. 각 모델들은 모두 ImageNet-22K에 사전학습하고 각 데이터셋 (CIFAR10, CIFAR100, Pets, Flowers102)에 fine-tune을 진행하였습니다.
Ablation Studies
1). Position Embedding
- 표 5는 positional embedding에 대한 실험 결과 입니다. CvT-13을 이용해 실험하였으며 각각 모든 stage, 첫번째 stage, 마지막 stage, 아예 뺀 경우에 대한 실험을 수행하였습니다. CvT는 positional embedding의 유무에 따른 성능 변화가 전혀 크지 않은 것을 볼 수 있습니다. 하지만, DeiT는 positional embedding를 제거하면 성능이 크게 떨어지는 것을 볼 수 있죠.
- 이러한 실험 결과는 CvT에 내장된 convolution 연산이 positional embedding을 대신한다고 결론 지을 수 있습니다.
2). Convolutional Token Embedding
- 표 6은 Convolutional Token Embedding에 대한 실험 결과입니다. 여기서도 마찬가지로 positional embedding 유무에 대한 실험과 함께 Convolutional Token Embedding vs patchify에 대한 실험을 수행합니다.
- Convolutional Token Embedding과 positional embedding를 모두 넣게 되면 성능이 오히려 떨어지는 것을 볼 수 있죠. 이는 CNN에서 추출한 spatial correlation을 오히려 방해한다고 볼 수 있을 거 같습니다. 심지어 positional embedding을 넣게 되면 파라미터 개수도 0.3M만큼 늘어나기 때문에 큰 손해입니다.
3). Convolutional Projection
- 표 7은 stride의 크기 (1 또는 2)에 따른 성능 및 FLOPs 변화를 보여주고 있습니다.
- stride의 크기를 2로 늘렸을 뿐인데 FLOPs가 6.55G에서 4.53G가 엄청나게 큰 감소를 하는 것을 볼 수 있습니다. 하지만, 성능도 약 0.7% 정도 떨어지죠. CvT에서는 이러한 성능 하락에도 불구하고 효율성을 위해 stride를 2로 맞춰서 학습하였습니다.
- 표 8은 Convolutional Projection (CvT) vs Position-wise Linear Projection (ViT)를 stage 별로 추가했을 때 성능 변화를 보여주고 있습니다.
- Convolutional Projection을 점점 많은 stage에 추가할 수록 성능이 점점 증가하는 것을 볼 수 있죠.
Code Analysis
from typing import List, Optional, Callable
import torch
import torch.nn as nn
from timm.models.registry import register_model
from timm.models.vision_transformer import _cfg
from timm.models.layers import trunc_normal_
from einops import rearrange
from einops.layers.torch import Rearrange
class PreNorm(nn.Module):
def __init__(self,
dim: int,
fn: nn.Module) -> None:
super(PreNorm, self).__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
x = self.norm(x)
return self.fn(x, **kwargs)
class DepthwiseSepConv2d(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int=1,
padding: int=0,
dilation: int=1,) -> None:
super(DepthwiseSepConv2d, self).__init__()
self.depthwise_conv = nn.Conv2d(in_channels=in_channels,
out_channels=in_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=in_channels)
self.bn = nn.BatchNorm2d(in_channels)
self.pointwise_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.depthwise_conv(x)
x = self.bn(x)
x = self.pointwise_conv(x)
return x
class FeedForward(nn.Module):
def __init__(self,
dim: int,
mlp_dim: int,
drop_rate: float = 0.) -> None:
super(FeedForward, self).__init__()
self.net = nn.Sequential(
nn.Linear(dim, mlp_dim),
nn.GELU(),
nn.Dropout(drop_rate),
nn.Linear(mlp_dim, dim),
nn.Dropout(drop_rate)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class ConvAttention(nn.Module):
def __init__(self,
dim: int,
img_size: int,
num_heads: int,
dim_head: int = 64,
kernel_size: int = 3,
q_stride: int = 1,
k_stride: int = 2,
v_stride: int = 2,
attn_drop_rate: float = 0.,
last_stage: bool=False) -> None:
super(ConvAttention, self).__init__()
self.last_stage = last_stage
self.num_heads = num_heads
self.img_size = img_size
project_out = not (num_heads == 1 and dim_head == dim)
self.scale = dim_head ** -0.5
padding = (kernel_size - q_stride) // 2
self.to_q = DepthwiseSepConv2d(dim, dim_head * num_heads, kernel_size=kernel_size, stride=q_stride, padding=padding)
self.to_k = DepthwiseSepConv2d(dim, dim_head * num_heads, kernel_size=kernel_size, stride=k_stride, padding=padding)
self.to_v = DepthwiseSepConv2d(dim, dim_head * num_heads, kernel_size=kernel_size, stride=v_stride, padding=padding)
self.to_out = nn.Sequential(
nn.Linear(dim_head * num_heads, dim),
nn.Dropout(attn_drop_rate)
) if project_out else nn.Identity()
def forward(self, x):
B, N, _, h = *x.shape, self.num_heads
if self.last_stage:
cls_token = x[:, 0]
x = x[:, 1:]
cls_token = rearrange(cls_token.unsqueeze(1), 'b n (h d) -> b h n d', h = h)
print(cls_token.shape)
x = rearrange(x, 'b (l w) n -> b n l w', l=self.img_size, w=self.img_size)
q = self.to_q(x)
k = self.to_k(x)
v = self.to_v(x)
q = rearrange(q, 'b (h d) l w -> b h (l w) d', h=h)
k = rearrange(k, 'b (h d) l w -> b h (l w) d', h=h)
v = rearrange(v, 'b (h d) l w -> b h (l w) d', h=h)
if self.last_stage:
q = torch.cat((cls_token, q), dim=2)
k = torch.cat((cls_token, k), dim=2)
v = torch.cat((cls_token, v), dim=2)
dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = dots.softmax(dim=-1)
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out
class ConvTransformerBlock(nn.Module):
def __init__(self,
dim: int,
img_size: int,
depth: int,
num_heads: int,
mlp_dim: int,
drop_rate: float=0.,
attn_drop_rate: float=0.,
norm_layer: Optional[Callable[..., nn.Module]]=nn.LayerNorm,
last_stage: bool=False) -> None:
super(ConvTransformerBlock, self).__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim=dim, fn=ConvAttention(dim=dim, img_size=img_size, num_heads=num_heads, attn_drop_rate=attn_drop_rate, last_stage=last_stage)),
PreNorm(dim=dim, fn=FeedForward(dim=dim, mlp_dim=mlp_dim, drop_rate=drop_rate))
]))
def forward(self, x: torch.Tensor) -> torch.Tensor:
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class ConvolutionalTokenEmbedding(nn.Module):
def __init__(self,
img_size: int = 224,
in_channels: int = 3,
embed_dim: int = 64,
kernel_size: int = 7,
stride: int = 4,
norm_layer: Optional[Callable[..., nn.Module]]=nn.LayerNorm) -> None:
super(ConvolutionalTokenEmbedding, self).__init__()
padding = (kernel_size - 1) // 2
self.conv = nn.Sequential(
nn.Conv2d(in_channels, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding),
Rearrange('b c h w -> b (h w) c', h=img_size, w=img_size),
norm_layer(embed_dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv(x)
return x
class CvT(nn.Module):
def __init__(self,
img_size: int = 224,
in_chans: int = 3,
num_classes: int = 1000,
embed_dims: List[int]=[64, 192, 384],
num_heads: List[int]=[1, 3, 6],
dim: int = 64,
scale_dim: int = 4,
qkv_bias: bool = True,
qk_scale: Optional[float] = None,
drop_rate: float = 0.,
attn_drop_rate: float = 0.,
drop_path_rate: float = 0.,
norm_layer: Optional[Callable[..., nn.Module]] = nn.LayerNorm,
depths: List[int]=[1, 2, 10],
num_stages: int=3,) -> None:
super(CvT, self).__init__()
self.num_classes = num_classes
self.depths = depths
self.embed_dims = embed_dims
self.num_stages = num_stages
self.img_size = img_size
for i in range(num_stages):
conv_token_embedding = ConvolutionalTokenEmbedding(img_size=img_size // (2 ** (i + 2)),
in_channels=in_chans if i == 0 else embed_dims[i - 1],
embed_dim=embed_dims[i],
kernel_size=7 if i == 0 else 3,
stride=4 if i == 0 else 2,
norm_layer=norm_layer)
conv_transformer_block = nn.Sequential(ConvTransformerBlock(dim=embed_dims[i],
num_heads=num_heads[i],
mlp_dim=dim * scale_dim,
img_size=img_size // (2 ** (i + 2)),
depth=depths[i],
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
norm_layer=norm_layer,
last_stage=False),
Rearrange('b (h w) c -> b c h w', h=img_size // (2 ** (i + 2)), w=img_size // (2 ** (i + 2))))
setattr(self, f"conv_token_embedding_{i}", conv_token_embedding)
setattr(self, f"conv_transformer_block_{i}", conv_transformer_block)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.drop_large = nn.Dropout(drop_rate)
self.mlp_head = nn.Sequential(
norm_layer(196),
nn.Linear(196, num_classes)
)
print("Weights initialization")
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_feature(self, x):
for i in range(self.num_stages):
conv_token_embedding = getattr(self, f"conv_token_embedding_{i}")
conv_transformer_block = getattr(self, f"conv_transformer_block_{i}")
x = conv_token_embedding(x)
x = conv_transformer_block(x)
x = self.drop_large(x)
x = torch.mean(x, dim=1)
x = x.reshape(x.shape[0], -1)
return x
def forward(self, x):
x = self.forward_feature(x)
print(x.shape)
x = self.mlp_head(x)
return x
@register_model
def cvt_13(pretrained=False, **kwargs):
model = CvT(
embed_dims=[64, 192, 384], qkv_bias=True, norm_layer=nn.LayerNorm, depths=[1, 2, 10], **kwargs)
model.default_cfg = _cfg()
return model
if __name__ == '__main__':
model = cvt_13()
inp = torch.randn(2, 3, 224, 224)
out = model(inp)