안녕하세요. 지난 포스팅의 [IC2D] PolyNet: A Pursuit of Structural Diversity in Very Deep Networks (CVPR2017)에서는 CNN 모델의 diversity를 강조하여 Inception 모델의 새로운 변형 구조인 PolyNet을 제안하였습니다. 오늘은 Inception 모델의 최종버전이라고 할 수 있는 Xception에 대해서 소개해드리도록 하겠습니다.
Background
지금까지 저희가 보았던 Inception 기반의 모델들을 보면 GoogLeNet (Inception V1, CVPR2015), Inception V2 ~ V3 (CVPR2016), Inception V4, Inception-ResNet-V1, Inception-ResNet-V2 (CVPR2017)이 있었습니다. 이러한 모델들은 기본적으로 Inception Module이라고 하는 multi-path 합성곱 블록을 사용하고 있습니다. 그렇다면, Inception Module이 학습 단계에서 어떤 효과를 주는 것일까요?
이를 알기 위해서는 3D 공간에서 합성곱이 어떤 의미인지 알아봐야합니다. 합성곱의 정의를 보시면 2개의 공간 차원에 대한 correlation과 1개의 깊이 차원에 대한 correlation을 계산하는 과정임을 알 수 있습니다. Inception Module은 이러한 연산을 좀 더 쉽게 만들고 연산을 분해하여 학습할 수 있기 때문에 효율적이라고 언급하고 있습니다.
위 그림에서 가장 왼쪽은 기존의 Inception Module입니다. 저희가 알고 있듯이 여러 개의 path로 구성되어 각 path 별로 다양한 크기의 필터를 가지는 합성곱 블록을 적용하여 학습을 하게 됩니다. 중간 그림은 기존의 Inception Module을 재해석하여 좀 더 효율적으로 만든 구조입니다. 왼쪽 그림에서 4개의 path 중 3개의 path가 $1 \times 1$ 합성곱 계층을 사용하고 있기 때문에 이 path만 가져온 뒤 각 path가 동일한 구조를 가지도록 만든것이죠. 이 부분은 어떻게 보면 ResNext와 유사한 해석이라고 볼 수 있습니다. 중간 그림을 더 단순화하여 처음에 3개의 특징 맵을 split 하여 $1 \times 1$ 합성곱을 적용하는 것이 아니라 일단 $1 \times 1$ 합성곱을 적용한 뒤 split 하는 것을 볼 수 있죠.
처음 그림을 좀 더 극단적으로 잘게 쪼개서 학습하는 것을 생각해볼 수도 있습니다. 이와 같이 극단적 (Extremely)으로 잘게 쪼개서 multi-path 블록을 구성하는 모델이 바로 Xception 입니다. 이와 같이 입력 특징 맵에 대해 몇 개의 채널로 나누어 학습하는 구조를 보신 적 있지 않으신가요? 바로 MobileNet입니다. 이를 통해, MobileNet에서는 성능은 최대한 유지하고 효율성을 증가시키게 되었죠. Xception 역시 기존의 Inception 모델들보다 효율성을 크게 증가시킨 모델이라고 보시면 될 거 같습니다. Xception도 Depthwise Separable Convolution을 적용하여 학습을 진행하게 되는데, MobileNet과는 크게 다른 점이 있습니다.
1). MobileNet에서는 Depthwise Convolution을 적용한 뒤 Pointwise Convolution을 적용하지만 Xception 에서는 반대로 Pointwise Convolution을 적용한 뒤 Depthwise Convolution을 적용합니다. 본 논문의 저자는 이 부분은 크게 문제될 것이 없다라고 하네요.
2). MobileNet에서는 Depthwise Separable Convolution을 적용한 뒤 ReLU를 적용하여 non-linearity를 추가합니다. 하지만, Xception에서는 적용하지 않습니다. 이는 생각보다 큰 성능 차이가 있다라고 하네요. Experiment Results 부분에서 보여드리도록 하겠습니다.
Xception Architecture
보시면 Xception은 Inception V3에 비해 약 1백만개 정도 적은 파라미터를 가지고 3초 정도 더 빠른 학습 속도를 가지고 있다고 합니다.
Experiment Results
본 논문에서는 2가지 대규모 영상 분류 데이터셋으로 ImageNet-1K와 JFT-300M이라는 데이터셋을 사용하여 학습하였다고 합니다. JFT-300M 데이터셋은 ImageNet-1K보다 훨씬 더 큰 데이터셋으로 구글에서 비공개하였기 때문에 접근할 수는 없습니다. 간략하게 설명드리면 JFT-300M은 약 3개의 영상으로 3만개의 레이블로 구성된 초대형 데이터셋입니다.
1). ImangeNet-1K Classification Results
당연하지만 ImageNet에 대해서 Inception V3보다도 훨씬 좋은 성능을 보이고 있습니다.
2). JFT-300M Classification Results
JFT-300M에서도 마찬가지로 Inception V3보다도 훨씬 좋은 성능을 보이고 있습니다. 본문을 보면 Xception 모델을 JFT-300M 모델에 학습시키기 위해 약 1달의 시간이 걸렸다고 하는 데 정말 엄두도 못할정도의 규모입니다.
3). Ablation Study
(1). Residual Connections
(2). Non-linearity (ReLU vs ELU vs Linear function)
Implementation Code
import torch.nn as nn
from .layers import BasicConv
class SeparableConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(SeparableConv, self).__init__()
self.depthwise_conv = nn.Conv2d(in_channels, in_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=in_channels)
self.pointwise_conv = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
def forward(self, x):
out = self.depthwise_conv(x)
out = self.pointwise_conv(out)
return out
class ResidualConv(nn.Module):
def __init__(self, in_channels, out_channels, activation=True):
super(ResidualConv, self).__init__()
self.residual_conv = nn.Sequential(
SeparableConv(in_channels, out_channels),
nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True),
SeparableConv(out_channels, out_channels),
nn.BatchNorm2d(out_channels),
nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)))
self.shortcut_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1), stride=(2, 2), padding=(0, 0)),
nn.BatchNorm2d(out_channels))
self.activation = activation
def forward(self, x):
if self.activation:
x = nn.ReLU(inplace=True)(x)
residual = self.residual_conv(x)
identity = self.shortcut_conv(x)
return residual + identity
class EntryFlow(nn.Module):
def __init__(self, num_channels):
super(EntryFlow, self).__init__()
self.conv1 = BasicConv(num_channels, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
self.conv2 = BasicConv(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.residual_conv1 = ResidualConv(64, 128, activation=False)
self.residual_conv2 = ResidualConv(128, 256, activation=True)
self.residual_conv3 = ResidualConv(256, 728, activation=True)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.residual_conv1(x)
x = self.residual_conv2(x)
x = self.residual_conv3(x)
return x
class MiddleFlow(nn.Module):
def __init__(self, in_channels):
super(MiddleFlow, self).__init__()
self.residual_conv = nn.Sequential(
nn.ReLU(inplace=True),
SeparableConv(in_channels),
nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True),
SeparableConv(in_channels),
nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True),
SeparableConv(in_channels),
nn.BatchNorm2d(in_channels))
self.shortcut_conv = nn.Sequential()
def forward(self, x):
return self.residual_conv(x) + self.shortcut_conv(x)
class ExitFlow(nn.Module):
def __init__(self):
super(ExitFlow, self).__init__()
self.residual_conv = ResidualConv(728, 1024, activation=True)
self.conv = nn.Sequential(
SeparableConv(1024, 1536),
nn.BatchNorm2d(1536), nn.ReLU(inplace=True),
SeparableConv(1536, 2048),
nn.BatchNorm2d(2048), nn.ReLU(inplace=True))
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x):
x = self.residual_conv(x)
x = self.conv(x)
x = self.avg_pool(x)
return x
class Xception(nn.Module):
def __init__(self, num_channels, num_classes):
super(Xception, self).__init__()
self.entry_flow = EntryFlow(num_channels)
self.middle_flow = self._make_middle_flow_blocks()
self.exit_flow = ExitFlow()
self.linear = nn.Linear(2048, num_classes)
def forward(self, x):
x = self.entry(x)
x = self.middle(x)
x = self.exit(x)
x = x.view(x.size(0), -1)
x = self.linear(x)
return x
def _make_middle_flow_blocks(self):
middle = nn.Sequential()
for i in range(8):
middle.add_module('middle_flow_block_{}'.format(i), MiddleFlow(728))
return middle