안녕하세요. 지난 포스팅의 [IC2D] Xception: Deep Learning with Depthwise Separable Convolutions (CVPR2017)에서는 Inception 모델의 최종 변형 구조인 Xception에 대해서 소개해드렸습니다. Xception은 실제로 많은 논문에서 ResNet과 같이 다양한 downstream task에서 backbone으로 사용되고 있으며 특히 Deepfake detection에서 많이 활용되고 있는 추세입니다. 오늘은 지금까지 성능을 향상시키기 위한 파라미터였던 깊이, 너비, cardinality, diversity가 아닌 attention의 개념을 컴퓨터 비전에 접목한 RAN에 대해서 소개시켜드리도록 하겠습니다.
Background
Attention 메커니즘은 많은 자연어 처리에서 사용되는 방법입니다. 하지만, 지금까지 많은 영상 분류 모델을 보았지만 아직까지는 attention 메커니즘을 적용하는 모델은 없었습니다. 이러한 자연어 처리에서의 attention에 영감을 받아 본 논문은 Residual Attention Network (RAN)이라는 모델을 제안합니다. RAN은 아주 깊은 모델 사이에 아주 쉽게 추가할 수 있는 모듈의 형태로 구현되었으며 이로 인해 깊어질수록 feature의 정보가 삭제되는 diminishing problem을 어느정도 해결하게 되었습니다. 또한, 유연하게 residual network 기반의 다양한 모델에 쉽게 적용할 수 있습니다.
RAN의 특징을 정리하면 다음과 같습니다.
1). Stacked Network Structure: 여러 개의 attention 모듈을 쌓아서 RAN을 구성합니다.
2). Attention Residual Learning: 단순히 attention 모듈을 여러 개 쌓아서 깊은 네트워크를 구성하는 것은 오히려 성능 감소가 나타났기 때문에 이를 막기 위해 새로운 학습 기법을 추가합니다.
3). Bottom-up Top-down feedforward attention: 이러한 방식의 구조는 다양한 컴퓨터 비전에 활용되었습니다. RAN에서는 이 구조를 attention 모듈에 추가하여 soft weights를 feature에 추가하는 방식으로 적용합니다.
Residual Attention Network
1). Mask branch and Trunk branch
그림1은 RAN에서 제안하는 전체적인 모듈의 블록 다이어그램 (왼쪽)과 적용된 결과 (오른쪽)을 보여주고 있습니다. 왼쪽 그림을 보입력 특징 맵은 2개의 branch로 나누어 입력이 들어갑니다. 왼쪽은 mask branch, 오른쪽은 trunk branch라고 부르도록 하겠습니다. trunk branch는 기존의 SOTA 영상 분류 모델에 쉽게 적용될 수 있는 feature processing을 담당합니다. 이를 $T(x)$라고 정의하도록 하죠. $x$는 입력 특징 맵입니다. 그리고 mask branch는 bottom-up top-down 구조를 가지는 branch로 입력 특징 맵과 동일한 크기의 마스크 $M(x)$를 얻기 위한 과정입니다. mask branch의 출력인 $M(x)$는 trunk branch의 출력인 $T(x)$를 softly weigting를 적용하게 됩니다. 즉, $M(x)$는 trunk branch $T(x)$의 출력을 조절하는 역할을 하는 면에서 Highway Network와 유사한 점이 있습니다. 따라서, 하나의 attention 모듈 $H$는 다음과 같이 쓸 수 있습니다.
$$H_{i, c}(x) = M_{i, c}(x) * T_{i, c}(x)$$
여기서, $i$는 공간 인덱스, $c \in \{1, \dots, C\}$는 채널의 인덱스를 의미합니다.
이때, attention module에서 중요한 것은 attention maks $M(x)$가 feature selection만 하지 않고 gradient update 과정에서 filtering도 함께 적용하고 있는 점 입니다. 아래 수식을 보시길 바랍니다.
$$\frac{\partial M(x, \theta)T(x, \phi)}{\partial \phi} = M(x, \theta) \frac{\partial T(x, \phi)}{\partial \phi}$$
여기서, $\theta$는 mask branch의 파라미터, $\phi$는 trunk branch의 파라미터로 정의됩니다. 수식을 보시면 trunk branch로 흘러들어가는 gradient과 마스크 $M(x, \theta)$로 인해 attention되는 영역만 흘러가는 것을 볼 수 있습니다. 이러한 결과는 attention 모듈이 모델을 noisy label에 강건하게 만들 수 있다는 것을 의미합니다.
2). Attention Residual Learning
아쉽게도 위 구조를 그대로 쌓게 되면 오히려 성능 하락이 발생했다고 합니다. 이유는 다음과 같습니다.
(1). mask $M(x)$가 0 ~ 1 사이의 값을 가지고 이를 지속적으로 곱하게 되면 깊은 계층에서 모든 값들이 0에 가까워지기 때문에 feature 정보가 없어지게 됨
(2). mask $M(x)$가 residual unit의 중요한 특성인 identity mapping을 없애기 때문에 학습에 오히려 방해가 될 수 있음
본 논문에서는 이를 해결하기 위해 attention residual learning을 제안합니다. 방식은 간단합니다. 저희가 soft mask weight를 하기 때문에 없어지는 identity mapping을 다시 추가하자는 것이죠. 따라서, attention module은 다음과 같이 바뀝니다.
$$H_{i, c}(x) = (1 + M_{i, c}(x)) * F_{i, c}(x)$$
이렇게 되면 $M(x)$에서 0에 가까운 값들이 있더라도 mask의 값이 1 ~ 2가 되기 때문에 0으로 수렴하지 않게 됩니다. 따라서, 강조되어야할 영역은 점점 값이 커지게 되는 것을 볼 수 있죠.
3). Soft Mask Branch
본 논문에서 제안하는 mask branch는 fast feed-forward와 Top-down feedback 구조를 결합했다고 합니다. 여기서, fast feed-forward는 입력 특징 맵의 global information를 빠르게 수집하게 되고 top-down feedback은 수집된 global information을 기존의 입력 특징맵에 추가해주는 방식으로 연산이 진행됩니다. CNN에서는 이러한 두 연산을 하나로 합쳐 bottom-up top-down fully convolutional structure로 구현하게 되죠.
그림3은 RAN에서 제안하는 attention module의 전체적인 블록 다이어그램입니다. 방금 설명드렸듯이 soft mask branch에서는 down sample과 up sample을 함께 적용하여 soft mask $M(x)$을 만들어냅니다. 그리고 attention residual learning을 통해 trunk branch의 attention을 수행하게 되죠. 여기서, upsample은 bilinear interpolation을 수행하여 연산량을 줄이게 됩니다. 연속적인 upsample 과정을 통해 입력 특징맵과 동일한 크기로 복구하게 되면 2개의 $1 \times 1$ 합성곱 계층을 적용한 뒤 sigmoid 연산을 적용하여 값의 범위를 $[0, 1]$로 조정합니다. 위 그림에서는 나오지 않았지만 UNet과 유사하게 downsampling되는 bottom-up와 upsampling되는 tom-down 사이에 skip connection을 추가하여 서로 다른 scale에서는 정보를 찾아내줍니다.
4). Spatial Attention and Channel Attention
자, 그렇다면 attention을 어떻게 적용할 수 있을까요? attention을 적용하는 방법도 다양하게 시도해볼 수 있습니다. RAN에서는 아래와 같이 3개의 attention 기법을 적용해봅니다.
(1). Mixed Attention $f_{1} = \frac{1}{1 + e^{-x_{i, c}}}$
(2). Channel Attention $f_{2} = \frac{x_{i, c}}{|| x_{i} ||}$
(3). Spatial Attention $f_{3} = \frac{1}{1 + e^{-(x_{i, c} - \text{mean}_{c}) / \text{std}_{c}}}$
여기서, $i$는 공간 인덱스, $c \in \{1, \dots, C\}$는 채널의 인덱스를 의미합니다. $f_{1}$은 단순히 sigmoid 활성화 함수를 취한 뒤 spatial과 channel 모두에게 attention을 적용하는 방식입니다. 반면에 $f_{2}$와 $f_{3}$는 각각 channel과 spatial에만 attention을 적용하는 방식이죠.
표1은 각 attention 방식 별로 CIFAR10으로 학습했을 때 실험 결과입니다. 결과적으로 mixed attention $f_{1}$이 가장 좋은 성능을 보이고 있습니다.
Residual Attention Network Architecture
그림2와 표2는 각각 RAN의 블록 다이어그램과 상세 파라미터를 보여주고 있습니다. RAN은 크게 3개의 stage로 구성되어있으며 각 stage에서 위쪽 branch는 trunk branch, 아래쪽 branch는 mask branch입니다.
Experiment Results
본 논문에서는 CIFAR10, CIFAR100, ImageNet을 이용해서 영상 분류 성능을 측정합니다. 일단, SOTA와 비교하기 전에 CIFAR 데이터셋을 이용해서 간단한 Ablation Study를 시작합니다.
1). Ablation Study
(1). Attention Residual Learning (ARL)
본 논문에서 제안하는 학습 기법 중 하나인 ARL과 Naive Attention Learning (NAL)사이의 성능을 비교합니다.
표3은 ARL과 NAL 사이의 성능을 차이를 보여주고 있습니다. 그림4와 함께 보시면 더욱 쉽게 이해가 가실겁니다. 표3에서 깊이가 깊어질수록 NAL의 성능은 오히려 떨어지고 있습니다. 왜냐하면 처음에 말씀드렸듯이 0 ~ 1 사이의 값을 가지는 마스크를 이용하기 때문에 특징 맵들이 점점 0에 수렴하기 때문이라는 말을 하였습니다. 이러한 경향성은 그림4에서도 나타납니다. Stage1에서는 NAL이 어느정도 feature response를 가지고 있지만 깊어질수록 이미 특징 맵들이 0으로 수렴을 해버렸기 때문에 Stage2부터는 feature response가 거의 0이 되고 Stage3에서는 아예 남아잇지 않은 모습입니다. 하지만, ARL은 기존의 ResNet-164와 유사한 feature response를 가지고 있으며 깊어질수록 성능이 향상되는 것을 볼 수 있습니다. 이러한 결과는 ARL의 효율성을 입증합니다.
(2). Comparison of Different Mask Structure
다음으로 실험하는 것은 mask branch의 구조를 바꾸어 실험을 하게 됩니다. RAN에서는 bottom-up top-down 구조의 모델을 사용하였죠. 이를 Encoder-Decoder 구조라고 하겠습니다. 그리고 단순히 합성곱 계층을 여러 개 쌓은 것을 Local Convolution이라고 하겠습니다.
표4는 두 구조 사이의 성능 차이를 보여주고 있습니다. 약 0.94%의 높은 마진으로 성능이 향상되었으며 mult-scale을 정보를 취합하여 trunk branch에 전달하는 Encoder-Decoder 구조의 mask branch의 효율성을 입증하고 있습니다.
(3). Noisy Label Robustness
마지막으로 본 논문에서 주장하던 Noisy label에 대한 강건성을 실험합니다. label에 약간의 노이즈를 추가했을 때 RAN은 ResNet에 비해 훨씬 강건한 것을 볼 수 있습니다.
2). CIFAR Classification Results
다음으로 CIFAR에서는 ResNet, WRN과의 성능을 비교합니다. RAN은 비교적 적은 파라미터로도 ResNet과 WRN 보다 훨씬 높은 성능 향상을 보이고 있어 효율적인 모델이라고 볼 수 있습니다.
3). ImageNet-1K Classification Results
마지막으로 ImageNet-1K에서 실험을 진행합니다. 이때, 본 논문에서 제안하는 Attention 모듈을 각 네트워크에 추가하여 학습을 수행하였습니다. 그 결과 더 적은 파라미터로도 유사한 성능이 나오거나 향상된 결과를 볼 수 있습니다.
Implementation Code
import torch.nn as nn
class ResidualUnit(nn.Module):
"""Residual block for resnet over 50 layers
"""
def __init__(self, in_channels, out_channels, stride=1):
super(ResidualUnit, self).__init__()
self.residual_function = nn.Sequential(
nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True),
nn.Conv2d(in_channels, out_channels // 4, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels // 4), nn.ReLU(inplace=True),
nn.Conv2d(out_channels // 4, out_channels // 4, stride=stride, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels // 4), nn.ReLU(inplace=True),
nn.Conv2d(out_channels // 4, out_channels, kernel_size=1, bias=False))
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, stride=stride, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels))
def forward(self, x):
return self.residual_function(x) + self.shortcut(x)
class AttentionModuleStage1(nn.Module):
def __init__(self, in_channels, out_channels, size1=(56, 56), size2=(28, 28), size3=(14, 14)):
super(AttentionModuleStage1, self).__init__()
self.first_residual_block = ResidualUnit(in_channels, out_channels)
self.trunk_branch = nn.Sequential(
ResidualUnit(in_channels, out_channels),
ResidualUnit(in_channels, out_channels))
self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.softmax1_blocks = ResidualUnit(in_channels, out_channels)
self.skip1_connection_residual_block = ResidualUnit(in_channels, out_channels)
self.mpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.softmax2_blocks = ResidualUnit(in_channels, out_channels)
self.skip2_connection_residual_block = ResidualUnit(in_channels, out_channels)
self.mpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.softmax3_blocks = nn.Sequential(
ResidualUnit(in_channels, out_channels),
ResidualUnit(in_channels, out_channels))
self.interpolation3 = nn.UpsamplingBilinear2d(size=size3)
self.softmax4_blocks = ResidualUnit(in_channels, out_channels)
self.interpolation2 = nn.UpsamplingBilinear2d(size=size2)
self.softmax5_blocks = ResidualUnit(in_channels, out_channels)
self.interpolation1 = nn.UpsamplingBilinear2d(size=size1)
self.softmax6_blocks = nn.Sequential(
nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False),
nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False),
nn.Sigmoid())
self.last_block = ResidualUnit(in_channels, out_channels)
def forward(self, x):
x = self.first_residual_block(x)
trunk_branch = self.trunk_branch(x)
out_mpool1 = self.mpool1(x)
out_softmax1 = self.softmax1_blocks(out_mpool1)
out_skip1_connection = self.skip1_connection_residual_block(out_softmax1)
out_mpool2 = self.mpool2(out_softmax1)
out_softmax2 = self.softmax2_blocks(out_mpool2)
out_skip2_connection = self.skip2_connection_residual_block(out_softmax2)
out_mpool3 = self.mpool3(out_softmax2)
out_softmax3 = self.softmax3_blocks(out_mpool3)
out_interp3 = self.interpolation3(out_softmax3) + out_softmax2
out = out_interp3 + out_skip2_connection
out_softmax4 = self.softmax4_blocks(out)
out_interp2 = self.interpolation2(out_softmax4) + out_softmax1
out = out_interp2 + out_skip1_connection
out_softmax5 = self.softmax5_blocks(out)
out_interp1 = self.interpolation1(out_softmax5) + trunk_branch
out_softmax6 = self.softmax6_blocks(out_interp1)
out = (1 + out_softmax6) * trunk_branch
out_last = self.last_blocks(out)
return out_last
class AttentionModuleStage2(nn.Module):
def __init__(self, in_channels, out_channels, size1=(28, 28), size2=(14, 14)):
super(AttentionModuleStage2, self).__init__()
self.first_residual_block = ResidualUnit(in_channels, out_channels)
self.trunk_branch = nn.Sequential(
ResidualUnit(in_channels, out_channels),
ResidualUnit(in_channels, out_channels))
self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.softmax1_blocks = ResidualUnit(in_channels, out_channels)
self.skip1_connection_residual_block = ResidualUnit(in_channels, out_channels)
self.mpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.softmax2_blocks = nn.Sequential(
ResidualUnit(in_channels, out_channels),
ResidualUnit(in_channels, out_channels))
self.interpolation2 = nn.UpsamplingBilinear2d(size=size2)
self.softmax3_blocks = ResidualUnit(in_channels, out_channels)
self.interpolation1 = nn.UpsamplingBilinear2d(size=size1)
self.softmax4_blocks = nn.Sequential(
nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False),
nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False),
nn.Sigmoid())
self.last_block = ResidualUnit(in_channels, out_channels)
def forward(self, x):
x = self.first_residual_block(x)
trunk_branch = self.trunk_branch(x)
out_mpool1 = self.mpool1(x)
out_softmax1 = self.softmax1_blocks(out_mpool1)
out_skip1_connection = self.skip1_connection_residual_block(out_softmax1)
out_mpool2 = self.mpool2(out_softmax1)
out_softmax2 = self.softmax2_blocks(out_mpool2)
out_interp2 = self.interpolation2(out_softmax2) + out_softmax1
out = out_interp2 + out_skip1_connection
out_softmax3 = self.softmax3_blocks(out)
out_interp1 = self.interpolation1(out_softmax3) + trunk_branch
out_softmax4 = self.softmax4_blocks(out_interp1)
out = (1 + out_softmax4) * trunk_branch
out_last = self.last_blocks(out)
return out_last
class AttentionModuleStage3(nn.Module):
def __init__(self, in_channels, out_channels, size1=(14, 14)):
super(AttentionModuleStage3, self).__init__()
self.first_residual_block = ResidualUnit(in_channels, out_channels)
self.trunk_branch = nn.Sequential(
ResidualUnit(in_channels, out_channels),
ResidualUnit(in_channels, out_channels))
self.mpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.softmax1_blocks = nn.Sequential(
ResidualUnit(in_channels, out_channels),
ResidualUnit(in_channels, out_channels))
self.interpolation1 = nn.UpsamplingBilinear2d(size=size1)
self.softmax2_blocks = nn.Sequential(
nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False),
nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False),
nn.Sigmoid())
self.last_block = ResidualUnit(in_channels, out_channels)
def forward(self, x):
x = self.first_residual_block(x)
trunk_branch = self.trunk_branch(x)
out_mpool1 = self.mpool1(x)
out_softmax1 = self.softmax1_blocks(out_mpool1)
out_interp1 = self.interpolation1(out_softmax1) + trunk_branch
out_softmax2 = self.softmax2_blocks(out_interp1)
out = (1 + out_softmax2) * trunk_branch
out_last = self.last_blocks(out)
return out_last
class RAN92(nn.Module):
def __init__(self, num_channels, num_classes):
super(RAN92, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(num_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False),
nn.BatchNorm2d(64), nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)))
self.residual_unit1 = ResidualUnit(64, 256)
self.attention_module_stage1_1 = AttentionModuleStage1(256, 256)
self.residual_unit2 = ResidualUnit(256, 512, stride=(2, 2))
self.attention_module_stage2_1 = AttentionModuleStage2(512, 512)
self.attention_module_stage2_2 = AttentionModuleStage2(512, 512)
self.residual_unit3 = ResidualUnit(512, 1024, stride=(2, 2))
self.attention_module_stage3_1 = AttentionModuleStage3(1024, 1024)
self.attention_module_stage3_2 = AttentionModuleStage3(1024, 1024)
self.attention_module_stage3_3 = AttentionModuleStage3(1024, 1024)
self.residual_unit4 = ResidualUnit(1024, 2048, stride=(2, 2))
self.residual_unit5 = ResidualUnit(2048, 2048)
self.residual_unit6 = ResidualUnit(2048, 2048)
self.mpool2 = nn.Sequential(
nn.BatchNorm2d(2048), nn.ReLU(inplace=True),
nn.AvgPool2d(kernel_size=7, stride=1))
self.fc = nn.Linear(2048, num_classes)
def forward(self, x):
out = self.conv1(x)
out = self.residual_unit1(out)
out = self.attention_module_stage1_1(out)
out = self.residual_unit2(out)
out = self.attention_module_stage2_1(out)
out = self.attention_module_stage2_2(out)
out = self.residual_unit3(out)
out = self.attention_module_stage3_1(out)
out = self.attention_module_stage3_2(out)
out = self.attention_module_stage3_3(out)
out = self.residual_unit4(out)
out = self.residual_unit5(out)
out = self.residual_unit6(out)
out = self.mpool2(out)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out