안녕하세요. 지난 포스팅의 [IC2D] EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks (PMLR2019)에서는 3개의 차원에 대한 모델 스케일링을 적용하는 compound scaling과 MNASNet을 조합한 EfficientNet에 대해서 알아보았습니다. 오늘은 BAM이라는 어텐션 모듈에 대해서 소개시켜드리도록 하겠습니다.
Background
최근 많은 CNN 기반의 모델들이 파라미터의 개수가 많아지고 구조가 복잡해지면서 모델의 표현력이 증가하고 있는 추세입니다. 즉, 모델의 성능을 올리는 가장 쉬운 방법은 깊이를 증가시키는 방법으로 대표적으로 VGGNet, ResNet, InceptionNet 등과 같은 모델들이 있었습니다.
오늘 소개할 모델인 BAM은 간단하지만 아주 효율적인 어텐션 모듈입니다. 본 논문에서는 BAM은 "bells and whistles"이 필요없다라는 표현을 사용하였습니다. 벨과 호루라기라는 뜻으로 추가적인 장치를 의미하는 관용적인 표현입니다. 즉, 추가적인 장치없이 어떠한 CNN 구조에도 쉽게 BAM을 적용할 수 있다는 뜻으로 범용성이 매우 높다라고 말할 수 있습니다.
Bottleneck Attention Module
그림1은 어떤 CNN 구조에 BAM을 적용한 모습을 보여주고 있습니다. 그림만 봤을 때 중간 특징맵이 BAM을 통과하면 특정 객체가 강조되는 특징 맵을 얻을 수 있는 것을 볼 수 있습니다. 그리고 이 BAM을 각 stage 사이사이마다 적용하여 입력된 특징맵에 대해 중요한 특징을 더욱 빠르게 추출할 수 있게 만들어 줄 수 있습니다. 이제 수식을 보도록 하겠습니다. 먼저, $\mathbf{F} \in \mathbb{R}^{C \times H \times W}$를 입력 특징맵, $\mathbf{M}(\mathbf{F}) \in \mathbb{R}^{C \times H \times W}$를 3D 어텐션 맵이라고 하겠습니다. 그러면 정제된 특징맵 (refined feature map) $\mathbf{F}^{'} $은 입력 특징맵과 어텐션 맵을 곱해서 다음과 같이 얻을 수 있습니다.
$$\mathbf{F}^{'} = \mathbf{F} + \mathbf{F} \otimes \mathbf{M}(\mathbf{F})$$
여기서, $\otimes$는 원소별 곱을 의미합니다. 본 논문에서는 잔차 학습 (residual learning)을 적용하기 위해 입력 특징맵을 한 번더 더해주는 것을 볼 수 있습니다. 이를 통해, 계산 그래프 상에서 기울기의 흐름이 원할하게 흐를 수 있도록 만들 수 있습니다. 그렇다면 다음 질문은 어텐션 맵 $\mathbf{M}(\mathbf{F})$을 어떻게 만드냐는 것입니다. 본 논문에서는 다음과 같이 두 가지 단계를 통해서 어텐션 맵을 추출합니다.
$$\mathbf{M} (\mathbf{F}) = \sigma (\mathbf{M}_{\mathbf{c}} (\mathbf{F}) + \mathbf{M}_{\mathbf{s}} (\mathbf{F}))$$
여기서, 각각 채널 어텐션 $\mathbf{M}_{\mathbf{c}} (\mathbf{F}) \in \mathbb{R}^{C}$와 공간 어텐션 $\mathbf{M}_{\mathbf{s}} (\mathbf{F}) \in \mathbb{R}^{H \times W}$입니다. 그리고 $\sigma$는 시그모이드 활성화 함수를 의미합니다.
1). Channel Attention Branch
그림2에서 파란색 계열이 채널 어텐션을 수행하는 부분입니다. 어디서 많이 본것같은 구조 아닌가요? 바로 SE Block입니다. 실제로 SE Block에서도 어텐션을 구현하기 위해 GAP를 적용한 뒤 완전연결계층에 적용하여 채널 별 어텐션 값을 추출하는 것을 볼 수 있었습니다. BAM에서도 채널 어텐션을 수행하기 위해 동일한 방법을 적용하는 것을 알 수 있습니다. 수식은 다음과 같이 쓸 수 있습니다.
$$\begin{align*} \mathbf{M}_{\mathbf{c}} (\mathbf{F}) &= BN (MLP (GAP (\mathbf{F}))) \\ &= BN (\mathbf{W}_{1}(\mathbf{W}_{0} GAP (\mathbf{F}) + \mathbf{b}_{0}) + \mathbf{b}_{1}) \end{align*}$$
여기서, $\mathbf{W}_{0} \in \mathbb{R}^{\frac{C}{r} \times C}, \mathbf{b}_{0} \in \mathbb{R}^{\frac{C}{r}}$ 그리고 $\mathbf{W}_{1} \in \mathbb{R}^{C \times \frac{C}{r}}, \mathbf{b}_{1} \in \mathbb{R}^{C}$의 모양을 가지고 $r$은 reduction ratio로 완전 연결 계층의 파라미터의 개수를 줄여주는 역할을 해줍니다.
2). Spatial Attention Branch
그림2에서 주황색 계열이 공간 어텐션을 수행하는 부분입니다. SE Block과 가장 큰 차이가 나는 부분이 바로 이 부분입니다. 기존의 SE Block은 채널 어텐션만 수행하였지만 BAM에서는 공간 어텐션도 함께 적용하는 것이죠. 이를 통해, 입력 특징 맵에서 서로 다른 위치에 대한 어텐션을 수행하여 특정 위치를 강조하거나 제한할 수 있습니다.
이를 위해서는 공간 어텐션을 수행하는 부분에서 위치 정보에 대한 이해가 더 필요합니다. 본 논문에서는 dilated convolution을 이용해서 해결합니다. 이는 기존의 합성곱 계층보다 적은 파라미터로 더욱 넓은 수용 영역 (receptive field)를 가지게 만들기 때문에 훨씬 효율적이라고 할 수 있죠. 추가적으로 여기에 bottleneck 구조를 적용합니다. 기존의 ResNet에서 bottleneck을 적용할 때 $1 \times 1$ 합성곱을 통해 특징 맵의 채널 개수를 줄인 뒤 $3 \times 3$ 합성곱을 수행하고 다시 $1 \times 1$ 합성곱을 적용하여 다시 특징 맵의 채널 개수를 복원합니다. BAM에서도 이 과정을 수행합니다. 그림2를 보시면 $1 \times 1$을 통해 reduction ratio $r$을 도입하여 채널의 개수를 줄여줍니다. 다음으로 2번의 dilated convolution을 적용하여 보다 넓은 수용영역을 기반으로 공간에 대한 이해를 상승시킨 뒤 마지막으로 $1 \times 1$ 합성곱을 더 적용합니다. 이때, 마지막 합성곱에서는 $\frac{C}{r}$개의 채널을 가지는 특징 맵을 1개로 압축해버립니다. 이를 수식으로 표현하면 다음과 같죠.
$$\mathbf{M}_{\mathbf{s}} (\mathbf{F}) = BN (f^{1 \times 1}_{3} (f^{3 \times 3}_{2} (f^{3 \times 3}_{1} (f^{1 \times 1}_{0} (\mathbf{F})))))$$
Experiment Results
1). CIFAR Classification Results
2). ImageNet Classification Results
Implementation Code
"""resnet in pytorch
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
Deep Residual Learning for Image Recognition
https://arxiv.org/abs/1512.03385v1
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from layers import BasicConv
class ChannelAttentionModule(nn.Module):
def __init__(self, in_channels, reduction=16):
super(ChannelAttentionModule, self).__init__()
self.squeeze = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(in_channels, in_channels // reduction, bias=True),
nn.BatchNorm1d(in_channels // reduction), nn.ReLU(inplace=True),
nn.Linear(in_channels // reduction, in_channels, bias=True),
nn.BatchNorm1d(in_channels), nn.ReLU(inplace=True))
def forward(self, x):
batch_size, channel, _, _ = x.size()
y = self.squeeze(x).view(batch_size, channel)
y = self.fc(y).view(batch_size, channel, 1, 1)
return y.expand_as(x)
class SpatialAttentionModule(nn.Module):
def __init__(self, in_channels, dilation=4, reduction=16):
super(SpatialAttentionModule, self).__init__()
self.conv1x1_1 = BasicConv(in_channels, in_channels // reduction, kernel_size=1, stride=1, padding=0, activation=nn.ReLU(inplace=True))
self.dilated_conv3x3_1 = BasicConv(in_channels // reduction, in_channels // reduction, kernel_size=3, stride=1, padding=dilation, dilation=dilation, activation=nn.ReLU(inplace=True))
self.dilated_conv3x3_2 = BasicConv(in_channels // reduction, in_channels // reduction, kernel_size=3, stride=1, padding=dilation, dilation=dilation, activation=nn.ReLU(inplace=True))
self.conv1x1_2 = BasicConv(in_channels // reduction, 1, kernel_size=1, stride=1, padding=0, activation=nn.ReLU(inplace=True))
def forward(self, x):
x = self.conv1x1_1(x)
x = self.dilated_conv3x3_1(x)
x = self.dilated_conv3x3_2(x)
x = self.conv1x1_2(x)
return x
class BAM(nn.Module):
def __init__(self, in_channels):
super(BAM, self).__init__()
self.channel_attention = ChannelAttentionModule(in_channels)
self.spatial_attention = SpatialAttentionModule(in_channels)
def forward(self, x):
channel_att_map = self.channel_attention(x)
spatial_att_map = self.spatial_attention(x)
att_map = 1 + F.sigmoid(channel_att_map + spatial_att_map)
return x * att_map
class BottleNeck(nn.Module):
"""Residual block for resnet over 50 layers
"""
expansion = 4
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.residual_function = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels * BottleNeck.expansion),
)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels * BottleNeck.expansion)
)
def forward(self, x):
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
class BasicBlock(nn.Module) :
"""Basic Block for resnet 18 and resnet 34
"""
#BasicBlock and BottleNeck block
#have different output size
#we use class attribute expansion
#to distinct
expansion = 1
def __init__(self, in_channels, out_channels, stride=1):
super(BasicBlock, self).__init__()
# residual function
self.residual_function = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), stride=(stride, stride), padding=(1, 1), bias=False),
nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=(3, 3), padding=(1, 1), bias=False),
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
)
# shortcut
self.shortcut = nn.Sequential()
#the shortcut output dimension is not the same with residual function
#use 1*1 convolution to match the dimension
if stride != 1 or in_channels != BasicBlock.expansion * out_channels :
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=(1, 1), stride=(stride, stride), bias=False),
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
)
def forward(self, x):
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
class ResNet(nn.Module) :
def __init__(self, block, num_block, num_classes=100, num_channels=3):
super(ResNet, self).__init__()
self.in_channels = 64
self.conv1 = nn.Sequential(
nn.Conv2d(num_channels, 64, kernel_size=(3, 3), padding=(1, 1), bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
BAM(64)
)
self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
def _make_layer(self, block, out_channels, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides :
layers.append(block(self.in_channels, out_channels, stride))
self.in_channels = out_channels * block.expansion
if out_channels != 512:
layers += [BAM(self.in_channels)]
return nn.Sequential(*layers)
def forward(self, x, feature=False):
output = self.conv1(x)
output = self.conv2_x(output)
output = self.conv3_x(output)
output = self.conv4_x(output)
output = self.conv5_x(output)
if feature: return output
output = self.avg_pool(output)
output = output.view(output.size(0), -1)
output = self.fc(output)
return output
def resnet18_bam(num_classes, num_channels) :
""" return a ResNet 18 object
[?, 3, 224, 224] -> [?, 64, 224, 224]
[?, 64, 224, 224] -> [?, 64, 224, 224]
[?, 64, 224, 224] -> [?, 128, 112, 112]
[?, 128, 112, 112] -> [?, 256, 56, 56]
[?, 256, 56, 56] -> [?, 512, 28, 28]
[?, 512, 28, 28] -> [?, 512, 1, 1]
"""
return ResNet(BasicBlock, [2, 2, 2, 2], num_classes, num_channels)
def resnet34_bam(num_classes, num_channels) :
""" return a ResNet 34 object
[?, 3, 224, 224] -> [?, 64, 224, 224]
[?, 64, 224, 224] -> [?, 64, 224, 224]
[?, 64, 224, 224] -> [?, 128, 112, 112]
[?, 128, 112, 112] -> [?, 256, 56, 56]
[?, 256, 56, 56] -> [?, 512, 28, 28]
[?, 512, 28, 28] -> [?, 512, 1, 1]
"""
return ResNet(BasicBlock, [2, 4, 6, 3], num_classes, num_channels)
def resnet50_bam(num_classes, num_channels) :
""" return a ResNet 50 object
[?, 3, 224, 224] -> [?, 64, 224, 224]
[?, 64, 224, 224] -> [?, 256, 224, 224]
[?, 256, 224, 224] -> [?, 512, 112, 112]
[?, 512, 112, 112] -> [?, 1024, 56, 56]
[?, 1024, 56, 56] -> [?, 2048, 28, 28]
[?, 2048, 28, 28] -> [?, 2048, 1, 1]
"""
return ResNet(BottleNeck, [2, 4, 6, 3], num_classes, num_channels)
def resnet101_bam(num_classes, num_channels) :
""" return a ResNet 101 object
[?, 3, 224, 224] -> [?, 64, 224, 224]
[?, 64, 224, 224] -> [?, 256, 224, 224]
[?, 256, 224, 224] -> [?, 512, 112, 112]
[?, 512, 112, 112] -> [?, 1024, 56, 56]
[?, 1024, 56, 56] -> [?, 2048, 28, 28]
[?, 2048, 28, 28] -> [?, 2048, 1, 1]
"""
return ResNet(BottleNeck, [3, 4, 23, 3], num_classes, num_channels)
def resnet152_bam(num_classes, num_channels):
""" return a ResNet 152 object
[?, 3, 224, 224] -> [?, 64, 224, 224]
[?, 64, 224, 224] -> [?, 256, 224, 224]
[?, 256, 224, 224] -> [?, 512, 112, 112]
[?, 512, 112, 112] -> [?, 1024, 56, 56]
[?, 1024, 56, 56] -> [?, 2048, 28, 28]
[?, 2048, 28, 28] -> [?, 2048, 1, 1]
"""
return ResNet(BottleNeck, [3, 8, 36, 3], num_classes, num_channels)
if __name__ == '__main__':
model = resnet152_bam(num_classes=1000, num_channels=3)
inp = torch.rand(2, 3, 224, 224)
oup = model(inp)
print(oup.shape)