dcgan.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. """ Deep Convolutional Generative Adversarial Network (DCGAN).
  2. Using deep convolutional generative adversarial networks (DCGAN) to generate
  3. digit images from a noise distribution.
  4. References:
  5. - Unsupervised representation learning with deep convolutional generative
  6. adversarial networks. A Radford, L Metz, S Chintala. arXiv:1511.06434.
  7. Links:
  8. - [DCGAN Paper](https://arxiv.org/abs/1511.06434).
  9. - [MNIST Dataset](http://yann.lecun.com/exdb/mnist/).
  10. Author: Aymeric Damien
  11. Project: https://github.com/aymericdamien/TensorFlow-Examples/
  12. """
  13. from __future__ import division, print_function, absolute_import
  14. import matplotlib.pyplot as plt
  15. import numpy as np
  16. import tensorflow as tf
  17. # Import MNIST data
  18. from tensorflow.examples.tutorials.mnist import input_data
  19. mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
  20. # Training Params
  21. num_steps = 20000
  22. batch_size = 32
  23. # Network Params
  24. image_dim = 784 # 28*28 pixels * 1 channel
  25. gen_hidden_dim = 256
  26. disc_hidden_dim = 256
  27. noise_dim = 200 # Noise data points
  28. # Generator Network
  29. # Input: Noise, Output: Image
  30. def generator(x, reuse=False):
  31. with tf.variable_scope('Generator', reuse=reuse):
  32. # TensorFlow Layers automatically create variables and calculate their
  33. # shape, based on the input.
  34. x = tf.layers.dense(x, units=6 * 6 * 128)
  35. x = tf.nn.tanh(x)
  36. # Reshape to a 4-D array of images: (batch, height, width, channels)
  37. # New shape: (batch, 6, 6, 128)
  38. x = tf.reshape(x, shape=[-1, 6, 6, 128])
  39. # Deconvolution, image shape: (batch, 14, 14, 64)
  40. x = tf.layers.conv2d_transpose(x, 64, 4, strides=2)
  41. # Deconvolution, image shape: (batch, 28, 28, 1)
  42. x = tf.layers.conv2d_transpose(x, 1, 2, strides=2)
  43. # Apply sigmoid to clip values between 0 and 1
  44. x = tf.nn.sigmoid(x)
  45. return x
  46. # Discriminator Network
  47. # Input: Image, Output: Prediction Real/Fake Image
  48. def discriminator(x, reuse=False):
  49. with tf.variable_scope('Discriminator', reuse=reuse):
  50. # Typical convolutional neural network to classify images.
  51. x = tf.layers.conv2d(x, 64, 5)
  52. x = tf.nn.tanh(x)
  53. x = tf.layers.average_pooling2d(x, 2, 2)
  54. x = tf.layers.conv2d(x, 128, 5)
  55. x = tf.nn.tanh(x)
  56. x = tf.layers.average_pooling2d(x, 2, 2)
  57. x = tf.contrib.layers.flatten(x)
  58. x = tf.layers.dense(x, 1024)
  59. x = tf.nn.tanh(x)
  60. # Output 2 classes: Real and Fake images
  61. x = tf.layers.dense(x, 2)
  62. return x
  63. # Build Networks
  64. # Network Inputs
  65. noise_input = tf.placeholder(tf.float32, shape=[None, noise_dim])
  66. real_image_input = tf.placeholder(tf.float32, shape=[None, 28, 28, 1])
  67. # Build Generator Network
  68. gen_sample = generator(noise_input)
  69. # Build 2 Discriminator Networks (one from real image input, one from generated samples)
  70. disc_real = discriminator(real_image_input)
  71. disc_fake = discriminator(gen_sample, reuse=True)
  72. disc_concat = tf.concat([disc_real, disc_fake], axis=0)
  73. # Build the stacked generator/discriminator
  74. stacked_gan = discriminator(gen_sample, reuse=True)
  75. # Build Targets (real or fake images)
  76. disc_target = tf.placeholder(tf.int32, shape=[None])
  77. gen_target = tf.placeholder(tf.int32, shape=[None])
  78. # Build Loss
  79. disc_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
  80. logits=disc_concat, labels=disc_target))
  81. gen_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
  82. logits=stacked_gan, labels=gen_target))
  83. # Build Optimizers
  84. optimizer_gen = tf.train.AdamOptimizer(learning_rate=0.001)
  85. optimizer_disc = tf.train.AdamOptimizer(learning_rate=0.001)
  86. # Training Variables for each optimizer
  87. # By default in TensorFlow, all variables are updated by each optimizer, so we
  88. # need to precise for each one of them the specific variables to update.
  89. # Generator Network Variables
  90. gen_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Generator')
  91. # Discriminator Network Variables
  92. disc_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Discriminator')
  93. # Create training operations
  94. train_gen = optimizer_gen.minimize(gen_loss, var_list=gen_vars)
  95. train_disc = optimizer_disc.minimize(disc_loss, var_list=disc_vars)
  96. # Initialize the variables (i.e. assign their default value)
  97. init = tf.global_variables_initializer()
  98. # Start training
  99. with tf.Session() as sess:
  100. # Run the initializer
  101. sess.run(init)
  102. for i in range(1, num_steps+1):
  103. # Prepare Input Data
  104. # Get the next batch of MNIST data (only images are needed, not labels)
  105. batch_x, _ = mnist.train.next_batch(batch_size)
  106. batch_x = np.reshape(batch_x, newshape=[-1, 28, 28, 1])
  107. # Generate noise to feed to the generator
  108. z = np.random.uniform(-1., 1., size=[batch_size, noise_dim])
  109. # Prepare Targets (Real image: 1, Fake image: 0)
  110. # The first half of data fed to the discriminator are real images,
  111. # the other half are fake images (coming from the generator).
  112. batch_disc_y = np.concatenate(
  113. [np.ones([batch_size]), np.zeros([batch_size])], axis=0)
  114. # Generator tries to fool the discriminator, thus targets are 1.
  115. batch_gen_y = np.ones([batch_size])
  116. # Training
  117. feed_dict = {real_image_input: batch_x, noise_input: z,
  118. disc_target: batch_disc_y, gen_target: batch_gen_y}
  119. _, _, gl, dl = sess.run([train_gen, train_disc, gen_loss, disc_loss],
  120. feed_dict=feed_dict)
  121. if i % 100 == 0 or i == 1:
  122. print('Step %i: Generator Loss: %f, Discriminator Loss: %f' % (i, gl, dl))
  123. # Generate images from noise, using the generator network.
  124. f, a = plt.subplots(4, 10, figsize=(10, 4))
  125. for i in range(10):
  126. # Noise input.
  127. z = np.random.uniform(-1., 1., size=[4, noise_dim])
  128. g = sess.run(gen_sample, feed_dict={noise_input: z})
  129. for j in range(4):
  130. # Generate image from noise. Extend to 3 channels for matplot figure.
  131. img = np.reshape(np.repeat(g[j][:, :, np.newaxis], 3, axis=2),
  132. newshape=(28, 28, 3))
  133. a[j][i].imshow(img)
  134. f.show()
  135. plt.draw()
  136. plt.waitforbuttonpress()