안녕하세요. 오늘은 CNN과 Transformer를 layer-wise cascading한 방식으로 혼합하고자 했던 UTNet에 대한 소개를 하도록 하겠습니다.
UTNet: A Hybrid Transformer Architecture for Medical Image Segmentation
Transformer architecture has emerged to be successful in a number of natural language processing tasks. However, its applications to medical vision remain largely unexplored. In this study, we present UTNet, a simple yet powerful hybrid Transformer archite
arxiv.org
Background
의료 영상 분할을 대장내시경, 현미경, 초음파, CT, MRI 등 다양한 모달리티 영상에서 폴립, 세포, 유방암, 폐 감염 등 관심 영역을 분할하는 것을 목표로 하고 있습니다. 이를 통해, 진단 보조 시스템과 함께 향후 치료 계획을 수립할 수 있으며 로봇 수술 시에도 중요한 역할을 하고 향후 종양의 성장세를 파악하여 얼마나 약화시킬 수 있는지도 예측할 수 있습니다.
이러한 중요성으로 이전부터 딥 러닝을 이용한 의료 영상 분할을 꾸준히 연구되어 왔습니다. 가장 대표적인 모델이 이전에 제가 소개했던 UNet, UNet++, AttentionUNet일 것 입니다. 이 모델들의 가장 중요한 특징 중 하나는 모든 계층이 합성곱 계층으로 구성되어 있다는 점 입니다. 최신 연구에 따르면 합성곱 계층은 주로 입력 영상 내의 텍스쳐 및 엣지 정보를 기반으로 최종판단을 진행하는 것으로 알려져있습니다. 하지만, 영상의 전체적인 정보 (global context information)을 고려하지 않기 때문에 최적이 아닌 성능을 가지게 되죠. 이러한 문제점을 해결하기 위해 최근 Transformer의 self-attention 매커니즘을 활용하는 모델들도 많이 제시되었습니다. 가장 대표적인 모델이 SwinUNet으로 모든 계층을 오직 Swin Transformer로만 구성한 모델입니다. 이를 통해, 모든 계층에서 성공적으로 입력 영상의 전체적인 정보를 활용할 수 있었지만 CNN이 가지고 있는 중요한 역할인 지역적 특징을 추출하지 못하는 문제점이 있습니다.
이러한 문제점이 입각하여 TransUNet은 CNN에서 추출한 정보를 ViT에게 넘겨 지역적 특징과 전역적 특징을 동시에 고려하였습니다. 하지만, 이는 마지막 계층만 Transformer의 전역적인 문맥 정보를 활용할 수 있다는 점에서 큰 문제가 있었습니다. 이러한 문제를 해결하기 위해 오늘 소개하는 UTNet은 CNN과 Transformer를 UNet기반으로 설계하여 layer-wise로 cascading하게 수행하여 모든 계층에서 전역적인 문맥정보를 얻을 수 있게 되었습니다.
Proposed Method: UTNet
1) Overall Architecture

그림 1은 UTNet의 전체적인 구조를 보여주고 있습니다. 언뜻보면 기존의 UNet과 별차이가 없는 것을 볼 수 있습니다. 실제로도 그 차이는 기존의 UNet은 각 계층별로 합성곱 연산을 두번씩 수행하는 반면 UTNet은 합성곱 + Self-attention을 한번씩 번갈아가면서 진행하게 됩니다.
2) Revisiting Self-Attention Mechanism
기본적으로 Transformer는 multi-head self-attention (MHSA)와 Feed-Forward Network (FFN)의 반복으로 구성됩니다. 특히, MHSA 같은 경우에는 단일 헤드 self-attention보다 다양한 sub-space에서 표현능력을 향상시키기 때문에 어텐션 능력을 더욱 향상시키게 됩니다. 이 과정을 간단하게 설명하면 다음과 같습니다.
STEP1. 입력 특징맵 $X \in \mathbb{R}^{C \times H \times W}$를 query ($\mathbf{Q}$), key ($\mathbf{K}$) , value ($\mathbf{V}$) 로 변환한다. 이때, 각 토큰들은 $\mathbf{Q}, \mathbf{K}, \mathbf{V} \in \mathbb{R}^{d \times H \times W}$로 형상을 얻게 된다. 여기서, $d$는 hidden dimension의 크기를 의미한다.
STEP2. $\mathbf{Q}, \mathbf{K}, \mathbf{V}$은 flatten하고 두 차원의 위치를 변경하여 $\mathbf{Q}, \mathbf{K}, \mathbf{V} \in \mathbb{R}^{n \times d}$가 되고 여기서 $n = HW$이다.
STEP3. $\mathbf{Q}, \mathbf{K}, \mathbf{V}$를 이용해서 self-attention을 적용한다.
$$\text{Attention} (\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{Softmax} \left( \frac{\mathbf{Q}\mathbf{K}^{T}}{\sqrt{d}} \right) \mathbf{V}$$
위 self-attention 수식에서 $\text{Softmax} \left( \frac{\mathbf{Q}\mathbf{K}^{T}}{\sqrt{d}} \right) = P \in \mathbb{R}^{n \times n}$로 정의해두면 $P$를 context aggregation matrix 또는 similarity matrix라고 부를 수 있습니다.
여기서 $i$번째 행의 query에 대응되는 context aggregation matrix는 $\text{Softmax} \left( \frac{\text{q}_{i}\mathbf{K}^{T}}{\sqrt{d}} \right) = P_{i} \in \mathbb{R}^{1 \times n}$로 볼 수 있으며 이는 query $\text{q}_{i}$와 각 key 값들 사이의 similarity를 측정한 것과 동일합니다. 최종적으로 context aggregation matrix는 value $\mathbf{V}$에 적용되어 context information을 저장하는 데 활용됩니다.
3) Efficient Self-Attention Mechanism
여기서 중요한 점은 self-attention을 수행하게 되면 $n$개의 토큰들 사이의 모든 similarity를 계산해야하기 때문에 굉장히 비효율적입니다. 또한, self-attention은 long-sequence 데이터에 대해서 low-rank 연산이기 때문에 가장 중요한 정보는 가장 큰 singular value에 모여있게 됩니다.

