Преглед на файлове

Refactor nearest_neighbor for TF1.0

Signed-off-by: Norman Heckscher <norman.heckscher@gmail.com>
Norman Heckscher преди 8 години
родител
ревизия
8e03823181
променени са 2 файла, в които са добавени 25 реда и са изтрити 16 реда
  1. 2 2
      examples/2_BasicModels/nearest_neighbor.py
  2. 23 14
      notebooks/2_BasicModels/nearest_neighbor.ipynb

+ 2 - 2
examples/2_BasicModels/nearest_neighbor.py

@@ -26,14 +26,14 @@ xte = tf.placeholder("float", [784])
 
 # Nearest Neighbor calculation using L1 Distance
 # Calculate L1 Distance
-distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.neg(xte))), reduction_indices=1)
+distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.negative(xte))), reduction_indices=1)
 # Prediction: Get min distance index (Nearest neighbor)
 pred = tf.arg_min(distance, 0)
 
 accuracy = 0.
 
 # Initializing the variables
-init = tf.initialize_all_variables()
+init = tf.global_variables_initializer()
 
 # Launch the graph
 with tf.Session() as sess:

+ 23 - 14
notebooks/2_BasicModels/nearest_neighbor.ipynb

@@ -18,7 +18,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 2,
+   "execution_count": 1,
    "metadata": {
     "collapsed": false
    },
@@ -27,10 +27,10 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "Extracting /tmp/data/train-images-idx3-ubyte.gz\n",
-      "Extracting /tmp/data/train-labels-idx1-ubyte.gz\n",
-      "Extracting /tmp/data/t10k-images-idx3-ubyte.gz\n",
-      "Extracting /tmp/data/t10k-labels-idx1-ubyte.gz\n"
+      "Extracting MNIST_data/train-images-idx3-ubyte.gz\n",
+      "Extracting MNIST_data/train-labels-idx1-ubyte.gz\n",
+      "Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n",
+      "Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n"
      ]
     }
    ],
@@ -40,14 +40,14 @@
     "\n",
     "# Import MINST data\n",
     "from tensorflow.examples.tutorials.mnist import input_data\n",
-    "mnist = input_data.read_data_sets(\"/tmp/data/\", one_hot=True)"
+    "mnist = input_data.read_data_sets(\"MNIST_data/\", one_hot=True)"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 3,
+   "execution_count": 2,
    "metadata": {
-    "collapsed": true
+    "collapsed": false
    },
    "outputs": [],
    "source": [
@@ -61,19 +61,19 @@
     "\n",
     "# Nearest Neighbor calculation using L1 Distance\n",
     "# Calculate L1 Distance\n",
-    "distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.neg(xte))), reduction_indices=1)\n",
+    "distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.negative(xte))), reduction_indices=1)\n",
     "# Prediction: Get min distance index (Nearest neighbor)\n",
     "pred = tf.arg_min(distance, 0)\n",
     "\n",
     "accuracy = 0.\n",
     "\n",
     "# Initializing the variables\n",
-    "init = tf.initialize_all_variables()"
+    "init = tf.global_variables_initializer()"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": 3,
    "metadata": {
     "collapsed": false
    },
@@ -305,6 +305,15 @@
     "    print \"Done!\"\n",
     "    print \"Accuracy:\", accuracy"
    ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": []
   }
  ],
  "metadata": {
@@ -316,16 +325,16 @@
   "language_info": {
    "codemirror_mode": {
     "name": "ipython",
-    "version": 2.0
+    "version": 2
    },
    "file_extension": ".py",
    "mimetype": "text/x-python",
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython2",
-   "version": "2.7.11"
+   "version": "2.7.13"
   }
  },
  "nbformat": 4,
  "nbformat_minor": 0
-}
+}