안녕하세요. 오늘도 Medium 포스트를 번역하는 포스팅을 해보도록 하겠습니다. 오늘의 포스팅은 제가 가장 처음에 했던 everyday-image-processing.tistory.com/95와 꽤나 유사합니다. 하지만 차이점은 3D 영상을 그대로 사용하는 지 안하는 지가 가장 큰 차이점입니다.
오늘 포스팅의 기본 아이디어는 GAN(Generative Adversial Network)입니다. 이 개념을 잘 아시는 분들은 그 중에서도 가장 기초적인 개념 중에 Pix2Pix에 대해서도 매우 잘 아실것이라고 생각합니다. 의료 분야에서는 일반적으로 segmentation을 위해서 사용되는 분야입니다.
1. 기본 Pix2Pix
기본적으로 Pix2Pix는 GAN과 마찬가지로 Generative Network와 Discriminant Network로 구성됩니다. 이때, Generative Network는 제가 가장 처음에 정리했던 포스팅의 U-net을 그대로 사용하게 됩니다. 애초에 U-net 자체가 Biomedical 분야에서 Segmentation을 위해서 사용된다는 사실을 생각하시면 그리 놀랍지 않은 사실입니다. U-net 논문에 대해서 정리한 포스트는 아래에 띄울테니 한번씩 보고 오시는 것을 추천드립니다!아티클 정리 - Medical images segmentation with Keras: U-net architecture
이제 U-net을 이용해서 새로운 이미지를 생성했다고 가정하겠습니다. 그러면 Discriminant Network는 무작위로 주어지는 이미지가 원본 데이터셋에서 온것인지, Generative Network에 의해서 생성된 이미지인지를 1과 0으로 판별합니다. 아주 단순한 기법으로 지금까지 수많은 분야에서 응용되고 있습니다!!
하지만 기본적으로 Pix2Pix의 단점은 MRI, CT와 같은 여러장의 이미지가 쌓인 "체적" 데이터를 다루기에는 적절하지 않다는 것입니다. 이는 medical AI의 큰 장애물이 될 것입니다. 그러나 이후에 새롭게 등장한 기법이 바로 Vox2Vox입니다. 이것은 단순한 2D 이미지만을 생성하는 Pix2Pix와는 달리 체적 데이터를 다룰 수 있다는 장점이 있습니다. 그러나... Vox2Vox에서도 큰 장애물이 있는 데 그것은 하드웨어적인 문제였습니다. 2D 이미지를 다룰 때도 꽤나 많은 연산량을 소비하였으나 3D 이미지를 다룰 때의 연산량은 말 그대로 어마무시 할 것입니다. 이것은 최근의 하드웨어의 발전으로 많이 해소가 된 문제이지만 그럼에도 불구하고 여전히 연산량 자체가 너무 많다는 점입니다.
2 Vox2Vox의 구현
원래는 파이토치 버전의 코드가 존재하지 않았다고 합니다. 하지만 이 게시글의 원작자가 직접 코드를 깃허브에 구현했다고 하니 수많은 star로 혼내줘야겠습니다.
먼저 코드를 깃허브로부터 불러오도록 하겠습니다. 원하는 폴더로 가셔서 터미널을 열고 아래의 명령어를 입력해주시면 됩니다.(참고로 제 환경은 Ubuntu18.04입니다.)
git clone https://github.com/chinokenochkan/vox2vox.git
그리고 환경을 맞추어줘야하는 데 해당 코드를 돌리기 위해서는 아래의 라이브러리들이 필요하다고 합니다.
- Python 3.7
- PyTorch>=0.4.0
- Torchvision
- Matplotlib
- Numpy, Scipy
- Pillow
- Scikit-image
- Pyvista
- h5py
일단 있는 것도 있고 없는 것도 있으니 저는 이중에서 필요한 것만 설치하도록 하겠습니다. 물론 만약 그런거 귀찮고 한방에 다 설치하시고 싶으신 분은 가상환경을 열고 아래의 명령어를 입력해주시면 됩니다.
pip install -r requirements.txt
이 명령어는 requirement.txt 문서에 들어있는 모든 라이브러리들을 자동으로 설치해줍니다. 그리고 저희가 바꿔야할 것은 Generator만 바꿔주면 됩니다. 원래는 U-net은 2D 이미지만 입력으로 받았지만 3D U-net이라는 신경망은 3D 이미지를 입력으로 받을 수 있는 신경망입니다. 그림을 기존의 U-net과 유사해보이지만 Bottle neck 부분에서 조금 변경이 생긴 것을 볼 수 있습니다. skip connection을 추가함으로써 convolution 연산이나 maxpooling 연산으로 인해 생기는 원본 데이터의 손실을 줄인 것 같습니다.
이제 이 Generator를 어떻게 구현했는 지 보도록 하겠습니다. 먼저, 전체 코드를 본 뒤 조금씩 쪼개어 보도록 하겠습니다.
#***********************#
#***Code by:************#
#***Chi Nok Enoch Kan***#
#***********************#
#*******<(^.^)>*********#
#***********************#
#*****Encoder Block*****#
#***********************#
#***********************#
#***********************#
class UNetDown(nn.Module):
def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
super(UNetDown, self).__init__()
layers = [nn.Conv3d(in_size, out_size, 4, 2, 1, bias=False)]
if normalize:
layers.append(nn.InstanceNorm3d(out_size))
layers.append(nn.LeakyReLU(0.2))
if dropout:
layers.append(nn.Dropout(dropout))
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
#*****Bottleneck Block*****#
class UNetMid(nn.Module):
def __init__(self, in_size, out_size, dropout=0.0):
super(UNetMid, self).__init__()
layers = [
nn.Conv3d(in_size, out_size, 4, 1, 1, bias=False),
nn.InstanceNorm3d(out_size),
nn.LeakyReLU(0.2)
]
if dropout:
layers.append(nn.Dropout(dropout))
self.model = nn.Sequential(*layers)
def forward(self, x, skip_input):
# print(x.shape)
x = torch.cat((x, skip_input), 1)
x = self.model(x)
x = nn.functional.pad(x, (1,0,1,0,1,0))
return x
#*****Decoder Block*****#
class UNetUp(nn.Module):
def __init__(self, in_size, out_size, dropout=0.0):
super(UNetUp, self).__init__()
layers = [
nn.ConvTranspose3d(in_size, out_size, 4, 2, 1, bias=False),
nn.InstanceNorm3d(out_size),
nn.ReLU(inplace=True),
]
if dropout:
layers.append(nn.Dropout(dropout))
self.model = nn.Sequential(*layers)
def forward(self, x, skip_input):
x = self.model(x)
x = torch.cat((x, skip_input), 1)
return x
일단 기본적으로 파이토치 코드이기 때문에 클래스를 이용해서 모듈을 정의하고 클래스의 메서드로 forward 함수를 정의하는 것을 볼 수 있습니다. 이는 파이토치의 공식과도 같기 때문에 외워두셔야합니다. 이제 조금씩 뜯어서 확인해보도록 하겠습니다.
class UNetDown(nn.Module):
def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
super(UNetDown, self).__init__()
layers = [nn.Conv3d(in_size, out_size, 4, 2, 1, bias=False)]
if normalize:
layers.append(nn.InstanceNorm3d(out_size))
layers.append(nn.LeakyReLU(0.2))
if dropout:
layers.append(nn.Dropout(dropout))
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
이 코드는 기존 U-net의 contracting path의 한 블럭을 의미합니다. 굉장히 유사하지만 다른 점은 Conv2d가 아닌 Conv3d를 사용하고 있는 모습을 볼 수 있습니다. 이와 같이 매우 간단하게 코드를 바꿈으로써 3D 이미지도 입력으로 받을 수 있음을 볼 수 있습니다.
#*****Bottleneck Block*****#
class UNetMid(nn.Module):
def __init__(self, in_size, out_size, dropout=0.0):
super(UNetMid, self).__init__()
layers = [
nn.Conv3d(in_size, out_size, 4, 1, 1, bias=False),
nn.InstanceNorm3d(out_size),
nn.LeakyReLU(0.2)
]
if dropout:
layers.append(nn.Dropout(dropout))
self.model = nn.Sequential(*layers)
def forward(self, x, skip_input):
# print(x.shape)
x = torch.cat((x, skip_input), 1)
x = self.model(x)
x = nn.functional.pad(x, (1,0,1,0,1,0))
return x
다음은 Bottle neck 블럭입니다. 이 부분 역시 기존의 U-net과 매우 유사하지만 Conv2d가 아닌 Conv3가 사용되었습니다. 또한 torch의 cat 함수를 통해서 하나로 합침으로써 skip connection이 구현되었습니다.
#*****Decoder Block*****#
class UNetUp(nn.Module):
def __init__(self, in_size, out_size, dropout=0.0):
super(UNetUp, self).__init__()
layers = [
nn.ConvTranspose3d(in_size, out_size, 4, 2, 1, bias=False),
nn.InstanceNorm3d(out_size),
nn.ReLU(inplace=True),
]
if dropout:
layers.append(nn.Dropout(dropout))
self.model = nn.Sequential(*layers)
def forward(self, x, skip_input):
x = self.model(x)
x = torch.cat((x, skip_input), 1)
return x
마지막으로 expanding path의 블럭입니다. 이 역시 차이점은 동일합니다. 이제 저희가 해야되는 것은 이 블럭들을 전부 하나의 형태로 묶어줘야한다는 것입니다. 아래의 Generative network를 보시면 됩니다.
class GeneratorUNet(nn.Module):
def __init__(self, in_channels=1, out_channels=1):
super(GeneratorUNet, self).__init__()
self.down1 = UNetDown(in_channels, 64, normalize=False)
self.down2 = UNetDown(64, 128)
self.down3 = UNetDown(128, 256)
self.down4 = UNetDown(256, 512)
self.mid1 = UNetMid(1024, 512, dropout=0.2)
self.mid2 = UNetMid(1024, 512, dropout=0.2)
self.mid3 = UNetMid(1024, 512, dropout=0.2)
self.mid4 = UNetMid(1024, 256, dropout=0.2)
self.up1 = UNetUp(256, 256)
self.up2 = UNetUp(512, 128)
self.up3 = UNetUp(256, 64)
# self.us = nn.Upsample(scale_factor=2)
self.final = nn.Sequential(
nn.ConvTranspose3d(128, out_channels, 4, 2, 1),
nn.Tanh()
)
def forward(self, x):
# U-Net generator with skip connections from encoder to decoder
d1 = self.down1(x)
d2 = self.down2(d1)
d3 = self.down3(d2)
d4 = self.down4(d3)
m1 = self.mid1(d4, d4)
m2 = self.mid2(m1, m1)
m3 = self.mid3(m2, m2)
m4 = self.mid4(m3, m3)
u1 = self.up1(m4, d3)
u2 = self.up2(u1, d2)
u3 = self.up3(u2, d1)
return self.final(u3)
view raw
위에서부터 보시면, contracting path 내에 블럭이 4개(입력 채널 -> 64 -> 128 -> 256 -> 512)로 점점 채널의 개수가 늘어나게 됩니다. 그 다음에는 bottle neck에서도 블랙이 4개가 사용되었습니다. 다만 dropout과 함께 skip connection이 적용되어 있음을 볼 수 있습니다. 그 다음으로 expanding path 내에 블럭이 또 4개가 있고 이때 contracting path에서 expanding path로의 skip connection이 존재하여 매번 channel의 개수가 유지됩니다.
이제 Generator는 이와 같이 매우 간단하게 구현할 수 있습니다. 다음으로 구현해야할 것은 Discriminator입니다. 사실 저는 GAN을 정확하게는 모르지만 이전에 잠깐 공부했을 때는 Generator와 Discriminator 사이의 학습 균형을 잘 이루어야하는 것이 중요하다고 설명을 들었습니다. 이 포스팅에서도 이 부분을 특히나 강조하고 있습니다. 아래의 코드를 보도록 하겠습니다.
##############################
# Discriminator
##############################
class Discriminator(nn.Module):
def __init__(self, in_channels=1):
super(Discriminator, self).__init__()
def discriminator_block(in_filters, out_filters, normalization=True):
"""Returns downsampling layers of each discriminator block"""
layers = [nn.Conv3d(in_filters, out_filters, 4, stride=2, padding=1)]
if normalization:
layers.append(nn.InstanceNorm3d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*discriminator_block(in_channels * 2, 64, normalization=False),
*discriminator_block(64, 128),
*discriminator_block(128, 256),
*discriminator_block(256, 512),
# nn.ZeroPad3d((1, 0, 1, 0)),
)
self.final = nn.Conv3d(512, 1, 4, padding=1, bias=False)
def forward(self, img_A, img_B):
# Concatenate image and condition image by channels to produce input
img_input = torch.cat((img_A, img_B), 1)
intermediate = self.model(img_input)
pad = nn.functional.pad(intermediate, pad=(1,0,1,0,1,0))
return self.final(pad)
방금 설명드렸다싶이 Discriminator는 입력받은 데이터가 기존 데이터셋에 있는 데이터인지, Generator에 의해서 생성된 데이터인지를 1과 0으로 판별한다고 하였습니다. 이 부분만 잘 기억하시면 됩니다. 먼저, discriminator block을 보시면 Conv3d를 이용해서 3D 이미지를 데이터로 받고 있습니다. 그리고 이를 여러개로 쌓아서 총 5개의 layer(in_channel*2 -> 64 -> 128 -> 256 -> 512->1)를 쌓았습니다. 마지막 layer에서 1이 sigmoid를 이용해서 0과 1을 최종적으로 판단하는 layer라고 보면 됩니다. 그 다음에는 공식처럼 따라오는 forward 메서드입니다. 이 메서드는 2개의 이미지를 받습니다. 아마도 생성된 이미지와 기존의 데이터셋인 거 같네요. torch의 concat 함수를 이용해서 하나로 합칩니다. 그 다음에 discriminator에 집어넣어서 결과를 얻습니다.
3. Train
# ---------------------
# Train Discriminator, only update every disc_update batches
# ---------------------
# Real loss
fake_B = generator(real_A)
pred_real = discriminator(real_B, real_A)
loss_real = criterion_GAN(pred_real, valid)
# Fake loss
pred_fake = discriminator(fake_B.detach(), real_A)
loss_fake = criterion_GAN(pred_fake, fake)
# Total loss
loss_D = 0.5 * (loss_real + loss_fake)
d_real_acu = torch.ge(pred_real.squeeze(), 0.5).float()
d_fake_acu = torch.le(pred_fake.squeeze(), 0.5).float()
d_total_acu = torch.mean(torch.cat((d_real_acu, d_fake_acu), 0))
if d_total_acu <= opt.d_threshold:
optimizer_D.zero_grad()
loss_D.backward()
optimizer_D.step()
discriminator_update = 'True'
# ------------------
# Train Generators
# ------------------
optimizer_D.zero_grad()
optimizer_G.zero_grad()
# GAN loss
fake_B = generator(real_A)
pred_fake = discriminator(fake_B, real_A)
loss_GAN = criterion_GAN(pred_fake, valid)
# Voxel-wise loss
loss_voxel = criterion_voxelwise(fake_B, real_B)
# Total loss
loss_G = loss_GAN + lambda_voxel * loss_voxel
loss_G.backward()
optimizer_G.step()
이제 모델들을 정의했으니 학습할 시간입니다. 위의 코드를 참고해주시길 바랍니다. 먼저, generator와 실제 데이터를 이용해서 가짜 데이터, 즉 생성된 데이터는 fake_B를 생성합니다. 그리고 discriminator한테 실제 데이터 2개 real_A와 real_B를 넘겨줍니다. 그러면 discriminator는 이 데이터들이 가짜인지 진짜인지 판별하여 이를 확률적으로 저희에게 알려줄 것입니다. 그리고 torch에서 제공하는 MSE 손실함수를 GAN의 손실로 사용합니다. 이를 real loss라고 정의하도록 하겠습니다(5~7line). 그 다음에는 fake loss를 정의해야합니다(9~10line). 이제 모든 손실들을 하나로 합치면 됩니다(12line).
이제!! 코드를 이용해서 학습을 진행하면 되지만,,, MICCAI에서 제공하는 BraTS 데이터셋을 요청한 지 얼마 안되서 아직 데이터를 받지 못하였습니다 ㅠㅠ 데이터셋이 도착하면!! 바로 학습을 진행하여 결과를 보도록 하겠습니다!!
'인공지능 > 아티클 정리' 카테고리의 다른 글
아티클 정리 - How to Do Hyperparameter Tuning on Any Python Script in 3 Easy Steps (0) | 2020.11.19 |
---|---|
아티클 정리 - DCGAN Under 100 Lines of Code (0) | 2020.10.16 |
아티클 정리 - Policy Gradient Reinforcement Learning in PyTorch (0) | 2020.09.09 |
아티클 정리 - Building a Face Recognizer in Python (0) | 2020.09.06 |
아티클 정리 - Medical images segmentation with Keras: U-net architecture (0) | 2020.09.04 |