{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "\"\"\" Auto Encoder Example.\n", "Using an auto encoder on MNIST handwritten digits.\n", "References:\n", " Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. \"Gradient-based\n", " learning applied to document recognition.\" Proceedings of the IEEE,\n", " 86(11):2278-2324, November 1998.\n", "Links:\n", " [MNIST Dataset] http://yann.lecun.com/exdb/mnist/\n", "\"\"\"" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Extracting /tmp/data/train-images-idx3-ubyte.gz\n", "Extracting /tmp/data/train-labels-idx1-ubyte.gz\n", "Extracting /tmp/data/t10k-images-idx3-ubyte.gz\n", "Extracting /tmp/data/t10k-labels-idx1-ubyte.gz\n" ] } ], "source": [ "from __future__ import division, print_function, absolute_import\n", "\n", "import tensorflow as tf\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "# Import MINST data\n", "from tensorflow.examples.tutorials.mnist import input_data\n", "mnist = input_data.read_data_sets(\"/tmp/data/\", one_hot=True)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Parameters\n", "learning_rate = 0.01\n", "training_epochs = 20\n", "batch_size = 256\n", "display_step = 1\n", "examples_to_show = 10\n", "\n", "# Network Parameters\n", "n_hidden_1 = 256 # 1st layer num features\n", "n_hidden_2 = 128 # 2nd layer num features\n", "n_input = 784 # MNIST data input (img shape: 28*28)\n", "\n", "# tf Graph input (only pictures)\n", "X = tf.placeholder(\"float\", [None, n_input])\n", "\n", "weights = {\n", " 'encoder_h1': tf.Variable(tf.random_normal([n_input, n_hidden_1])),\n", " 'encoder_h2': tf.Variable(tf.random_normal([n_hidden_1, n_hidden_2])),\n", " 'decoder_h1': tf.Variable(tf.random_normal([n_hidden_2, n_hidden_1])),\n", " 'decoder_h2': tf.Variable(tf.random_normal([n_hidden_1, n_input])),\n", "}\n", "biases = {\n", " 'encoder_b1': tf.Variable(tf.random_normal([n_hidden_1])),\n", " 'encoder_b2': tf.Variable(tf.random_normal([n_hidden_2])),\n", " 'decoder_b1': tf.Variable(tf.random_normal([n_hidden_1])),\n", " 'decoder_b2': tf.Variable(tf.random_normal([n_input])),\n", "}" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Building the encoder\n", "def encoder(x):\n", " # Encoder Hidden layer with sigmoid activation #1\n", " layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(x, weights['encoder_h1']),\n", " biases['encoder_b1']))\n", " # Decoder Hidden layer with sigmoid activation #2\n", " layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1, weights['encoder_h2']),\n", " biases['encoder_b2']))\n", " return layer_2\n", "\n", "\n", "# Building the decoder\n", "def decoder(x):\n", " # Encoder Hidden layer with sigmoid activation #1\n", " layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(x, weights['decoder_h1']),\n", " biases['decoder_b1']))\n", " # Decoder Hidden layer with sigmoid activation #2\n", " layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1, weights['decoder_h2']),\n", " biases['decoder_b2']))\n", " return layer_2\n", "\n", "# Construct model\n", "encoder_op = encoder(X)\n", "decoder_op = decoder(encoder_op)\n", "\n", "# Prediction\n", "y_pred = decoder_op\n", "# Targets (Labels) are the input data.\n", "y_true = X\n", "\n", "# Define loss and optimizer, minimize the squared error\n", "cost = tf.reduce_mean(tf.pow(y_true - y_pred, 2))\n", "optimizer = tf.train.RMSPropOptimizer(learning_rate).minimize(cost)\n", "\n", "# Initializing the variables\n", "init = tf.initialize_all_variables()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 0001 cost= 0.218654603\n", "Epoch: 0002 cost= 0.173306286\n", "Epoch: 0003 cost= 0.154793650\n", "Epoch: 0004 cost= 0.146902516\n", "Epoch: 0005 cost= 0.141993478\n", "Epoch: 0006 cost= 0.132718414\n", "Epoch: 0007 cost= 0.125991374\n", "Epoch: 0008 cost= 0.122500181\n", "Epoch: 0009 cost= 0.115299642\n", "Epoch: 0010 cost= 0.115390278\n", "Epoch: 0011 cost= 0.114480168\n", "Epoch: 0012 cost= 0.113888472\n", "Epoch: 0013 cost= 0.111597553\n", "Epoch: 0014 cost= 0.110663064\n", "Epoch: 0015 cost= 0.108673096\n", "Epoch: 0016 cost= 0.104775786\n", "Epoch: 0017 cost= 0.106273368\n", "Epoch: 0018 cost= 0.104061618\n", "Epoch: 0019 cost= 0.103227913\n", "Epoch: 0020 cost= 0.099696413\n", "Optimization Finished!\n" ] } ], "source": [ "# Launch the graph\n", "# Using InteractiveSession (more convenient while using Notebooks)\n", "sess = tf.InteractiveSession()\n", "sess.run(init)\n", "\n", "total_batch = int(mnist.train.num_examples/batch_size)\n", "# Training cycle\n", "for epoch in range(training_epochs):\n", " # Loop over all batches\n", " for i in range(total_batch):\n", " batch_xs, batch_ys = mnist.train.next_batch(batch_size)\n", " # Run optimization op (backprop) and cost op (to get loss value)\n", " _, c = sess.run([optimizer, cost], feed_dict={X: batch_xs})\n", " # Display logs per epoch step\n", " if epoch % display_step == 0:\n", " print(\"Epoch:\", '%04d' % (epoch+1),\n", " \"cost=\", \"{:.9f}\".format(c))\n", "\n", "print(\"Optimization Finished!\")\n", "\n", "# Applying encode and decode over test set\n", "encode_decode = sess.run(\n", " y_pred, feed_dict={X: mnist.test.images[:examples_to_show]})\n", "# Compare original images with their reconstructions\n", "f, a = plt.subplots(2, 10, figsize=(10, 2))\n", "for i in range(examples_to_show):\n", " a[0][i].imshow(np.reshape(mnist.test.images[i], (28, 28)))\n", " a[1][i].imshow(np.reshape(encode_decode[i], (28, 28)))\n", "f.show()\n", "plt.draw()\n", "plt.waitforbuttonpress()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2.0 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.11" } }, "nbformat": 4, "nbformat_minor": 0 }