머신러닝

1-5 TensorFlow 1.15.0 과 2.0 버전 사이에서

coding art 2020. 1. 25. 15:33
728x90


TensorfFlow 2.0 으로의 업그레이드에 앞서 텐서플로우 1.15.0 버전에서 실행되는 MNIST 코드를 어떻게 사용할 것인지 살펴보자.

 

다음의 헤더영역의 명령 구조를 관찰해 보자. 첫줄은 tensorflow 2.0 으로 선언하여 tensorflow 라이브러리를 불러들이는 명령으로서 앞으로 2.0을 사용하게 되면 자주 사용하게될 명령이다. 한편 두 번째 줄은 tensorflow 2.0 지원을 억제하는 명령으로서 tensorflow 2.0 이하의 버전에서 실행되는 MNIST 파이선 코드 사용을 지원한다. 현재 설치된 tensorflow 버전이 1.15.0 임에도 불구하고 2.0 버전에서 사용하는 첫 번째 줄 명령을 쓸 수 있다는 것은 과도기적인 사용법을 지원한다는 의미인 듯하다.


from tesorflow.exaples.tutorials.mnist ∙∙∙ 명령들은 LeCUN MNSIT 데이터베이스를 불러들이는 명령이지만 tensorflow 2.0 에서는 더 이상 사용이 불가능한 부분이다. 2.0 에서는 keras를 기준으로 MNIST 뉴럴 네트워크 코딩을 지원하기 위한 별도의 명령이 준비되어 있다. 뉴럴 네트워크이든지 아니면 CNN (Convolutionary Neural Network) 코딩을 위한 보다 높은 레벨의 APT 즉 정교한 명령 체계가 제공되고 있다.

하지만 필자처럼 cost 함수를 구성히는 과정에서 랜덤 웨이트과 바이아스를 사용하여 hypothesis 모델링을 연구하는 경우에는 여전히 낮은 레벨의 API 에 해당하는 tensorflow1.15.0 를 사용을 선호할 수밖에 없다.

  

이러한 과도기적인 방식으로 MNIST 코드를 1.15.0 버전에서 실행하면 다음과 같이 경고가 포함된 메시지를 받아 볼 수 있다. 하지만 출력은 성공적으로 된다.



#mnist_tf_1.15_01

# -*- coding: utf-8 -*-

# MNIST data download
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

# tensors
x = tf.placeholder(tf.float32, [None, 784])
y = tf.nn.softmax(tf.matmul(x, W) + b)
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))

# cross-entropy
y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

# training
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for i in range(1000):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

# test accuracy
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))