gan.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. """ Generative Adversarial Networks (GAN).
  2. Using generative adversarial networks (GAN) to generate digit images from a
  3. noise distribution.
  4. References:
  5. - Generative adversarial nets. I Goodfellow, J Pouget-Abadie, M Mirza,
  6. B Xu, D Warde-Farley, S Ozair, Y. Bengio. Advances in neural information
  7. processing systems, 2672-2680.
  8. - Understanding the difficulty of training deep feedforward neural networks.
  9. X Glorot, Y Bengio. Aistats 9, 249-256
  10. Links:
  11. - [GAN Paper](https://arxiv.org/pdf/1406.2661.pdf).
  12. - [MNIST Dataset](http://yann.lecun.com/exdb/mnist/).
  13. - [Xavier Glorot Init](www.cs.cmu.edu/~bhiksha/courses/deeplearning/Fall.../AISTATS2010_Glorot.pdf).
  14. Author: Aymeric Damien
  15. Project: https://github.com/aymericdamien/TensorFlow-Examples/
  16. """
  17. from __future__ import division, print_function, absolute_import
  18. import matplotlib.pyplot as plt
  19. import numpy as np
  20. import tensorflow as tf
  21. # Import MNIST data
  22. from tensorflow.examples.tutorials.mnist import input_data
  23. mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
  24. # Training Params
  25. num_steps = 100000
  26. batch_size = 128
  27. learning_rate = 0.0002
  28. # Network Params
  29. image_dim = 784 # 28*28 pixels
  30. gen_hidden_dim = 256
  31. disc_hidden_dim = 256
  32. noise_dim = 100 # Noise data points
  33. # A custom initialization (see Xavier Glorot init)
  34. def glorot_init(shape):
  35. return tf.random_normal(shape=shape, stddev=1. / tf.sqrt(shape[0] / 2.))
  36. # Store layers weight & bias
  37. weights = {
  38. 'gen_hidden1': tf.Variable(glorot_init([noise_dim, gen_hidden_dim])),
  39. 'gen_out': tf.Variable(glorot_init([gen_hidden_dim, image_dim])),
  40. 'disc_hidden1': tf.Variable(glorot_init([image_dim, disc_hidden_dim])),
  41. 'disc_out': tf.Variable(glorot_init([disc_hidden_dim, 1])),
  42. }
  43. biases = {
  44. 'gen_hidden1': tf.Variable(tf.zeros([gen_hidden_dim])),
  45. 'gen_out': tf.Variable(tf.zeros([image_dim])),
  46. 'disc_hidden1': tf.Variable(tf.zeros([disc_hidden_dim])),
  47. 'disc_out': tf.Variable(tf.zeros([1])),
  48. }
  49. # Generator
  50. def generator(x):
  51. hidden_layer = tf.matmul(x, weights['gen_hidden1'])
  52. hidden_layer = tf.add(hidden_layer, biases['gen_hidden1'])
  53. hidden_layer = tf.nn.relu(hidden_layer)
  54. out_layer = tf.matmul(hidden_layer, weights['gen_out'])
  55. out_layer = tf.add(out_layer, biases['gen_out'])
  56. out_layer = tf.nn.sigmoid(out_layer)
  57. return out_layer
  58. # Discriminator
  59. def discriminator(x):
  60. hidden_layer = tf.matmul(x, weights['disc_hidden1'])
  61. hidden_layer = tf.add(hidden_layer, biases['disc_hidden1'])
  62. hidden_layer = tf.nn.relu(hidden_layer)
  63. out_layer = tf.matmul(hidden_layer, weights['disc_out'])
  64. out_layer = tf.add(out_layer, biases['disc_out'])
  65. out_layer = tf.nn.sigmoid(out_layer)
  66. return out_layer
  67. # Build Networks
  68. # Network Inputs
  69. gen_input = tf.placeholder(tf.float32, shape=[None, noise_dim], name='input_noise')
  70. disc_input = tf.placeholder(tf.float32, shape=[None, image_dim], name='disc_input')
  71. # Build Generator Network
  72. gen_sample = generator(gen_input)
  73. # Build 2 Discriminator Networks (one from noise input, one from generated samples)
  74. disc_real = discriminator(disc_input)
  75. disc_fake = discriminator(gen_sample)
  76. # Build Loss
  77. gen_loss = -tf.reduce_mean(tf.log(disc_fake))
  78. disc_loss = -tf.reduce_mean(tf.log(disc_real) + tf.log(1. - disc_fake))
  79. # Build Optimizers
  80. optimizer_gen = tf.train.AdamOptimizer(learning_rate=learning_rate)
  81. optimizer_disc = tf.train.AdamOptimizer(learning_rate=learning_rate)
  82. # Training Variables for each optimizer
  83. # By default in TensorFlow, all variables are updated by each optimizer, so we
  84. # need to precise for each one of them the specific variables to update.
  85. # Generator Network Variables
  86. gen_vars = [weights['gen_hidden1'], weights['gen_out'],
  87. biases['gen_hidden1'], biases['gen_out']]
  88. # Discriminator Network Variables
  89. disc_vars = [weights['disc_hidden1'], weights['disc_out'],
  90. biases['disc_hidden1'], biases['disc_out']]
  91. # Create training operations
  92. train_gen = optimizer_gen.minimize(gen_loss, var_list=gen_vars)
  93. train_disc = optimizer_disc.minimize(disc_loss, var_list=disc_vars)
  94. # Initialize the variables (i.e. assign their default value)
  95. init = tf.global_variables_initializer()
  96. # Start training
  97. with tf.Session() as sess:
  98. # Run the initializer
  99. sess.run(init)
  100. for i in range(1, num_steps+1):
  101. # Prepare Data
  102. # Get the next batch of MNIST data (only images are needed, not labels)
  103. batch_x, _ = mnist.train.next_batch(batch_size)
  104. # Generate noise to feed to the generator
  105. z = np.random.uniform(-1., 1., size=[batch_size, noise_dim])
  106. # Train
  107. feed_dict = {disc_input: batch_x, gen_input: z}
  108. _, _, gl, dl = sess.run([train_gen, train_disc, gen_loss, disc_loss],
  109. feed_dict=feed_dict)
  110. if i % 1000 == 0 or i == 1:
  111. print('Step %i: Generator Loss: %f, Discriminator Loss: %f' % (i, gl, dl))
  112. # Generate images from noise, using the generator network.
  113. f, a = plt.subplots(4, 10, figsize=(10, 4))
  114. for i in range(10):
  115. # Noise input.
  116. z = np.random.uniform(-1., 1., size=[4, noise_dim])
  117. g = sess.run([gen_sample], feed_dict={gen_input: z})
  118. g = np.reshape(g, newshape=(4, 28, 28, 1))
  119. # Reverse colours for better display
  120. g = -1 * (g - 1)
  121. for j in range(4):
  122. # Generate image from noise. Extend to 3 channels for matplot figure.
  123. img = np.reshape(np.repeat(g[j][:, :, np.newaxis], 3, axis=2),
  124. newshape=(28, 28, 3))
  125. a[j][i].imshow(img)
  126. f.show()
  127. plt.draw()
  128. plt.waitforbuttonpress()