그림 2는 이러한 비효율성 문제를 해결하기 위한 Efficient Self-Attention 구조입니다. 핵심은 입력 특징맵 $X \in \mathbb{R}^{C \times H \times W}$의 토큰들의 개수를 줄이기 위해 sub-sampling을 수행하는 것 입니다. 사실 그거 말고는 크게 다른 점은 없지만 이를 수식으로 나타내면 다음과 같습니다.
$$\text{Attention} (\mathbf{Q}, \overline{\mathbf{K}}, \overline{\mathbf{V}}) = \text{Softmax} \left( \frac{\mathbf{Q}\overline{\mathbf{K}}^{T}}{\sqrt{d}} \right) \overline{\mathbf{V}}$$
여기서 $\overline{\mathbf{K}}, \overline{\mathbf{V}} \in \mathbb{R}^{k \times d}$로 $k = hw < n$입니다. $h$와 $w$는 기존의 입력 특징 맵에 sub-sampling을 적용했을 때 나오는 더 작은 크기의 해상도로 정의되기 때문에 더 효율적으로 동작하게 됩니다. 따라서, 기존의 self-attention의 복잡도가 $\mathcal{O} (n^{2}d)$ 였다면 efficient self-attention의 복잡도는 $\mathcal{O} (nkd)$가 되어 효율적으로 변하게 됩니다.
4) Relative Positional Encoding
기존의 self-attention의 다른 단점 중 하나는 위치 정보를 무시한다는 점 입니다. ViT에서는 완전 초기에 linear embedding을 시킬 때 말고는 위치 정보를 부여해주지는 않았던 것을 기억하실겁니다. 이를 해결하기 위한 방법은 sinusoidal embedding을 하는 방법도 있지만 이는 2D 정보를 무시하기 때문에 최고의 선택이 아닙니다. 따라서 본 논문에서는 2D 기반의 relative positional embedding을 수행합니다. 이를 위해 softmax 연산을 수행하기 전에 $i = (i_{x}, i_{y})$와 $j = (j_{x}, j_{y})$ 사이의 relative position encoding을 더해줍니다.
$$l_{i, j} = \frac{\text{q}_{i}^{T}}{\sqrt{d}} \left( k_{j} + r^{W}_{j_{x} - i_{x}} + r^{H}_{j_{y} - i_{y}} \right)$$
여기서 $r^{W}_{j_{x} - i_{x}}$와 $r^{H}_{j_{y} - i_{y}}$는 학습가능한 벡터로 정의됩니다. 최종적으로 efficient self-attention과 relative positional encoding을 합쳐서 다음과 같이 쓰게 됩니다.
$$\text{Attention} (\mathbf{Q}, \overline{\mathbf{K}}, \overline{\mathbf{V}}) = \text{Softmax} \left( \frac{\mathbf{Q}\overline{\mathbf{K}}^{T} + \mathbf{S}^{rel}_{H} + \mathbf{S}^{rel}_{W}}{\sqrt{d}} \right) \overline{\mathbf{V}}$$
Experiment Results
1) Quantitative Results


2) Qualitative Results

3) Ablation Study
