|
@@ -0,0 +1,389 @@
|
|
|
+{
|
|
|
+ "cells": [
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 1,
|
|
|
+ "metadata": {
|
|
|
+ "collapsed": true
|
|
|
+ },
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "# A nearest neighbor learning algorithm example using TensorFlow library.\n",
|
|
|
+ "# This example is using the MNIST database of handwritten digits \n",
|
|
|
+ "# (http://yann.lecun.com/exdb/mnist/)\n",
|
|
|
+ "\n",
|
|
|
+ "# Author: Aymeric Damien\n",
|
|
|
+ "# Project: https://github.com/aymericdamien/TensorFlow-Examples/"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 2,
|
|
|
+ "metadata": {
|
|
|
+ "collapsed": true
|
|
|
+ },
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "import numpy as np\n",
|
|
|
+ "import tensorflow as tf"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 3,
|
|
|
+ "metadata": {
|
|
|
+ "collapsed": false
|
|
|
+ },
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "Extracting /tmp/data/train-images-idx3-ubyte.gz\n",
|
|
|
+ "Extracting /tmp/data/train-labels-idx1-ubyte.gz\n",
|
|
|
+ "Extracting /tmp/data/t10k-images-idx3-ubyte.gz\n",
|
|
|
+ "Extracting /tmp/data/t10k-labels-idx1-ubyte.gz\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "# Import MINST data\n",
|
|
|
+ "import input_data\n",
|
|
|
+ "mnist = input_data.read_data_sets(\"/tmp/data/\", one_hot=True)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 4,
|
|
|
+ "metadata": {
|
|
|
+ "collapsed": true
|
|
|
+ },
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "# In this example, we limit mnist data\n",
|
|
|
+ "Xtr, Ytr = mnist.train.next_batch(5000) #5000 for training (nn candidates)\n",
|
|
|
+ "Xte, Yte = mnist.test.next_batch(200) #200 for testing"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 5,
|
|
|
+ "metadata": {
|
|
|
+ "collapsed": true
|
|
|
+ },
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "# Reshape images to 1D\n",
|
|
|
+ "Xtr = np.reshape(Xtr, newshape=(-1, 28*28))\n",
|
|
|
+ "Xte = np.reshape(Xte, newshape=(-1, 28*28))"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 6,
|
|
|
+ "metadata": {
|
|
|
+ "collapsed": true
|
|
|
+ },
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "# tf Graph Input\n",
|
|
|
+ "xtr = tf.placeholder(\"float\", [None, 784])\n",
|
|
|
+ "xte = tf.placeholder(\"float\", [784])"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 8,
|
|
|
+ "metadata": {
|
|
|
+ "collapsed": true
|
|
|
+ },
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "# Nearest Neighbor calculation using L1 Distance\n",
|
|
|
+ "# Calculate L1 Distance\n",
|
|
|
+ "distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.neg(xte))), reduction_indices=1)\n",
|
|
|
+ "# Predict: Get min distance index (Nearest neighbor)\n",
|
|
|
+ "pred = tf.arg_min(distance, 0)\n",
|
|
|
+ "\n",
|
|
|
+ "accuracy = 0."
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 9,
|
|
|
+ "metadata": {
|
|
|
+ "collapsed": true
|
|
|
+ },
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "# Initializing the variables\n",
|
|
|
+ "init = tf.initialize_all_variables()"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 10,
|
|
|
+ "metadata": {
|
|
|
+ "collapsed": false
|
|
|
+ },
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "Test 0 Prediction: 7 True Class: 7\n",
|
|
|
+ "Test 1 Prediction: 2 True Class: 2\n",
|
|
|
+ "Test 2 Prediction: 1 True Class: 1\n",
|
|
|
+ "Test 3 Prediction: 0 True Class: 0\n",
|
|
|
+ "Test 4 Prediction: 4 True Class: 4\n",
|
|
|
+ "Test 5 Prediction: 1 True Class: 1\n",
|
|
|
+ "Test 6 Prediction: 4 True Class: 4\n",
|
|
|
+ "Test 7 Prediction: 9 True Class: 9\n",
|
|
|
+ "Test 8 Prediction: 8 True Class: 5\n",
|
|
|
+ "Test 9 Prediction: 9 True Class: 9\n",
|
|
|
+ "Test 10 Prediction: 0 True Class: 0\n",
|
|
|
+ "Test 11 Prediction: 0 True Class: 6\n",
|
|
|
+ "Test 12 Prediction: 9 True Class: 9\n",
|
|
|
+ "Test 13 Prediction: 0 True Class: 0\n",
|
|
|
+ "Test 14 Prediction: 1 True Class: 1\n",
|
|
|
+ "Test 15 Prediction: 5 True Class: 5\n",
|
|
|
+ "Test 16 Prediction: 4 True Class: 9\n",
|
|
|
+ "Test 17 Prediction: 7 True Class: 7\n",
|
|
|
+ "Test 18 Prediction: 3 True Class: 3\n",
|
|
|
+ "Test 19 Prediction: 4 True Class: 4\n",
|
|
|
+ "Test 20 Prediction: 9 True Class: 9\n",
|
|
|
+ "Test 21 Prediction: 6 True Class: 6\n",
|
|
|
+ "Test 22 Prediction: 6 True Class: 6\n",
|
|
|
+ "Test 23 Prediction: 5 True Class: 5\n",
|
|
|
+ "Test 24 Prediction: 4 True Class: 4\n",
|
|
|
+ "Test 25 Prediction: 0 True Class: 0\n",
|
|
|
+ "Test 26 Prediction: 7 True Class: 7\n",
|
|
|
+ "Test 27 Prediction: 4 True Class: 4\n",
|
|
|
+ "Test 28 Prediction: 0 True Class: 0\n",
|
|
|
+ "Test 29 Prediction: 1 True Class: 1\n",
|
|
|
+ "Test 30 Prediction: 3 True Class: 3\n",
|
|
|
+ "Test 31 Prediction: 1 True Class: 1\n",
|
|
|
+ "Test 32 Prediction: 3 True Class: 3\n",
|
|
|
+ "Test 33 Prediction: 4 True Class: 4\n",
|
|
|
+ "Test 34 Prediction: 7 True Class: 7\n",
|
|
|
+ "Test 35 Prediction: 2 True Class: 2\n",
|
|
|
+ "Test 36 Prediction: 7 True Class: 7\n",
|
|
|
+ "Test 37 Prediction: 1 True Class: 1\n",
|
|
|
+ "Test 38 Prediction: 2 True Class: 2\n",
|
|
|
+ "Test 39 Prediction: 1 True Class: 1\n",
|
|
|
+ "Test 40 Prediction: 1 True Class: 1\n",
|
|
|
+ "Test 41 Prediction: 7 True Class: 7\n",
|
|
|
+ "Test 42 Prediction: 4 True Class: 4\n",
|
|
|
+ "Test 43 Prediction: 1 True Class: 2\n",
|
|
|
+ "Test 44 Prediction: 3 True Class: 3\n",
|
|
|
+ "Test 45 Prediction: 5 True Class: 5\n",
|
|
|
+ "Test 46 Prediction: 1 True Class: 1\n",
|
|
|
+ "Test 47 Prediction: 2 True Class: 2\n",
|
|
|
+ "Test 48 Prediction: 4 True Class: 4\n",
|
|
|
+ "Test 49 Prediction: 4 True Class: 4\n",
|
|
|
+ "Test 50 Prediction: 6 True Class: 6\n",
|
|
|
+ "Test 51 Prediction: 3 True Class: 3\n",
|
|
|
+ "Test 52 Prediction: 5 True Class: 5\n",
|
|
|
+ "Test 53 Prediction: 5 True Class: 5\n",
|
|
|
+ "Test 54 Prediction: 6 True Class: 6\n",
|
|
|
+ "Test 55 Prediction: 0 True Class: 0\n",
|
|
|
+ "Test 56 Prediction: 4 True Class: 4\n",
|
|
|
+ "Test 57 Prediction: 1 True Class: 1\n",
|
|
|
+ "Test 58 Prediction: 9 True Class: 9\n",
|
|
|
+ "Test 59 Prediction: 5 True Class: 5\n",
|
|
|
+ "Test 60 Prediction: 7 True Class: 7\n",
|
|
|
+ "Test 61 Prediction: 8 True Class: 8\n",
|
|
|
+ "Test 62 Prediction: 9 True Class: 9\n",
|
|
|
+ "Test 63 Prediction: 3 True Class: 3\n",
|
|
|
+ "Test 64 Prediction: 7 True Class: 7\n",
|
|
|
+ "Test 65 Prediction: 4 True Class: 4\n",
|
|
|
+ "Test 66 Prediction: 6 True Class: 6\n",
|
|
|
+ "Test 67 Prediction: 4 True Class: 4\n",
|
|
|
+ "Test 68 Prediction: 3 True Class: 3\n",
|
|
|
+ "Test 69 Prediction: 0 True Class: 0\n",
|
|
|
+ "Test 70 Prediction: 7 True Class: 7\n",
|
|
|
+ "Test 71 Prediction: 0 True Class: 0\n",
|
|
|
+ "Test 72 Prediction: 2 True Class: 2\n",
|
|
|
+ "Test 73 Prediction: 7 True Class: 9\n",
|
|
|
+ "Test 74 Prediction: 1 True Class: 1\n",
|
|
|
+ "Test 75 Prediction: 7 True Class: 7\n",
|
|
|
+ "Test 76 Prediction: 3 True Class: 3\n",
|
|
|
+ "Test 77 Prediction: 7 True Class: 2\n",
|
|
|
+ "Test 78 Prediction: 9 True Class: 9\n",
|
|
|
+ "Test 79 Prediction: 7 True Class: 7\n",
|
|
|
+ "Test 80 Prediction: 7 True Class: 7\n",
|
|
|
+ "Test 81 Prediction: 6 True Class: 6\n",
|
|
|
+ "Test 82 Prediction: 2 True Class: 2\n",
|
|
|
+ "Test 83 Prediction: 7 True Class: 7\n",
|
|
|
+ "Test 84 Prediction: 8 True Class: 8\n",
|
|
|
+ "Test 85 Prediction: 4 True Class: 4\n",
|
|
|
+ "Test 86 Prediction: 7 True Class: 7\n",
|
|
|
+ "Test 87 Prediction: 3 True Class: 3\n",
|
|
|
+ "Test 88 Prediction: 6 True Class: 6\n",
|
|
|
+ "Test 89 Prediction: 1 True Class: 1\n",
|
|
|
+ "Test 90 Prediction: 3 True Class: 3\n",
|
|
|
+ "Test 91 Prediction: 6 True Class: 6\n",
|
|
|
+ "Test 92 Prediction: 9 True Class: 9\n",
|
|
|
+ "Test 93 Prediction: 3 True Class: 3\n",
|
|
|
+ "Test 94 Prediction: 1 True Class: 1\n",
|
|
|
+ "Test 95 Prediction: 4 True Class: 4\n",
|
|
|
+ "Test 96 Prediction: 1 True Class: 1\n",
|
|
|
+ "Test 97 Prediction: 7 True Class: 7\n",
|
|
|
+ "Test 98 Prediction: 6 True Class: 6\n",
|
|
|
+ "Test 99 Prediction: 9 True Class: 9\n",
|
|
|
+ "Test 100 Prediction: 6 True Class: 6\n",
|
|
|
+ "Test 101 Prediction: 0 True Class: 0\n",
|
|
|
+ "Test 102 Prediction: 5 True Class: 5\n",
|
|
|
+ "Test 103 Prediction: 4 True Class: 4\n",
|
|
|
+ "Test 104 Prediction: 9 True Class: 9\n",
|
|
|
+ "Test 105 Prediction: 9 True Class: 9\n",
|
|
|
+ "Test 106 Prediction: 2 True Class: 2\n",
|
|
|
+ "Test 107 Prediction: 1 True Class: 1\n",
|
|
|
+ "Test 108 Prediction: 9 True Class: 9\n",
|
|
|
+ "Test 109 Prediction: 4 True Class: 4\n",
|
|
|
+ "Test 110 Prediction: 8 True Class: 8\n",
|
|
|
+ "Test 111 Prediction: 7 True Class: 7\n",
|
|
|
+ "Test 112 Prediction: 3 True Class: 3\n",
|
|
|
+ "Test 113 Prediction: 9 True Class: 9\n",
|
|
|
+ "Test 114 Prediction: 7 True Class: 7\n",
|
|
|
+ "Test 115 Prediction: 9 True Class: 4\n",
|
|
|
+ "Test 116 Prediction: 9 True Class: 4\n",
|
|
|
+ "Test 117 Prediction: 4 True Class: 4\n",
|
|
|
+ "Test 118 Prediction: 9 True Class: 9\n",
|
|
|
+ "Test 119 Prediction: 7 True Class: 2\n",
|
|
|
+ "Test 120 Prediction: 5 True Class: 5\n",
|
|
|
+ "Test 121 Prediction: 4 True Class: 4\n",
|
|
|
+ "Test 122 Prediction: 7 True Class: 7\n",
|
|
|
+ "Test 123 Prediction: 6 True Class: 6\n",
|
|
|
+ "Test 124 Prediction: 7 True Class: 7\n",
|
|
|
+ "Test 125 Prediction: 9 True Class: 9\n",
|
|
|
+ "Test 126 Prediction: 0 True Class: 0\n",
|
|
|
+ "Test 127 Prediction: 5 True Class: 5\n",
|
|
|
+ "Test 128 Prediction: 8 True Class: 8\n",
|
|
|
+ "Test 129 Prediction: 5 True Class: 5\n",
|
|
|
+ "Test 130 Prediction: 6 True Class: 6\n",
|
|
|
+ "Test 131 Prediction: 6 True Class: 6\n",
|
|
|
+ "Test 132 Prediction: 5 True Class: 5\n",
|
|
|
+ "Test 133 Prediction: 7 True Class: 7\n",
|
|
|
+ "Test 134 Prediction: 8 True Class: 8\n",
|
|
|
+ "Test 135 Prediction: 1 True Class: 1\n",
|
|
|
+ "Test 136 Prediction: 0 True Class: 0\n",
|
|
|
+ "Test 137 Prediction: 1 True Class: 1\n",
|
|
|
+ "Test 138 Prediction: 6 True Class: 6\n",
|
|
|
+ "Test 139 Prediction: 4 True Class: 4\n",
|
|
|
+ "Test 140 Prediction: 6 True Class: 6\n",
|
|
|
+ "Test 141 Prediction: 7 True Class: 7\n",
|
|
|
+ "Test 142 Prediction: 2 True Class: 3\n",
|
|
|
+ "Test 143 Prediction: 1 True Class: 1\n",
|
|
|
+ "Test 144 Prediction: 7 True Class: 7\n",
|
|
|
+ "Test 145 Prediction: 1 True Class: 1\n",
|
|
|
+ "Test 146 Prediction: 8 True Class: 8\n",
|
|
|
+ "Test 147 Prediction: 2 True Class: 2\n",
|
|
|
+ "Test 148 Prediction: 0 True Class: 0\n",
|
|
|
+ "Test 149 Prediction: 1 True Class: 2\n",
|
|
|
+ "Test 150 Prediction: 9 True Class: 9\n",
|
|
|
+ "Test 151 Prediction: 9 True Class: 9\n",
|
|
|
+ "Test 152 Prediction: 5 True Class: 5\n",
|
|
|
+ "Test 153 Prediction: 5 True Class: 5\n",
|
|
|
+ "Test 154 Prediction: 1 True Class: 1\n",
|
|
|
+ "Test 155 Prediction: 5 True Class: 5\n",
|
|
|
+ "Test 156 Prediction: 6 True Class: 6\n",
|
|
|
+ "Test 157 Prediction: 0 True Class: 0\n",
|
|
|
+ "Test 158 Prediction: 3 True Class: 3\n",
|
|
|
+ "Test 159 Prediction: 4 True Class: 4\n",
|
|
|
+ "Test 160 Prediction: 4 True Class: 4\n",
|
|
|
+ "Test 161 Prediction: 6 True Class: 6\n",
|
|
|
+ "Test 162 Prediction: 5 True Class: 5\n",
|
|
|
+ "Test 163 Prediction: 4 True Class: 4\n",
|
|
|
+ "Test 164 Prediction: 6 True Class: 6\n",
|
|
|
+ "Test 165 Prediction: 5 True Class: 5\n",
|
|
|
+ "Test 166 Prediction: 4 True Class: 4\n",
|
|
|
+ "Test 167 Prediction: 5 True Class: 5\n",
|
|
|
+ "Test 168 Prediction: 1 True Class: 1\n",
|
|
|
+ "Test 169 Prediction: 4 True Class: 4\n",
|
|
|
+ "Test 170 Prediction: 9 True Class: 4\n",
|
|
|
+ "Test 171 Prediction: 7 True Class: 7\n",
|
|
|
+ "Test 172 Prediction: 2 True Class: 2\n",
|
|
|
+ "Test 173 Prediction: 3 True Class: 3\n",
|
|
|
+ "Test 174 Prediction: 2 True Class: 2\n",
|
|
|
+ "Test 175 Prediction: 1 True Class: 7\n",
|
|
|
+ "Test 176 Prediction: 1 True Class: 1\n",
|
|
|
+ "Test 177 Prediction: 8 True Class: 8\n",
|
|
|
+ "Test 178 Prediction: 1 True Class: 1\n",
|
|
|
+ "Test 179 Prediction: 8 True Class: 8\n",
|
|
|
+ "Test 180 Prediction: 1 True Class: 1\n",
|
|
|
+ "Test 181 Prediction: 8 True Class: 8\n",
|
|
|
+ "Test 182 Prediction: 5 True Class: 5\n",
|
|
|
+ "Test 183 Prediction: 0 True Class: 0\n",
|
|
|
+ "Test 184 Prediction: 2 True Class: 8\n",
|
|
|
+ "Test 185 Prediction: 9 True Class: 9\n",
|
|
|
+ "Test 186 Prediction: 2 True Class: 2\n",
|
|
|
+ "Test 187 Prediction: 5 True Class: 5\n",
|
|
|
+ "Test 188 Prediction: 0 True Class: 0\n",
|
|
|
+ "Test 189 Prediction: 1 True Class: 1\n",
|
|
|
+ "Test 190 Prediction: 1 True Class: 1\n",
|
|
|
+ "Test 191 Prediction: 1 True Class: 1\n",
|
|
|
+ "Test 192 Prediction: 0 True Class: 0\n",
|
|
|
+ "Test 193 Prediction: 4 True Class: 9\n",
|
|
|
+ "Test 194 Prediction: 0 True Class: 0\n",
|
|
|
+ "Test 195 Prediction: 1 True Class: 3\n",
|
|
|
+ "Test 196 Prediction: 1 True Class: 1\n",
|
|
|
+ "Test 197 Prediction: 6 True Class: 6\n",
|
|
|
+ "Test 198 Prediction: 4 True Class: 4\n",
|
|
|
+ "Test 199 Prediction: 2 True Class: 2\n",
|
|
|
+ "Done!\n",
|
|
|
+ "Accuracy: 0.92\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "# Launch the graph\n",
|
|
|
+ "with tf.Session() as sess:\n",
|
|
|
+ " sess.run(init)\n",
|
|
|
+ "\n",
|
|
|
+ " # loop over test data\n",
|
|
|
+ " for i in range(len(Xte)):\n",
|
|
|
+ " # Get nearest neighbor\n",
|
|
|
+ " nn_index = sess.run(pred, feed_dict={xtr: Xtr, xte: Xte[i,:]})\n",
|
|
|
+ " # Get nearest neighbor class label and compare it to its true label\n",
|
|
|
+ " print \"Test\", i, \"Prediction:\", np.argmax(Ytr[nn_index]), \\\n",
|
|
|
+ " \"True Class:\", np.argmax(Yte[i])\n",
|
|
|
+ " # Calculate accuracy\n",
|
|
|
+ " if np.argmax(Ytr[nn_index]) == np.argmax(Yte[i]):\n",
|
|
|
+ " accuracy += 1./len(Xte)\n",
|
|
|
+ " print \"Done!\"\n",
|
|
|
+ " print \"Accuracy:\", accuracy"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "metadata": {
|
|
|
+ "collapsed": true
|
|
|
+ },
|
|
|
+ "outputs": [],
|
|
|
+ "source": []
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "metadata": {
|
|
|
+ "kernelspec": {
|
|
|
+ "display_name": "IPython (Python 2.7)",
|
|
|
+ "language": "python",
|
|
|
+ "name": "python2"
|
|
|
+ },
|
|
|
+ "language_info": {
|
|
|
+ "codemirror_mode": {
|
|
|
+ "name": "ipython",
|
|
|
+ "version": 2
|
|
|
+ },
|
|
|
+ "file_extension": ".py",
|
|
|
+ "mimetype": "text/x-python",
|
|
|
+ "name": "python",
|
|
|
+ "nbconvert_exporter": "python",
|
|
|
+ "pygments_lexer": "ipython2",
|
|
|
+ "version": "2.7.8"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "nbformat": 4,
|
|
|
+ "nbformat_minor": 0
|
|
|
+}
|