Programming/Pytorch&Tensorflow
[Pytorch] 생초보의 파이토치 일기 - MNIST 손글씨 데이터 분류 99% 달성하기 2
지난 포스팅에서 CNN 모델을 이용해서 99%의 분류 성능을 달성하였다. 원하는 목표를 달성하였으니 좀 더 코드를 정리하고 내 나름대로 이쁘게 정리해보도록 하자. 일단 사진과 같이 코드 파일을 신경망을 구성하는 model.py와 학습을 하는 main.py을 나누자. 1. model.py 이 코드에서는 신경망을 정의하는 클래스들을 모아놓았다. 신경망은 지난 포스팅에서 정의한 모델 DNN, CNN 그대로 사용하였다. 먼저, DNN 모델 클래스이다. # DNN 신경망 구성 class DNN(nn.Module) : def __init__(self): super(DNN, self).__init__() # Input shape = (?, 28, 28, 1) -> (?, 28 * 28) = (?, 784) # Den..