|
@@ -2,9 +2,7 @@
|
|
|
"cells": [
|
|
|
{
|
|
|
"cell_type": "markdown",
|
|
|
- "metadata": {
|
|
|
- "collapsed": false
|
|
|
- },
|
|
|
+ "metadata": {},
|
|
|
"source": [
|
|
|
"# TensorFlow Dataset API\n",
|
|
|
"\n",
|
|
@@ -19,9 +17,7 @@
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
"execution_count": 1,
|
|
|
- "metadata": {
|
|
|
- "collapsed": false
|
|
|
- },
|
|
|
+ "metadata": {},
|
|
|
"outputs": [
|
|
|
{
|
|
|
"name": "stdout",
|
|
@@ -64,21 +60,21 @@
|
|
|
"sess = tf.Session()\n",
|
|
|
"\n",
|
|
|
"# Create a dataset tensor from the images and the labels\n",
|
|
|
- "dataset = tf.contrib.data.Dataset.from_tensor_slices(\n",
|
|
|
+ "dataset = tf.data.Dataset.from_tensor_slices(\n",
|
|
|
" (mnist.train.images, mnist.train.labels))\n",
|
|
|
+ "# Automatically refill the data queue when empty\n",
|
|
|
+ "dataset = dataset.repeat()\n",
|
|
|
"# Create batches of data\n",
|
|
|
"dataset = dataset.batch(batch_size)\n",
|
|
|
- "# Create an iterator, to go over the dataset\n",
|
|
|
+ "# Prefetch data for faster\n",
|
|
|
+ "dataset = dataset.prefetch(batch_size)\n",
|
|
|
+ "\n",
|
|
|
+ "# Create an iterator over the dataset\n",
|
|
|
"iterator = dataset.make_initializable_iterator()\n",
|
|
|
- "# It is better to use 2 placeholders, to avoid to load all data into memory,\n",
|
|
|
- "# and avoid the 2Gb restriction length of a tensor.\n",
|
|
|
- "_data = tf.placeholder(tf.float32, [None, n_input])\n",
|
|
|
- "_labels = tf.placeholder(tf.float32, [None, n_classes])\n",
|
|
|
"# Initialize the iterator\n",
|
|
|
- "sess.run(iterator.initializer, feed_dict={_data: mnist.train.images,\n",
|
|
|
- " _labels: mnist.train.labels})\n",
|
|
|
+ "sess.run(iterator.initializer)\n",
|
|
|
"\n",
|
|
|
- "# Neural Net Input\n",
|
|
|
+ "# Neural Net Input (images, labels)\n",
|
|
|
"X, Y = iterator.get_next()"
|
|
|
]
|
|
|
},
|
|
@@ -155,7 +151,6 @@
|
|
|
"cell_type": "code",
|
|
|
"execution_count": 4,
|
|
|
"metadata": {
|
|
|
- "collapsed": false,
|
|
|
"scrolled": false
|
|
|
},
|
|
|
"outputs": [
|
|
@@ -188,15 +183,8 @@
|
|
|
"# Training cycle\n",
|
|
|
"for step in range(1, num_steps + 1):\n",
|
|
|
" \n",
|
|
|
- " try:\n",
|
|
|
- " # Run optimization\n",
|
|
|
- " sess.run(train_op)\n",
|
|
|
- " except tf.errors.OutOfRangeError:\n",
|
|
|
- " # Reload the iterator when it reaches the end of the dataset\n",
|
|
|
- " sess.run(iterator.initializer, \n",
|
|
|
- " feed_dict={_data: mnist.train.images,\n",
|
|
|
- " _labels: mnist.train.labels})\n",
|
|
|
- " sess.run(train_op)\n",
|
|
|
+ " # Run optimization\n",
|
|
|
+ " sess.run(train_op)\n",
|
|
|
" \n",
|
|
|
" if step % display_step == 0 or step == 1:\n",
|
|
|
" # Calculate batch loss and accuracy\n",
|
|
@@ -212,7 +200,7 @@
|
|
|
],
|
|
|
"metadata": {
|
|
|
"kernelspec": {
|
|
|
- "display_name": "Python [default]",
|
|
|
+ "display_name": "Python 2",
|
|
|
"language": "python",
|
|
|
"name": "python2"
|
|
|
},
|
|
@@ -226,7 +214,7 @@
|
|
|
"name": "python",
|
|
|
"nbconvert_exporter": "python",
|
|
|
"pygments_lexer": "ipython2",
|
|
|
- "version": "2.7.12"
|
|
|
+ "version": "2.7.14"
|
|
|
}
|
|
|
},
|
|
|
"nbformat": 4,
|