autoencoder.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. # -*- coding: utf-8 -*-
  2. """ Auto Encoder Example.
  3. Using an auto encoder on MNIST handwritten digits.
  4. References:
  5. Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. "Gradient-based
  6. learning applied to document recognition." Proceedings of the IEEE,
  7. 86(11):2278-2324, November 1998.
  8. Links:
  9. [MNIST Dataset] http://yann.lecun.com/exdb/mnist/
  10. """
  11. from __future__ import division, print_function, absolute_import
  12. import tensorflow as tf
  13. import numpy as np
  14. import matplotlib.pyplot as plt
  15. # Import MNIST data
  16. from tensorflow.examples.tutorials.mnist import input_data
  17. mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
  18. # Parameters
  19. learning_rate = 0.01
  20. training_epochs = 20
  21. batch_size = 256
  22. display_step = 1
  23. examples_to_show = 10
  24. # Network Parameters
  25. n_hidden_1 = 256 # 1st layer num features
  26. n_hidden_2 = 128 # 2nd layer num features
  27. n_input = 784 # MNIST data input (img shape: 28*28)
  28. # tf Graph input (only pictures)
  29. X = tf.placeholder("float", [None, n_input])
  30. weights = {
  31. 'encoder_h1': tf.Variable(tf.random_normal([n_input, n_hidden_1])),
  32. 'encoder_h2': tf.Variable(tf.random_normal([n_hidden_1, n_hidden_2])),
  33. 'decoder_h1': tf.Variable(tf.random_normal([n_hidden_2, n_hidden_1])),
  34. 'decoder_h2': tf.Variable(tf.random_normal([n_hidden_1, n_input])),
  35. }
  36. biases = {
  37. 'encoder_b1': tf.Variable(tf.random_normal([n_hidden_1])),
  38. 'encoder_b2': tf.Variable(tf.random_normal([n_hidden_2])),
  39. 'decoder_b1': tf.Variable(tf.random_normal([n_hidden_1])),
  40. 'decoder_b2': tf.Variable(tf.random_normal([n_input])),
  41. }
  42. # Building the encoder
  43. def encoder(x):
  44. # Encoder Hidden layer with sigmoid activation #1
  45. layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(x, weights['encoder_h1']),
  46. biases['encoder_b1']))
  47. # Decoder Hidden layer with sigmoid activation #2
  48. layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1, weights['encoder_h2']),
  49. biases['encoder_b2']))
  50. return layer_2
  51. # Building the decoder
  52. def decoder(x):
  53. # Encoder Hidden layer with sigmoid activation #1
  54. layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(x, weights['decoder_h1']),
  55. biases['decoder_b1']))
  56. # Decoder Hidden layer with sigmoid activation #2
  57. layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1, weights['decoder_h2']),
  58. biases['decoder_b2']))
  59. return layer_2
  60. # Construct model
  61. encoder_op = encoder(X)
  62. decoder_op = decoder(encoder_op)
  63. # Prediction
  64. y_pred = decoder_op
  65. # Targets (Labels) are the input data.
  66. y_true = X
  67. # Define loss and optimizer, minimize the squared error
  68. cost = tf.reduce_mean(tf.pow(y_true - y_pred, 2))
  69. optimizer = tf.train.RMSPropOptimizer(learning_rate).minimize(cost)
  70. # Initializing the variables
  71. init = tf.initialize_all_variables()
  72. # Launch the graph
  73. with tf.Session() as sess:
  74. sess.run(init)
  75. total_batch = int(mnist.train.num_examples/batch_size)
  76. # Training cycle
  77. for epoch in range(training_epochs):
  78. # Loop over all batches
  79. for i in range(total_batch):
  80. batch_xs, batch_ys = mnist.train.next_batch(batch_size)
  81. # Run optimization op (backprop) and cost op (to get loss value)
  82. _, c = sess.run([optimizer, cost], feed_dict={X: batch_xs})
  83. # Display logs per epoch step
  84. if epoch % display_step == 0:
  85. print("Epoch:", '%04d' % (epoch+1),
  86. "cost=", "{:.9f}".format(c))
  87. print("Optimization Finished!")
  88. # Applying encode and decode over test set
  89. encode_decode = sess.run(
  90. y_pred, feed_dict={X: mnist.test.images[:examples_to_show]})
  91. # Compare original images with their reconstructions
  92. f, a = plt.subplots(2, 10, figsize=(10, 2))
  93. for i in range(examples_to_show):
  94. a[0][i].imshow(np.reshape(mnist.test.images[i], (28, 28)))
  95. a[1][i].imshow(np.reshape(encode_decode[i], (28, 28)))
  96. f.show()
  97. plt.draw()
  98. plt.waitforbuttonpress()