안녕하세요. 오늘은 간단한 구조이지만 매우 강력한 성능을 자랑하는 segmentation 구조 중에 하나인 U-net에 대해서 다루어보겠습니다. 참고로 이전에 리뷰한 FCN과 연결되니 읽어보고 오시는 것을 추천드립니다.
Network Architecture
U-net의 네트워크 구조는 위와 같습니다. 실제로 U자형을 그리는 것을 볼 수 있습니다. 구조를 한번 뜯어보겠습니다. 본 논문에서는 크게 2개의 서브 네트워크로 나누어 설명합니다. 각각 contracting path, expansive path라고 합니다. contracting path는 왼쪽의 채널의 개수가 점점 많아지는 네트워크이고 expansive path는 오른쪽의 채널의 개수가 점점 줄어들어 최종적으로는 입력 이미지와의 shape이 동일해지는 네트워크입니다.
contracting path에서 각 block은 2번의 3x3 conv과 2x2 maxpooling으로 이루어져있습니다. 이를 통해서 점차 downsampling이 되는 효과를 얻을 수 있습니다. 각 downsampling 스텝은 feature map의 채널을 2배씩 늘리는 효과를 얻고 있습니다. 그 순서를 보면 128 $\rightarrow$ 256 $\rightarrow$ 512 $\rightarrow$ 1024로 구성되는 것을 볼 수 있습니다.
expanding path에서는 contracting path로부터 얻은 feature map을 2배씩 upsampling(up-convolution)을 적용합니다. 이를 통해서 feature map의 가로, 세로의 크기를 다시 늘릴 수 있습니다. 여기서 이전에 리뷰한 논문인 FCN에서 적용한 skip connection을 적용합니다. upsampling의 최대 단점은 feature map의 정보를 잃어버린다는 것에 있었습니다. 그래서 이전 layer의 정보를 concatenate하는 방법을 사용했습니다. U-net에서는 expanding path의 각 block의 시작마다 그에 대응되는 contracting path의 block의 maxpooling을 적용하기 전 output을 하나로 concatenate하여 upsampling할 때마다 잃어버리는 정보를 보존하도록 하였습니다. upsampling을 적용한 뒤에는 3x3 conv를 2번 적용하면서 다시 점차 입력 이미지와 shape을 맞추어갑니다.
Training
본 논문에서는 SGD 최적화 기법을 사용하여 학습을 진행하였습니다. 또한 네트워크의 마지막 특징 맵(feature map)에서 각 픽셀이 어떤 클래스에 들어가는 지에 대한 예측값을 소프트맥스(softmax) 함수를 이용하여 계산합니다.
따라서 loss function으로 cross entropy 함수가 사용됩니다.
여기서 $w(\mathbf{x})$는 동일한 클래스에 대한 세포의 경계를 분리하기 위한 Weight map loss를 의미합니다.
위의 그림의 (d)에 주목해주시길 바랍니다. 두 세포 사이의 간격을 잘 포착해야 더 정확하게 나눌 수 있기 때문에 이 loss를 추가적으로 적용하게 되는 것입니다.
Implementation Code
import torch
import torch.nn as nn
import torch.nn.functional as F
class Unet(nn.Module) :
def __init__(self, in_dim, n_class, num_filters):
super(Unet, self).__init__()
self.criterion = nn.BCEWithLogitsLoss()
self.in_dim = in_dim
self.n_class = n_class
self.num_filters = num_filters
act_fn = nn.LeakyReLU(0.2, inplace=True)
# Encoding Parts
self.down1 = conv_block_2(self.in_dim, self.num_filters, act_fn)
self.pool1 = maxpool()
self.down2 = conv_block_2(self.num_filters * 1, self.num_filters * 2, act_fn)
self.pool2 = maxpool()
self.down3 = conv_block_2(self.num_filters * 2, self.num_filters * 4, act_fn)
self.pool3 = maxpool()
self.down4 = conv_block_2(self.num_filters * 4, self.num_filters * 8, act_fn)
self.pool4 = maxpool()
self.bridge = conv_block_2(self.num_filters * 8, self.num_filters * 16, act_fn)
# Decoding Parts
self.trans1 = conv_trans_block(self.num_filters * 16, self.num_filters * 8, act_fn)
self.up1 = conv_block_2(self.num_filters * 16, self.num_filters * 8, act_fn)
self.trans2 = conv_trans_block(self.num_filters * 8, self.num_filters * 4, act_fn)
self.up2 = conv_block_2(self.num_filters * 8, self.num_filters * 4, act_fn)
self.trans3 = conv_trans_block(self.num_filters * 4, self.num_filters * 2, act_fn)
self.up3 = conv_block_2(self.num_filters * 4, self.num_filters * 2, act_fn)
self.trans4 = conv_trans_block(self.num_filters * 2, self.num_filters * 1, act_fn)
self.up4 = conv_block_2(self.num_filters * 2, self.num_filters * 1, act_fn)
# output block
self.out = nn.Sequential(
nn.Conv2d(self.num_filters, self.n_class, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
)
def forward(self, x, mode):
# feature encoding
down1 = self.down1(x)
pool1 = self.pool1(down1)
down2 = self.down2(pool1)
pool2 = self.pool2(down2)
down3 = self.down3(pool2)
pool3 = self.pool3(down3)
down4 = self.down4(pool3)
pool4 = self.pool4(down4)
bridge = self.bridge(pool4)
# feature decoding
trans1 = self.trans1(bridge)
concat1 = torch.cat([trans1, down4], dim=1)
up1 = self.up1(concat1)
trans2 = self.trans2(up1)
concat2 = torch.cat([trans2, down3], dim=1)
up2 = self.up2(concat2)
trans3 = self.trans3(up2)
concat3 = torch.cat([trans3, down2], dim=1)
up3 = self.up3(concat3)
trans4 = self.trans4(up3)
concat4 = torch.cat([trans4, down1], dim=1)
up4 = self.up4(concat4)
out = self.out(up4)
return out
def _calculate_criterion(self, criterion, y_pred, y_true, mode):
loss = criterion(y_pred, y_true)
return loss
def conv_block_2(in_dim, out_dim, act_fn) :
model = nn.Sequential(
conv_block(in_dim, out_dim, act_fn),
nn.Conv2d(out_dim, out_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.BatchNorm2d(out_dim)
)
return model
def conv_trans_block(in_dim, out_dim, act_fn) :
model = nn.Sequential(
nn.ConvTranspose2d(in_dim, out_dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1)),
nn.BatchNorm2d(out_dim), act_fn
)
return model
def maxpool() :
pool = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=(0, 0))
return pool
def conv_block(in_dim, out_dim, act_fn) :
model = nn.Sequential(
nn.Conv2d(in_dim, out_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
nn.BatchNorm2d(out_dim),
act_fn
)
return model