Browse Source

fix gpu support

aymericdamien 9 years ago
parent
commit
e43ca9293a
1 changed files with 27 additions and 30 deletions
  1. 27 30
      convolutional_network.py

+ 27 - 30
convolutional_network.py

@@ -6,18 +6,19 @@ import tensorflow as tf
 
 # Parameters
 learning_rate = 0.001
-training_epochs = 3
-batch_size = 64
-display_batch = 200 #set to 0 to turn off
-display_step = 1
+training_iters = 100000
+batch_size = 128
+display_step = 10
 
 #Network Parameters
 n_input = 784 #MNIST data input
 n_classes = 10 #MNIST total classes
+dropout = 0.75
 
 # Create model
-x = tf.placeholder("float", [None, n_input])
-y = tf.placeholder("float", [None, n_classes])
+x = tf.placeholder(tf.types.float32, [None, n_input])
+y = tf.placeholder(tf.types.float32, [None, n_classes])
+keep_prob = tf.placeholder(tf.types.float32) #dropout
 
 def conv2d(img, w, b):
     return tf.nn.relu(tf.nn.conv2d(img, w, strides=[1, 1, 1, 1], padding='SAME') + b)
@@ -25,20 +26,20 @@ def conv2d(img, w, b):
 def max_pool(img, k):
     return tf.nn.max_pool(img, ksize=[1, k, k, 1], strides=[1, k, k, 1], padding='SAME')
 
-def conv_net(_X, _weights, _biases):
+def conv_net(_X, _weights, _biases, _dropout):
     _X = tf.reshape(_X, shape=[-1, 28, 28, 1])
 
     conv1 = conv2d(_X, _weights['wc1'], _biases['bc1'])
     conv1 = max_pool(conv1, k=2)
-    conv1 = tf.nn.dropout(conv1, 0.75)
+    conv1 = tf.nn.dropout(conv1, _dropout)
 
     conv2 = conv2d(conv1, _weights['wc2'], _biases['bc2'])
     conv2 = max_pool(conv2, k=2)
-    conv2 = tf.nn.dropout(conv2, 0.75)
+    conv2 = tf.nn.dropout(conv2, _dropout)
 
     dense1 = tf.reshape(conv2, [-1, _weights['wd1'].get_shape().as_list()[0]])
     dense1 = tf.nn.relu(tf.matmul(dense1, _weights['wd1']) + _biases['bd1'])
-    dense1 = tf.nn.dropout(dense1, 0.75)
+    dense1 = tf.nn.dropout(dense1, _dropout)
 
     out = tf.matmul(dense1, _weights['out']) + _biases['out']
     return out
@@ -57,32 +58,28 @@ biases = {
     'out': tf.Variable(tf.random_normal([n_classes]))
 }
 
-pred = conv_net(x, weights, biases)
+pred = conv_net(x, weights, biases, keep_prob)
 cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y))
 optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
 
+#Evaluate model
+correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
+accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.types.float32))
+
 # Train
 #load mnist data
 init = tf.initialize_all_variables()
 with tf.Session() as sess:
     sess.run(init)
-    #one epoch can take a long time on CPU
-    for epoch in range(training_epochs):
-        avg_cost = 0.
-        total_batch = int(mnist.train.num_examples/batch_size)
-        for i in range(total_batch):
-            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
-            sess.run(optimizer, feed_dict={x: batch_xs, y: batch_ys})
-            avg_cost += sess.run(cost, feed_dict={x: batch_xs, y: batch_ys})/total_batch
-            if i % display_batch == 0 and display_batch > 0:
-                print "Epoch:", '%04d' % (epoch+1), "Batch " + str(i) + "/" + str(total_batch), "cost=", \
-                    "{:.9f}".format(sess.run(cost, feed_dict={x: batch_xs, y: batch_ys}))
-        if epoch % display_step == 0:
-            print "Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(avg_cost)
-
+    step = 1
+    avg_cost = 0.
+    while step * batch_size < training_iters:
+        batch_xs, batch_ys = mnist.train.next_batch(batch_size)
+        sess.run(optimizer, feed_dict={x: batch_xs, y: batch_ys, keep_prob: dropout})
+        if step % display_step == 0:
+            avg_cost += sess.run(cost, feed_dict={x: batch_xs, y: batch_ys, keep_prob: 1.})/batch_size
+            print "Iter", str(step*batch_size), "cost=", "{:.9f}".format(avg_cost/step)
+        step += 1
     print "Optimization Finished!"
-
-    # Test trained model
-    correct_prediction = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
-    accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
-    print "Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels})
+    #Accuracy on 256 mnist test images
+    print "Accuracy:", sess.run(accuracy, feed_dict={x: mnist.test.images[:256], y: mnist.test.labels[:256], keep_prob: 1.})