Browse Source

Refactor logistic_regression for TF1.0

Signed-off-by: Norman Heckscher <norman.heckscher@gmail.com>
Norman Heckscher 8 years ago
parent
commit
7839ba225b

+ 1 - 1
examples/2_BasicModels/logistic_regression.py

@@ -38,7 +38,7 @@ cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1))
 optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
 
 # Initializing the variables
-init = tf.initialize_all_variables()
+init = tf.global_variables_initializer()
 
 # Launch the graph
 with tf.Session() as sess:

+ 38 - 39
notebooks/2_BasicModels/logistic_regression.ipynb

@@ -18,7 +18,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 2,
+   "execution_count": 1,
    "metadata": {
     "collapsed": false
    },
@@ -27,10 +27,10 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "Extracting /tmp/data/train-images-idx3-ubyte.gz\n",
-      "Extracting /tmp/data/train-labels-idx1-ubyte.gz\n",
-      "Extracting /tmp/data/t10k-images-idx3-ubyte.gz\n",
-      "Extracting /tmp/data/t10k-labels-idx1-ubyte.gz\n"
+      "Extracting MNIST_data/train-images-idx3-ubyte.gz\n",
+      "Extracting MNIST_data/train-labels-idx1-ubyte.gz\n",
+      "Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n",
+      "Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n"
      ]
     }
    ],
@@ -39,14 +39,14 @@
     "\n",
     "# Import MINST data\n",
     "from tensorflow.examples.tutorials.mnist import input_data\n",
-    "mnist = input_data.read_data_sets(\"/tmp/data/\", one_hot=True)"
+    "mnist = input_data.read_data_sets(\"MNIST_data/\", one_hot=True)"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": 3,
    "metadata": {
-    "collapsed": true
+    "collapsed": false
    },
    "outputs": [],
    "source": [
@@ -73,12 +73,12 @@
     "optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)\n",
     "\n",
     "# Initializing the variables\n",
-    "init = tf.initialize_all_variables()"
+    "init = tf.global_variables_initializer()"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 4,
+   "execution_count": null,
    "metadata": {
     "collapsed": false
    },
@@ -87,33 +87,23 @@
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "Epoch: 0001 cost= 1.182138961\n",
-      "Epoch: 0002 cost= 0.664670898\n",
-      "Epoch: 0003 cost= 0.552613988\n",
-      "Epoch: 0004 cost= 0.498497931\n",
-      "Epoch: 0005 cost= 0.465418769\n",
-      "Epoch: 0006 cost= 0.442546219\n",
-      "Epoch: 0007 cost= 0.425473814\n",
-      "Epoch: 0008 cost= 0.412171735\n",
-      "Epoch: 0009 cost= 0.401359516\n",
-      "Epoch: 0010 cost= 0.392401536\n",
-      "Epoch: 0011 cost= 0.384750201\n",
-      "Epoch: 0012 cost= 0.378185581\n",
-      "Epoch: 0013 cost= 0.372401533\n",
-      "Epoch: 0014 cost= 0.367302442\n",
-      "Epoch: 0015 cost= 0.362702316\n",
-      "Epoch: 0016 cost= 0.358568827\n",
-      "Epoch: 0017 cost= 0.354882155\n",
-      "Epoch: 0018 cost= 0.351430912\n",
-      "Epoch: 0019 cost= 0.348316068\n",
-      "Epoch: 0020 cost= 0.345392556\n",
-      "Epoch: 0021 cost= 0.342737278\n",
-      "Epoch: 0022 cost= 0.340264994\n",
-      "Epoch: 0023 cost= 0.337890242\n",
-      "Epoch: 0024 cost= 0.335708558\n",
-      "Epoch: 0025 cost= 0.333686476\n",
-      "Optimization Finished!\n",
-      "Accuracy: 0.889667\n"
+      "Epoch: 0001 cost= 1.182138959\n",
+      "Epoch: 0002 cost= 0.664778162\n",
+      "Epoch: 0003 cost= 0.552686284\n",
+      "Epoch: 0004 cost= 0.498628905\n",
+      "Epoch: 0005 cost= 0.465469866\n",
+      "Epoch: 0006 cost= 0.442537872\n",
+      "Epoch: 0007 cost= 0.425462044\n",
+      "Epoch: 0008 cost= 0.412185303\n",
+      "Epoch: 0009 cost= 0.401311587\n",
+      "Epoch: 0010 cost= 0.392326203\n",
+      "Epoch: 0011 cost= 0.384736038\n",
+      "Epoch: 0012 cost= 0.378137191\n",
+      "Epoch: 0013 cost= 0.372363752\n",
+      "Epoch: 0014 cost= 0.367308579\n",
+      "Epoch: 0015 cost= 0.362704660\n",
+      "Epoch: 0016 cost= 0.358588599\n",
+      "Epoch: 0017 cost= 0.354823110\n"
      ]
     }
    ],
@@ -146,6 +136,15 @@
     "    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n",
     "    print \"Accuracy:\", accuracy.eval({x: mnist.test.images[:3000], y: mnist.test.labels[:3000]})"
    ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": []
   }
  ],
  "metadata": {
@@ -157,16 +156,16 @@
   "language_info": {
    "codemirror_mode": {
     "name": "ipython",
-    "version": 2.0
+    "version": 2
    },
    "file_extension": ".py",
    "mimetype": "text/x-python",
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython2",
-   "version": "2.7.11"
+   "version": "2.7.13"
   }
  },
  "nbformat": 4,
  "nbformat_minor": 0
-}
+}