convolutional_network.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. '''
  2. A Convolutional Network implementation example using TensorFlow library.
  3. This example is using the MNIST database of handwritten digits (http://yann.lecun.com/exdb/mnist/)
  4. Author: Aymeric Damien
  5. Project: https://github.com/aymericdamien/TensorFlow-Examples/
  6. '''
  7. # Import MINST data
  8. import input_data
  9. mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
  10. import tensorflow as tf
  11. # Parameters
  12. learning_rate = 0.001
  13. training_iters = 100000
  14. batch_size = 128
  15. display_step = 10
  16. # Network Parameters
  17. n_input = 784 # MNIST data input (img shape: 28*28)
  18. n_classes = 10 # MNIST total classes (0-9 digits)
  19. dropout = 0.75 # Dropout, probability to keep units
  20. # tf Graph input
  21. x = tf.placeholder(tf.float32, [None, n_input])
  22. y = tf.placeholder(tf.float32, [None, n_classes])
  23. keep_prob = tf.placeholder(tf.float32) #dropout (keep probability)
  24. # Create model
  25. def conv2d(img, w, b):
  26. return tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(img, w, strides=[1, 1, 1, 1], padding='SAME'),b))
  27. def max_pool(img, k):
  28. return tf.nn.max_pool(img, ksize=[1, k, k, 1], strides=[1, k, k, 1], padding='SAME')
  29. def conv_net(_X, _weights, _biases, _dropout):
  30. # Reshape input picture
  31. _X = tf.reshape(_X, shape=[-1, 28, 28, 1])
  32. # Convolution Layer
  33. conv1 = conv2d(_X, _weights['wc1'], _biases['bc1'])
  34. # Max Pooling (down-sampling)
  35. conv1 = max_pool(conv1, k=2)
  36. # Apply Dropout
  37. conv1 = tf.nn.dropout(conv1, _dropout)
  38. # Convolution Layer
  39. conv2 = conv2d(conv1, _weights['wc2'], _biases['bc2'])
  40. # Max Pooling (down-sampling)
  41. conv2 = max_pool(conv2, k=2)
  42. # Apply Dropout
  43. conv2 = tf.nn.dropout(conv2, _dropout)
  44. # Fully connected layer
  45. dense1 = tf.reshape(conv2, [-1, _weights['wd1'].get_shape().as_list()[0]]) # Reshape conv2 output to fit dense layer input
  46. dense1 = tf.nn.relu(tf.add(tf.matmul(dense1, _weights['wd1']), _biases['bd1'])) # Relu activation
  47. dense1 = tf.nn.dropout(dense1, _dropout) # Apply Dropout
  48. # Output, class prediction
  49. out = tf.add(tf.matmul(dense1, _weights['out']), _biases['out'])
  50. return out
  51. # Store layers weight & bias
  52. weights = {
  53. 'wc1': tf.Variable(tf.random_normal([5, 5, 1, 32])), # 5x5 conv, 1 input, 32 outputs
  54. 'wc2': tf.Variable(tf.random_normal([5, 5, 32, 64])), # 5x5 conv, 32 inputs, 64 outputs
  55. 'wd1': tf.Variable(tf.random_normal([7*7*64, 1024])), # fully connected, 7*7*64 inputs, 1024 outputs
  56. 'out': tf.Variable(tf.random_normal([1024, n_classes])) # 1024 inputs, 10 outputs (class prediction)
  57. }
  58. biases = {
  59. 'bc1': tf.Variable(tf.random_normal([32])),
  60. 'bc2': tf.Variable(tf.random_normal([64])),
  61. 'bd1': tf.Variable(tf.random_normal([1024])),
  62. 'out': tf.Variable(tf.random_normal([n_classes]))
  63. }
  64. # Construct model
  65. pred = conv_net(x, weights, biases, keep_prob)
  66. # Define loss and optimizer
  67. cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y))
  68. optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
  69. # Evaluate model
  70. correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
  71. accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
  72. # Initializing the variables
  73. init = tf.initialize_all_variables()
  74. # Launch the graph
  75. with tf.Session() as sess:
  76. sess.run(init)
  77. step = 1
  78. # Keep training until reach max iterations
  79. while step * batch_size < training_iters:
  80. batch_xs, batch_ys = mnist.train.next_batch(batch_size)
  81. # Fit training using batch data
  82. sess.run(optimizer, feed_dict={x: batch_xs, y: batch_ys, keep_prob: dropout})
  83. if step % display_step == 0:
  84. # Calculate batch accuracy
  85. acc = sess.run(accuracy, feed_dict={x: batch_xs, y: batch_ys, keep_prob: 1.})
  86. # Calculate batch loss
  87. loss = sess.run(cost, feed_dict={x: batch_xs, y: batch_ys, keep_prob: 1.})
  88. print "Iter " + str(step*batch_size) + ", Minibatch Loss= " + "{:.6f}".format(loss) + ", Training Accuracy= " + "{:.5f}".format(acc)
  89. step += 1
  90. print "Optimization Finished!"
  91. # Calculate accuracy for 256 mnist test images
  92. print "Testing Accuracy:", sess.run(accuracy, feed_dict={x: mnist.test.images[:256], y: mnist.test.labels[:256], keep_prob: 1.})