안녕하세요. 지난 포스팅의 [IC2D] Res2Net: A New Multi-Scale Backbone Architecture (IEEE TPAMI2019)에서는 2019년에 인공지능 최고 저널 중 하나인 TPAMI에 억셉된 모델인 Res2Net에 대해서 소개시켜드렸습니다. Res2Net은 multi-scale 정보를 활용하기 위해 입력 특징 맵을 group convolution을 이용해서 쪼갠 뒤 각 그룹 별로 계층적 residual-like connection을 추가하였습니다. 이를 통해, 기존 ResNet보다 훨씬 더 넓은 receptive field를 가지게 됨을 알 수 있었으며 새로운 모델의 차원인 scale을 도입하였습니다. 오늘은 CondenseNet이라는 효율성을 기반으로 한 새로운 모델에 대해서 소개시켜드리도록 하겠습니다.
Background
심층 신경망에서 모델 설계는 크게 두 가지 방향으로 나눌 수 있습니다. 첫번째는 컴퓨팅 자원은 신경쓰지 않고 오로지 성능 향상에만 초점을 맞추는 것으로 CNN의 최대 인지 능력을 연구하는 것을 목표로 합니다. 두번째로는 컴퓨팅 자원을 고려하고 성능 하락을 방어하는 것에 초점을 맞추는 것입니다. 특히, 두번째 방향성 같은 경우에는 대표적으로 MobileNet V1, ShuffleNet, MobileNet V2 등이 있었습니다. 이 모델들의 공통점은 group convolution을 사용하여 합성곱에서 발생하는 연산량을 줄이게 됩니다.
위와 같이 새로운 모델을 개발할 수도 있지만 필요없는 필터를 제거하여 파라미터의 개수를 줄일 수도 있고 (prunning) 적은 precision 시스템 및 가중치 양자화 (Low-Precision & Quantized Weight)를 통해 모델을 압축할 수도 있습니다. 오늘 소개할 CondenseNet은 제목에서 보시다싶이 기존의 DenseNet에 존재하는 수많은 필터 연결들 중 필요한 것만 살리는 prunning 기법에 group convolution을 추가한 방법입니다. 이를 통해, 기존의 DenseNet보다 약 10배 적은 연산량으로 높은 성능을 보이게 됩니다.
CondenseNets
1). Motivations
위 그림은 PreAct ResNet의 residual path 입니다. 일반적으로 많은 모델들에서 해당 블록을 기반으로 설계를 하죠. 이때, MobileNet V1에서 보았듯이 group convolution을 적용하게 되면 성능이 크게 감소하게 됩니다. 그렇다면 왜 group convolution을 적용하면 성능이 떨어지게 될까요? 본 논문에서는 2가지 문제점을 지적합니다. 각각 (1). intrinsic order과 (2). high diversity입니다. 이는 group convolution을 수행할 때 입력 특징 맵들을 group으로 분리하게 될텐데, 이는 서로 관련성이 높은 특징 맵들이 분리되어 오히려 효율성인 feature re-use를 방해하기 때문입니다.
CondenseNet에서는 이 문제를 해결하기 위해 채널단위로 permutting을 수행합니다. 여기서, ShuffleNet을 떠올리셨다면 아주 좋습니다. 해당 모델도 채널들을 shuffling하는 과정이 존재합니다. 두 모델의 큰 차이점은 shuffling하는 단위입니다. ShuffleNet에서는 shuffling 연산을 수행할 때 각 그룹에서 한번 더 서브그룹으로 쪼갠뒤 그룹 간의 shuffling을 수행하는 데 반해, CondenseNet에서는 채널 단위로 permuting을 수행한다는 점이 다릅니다. CondenseNet에서는 이를 통해 group convolution에 의한 부정적인 효과를 어느정도 억제할 수 있다고 합니다.
위 그림은 DenseNet에서 제안한 하나의 블록입니다. 블록 내에서 깊은 계층일수록 굉장히 많은 양의 connection이 존재하기 때문에 조밀한 연결 (dense connection)이라고 부릅니다. 그런데, 이렇게 많이 연결된 필터 중에서 정말 다 필요한 것만 있을까요? CondenseNet에서는 이러한 점을 지적하여 필요없는 연결을 삭제 (prunning)을 통해 필요한 연결만을 가져오도록 하는 group convolution인 Learned Group Convolution (LGC)을 제안하였습니다.
2). Learned Group Convolution
그림3은 CondenseNet에서 제안하는 Learned Group Convolution의 전체적인 모습을 보여주고 있습니다. 기본적으로 LGC는 두 개의 stage로 구성되어 있습니다. 각각 Condensing Stage와 Optimization Stage입니다. Condensing Stage에서는 미리 정의된 반복횟수만큼 반복하면서 모델의 sparsity를 증가시키기 위해 적은 magnitude를 가지는 가중치 필터를 삭제하는 과정입니다. 그리고 Optimization Stage에서는 이제 grouping된 필터들을 고정시킨 뒤 전체 모델을 학습하게 됩니다. 즉, Condensing Stage는 prunning 과정이고 Optimization Stage는 모델을 학습하는 과정입니다.
(1). Filter Groups
그림2는 standard convolution과 group convolution 사이의 차이를 시각화하여 보여주고 있습니다. 저희가 standard convolution을 수행하면 입력 및 출력 특징 맵의 채널 개수와 필터의 크기가 주요한 파라미터입니다. 이를 각각 $O, R, H, W$라고 정의하면 standard convolution을 수행할 때 필터의 모양은 $O \times R \times H \times W$입니다. 이때, CondenseNet에서 LGC를 적용하는 부분은 $1 \times 1$ 필터 크기의 합성곱입니다. 따라서, $H = W = 1$이 되므로 $1 \times 1$ 크기의 standard convolution의 필터 모양은 $O \times R$이 됩니다.
여기서, group convolution을 적용한다고 가정하면 $G$개의 필터들 $\{\mathbf{F}^{1}, \dots, \mathbf{F}^{G}\}$가 각각 $\frac{O}{G} \times R$의 필터 크기를 가지게 됩니다. 여기까지가 group convolution의 정의였습니다.
(2). Condensation Criterion
이제 문제는 무슨 기준으로 필터를 삭제할 것 인지 입니다. CondenseNet에서는 L1 노름을 이용합니다.
$$\sum_{i = 1}^{\frac{O}{G}} \left| \mathbf{F}^{g}_{i, j} \right|$$
여기서 $i$는 row index, $j$는 column index를 의미하는 것으로 각각 출력 특징 맵의 $i$번째 채널과 입력 특징 맵의 $j$번째 채널로 매핑됩니다. 다음으로 각 필터 별로 위 L1 노름을 계산합니다.
$$\{ \sum_{i = 1}^{\frac{O}{G}} \left| \mathbf{F}^{g}_{i, 1} \right|, \dots, \sum_{i = 1}^{\frac{O}{G}} \left| \mathbf{F}^{g}_{i, R} \right| \}$$
이제 위의 필터들 중에서 L1 노름이 작은 경우는 0으로 바꾸어버립니다. 그러면 해당 특징 맵의 출력은 0이 나오기 때문에 삭제되는 것과 동일한 효과를 내는 것이죠.
(3). Group Lasso
Condensing Stage에서 필터들을 삭제하게 되면 아무래도 성능이 감소할 수 밖에 없습니다. 이러한 부작용을 감소시키기 위해 Group Lasso Regularizer를 적용하여 추가적으로 학습을 진행하였습니다. lasso의 특성 상 의미없는 특징들을 0으로 강제하게 되는데 이를 통해 모델의 sparsity를 강화할 수 있다고 합니다.
(4). Condensation Factor
Condensing Stage에서 필터를 삭제할 때 얼마나 삭제할 지에 대한 파라미터로 $C$를 도입하였습니다. 그림3에서도 볼 수 있다 싶이 $C = 3$이면 2개의 Condensing Stage가 존재합니다. 이제, 각 Stage에서 $\frac{1}{C}$ 만큼의 필터를 삭제하기 때문에 $\frac{R}{C}$개의 연결만 남기게 됩니다. 이때, Condensation Factor $C$를 기반으로 Condensing Stage의 반복횟수와 Optimization Stage의 반복횟수를 $\frac{M}{2(C - 1)}$로 고정하여 학습을 진행합니다. 여기서, $M$은 에폭 횟수입니다.
(5). Learning Rate
그림4는 Condensation Factor $C = 4$라고 했을 때 학습률과 학습 손실함수의 경향성을 그렸습니다. 보시면 150 에폭에서 크게 손실함수가 뛰게 되는 데 이는 150 에폭에 도달하면 거의 절반 정도의 필터들을 삭제하기 때문입니다. 이 과정에서 CondenseNet은 Cosine-Shape Learning Rate Scheduler를 활용하였습니다.
(6). Index Layer
위 그림에서 왼쪽 그림은 학습 시, 오른쪽 그림은 평가 시의 모델 구조 입니다. 차이점은 Index Layer의 유무입니다. 해당 Layer는 학습 과정에서 얻은 필요한 connection을 선택하고 재배열하는 연산을 수행하는 계층입니다.
CondensNet Architecture
(1). Exponentially Increasing Growth Rate (IGH)
DenseNet에서 Bottleneck에 있던 Growth Rate를 기억하시나요? DenseNet에서는 이를 상수값인 $k$로 고정하여 모든 블록에서 동일하게 사용되었습니다. 반면에 CondenseNet에서는 어차피 필터를 삭제하는 김에 깊어질수록 더 많은 필터들을 처음에 만들게 합니다. 이를 IGH라고 하며 CondenseNet에서 Growth Rate는 $k = 2^{m - 1}k_{0}$가 됩니다. 여기서, $m$는 Dense Block의 인덱스를 의미합니다.
(2). Fully Dense Connectivity (FDC)
(1)번의 이유와 비슷하게 어차피 삭제될 것이기 때문에 의미있는 특징들을 더 많은 추출하기 위해 블록 내에서 Dense Connection과 함께 모든 블록 사이의 Dense Connection을 추가하였으며 이때 resolution이 다르기 때문에 이를 맞추기 위해 풀링 연산을 한 뒤 concat을 수행하였습니다.
(3). CondenseNet Architecture
Experiment Results
CondenseNet은 CIFAR와 ImageNet-1K (2012) 데이터셋에 대해서 성능을 측정하였습니다.
(1). CIFAR Classification Results
(2). ImageNet-1K (2012) Classification Results
(3). Ablation Study
(1). Pruning Strategy
(2). Effectiveness of Condensation Factor and Group
Implementation Code
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from layers import BasicConv
class LearnedGroupConv(nn.Module):
global_progress = 0.0
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int=1,
padding: int=0,
dilation: int=1,
groups: int=1,
condensation_factor: int=4,
dropout_rate: float=0.):
super(LearnedGroupConv, self).__init__()
if dropout_rate > 0.: self.dropout = nn.Dropout(dropout_rate, inplace=False)
if condensation_factor is None: condensation_factor = groups
self.in_channels = in_channels
self.out_channels = out_channels
self.condensation_factor = condensation_factor
self.dropout_rate = dropout_rate
self.groups = groups
self.bn = nn.BatchNorm2d(in_channels)
self.relu = nn.ReLU(inplace=True)
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
stride=stride, padding=padding, dilation=dilation, groups=1, bias=False)
### Parameters that should be carefully used
self.register_buffer('_count', torch.zeros(1))
self.register_buffer('_stage', torch.zeros(1))
self.register_buffer('_mask', torch.ones(self.conv.weight.size()))
### Check if arguments are valid
assert in_channels % groups == 0, "group number can not be divided by input channels"
assert in_channels % condensation_factor == 0, "condensation factor can not be divided by input channels"
assert out_channels % groups == 0, "group number can not be divided by output channels"
def forward(self, x):
self._check_drop()
x = self.bn(x)
x = self.relu(x)
if self.dropout_rate > 0.: x = self.dropout(x)
### Masked Output
weight = self.conv.weight * self.mask
return F.conv2d(x, weight, None, self.conv.stride, self.conv.padding, self.conv.dilation, groups=1)
def _check_drop(self):
progress = LearnedGroupConv.global_progress
delta = 0
### Get Current Stage
for i in range(self.condensation_factor - 1):
if progress * 2 < (i + 1) / (self.condensation_factor - 1):
stage = i
break
else:
stage = self.condensation_factor - 1
### Check for dropping
if not self._reach_stage(stage):
self.stage = stage
delta = self.in_channels // self.condense_factor
if delta > 0: self._dropping(delta)
return
def _reach_stage(self, stage):
return (self._stage >= stage).all()
@property
def mask(self):
return Variable(self._mask)
def lasso_loss(self):
if self._reach_stage(self.groups - 1): return 0
weight = self.conv.weight * self.mask
### Assume only apply to 1x1 conv to speed up
assert weight.size()[-1] == 1
weight = weight.squeeze().pow(2)
d_out = self.out_channels // self.groups
### Shuffle weight
weight = weight.view(d_out, self.groups, self.in_channels)
weight = weight.sum(0).clamp(min=1e-6).sqrt()
return weight.sum()
class _DenseLayer(nn.Module):
expansion = 4
def __init__(self,
in_channels: int,
growth_rate: int,
group1x1: int,
group3x3: int,
condensation_factor: int=4,
dropout_rate: float=0.):
super(_DenseLayer, self).__init__()
### 1x1 conv i --> b*k
self.conv1x1 = LearnedGroupConv(in_channels, _DenseLayer.expansion * growth_rate,
kernel_size=1, stride=1, padding=0,
groups=group1x1,
condensation_factor=condensation_factor,
dropout_rate=dropout_rate)
### 3x3 conv b*k --> k
self.conv3x3 = BasicConv(_DenseLayer.expansion * growth_rate, growth_rate,
kernel_size=3, stride=1, padding=1,
groups=group3x3,
bias=False,
activation=nn.ReLU(inplace=True),
preact=True)
def forward(self, x):
x_ = x
x = self.conv1x1(x)
x = self.conv3x3(x)
return torch.cat([x, x_], dim=1)
class _DenseBlock(nn.Sequential):
def __init__(self,
num_layers: int,
in_channels: int,
growth_rate: int,
group1x1: int,
group3x3: int,
condensation_factor: int=4,
dropout_rate: float=0.):
super(_DenseBlock, self).__init__()
for i in range(num_layers):
layer = _DenseLayer(in_channels + i * growth_rate, growth_rate,
group1x1=group1x1, group3x3=group3x3,
condensation_factor=condensation_factor,
dropout_rate=dropout_rate)
self.add_module('dense_layer_{}'.format(i + 1), layer)
class _Transition(nn.Module):
def __init__(self):
super(_Transition, self).__init__()
self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
def forward(self, x):
return self.pool(x)
class CondenseNet(nn.Module):
def __init__(self,
num_channels: int = 3,
num_classes: int = 1000,
stages: List[int] = [4, 6, 8, 10, 8],
growth_rate: List[int] = [8, 16, 32, 64, 128]):
super(CondenseNet, self).__init__()
self.stages = stages
self.growth_rate = growth_rate
assert len(self.stages) == len(self.growth_rate)
self.init_stride = 2
self.pool_size = 7
self.features = nn.Sequential()
### Initial nChannels should be 3
self.num_features = 2 * self.growth_rate[0]
### Dense Block 1 (224 x 224)
self.features.add_module('init_conv', nn.Conv2d(num_channels, self.num_features, kernel_size=3, stride=self.init_stride, padding=1, bias=False))
for i in range(len(self.stages)):
### Dense Block i
self.add_block(i)
self.classifier = nn.Linear(self.num_features, num_classes)
def add_block(self, stage):
### Check if i th stage is the last one
last = (stage == len(self.stages) - 1)
block = _DenseBlock(num_layers=self.stages[stage],
in_channels=self.num_features,
growth_rate=self.growth_rate[stage],
group1x1=4, group3x3=4,
condensation_factor=4,
dropout_rate=0.)
self.features.add_module('dense_block_{}'.format(stage + 1), block)
self.num_features += self.stages[stage] * self.growth_rate[stage]
if not last:
trans = _Transition()
self.features.add_module('transition_block_{}'.format(stage + 1), trans)
else:
self.features.add_module('norm_last', nn.BatchNorm2d(self.num_features))
self.features.add_module('relu_last', nn.ReLU(inplace=True))
self.features.add_module('pool_last', nn.AvgPool2d(self.pool_size))
def forward(self, x, progress=None):
if progress: LearnedGroupConv.global_progress = progress
features = self.features(x)
out = features.view(features.size(0), -1)
out = self.classifier(out)
return out
if __name__=='__main__':
# model = CondenseNet()
# print(model)
# inp = torch.randn(2, 3, 224, 224)
# oup = model(inp)
import os
import time
import math
import tqdm
import argparse
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from IC2D_models import model_to_device
from IC2D_Experiment import dataset_argument
from utils.get_functions import get_deivce
parser = argparse.ArgumentParser(description='Following are the arguments that can be passed form the terminal itself!')
parser.add_argument('--data_path', type=str, default='/media/jhnam0514/68334fe0-2b83-45d6-98e3-76904bf08127/home/namjuhyeon/Desktop/LAB/AwesomeDeepLearning/dataset/IC2D_dataset')
parser.add_argument('--data_type', type=str, default='ImageNet')
parser.add_argument('--num_workers', type=int, default=4, help='number of workers')
parser.add_argument('--seed', type=int, default=4321, help='random seed')
parser.add_argument('--train', default=False, action='store_true')
# Multi-Processing parameters
parser.add_argument('--multiprocessing-distributed', action='store_true',
help='Use multi-processing distributed training to launch '
'N processes per node, which has N GPUs. This is the '
'fastest way to use PyTorch for either single node or '
'multi node data parallel training')
parser.add_argument('--world-size', default=-1, type=int, help='number of nodes for distributed training')
parser.add_argument('--rank', default=-1, type=int, help='node rank for distributed training')
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, help='url used to set up distributed training')
parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend')
# Train parameter
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--criterion', type=str, default='CCE', choices=['CCE', 'BCE'])
parser.add_argument('--start_epoch', type=int, default=1)
parser.add_argument('--final_epoch', type=int, default=200)
parser.add_argument('--linear_node', type=int, default=4096)
# CondenseNet parameter
parser.add_argument('--group-lasso-lambda', default=0.1, type=float, metavar='LASSO',
help='group lasso loss weight (default: 0)')
# Optimizer Configuration
parser.add_argument('--optimizer_name', type=str, default='SGD')
parser.add_argument('--lr', type=float, default=1e-1)
parser.add_argument('--momentum', type=float, default=0.9)
parser.add_argument('--weight_decay', type=float, default=0.0001)
# Learning Rate Scheduler (LRS) Configuration
parser.add_argument('--LRS_name', type=str, default=None)
# Print parameter
parser.add_argument('--step', type=int, default=10)
args = parser.parse_args()
args = dataset_argument(args)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def transform_generator(mode):
if mode == 'train' :
train_transform_list = [
transforms.RandomResizedCrop(args.image_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
]
transform_train_target = None
return transforms.Compose(train_transform_list), transform_train_target
elif mode == 'test' :
test_transform_list = [
transforms.Resize(256),
transforms.CenterCrop(args.image_size),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
]
transform_test_target = None
return transforms.Compose(test_transform_list), transform_test_target
def adjust_learning_rate(optimizer, epoch, args, batch=None,
nBatch=None, method='cosine'):
if method == 'cosine':
T_total = args.final_epoch * nBatch
T_cur = (epoch % args.final_epoch) * nBatch + batch
lr = 0.5 * args.lr * (1 + math.cos(math.pi * T_cur / T_total))
elif method == 'multistep':
if args.data in ['cifar10', 'cifar100']:
lr, decay_rate = args.lr, 0.1
if epoch >= args.final_epoch * 0.75:
lr *= decay_rate ** 2
elif epoch >= args.final_epoch * 0.5:
lr *= decay_rate
else:
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
lr = args.lr * (0.1 ** (epoch // 30))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
return lr
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def train_epoch(train_loader, model, criterion, optimizer, epoch):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
learned_module_list = []
### Switch to train mode
model.train()
### Find all learned convs to prepare for group lasso loss
for m in model.modules():
if m.__str__().startswith('LearnedGroupConv'):
learned_module_list.append(m)
running_lr = None
end = time.time()
for i, (image, target) in tqdm.tqdm(enumerate(train_loader)):
progress = float(epoch * len(train_loader) + i) / (args.final_epoch * len(train_loader))
args.progress = progress
### Adjust learning rate
lr = adjust_learning_rate(optimizer, epoch, args, batch=i, nBatch=len(train_loader), method='cosine')
if running_lr is None: running_lr = lr
### Measure data loading time
data_time.update(time.time() - end)
image, target = image.to(args.device), target.to(args.device)
### Compute output
output = model(image, progress)
loss = criterion(output, target)
### Add group lasso loss
if args.group_lasso_lambda > 0:
lasso_loss = 0
for m in learned_module_list:
lasso_loss = lasso_loss + m.lasso_loss()
loss = loss + args.group_lasso_lambda * lasso_loss
### Measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
losses.update(loss.item(), image.size(0))
top1.update(prec1.item(), image.size(0))
top5.update(prec5.item(), image.size(0))
### Compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
### Measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.step == 0:
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f}\t' # ({batch_time.avg:.3f}) '
'Data {data_time.val:.3f}\t' # ({data_time.avg:.3f}) '
'Loss {loss.val:.4f}\t' # ({loss.avg:.4f}) '
'Prec@1 {top1.val:.3f}\t' # ({top1.avg:.3f}) '
'Prec@5 {top5.val:.3f}\t' # ({top5.avg:.3f})'
'lr {lr: .4f}'.format(
epoch, i, len(train_loader), batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1, top5=top5, lr=lr))
return 100. - top1.avg, 100. - top5.avg, losses.avg, running_lr
args.distributed = False
args.device = get_deivce()
### Load Dataset (ImageNet)
train_dir = os.path.join(args.dataset_dir, '2012', 'train')
test_dir = os.path.join(args.dataset_dir, '2012', 'val')
train_image_transform, train_target_transform = transform_generator('train')
test_image_transform, test_target_transform = transform_generator('test')
train_dataset = datasets.ImageFolder(train_dir, transform=train_image_transform, target_transform=train_target_transform)
test_dataset = datasets.ImageFolder(test_dir, transform=test_image_transform, target_transform=test_target_transform)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)
test_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True)
### Load CondenseNet (ImageNet)
model = CondenseNet()
model = model_to_device(args, model)
### Define Loss function and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
nesterov=True)
for epoch in range(args.start_epoch, args.final_epoch + 1):
### Train for one epoch
tr_prec1, tr_prec5, loss, lr = train_epoch(train_loader, model, criterion, optimizer, epoch)