autoencoder.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. # -*- coding: utf-8 -*-
  2. """ Auto Encoder Example.
  3. Using an auto encoder on MNIST handwritten digits.
  4. References:
  5. Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. "Gradient-based
  6. learning applied to document recognition." Proceedings of the IEEE,
  7. 86(11):2278-2324, November 1998.
  8. Links:
  9. [MNIST Dataset] http://yann.lecun.com/exdb/mnist/
  10. """
  11. from __future__ import division, print_function, absolute_import
  12. import tensorflow as tf
  13. import numpy as np
  14. import matplotlib.pyplot as plt
  15. # Import MINST data
  16. import input_data
  17. mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
  18. # Parameters
  19. learning_rate = 0.01
  20. training_epochs = 20
  21. batch_size = 256
  22. display_step = 1
  23. examples_to_show = 10
  24. # Network Parameters
  25. n_hidden_1 = 256 # 1st layer num features
  26. n_hidden_2 = 128 # 2nd layer num features
  27. n_input = 784 # MNIST data input (img shape: 28*28)
  28. # tf Graph input (only pictures)
  29. X = tf.placeholder("float", [None, n_input])
  30. weights = {
  31. 'encoder_h1': tf.Variable(tf.random_normal([n_input, n_hidden_1])),
  32. 'encoder_h2': tf.Variable(tf.random_normal([n_hidden_1, n_hidden_2])),
  33. 'decoder_h1': tf.Variable(tf.random_normal([n_hidden_2, n_hidden_1])),
  34. 'decoder_h2': tf.Variable(tf.random_normal([n_hidden_1, n_input])),
  35. }
  36. biases = {
  37. 'encoder_b1': tf.Variable(tf.random_normal([n_hidden_1])),
  38. 'encoder_b2': tf.Variable(tf.random_normal([n_hidden_2])),
  39. 'decoder_b1': tf.Variable(tf.random_normal([n_hidden_1])),
  40. 'decoder_b2': tf.Variable(tf.random_normal([n_input])),
  41. }
  42. # Building the encoder
  43. def encoder(x):
  44. # Encoder Hidden layer with sigmoid activation #1
  45. layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(x, weights['encoder_h1']), biases['encoder_b1']))
  46. # Decoder Hidden layer with sigmoid activation #2
  47. layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1, weights['encoder_h2']), biases['encoder_b2']))
  48. return layer_2
  49. # Building the decoder
  50. def decoder(x):
  51. # Encoder Hidden layer with sigmoid activation #1
  52. layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(x, weights['decoder_h1']), biases['decoder_b1']))
  53. # Decoder Hidden layer with sigmoid activation #2
  54. layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1, weights['decoder_h2']), biases['decoder_b2']))
  55. return layer_2
  56. # Construct model
  57. encoder_op = encoder(X)
  58. decoder_op = decoder(encoder_op)
  59. y_pred = decoder_op
  60. y_true = X
  61. # Define loss and optimizer, minimize the squared error
  62. cost = tf.reduce_mean(tf.pow(y_true - y_pred, 2))
  63. optimizer = tf.train.RMSPropOptimizer(learning_rate).minimize(cost)
  64. # Initializing the variables
  65. init = tf.initialize_all_variables()
  66. # Launch the graph
  67. with tf.Session() as sess:
  68. sess.run(init)
  69. total_batch = int(mnist.train.num_examples/batch_size)
  70. # Training cycle
  71. for epoch in range(training_epochs):
  72. # Loop over all batches
  73. for i in range(total_batch):
  74. batch_xs, batch_ys = mnist.train.next_batch(batch_size)
  75. # Fit training using batch data
  76. _, cost_value = sess.run([optimizer, cost], feed_dict={X: batch_xs})
  77. # Display logs per epoch step
  78. if epoch % display_step == 0:
  79. print("Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(cost_value))
  80. print("Optimization Finished!")
  81. #Applying encode and decode over test set
  82. encode_decode = sess.run(y_pred, feed_dict={X: mnist.test.images[:examples_to_show]})
  83. # Compare original images with their reconstructions
  84. f, a = plt.subplots(2, 10, figsize=(10, 2))
  85. for i in range(examples_to_show):
  86. a[0][i].imshow(np.reshape(mnist.test.images[i], (28, 28)))
  87. a[1][i].imshow(np.reshape(encode_decode[i], (28, 28)))
  88. f.show()
  89. plt.draw()
  90. plt.waitforbuttonpress()
  91. # # Regression, with mean square error
  92. # net = tflearn.regression(decoder, optimizer='adam', learning_rate=0.001,
  93. # loss='mean_square', metric=None)
  94. #
  95. # # Training the auto encoder
  96. # model = tflearn.DNN(net, tensorboard_verbose=0)
  97. # model.fit(X, X, n_epoch=10, validation_set=(testX, testX),
  98. # run_id="auto_encoder", batch_size=256)
  99. #
  100. # # Encoding X[0] for test
  101. # print("\nTest encoding of X[0]:")
  102. # # New model, re-using the same session, for weights sharing
  103. # encoding_model = tflearn.DNN(encoder, session=model.session)
  104. # print(encoding_model.predict([X[0]]))
  105. #
  106. # # Testing the image reconstruction on new data (test set)
  107. # print("\nVisualizing results after being encoded and decoded:")
  108. # testX = tflearn.data_utils.shuffle(testX)[0]
  109. # # Applying encode and decode over test set
  110. # encode_decode = model.predict(testX)
  111. # # Compare original images with their reconstructions
  112. # f, a = plt.subplots(2, 10, figsize=(10, 2))
  113. # for i in range(10):
  114. # a[0][i].imshow(np.reshape(testX[i], (28, 28)))
  115. # a[1][i].imshow(np.reshape(encode_decode[i], (28, 28)))
  116. # f.show()
  117. # plt.draw()
  118. # plt.waitforbuttonpress()