nearest_neighbor.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. '''
  2. A nearest neighbor learning algorithm example using TensorFlow library.
  3. This example is using the MNIST database of handwritten digits (http://yann.lecun.com/exdb/mnist/)
  4. Author: Aymeric Damien
  5. Project: https://github.com/aymericdamien/TensorFlow-Examples/
  6. '''
  7. import numpy as np
  8. import tensorflow as tf
  9. # Import MINST data
  10. import input_data
  11. mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
  12. # In this example, we limit mnist data
  13. Xtr, Ytr = mnist.train.next_batch(5000) #5000 for training (nn candidates)
  14. Xte, Yte = mnist.test.next_batch(200) #200 for testing
  15. # Reshape images to 1D
  16. Xtr = np.reshape(Xtr, newshape=(-1, 28*28))
  17. Xte = np.reshape(Xte, newshape=(-1, 28*28))
  18. # tf Graph Input
  19. xtr = tf.placeholder("float", [None, 784])
  20. xte = tf.placeholder("float", [784])
  21. # Nearest Neighbor calculation using L1 Distance
  22. # Calculate L1 Distance
  23. distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.neg(xte))), reduction_indices=1)
  24. # Predict: Get min distance index (Nearest neighbor)
  25. pred = tf.arg_min(distance, 0)
  26. accuracy = 0.
  27. # Initializing the variables
  28. init = tf.initialize_all_variables()
  29. # Launch the graph
  30. with tf.Session() as sess:
  31. sess.run(init)
  32. # loop over test data
  33. for i in range(len(Xte)):
  34. # Get nearest neighbor
  35. nn_index = sess.run(pred, feed_dict={xtr: Xtr, xte: Xte[i,:]})
  36. # Get nearest neighbor class label and compare it to its true label
  37. print "Test", i, "Prediction:", np.argmax(Ytr[nn_index]), "True Class:", np.argmax(Yte[i])
  38. # Calculate accuracy
  39. if np.argmax(Ytr[nn_index]) == np.argmax(Yte[i]):
  40. accuracy += 1./len(Xte)
  41. print "Done!"
  42. print "Accuracy:", accuracy