variational_autoencoder.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. """ Variational Auto-Encoder Example.
  2. Using a variational auto-encoder to generate digits images from noise.
  3. MNIST handwritten digits are used as training examples.
  4. References:
  5. - Auto-Encoding Variational Bayes The International Conference on Learning
  6. Representations (ICLR), Banff, 2014. D.P. Kingma, M. Welling
  7. - Understanding the difficulty of training deep feedforward neural networks.
  8. X Glorot, Y Bengio. Aistats 9, 249-256
  9. - Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. "Gradient-based
  10. learning applied to document recognition." Proceedings of the IEEE,
  11. 86(11):2278-2324, November 1998.
  12. Links:
  13. - [VAE Paper] https://arxiv.org/abs/1312.6114
  14. - [Xavier Glorot Init](www.cs.cmu.edu/~bhiksha/courses/deeplearning/Fall.../AISTATS2010_Glorot.pdf).
  15. - [MNIST Dataset] http://yann.lecun.com/exdb/mnist/
  16. Author: Aymeric Damien
  17. Project: https://github.com/aymericdamien/TensorFlow-Examples/
  18. """
  19. from __future__ import division, print_function, absolute_import
  20. import numpy as np
  21. import matplotlib.pyplot as plt
  22. from scipy.stats import norm
  23. import tensorflow as tf
  24. # Import MNIST data
  25. from tensorflow.examples.tutorials.mnist import input_data
  26. mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
  27. # Parameters
  28. learning_rate = 0.001
  29. num_steps = 30000
  30. batch_size = 64
  31. # Network Parameters
  32. image_dim = 784 # MNIST images are 28x28 pixels
  33. hidden_dim = 512
  34. latent_dim = 2
  35. # A custom initialization (see Xavier Glorot init)
  36. def glorot_init(shape):
  37. return tf.random_normal(shape=shape, stddev=1. / tf.sqrt(shape[0] / 2.))
  38. # Variables
  39. weights = {
  40. 'encoder_h1': tf.Variable(glorot_init([image_dim, hidden_dim])),
  41. 'z_mean': tf.Variable(glorot_init([hidden_dim, latent_dim])),
  42. 'z_std': tf.Variable(glorot_init([hidden_dim, latent_dim])),
  43. 'decoder_h1': tf.Variable(glorot_init([latent_dim, hidden_dim])),
  44. 'decoder_out': tf.Variable(glorot_init([hidden_dim, image_dim]))
  45. }
  46. biases = {
  47. 'encoder_b1': tf.Variable(glorot_init([hidden_dim])),
  48. 'z_mean': tf.Variable(glorot_init([latent_dim])),
  49. 'z_std': tf.Variable(glorot_init([latent_dim])),
  50. 'decoder_b1': tf.Variable(glorot_init([hidden_dim])),
  51. 'decoder_out': tf.Variable(glorot_init([image_dim]))
  52. }
  53. # Building the encoder
  54. input_image = tf.placeholder(tf.float32, shape=[None, image_dim])
  55. encoder = tf.matmul(input_image, weights['encoder_h1']) + biases['encoder_b1']
  56. encoder = tf.nn.tanh(encoder)
  57. z_mean = tf.matmul(encoder, weights['z_mean']) + biases['z_mean']
  58. z_std = tf.matmul(encoder, weights['z_std']) + biases['z_std']
  59. # Sampler: Normal (gaussian) random distribution
  60. eps = tf.random_normal(tf.shape(z_std), dtype=tf.float32, mean=0., stddev=1.0,
  61. name='epsilon')
  62. z = z_mean + tf.exp(z_std / 2) * eps
  63. # Building the decoder (with scope to re-use these layers later)
  64. decoder = tf.matmul(z, weights['decoder_h1']) + biases['decoder_b1']
  65. decoder = tf.nn.tanh(decoder)
  66. decoder = tf.matmul(decoder, weights['decoder_out']) + biases['decoder_out']
  67. decoder = tf.nn.sigmoid(decoder)
  68. # Define VAE Loss
  69. def vae_loss(x_reconstructed, x_true):
  70. # Reconstruction loss
  71. encode_decode_loss = x_true * tf.log(1e-10 + x_reconstructed) \
  72. + (1 - x_true) * tf.log(1e-10 + 1 - x_reconstructed)
  73. encode_decode_loss = -tf.reduce_sum(encode_decode_loss, 1)
  74. # KL Divergence loss
  75. kl_div_loss = 1 + z_std - tf.square(z_mean) - tf.exp(z_std)
  76. kl_div_loss = -0.5 * tf.reduce_sum(kl_div_loss, 1)
  77. return tf.reduce_mean(encode_decode_loss + kl_div_loss)
  78. loss_op = vae_loss(decoder, input_image)
  79. optimizer = tf.train.RMSPropOptimizer(learning_rate=learning_rate)
  80. train_op = optimizer.minimize(loss_op)
  81. # Initialize the variables (i.e. assign their default value)
  82. init = tf.global_variables_initializer()
  83. # Start training
  84. with tf.Session() as sess:
  85. # Run the initializer
  86. sess.run(init)
  87. for i in range(1, num_steps+1):
  88. # Prepare Data
  89. # Get the next batch of MNIST data (only images are needed, not labels)
  90. batch_x, _ = mnist.train.next_batch(batch_size)
  91. # Train
  92. feed_dict = {input_image: batch_x}
  93. _, l = sess.run([train_op, loss_op], feed_dict=feed_dict)
  94. if i % 1000 == 0 or i == 1:
  95. print('Step %i, Loss: %f' % (i, l))
  96. # Testing
  97. # Generator takes noise as input
  98. noise_input = tf.placeholder(tf.float32, shape=[None, latent_dim])
  99. # Rebuild the decoder to create image from noise
  100. decoder = tf.matmul(noise_input, weights['decoder_h1']) + biases['decoder_b1']
  101. decoder = tf.nn.tanh(decoder)
  102. decoder = tf.matmul(decoder, weights['decoder_out']) + biases['decoder_out']
  103. decoder = tf.nn.sigmoid(decoder)
  104. # Building a manifold of generated digits
  105. n = 20
  106. x_axis = np.linspace(-3, 3, n)
  107. y_axis = np.linspace(-3, 3, n)
  108. canvas = np.empty((28 * n, 28 * n))
  109. for i, yi in enumerate(x_axis):
  110. for j, xi in enumerate(y_axis):
  111. z_mu = np.array([[xi, yi]] * batch_size)
  112. x_mean = sess.run(decoder, feed_dict={noise_input: z_mu})
  113. canvas[(n - i - 1) * 28:(n - i) * 28, j * 28:(j + 1) * 28] = \
  114. x_mean[0].reshape(28, 28)
  115. plt.figure(figsize=(8, 10))
  116. Xi, Yi = np.meshgrid(x_axis, y_axis)
  117. plt.imshow(canvas, origin="upper", cmap="gray")
  118. plt.show()