autoencoder.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. """ Auto Encoder Example.
  2. Build a 2 layers auto-encoder with TensorFlow to compress images to a
  3. lower latent space and then reconstruct them.
  4. References:
  5. Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. "Gradient-based
  6. learning applied to document recognition." Proceedings of the IEEE,
  7. 86(11):2278-2324, November 1998.
  8. Links:
  9. [MNIST Dataset] http://yann.lecun.com/exdb/mnist/
  10. Author: Aymeric Damien
  11. Project: https://github.com/aymericdamien/TensorFlow-Examples/
  12. """
  13. from __future__ import division, print_function, absolute_import
  14. import tensorflow as tf
  15. import numpy as np
  16. import matplotlib.pyplot as plt
  17. # Import MNIST data
  18. from tensorflow.examples.tutorials.mnist import input_data
  19. mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
  20. # Training Parameters
  21. learning_rate = 0.01
  22. num_steps = 30000
  23. batch_size = 256
  24. display_step = 1000
  25. examples_to_show = 10
  26. # Network Parameters
  27. num_hidden_1 = 256 # 1st layer num features
  28. num_hidden_2 = 128 # 2nd layer num features (the latent dim)
  29. num_input = 784 # MNIST data input (img shape: 28*28)
  30. # tf Graph input (only pictures)
  31. X = tf.placeholder("float", [None, num_input])
  32. weights = {
  33. 'encoder_h1': tf.Variable(tf.random_normal([num_input, num_hidden_1])),
  34. 'encoder_h2': tf.Variable(tf.random_normal([num_hidden_1, num_hidden_2])),
  35. 'decoder_h1': tf.Variable(tf.random_normal([num_hidden_2, num_hidden_1])),
  36. 'decoder_h2': tf.Variable(tf.random_normal([num_hidden_1, num_input])),
  37. }
  38. biases = {
  39. 'encoder_b1': tf.Variable(tf.random_normal([num_hidden_1])),
  40. 'encoder_b2': tf.Variable(tf.random_normal([num_hidden_2])),
  41. 'decoder_b1': tf.Variable(tf.random_normal([num_hidden_1])),
  42. 'decoder_b2': tf.Variable(tf.random_normal([num_input])),
  43. }
  44. # Building the encoder
  45. def encoder(x):
  46. # Encoder Hidden layer with sigmoid activation #1
  47. layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(x, weights['encoder_h1']),
  48. biases['encoder_b1']))
  49. # Encoder Hidden layer with sigmoid activation #2
  50. layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1, weights['encoder_h2']),
  51. biases['encoder_b2']))
  52. return layer_2
  53. # Building the decoder
  54. def decoder(x):
  55. # Decoder Hidden layer with sigmoid activation #1
  56. layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(x, weights['decoder_h1']),
  57. biases['decoder_b1']))
  58. # Decoder Hidden layer with sigmoid activation #2
  59. layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1, weights['decoder_h2']),
  60. biases['decoder_b2']))
  61. return layer_2
  62. # Construct model
  63. encoder_op = encoder(X)
  64. decoder_op = decoder(encoder_op)
  65. # Prediction
  66. y_pred = decoder_op
  67. # Targets (Labels) are the input data.
  68. y_true = X
  69. # Define loss and optimizer, minimize the squared error
  70. loss = tf.reduce_mean(tf.pow(y_true - y_pred, 2))
  71. optimizer = tf.train.RMSPropOptimizer(learning_rate).minimize(loss)
  72. # Initialize the variables (i.e. assign their default value)
  73. init = tf.global_variables_initializer()
  74. # Start Training
  75. # Start a new TF session
  76. with tf.Session() as sess:
  77. # Run the initializer
  78. sess.run(init)
  79. # Training
  80. for i in range(1, num_steps+1):
  81. # Prepare Data
  82. # Get the next batch of MNIST data (only images are needed, not labels)
  83. batch_x, _ = mnist.train.next_batch(batch_size)
  84. # Run optimization op (backprop) and cost op (to get loss value)
  85. _, l = sess.run([optimizer, loss], feed_dict={X: batch_x})
  86. # Display logs per step
  87. if i % display_step == 0 or i == 1:
  88. print('Step %i: Minibatch Loss: %f' % (i, l))
  89. # Testing
  90. # Encode and decode images from test set and visualize their reconstruction.
  91. n = 4
  92. canvas_orig = np.empty((28 * n, 28 * n))
  93. canvas_recon = np.empty((28 * n, 28 * n))
  94. for i in range(n):
  95. # MNIST test set
  96. batch_x, _ = mnist.test.next_batch(n)
  97. # Encode and decode the digit image
  98. g = sess.run(decoder_op, feed_dict={X: batch_x})
  99. # Display original images
  100. for j in range(n):
  101. # Draw the original digits
  102. canvas_orig[i * 28:(i + 1) * 28, j * 28:(j + 1) * 28] = \
  103. batch_x[j].reshape([28, 28])
  104. # Display reconstructed images
  105. for j in range(n):
  106. # Draw the reconstructed digits
  107. canvas_recon[i * 28:(i + 1) * 28, j * 28:(j + 1) * 28] = \
  108. g[j].reshape([28, 28])
  109. print("Original Images")
  110. plt.figure(figsize=(n, n))
  111. plt.imshow(canvas_orig, origin="upper", cmap="gray")
  112. plt.show()
  113. print("Reconstructed Images")
  114. plt.figure(figsize=(n, n))
  115. plt.imshow(canvas_recon, origin="upper", cmap="gray")
  116. plt.show()