안녕하세요. 지난 포스팅의 [IC2D] BAM: Bottleneck Attention Module (BMVC2018)에서는 기존의 SE Block에서 제안한 Channel Attention을 확장하여 합성곱 연산을 통한 Spatial Attention을 병렬적으로 적용한 BAM에 대해서 알아보았습니다. 오늘은 BAM을 확장한 CBAM에 대해서 알아보도록 하겠습니다.
Background
지금까지 많은 CNN 구조들이 깊이, 너비, cardinality와 같은 차원을 제안함으로써 모델의 성능 향상을 얻어냈습니다. 특히, ResNeck 기반의 모델들이 많이 제안되었죠. 대표적으로 ResNet, WRN, Xception, ResNext 등이 있었습니다. 그 중에서도 Xception과 ResNext에서는 cardinality의 중요성을 강조하며 기존에 집중하던 깊이나 너비보다 성능 향상에 있어 큰 영향을 끼친다는 것을 실험적으로 검증하였습니다.
하지만, 새로운 학습 모델로써 자연어 처리에서 자주 활용되던 어텐션 기법이 CNN 구조에 병합되기 시작하면서 더 높은 성능 향상을 달성하였습니다. 대표적으로 RAN, SE Block, BAM 등이 있었습니다. 이러한 방법들은 모두 ResNet 구조에 새로운 모듈을 추가하는 방식으로 적은 파라미터로 높은 성능 향상을 얻는 것을 목표로 합니다. 특히, 지난 포스팅에서 보았던 BAM에서는 ResNet 내부의 스테이지가 진행될수록 객체에 집중하고 있는 것을 볼 수 있습니다. 본 논문에서는 이러한 BAM 구조를 변형한 CBAM을 제안하였습니다.
Convolutional Block Attention Module
그림1은 본 논문에서 제안하는 CBAM의 전체적인 개략도입니다. 지난 포스팅의 BAM과 비교해보시면 기본적으로 채널 어텐션 및 공간 어텐션은 동일하게 하는 것을 볼 수 있습니다. 다만 다른 점은 이를 병렬적으로 하는 것이 아닌 cascade한 방식으로 진행하는 점 입니다. 이를 수식으로 표현하면 다음과 같습니다.
$$\mathbf{F}^{'} = \mathbf{M}_{\mathbf{c}} (\mathbf{F}) \otimes \mathbf{F}$$
$$\mathbf{F}^{''} = \mathbf{M}_{\mathbf{s}} (\mathbf{F}^{'}) \otimes \mathbf{F^{'}}$$
여기서, $\mathbf{F} \in \mathbb{R}^{C \times H \times W}$는 입력 특징 맵을 의미합니다. 그리고 $\mathbf{M}_{\mathbf{c}} \in \mathbb{R}^{C \times 1 \times 1}$과 $\mathbf{M}_{\mathbf{s}} \in \mathbb{R}^{1 \times H \times W}$는 각각 1차원 채널 어텐션 맵 및 2차원 공간 어텐션 맵을 의미합니다. 마지막으로 $\otimes$는 원소별 곱셈을 의미합니다. 이 과정을 통해 최종 refined 특징 맵인 $\mathbf{F}^{''}$을 얻을 수 있습니다. 이제 각 모듈의 구조를 좀 더 자세히 보도록 하겠습니다.
1). Channel Attention Module
위 그림은 Channel Attention Module의 내부 구조를 상세하게 보여주고 있습니다. 기본적으로 기존의 SE Block과 비교해보면 재밌을 거 같네요. SE Block에서는 squeeze 연산을 위해 GAP (Global Average Pooling)을 적용하였습니다. 하지만, CBAM에서는 GMP (Global Max Pooling)도 함께 적용하여 채널 어텐션 맵을 추출합니다. 기본적으로 어텐션 맵을 추출하기 위해 사용하는 MLP는 두 특징이 공유하게 됩니다. 그러면 2개의 추출된 특징을 얻을 수 있을텐데, 두 개를 더한 뒤 시그모이드 연산을 수행하여 최종적으로 채널 어텐션 맵을 추출하게 되는 것이죠. 이를 수식으로 적으면 다음과 같습니다.
$$\mathbf{M}_{\mathbf{c}} (\mathbf{F}) = \sigma \left( \mathbf{W}_{\mathbf{1}} \left( \mathbf{W}_{\mathbf{0}} \left( \mathbf{F}_{\mathbf{avg}}^{\mathbf{c}} \right) \right ) + \mathbf{W}_{\mathbf{1}} \left( \mathbf{W}_{\mathbf{0}} \left( \mathbf{F}_{\mathbf{max}}^{\mathbf{c}} \right) \right)\right)$$
여기서, $\mathbf{F}_{\mathbf{avg}}^{\mathbf{c}}$와 $\mathbf{F}_{\mathbf{max}}^{\mathbf{c}}$는 각각 GAP와 GMP를 이용하여 추출한 1차원 채널 설명자 (channel descriptor)입니다. 다음으로 $\mathbf{W}_{\mathbf{0}} \in \mathbb{R}^{C / r \times C}$와 $\mathbf{W}_{\mathbf{1}} \in \mathbb{R}^{C \times C / r}$은 각각 MLP의 파라미터를 의미하죠. 마지막으로 $\sigma$는 시그모이드 함수를 의미합니다. 이를 통해, $[0, 1]$ 사이의 범위를 가지는 채널 어텐션 맵을 얻을 수 있습니다.
2). Spatial Attention Module
다음으로 위 그림은 Spatial Attention Module의 내부 구조를 상세하게 보여주고 있습니다. 이 역시 Channel Attention Module과 동일하게 Average 및 Max를 통해 얻은 2차원 특징 맵을 활용하는 것을 볼 수 있습니다. 다만, 다른 점은 공간 정보에 대한 어텐션 맵을 추출하기 위해 채널 단위로 Average와 Max를 취했는 점이죠. 추출된 두 특징은 채널 단위로 합쳐진 뒤 $7 \times 7$ 크기의 합성곱 계층과 시그모이드 연산을 적용하여 최종 공간 어텐션 맵을 얻게 됩니다. 이를 수식으로 적으면 다음과 같습니다.
$$\mathbf{M}_{\mathbf{s}} (\mathbf{F}) = \sigma \left( f^{7 \times 7} \left( \left[ \mathbf{F}^{\mathbf{s}}_{\mathbf{avg}}, \mathbf{F}^{\mathbf{s}}_{\mathbf{max}} \right] \right) \right)$$
여기서, $\mathbf{F}_{\mathbf{avg}}^{\mathbf{s}}$와 $\mathbf{F}_{\mathbf{max}}^{\mathbf{s}}$는 각각 Average Pooling과 Max Pooling를 이용하여 추출한 2차원 공간 설명자 (spatial descriptor)입니다.
3). Integration with ResNet
본 논문에서 제안한 CBAM은 그림3과 같이 기존의 ResNet에 쉽게 추가하여 구현될 수 있습니다. 여기서, 기존의 BAM과의 차이점이 나오게 되죠. BAM은 다음 stage로 넘어갈 때 수행됩니다. 하지만, CBAM은 모든 Residual block 내부에 적용할 수 있습니다. 이 부분이 또다른 차이점이 되겠네요.
Experiment Results
본 논문에서는 CBAM을 추가한 CNN에 대해 ImageNet-1K에서의 분류 성능 및 MS COCO와 VOC2007에서의 객체 탐지 능력을 평가합니다. 일단, 본 논문에서 수행한 절제 연구부터 보도록 하겠습니다.
1). Ablation Study
CBAM에서 보셨을 때 몇 가지 궁금한 점이 있으시지 않으셨나요? 대표적으로 아래와 같은 것들이 있을 겁니다.
- GAP 또는 GMP만 사용했을 때 성능 변화
- Channel Attention Module 및 Spatial Attention Module을 하나만 사용했을 때
- Channel Attention Module 및 Spatial Attention Module의 순서에 따른 성능 변화
하나하나 분석해보도록 하죠.
(1). Channel Attention
제일 먼저 확인한 것은 Channel Attention Module 내에서 GAP 및 GMP만 사용했을 때 성능 변화입니다. 일단, GAP만 사용했을 때는 사실 SE Block과 동일합니다. 이를 GMP만 사용했을 때와 비교해보면 성능이 살짝 떨어지는 것을 볼 수 있습니다. 결과적으로 GAP와 GMP를 함께 사용했을 때 성능 향상을 얻을 수 있었네요.
(2). Spatial Attention
표2는 channel attention에 spatial attention을 적용했을 때 다양한 방식으로 수행했을 때 실험결과를 보여주고 있습니다. 결과적으로 Average Pooling과 Max Pooling을 함께 적용했을 때 가장 좋은 성능을 보이고 있습니다.
그렇다면 왜 Max Pooling을 수행했을 때 성능이 향상될까요? 일반적으로 특징 맵에서 강하게 활성화된 영역은 중요하다고 모델이 해당 영상을 판단할 때 중요하다는 것과 동일한 말입니다. 이때, Average Pooling을 적용하면 해당 영역의 활성화 정도가 감소하기 때문에 잘 추출한 특징이 어느정도 삭제될 수 있습니다. 하지만, 여기서 Max Pooling을 통해 추출한 정보도 함께 적용하면 흔히 saliency region이라고 하는 부분을 함께 사용하게 됩니다. 여기서 saliency란 "돌출된"이라는 뜻으로 특징 맵 또는 입력 영상 내에 중요한 영역이라고 보시면 될 거 같습니다. 즉, Average Pooling을 통해 잃어버리는 정보를 Max Pooling이 어느정도 상쇄해주기 때문에 성능 향상을 얻을 수 있다고 해석할 수 있습니다.
표3은 두 어텐션 모듈의 순서 및 적용하는 방식을 병렬로 바꾸었을 때 성능을 보여주고 있습니다. 결과적으로 channel attention을 적용한 뒤 spatial attention을 적용하는 것이 성능이 가장 높습니다.
2). ImageNet-1K Classification Results
Implementation Code
"""resnet in pytorch
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
Deep Residual Learning for Image Recognition
https://arxiv.org/abs/1512.03385v1
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from layers import BasicConv
class ChannelAttentionModule(nn.Module):
def __init__(self, in_channels, reduction=16):
super(ChannelAttentionModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
nn.Conv2d(in_channels, in_channels // reduction, kernel_size=1, stride=1, padding=0, bias=False),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels // reduction, in_channels, kernel_size=1, stride=1, padding=0, bias=False))
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.fc(self.avg_pool(x))
max_out = self.fc(self.max_pool(x))
return self.sigmoid(avg_out + max_out)
class SpatialAttentionModule(nn.Module):
def __init__(self):
super(SpatialAttentionModule, self).__init__()
self.conv1 = nn.Conv2d(2, 1, kernel_size=7, stride=1, padding=3, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
x = self.conv1(x)
return self.sigmoid(x)
class CBAM(nn.Module):
def __init__(self, in_channels):
super(CBAM, self).__init__()
self.channel_attention = ChannelAttentionModule(in_channels)
self.spatial_attention = SpatialAttentionModule()
def forward(self, x):
channel_att_map = self.channel_attention(x)
x = x * channel_att_map
spatial_att_map = self.spatial_attention(x)
x = x * spatial_att_map
return x
class BottleNeck(nn.Module):
"""Residual block for resnet over 50 layers
"""
expansion = 4
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.residual_function = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels * BottleNeck.expansion),
CBAM(out_channels * BottleNeck.expansion)
)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels * BottleNeck.expansion)
)
def forward(self, x):
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
class BasicBlock(nn.Module) :
"""Basic Block for resnet 18 and resnet 34
"""
#BasicBlock and BottleNeck block
#have different output size
#we use class attribute expansion
#to distinct
expansion = 1
def __init__(self, in_channels, out_channels, stride=1):
super(BasicBlock, self).__init__()
# residual function
self.residual_function = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), stride=(stride, stride), padding=(1, 1), bias=False),
nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=(3, 3), padding=(1, 1), bias=False),
nn.BatchNorm2d(out_channels * BasicBlock.expansion),
CBAM(out_channels * BasicBlock.expansion)
)
# shortcut
self.shortcut = nn.Sequential()
#the shortcut output dimension is not the same with residual function
#use 1*1 convolution to match the dimension
if stride != 1 or in_channels != BasicBlock.expansion * out_channels :
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=(1, 1), stride=(stride, stride), bias=False),
nn.BatchNorm2d(out_channels * BasicBlock.expansion)
)
def forward(self, x):
return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
class ResNet(nn.Module) :
def __init__(self, block, num_block, num_classes=100, num_channels=3):
super(ResNet, self).__init__()
self.in_channels = 64
self.conv1 = nn.Sequential(
nn.Conv2d(num_channels, 64, kernel_size=(3, 3), padding=(1, 1), bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
)
self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
def _make_layer(self, block, out_channels, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides :
layers.append(block(self.in_channels, out_channels, stride))
self.in_channels = out_channels * block.expansion
return nn.Sequential(*layers)
def forward(self, x, feature=False):
output = self.conv1(x)
output = self.conv2_x(output)
output = self.conv3_x(output)
output = self.conv4_x(output)
output = self.conv5_x(output)
if feature: return output
output = self.avg_pool(output)
output = output.view(output.size(0), -1)
output = self.fc(output)
return output
def resnet18_cbam(num_classes, num_channels) :
""" return a ResNet 18 object
[?, 3, 224, 224] -> [?, 64, 224, 224]
[?, 64, 224, 224] -> [?, 64, 224, 224]
[?, 64, 224, 224] -> [?, 128, 112, 112]
[?, 128, 112, 112] -> [?, 256, 56, 56]
[?, 256, 56, 56] -> [?, 512, 28, 28]
[?, 512, 28, 28] -> [?, 512, 1, 1]
"""
return ResNet(BasicBlock, [2, 2, 2, 2], num_classes, num_channels)
def resnet34_cbam(num_classes, num_channels) :
""" return a ResNet 34 object
[?, 3, 224, 224] -> [?, 64, 224, 224]
[?, 64, 224, 224] -> [?, 64, 224, 224]
[?, 64, 224, 224] -> [?, 128, 112, 112]
[?, 128, 112, 112] -> [?, 256, 56, 56]
[?, 256, 56, 56] -> [?, 512, 28, 28]
[?, 512, 28, 28] -> [?, 512, 1, 1]
"""
return ResNet(BasicBlock, [2, 4, 6, 3], num_classes, num_channels)
def resnet50_cbam(num_classes, num_channels) :
""" return a ResNet 50 object
[?, 3, 224, 224] -> [?, 64, 224, 224]
[?, 64, 224, 224] -> [?, 256, 224, 224]
[?, 256, 224, 224] -> [?, 512, 112, 112]
[?, 512, 112, 112] -> [?, 1024, 56, 56]
[?, 1024, 56, 56] -> [?, 2048, 28, 28]
[?, 2048, 28, 28] -> [?, 2048, 1, 1]
"""
return ResNet(BottleNeck, [2, 4, 6, 3], num_classes, num_channels)
def resnet101_cbam(num_classes, num_channels) :
""" return a ResNet 101 object
[?, 3, 224, 224] -> [?, 64, 224, 224]
[?, 64, 224, 224] -> [?, 256, 224, 224]
[?, 256, 224, 224] -> [?, 512, 112, 112]
[?, 512, 112, 112] -> [?, 1024, 56, 56]
[?, 1024, 56, 56] -> [?, 2048, 28, 28]
[?, 2048, 28, 28] -> [?, 2048, 1, 1]
"""
return ResNet(BottleNeck, [3, 4, 23, 3], num_classes, num_channels)
def resnet152_cbam(num_classes, num_channels):
""" return a ResNet 152 object
[?, 3, 224, 224] -> [?, 64, 224, 224]
[?, 64, 224, 224] -> [?, 256, 224, 224]
[?, 256, 224, 224] -> [?, 512, 112, 112]
[?, 512, 112, 112] -> [?, 1024, 56, 56]
[?, 1024, 56, 56] -> [?, 2048, 28, 28]
[?, 2048, 28, 28] -> [?, 2048, 1, 1]
"""
return ResNet(BottleNeck, [3, 8, 36, 3], num_classes, num_channels)
if __name__ == '__main__':
model = resnet152_cbam(num_classes=1000, num_channels=3)
inp = torch.rand(2, 3, 224, 224)
oup = model(inp)
print(oup.shape)