CIFAR-10은 image classification 작업을 위한 출발점으로서 PyTorch 예제로 제공되고 있다. 실행 방법은 구글 Colab에서 직접 실행하든지 아니면 Jupyter Notebook 용 코드를 다운 받아 실행하는 방법이 있을 수 있다. 현 크롬 인터넷을 열어 로그인 한 상태라면 Run in Google Colab 링크 열기를 클릭해서 실행해 보자.
첫 줄의 inline command인 [1] %matplotlib inline은 magic function으로서 cell 별로 plot 결과를 출력 시킨다. 이는 아나콘다 편집기인 spyder에서 코드 전체를 한꺼번에 실행시키면 끝 부분에 출력들이 몰리게 되는데 반하여 cell 별로 실행하게 될 경우 이점을 제공한다.
PyTorch에서 image classification 작업을 하려면 torch 외에도 이미지 dataset을 제공하는 torchvision 및 이에 딸린 이미지 편집용 transforms 가 필요하다.
torchvision 의 dataset 이 제공하는 이미지 픽셀들의 값은 0.0~1.0 범위에서 평균값이 0.5 이며 표준 편차가 0.5 범위로 각 RGB 채널별로 설정되어 있다. 이미지 이 데이터는 비행기, 차, 새, 고양이, 사슴, 개, 개구리, 말, 배, 트럭을 포함하는 10종의 클라스로 구성된다.
학습용 trainset 과 테스트용 testset을 읽어 들이면 Colabo 의 content 하에 data 폴더를 형성한다.
cell 실행 결과를 보면 https://www.cs.toronto.edu/~kriz/∙∙∙폴더를 볼 수 있는데 터론토 대학 컴퓨터학과의 예전 박사과정이었으면서 Alexnet 으로 유명한 Alex Krizhevsky가 만든 사이트이다.
학습 전에 불러 온 이미지 중에 4가지를 출력해 보자. iter() 명령과 next() 명령은 텐서플로우에서 batch 데이터를 부르는 과정과 유사하다.
뉴럴 네트워크 라이브러리 nn과 nn.functional을 불러 오자. nn 은 뉴럴 네트워크 모델을 지원하며 nn.functional은 activation을 담당하는 Relu 함수처리를 지원한다.
클라스 Net 의 initialization 과정에서 외부의 Parent 클라스 대신 Net 자체를 super(Net, self)를 사용하여 처리하면서 초기화 하도록 한다. CNN 에서는 특히 랜덤 웨이트 매트릭스를 비롯하여 초기화해야 하는 오브젝트들이 있다.
아울러 instance 계산을 위해 입력 데이터 x 가 주어지면 method인 forward 가 계산 후 최종 값을 돌려준다.
torch.optim 라이브러리를 불러오고 Cross Entropy cost 함수를 설정한다. 아울러 Stochastic Gradient Descent 옵티마이저를 설정하는데 back-propagation알고리듬과 관련된 net.parameters( )와 Adam optimizer를 사용할 때처럼 learning rate 값과 모멘텀 값을 Default 로 설정한다.
epoch값을 2로 하여 학습을 실행한다. 소요시간은 GPU를 사용하여 4분이 걸린다. loss.backward() 명령에 따른 Back-propagation 루프 과정에서 loss 함수 계산을 위해 사용하는 웨이트 좌표 값들 위치에서 loss 함수의 웨이트 변수에 대한 편미분 값 즉 gradient값을 계산하여 learning rate를 곱하여 뺀 새 웨이트 좌표 값에 대해 loss 함수를 계산해 나가면서 최소값을 찾아내게 된다. 따라서 매번 계산했던 gradient 값을 optimizer.zero_grad() 명령을 사용하여 0.0으로 두고 새 웨이트 좌표 값에 대해 loss 함수 값을 계산해야 한다. 마지막의 optimizer.step()는 한 번의 back-propagation 알고리듬에 의해서 loss 함수를 계산한 후 다음 스텝으로 넘어가기 위한 준비를 뜻한다.
batch용 테스트 데이터 중에서 무작위호 batch_size=4에 해당하는 한 묶음 데이터 4개를 램덤하게 불러내어 그래픽 출력 확인해 보자.
앞 단계에서 학습한 결과를 사용하여 임의로 추출한 4개의 이미지 데이터에 대한 instance를 계산하자.
테스트용 샘플 이미지 데이터 1만개에 대해서 accuracy를 테스트해 보면 55% 가 얻어진다. 현재의 코드는 보편적인 CNN 코드를 사용했기 때문에 이정도 이며 CNN을 기반으로 보다 복잡한 GoogleNet 이라든지 ResNet을 사용한다면 인식률이 훨씬 더 높아 질 수 있을 것이다.
'PyTorch' 카테고리의 다른 글
PyTorch 초보자를 위한 Transfer Learning Tutorial 예제 구글 Colabo GPU 처리 (0) | 2020.06.03 |
---|---|
1-16 PyTorch 코딩 선형회귀법 예제 (0) | 2020.03.08 |
1-15 Anaconda PyTorch (base) 설치 (0) | 2020.02.02 |
PyTorch MNIST CPU 코딩 (0) | 2019.12.07 |