{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Deep Convolutional Generative Adversarial Network Example\n", "\n", "Build a deep convolutional generative adversarial network (DCGAN) to generate digit images from a noise distribution with TensorFlow v2.\n", "\n", "- Author: Aymeric Damien\n", "- Project: https://github.com/aymericdamien/TensorFlow-Examples/" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## DCGAN Overview\n", "\n", "\"dcgan\"\n", "\n", "References:\n", "- [Unsupervised representation learning with deep convolutional generative adversarial networks](https://arxiv.org/pdf/1511.06434). A Radford, L Metz, S Chintala, 2016.\n", "- [Understanding the difficulty of training deep feedforward neural networks](http://proceedings.mlr.press/v9/glorot10a.html). X Glorot, Y Bengio. Aistats 9, 249-256\n", "- [Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift](https://arxiv.org/abs/1502.03167). Sergey Ioffe, Christian Szegedy. 2015.\n", "\n", "## MNIST Dataset Overview\n", "\n", "This example is using MNIST handwritten digits. The dataset contains 60,000 examples for training and 10,000 examples for testing. The digits have been size-normalized and centered in a fixed-size image (28x28 pixels) with values from 0 to 255. \n", "\n", "In this example, each image will be converted to float32 and normalized from [0, 255] to [0, 1].\n", "\n", "![MNIST Dataset](http://neuralnetworksanddeeplearning.com/images/mnist_100_digits.png)\n", "\n", "More info: http://yann.lecun.com/exdb/mnist/" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "from __future__ import absolute_import, division, print_function\n", "\n", "import tensorflow as tf\n", "from tensorflow.keras import Model, layers\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "# MNIST Dataset parameters.\n", "num_features = 784 # data features (img shape: 28*28).\n", "\n", "# Training parameters.\n", "lr_generator = 0.0002\n", "lr_discriminator = 0.0002\n", "training_steps = 20000\n", "batch_size = 128\n", "display_step = 500\n", "\n", "# Network parameters.\n", "noise_dim = 100 # Noise data points." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "# Prepare MNIST data.\n", "from tensorflow.keras.datasets import mnist\n", "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n", "# Convert to float32.\n", "x_train, x_test = np.array(x_train, np.float32), np.array(x_test, np.float32)\n", "# Normalize images value from [0, 255] to [0, 1].\n", "x_train, x_test = x_train / 255., x_test / 255." ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "# Use tf.data API to shuffle and batch data.\n", "train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n", "train_data = train_data.repeat().shuffle(10000).batch(batch_size).prefetch(1)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# Create TF Model.\n", "class Generator(Model):\n", " # Set layers.\n", " def __init__(self):\n", " super(Generator, self).__init__()\n", " self.fc1 = layers.Dense(7 * 7 * 128)\n", " self.bn1 = layers.BatchNormalization()\n", " self.conv2tr1 = layers.Conv2DTranspose(64, 5, strides=2, padding='SAME')\n", " self.bn2 = layers.BatchNormalization()\n", " self.conv2tr2 = layers.Conv2DTranspose(1, 5, strides=2, padding='SAME')\n", "\n", " # Set forward pass.\n", " def call(self, x, is_training=False):\n", " x = self.fc1(x)\n", " x = self.bn1(x, training=is_training)\n", " x = tf.nn.leaky_relu(x)\n", " # Reshape to a 4-D array of images: (batch, height, width, channels)\n", " # New shape: (batch, 7, 7, 128)\n", " x = tf.reshape(x, shape=[-1, 7, 7, 128])\n", " # Deconvolution, image shape: (batch, 14, 14, 64)\n", " x = self.conv2tr1(x)\n", " x = self.bn2(x, training=is_training)\n", " x = tf.nn.leaky_relu(x)\n", " # Deconvolution, image shape: (batch, 28, 28, 1)\n", " x = self.conv2tr2(x)\n", " x = tf.nn.tanh(x)\n", " return x\n", "\n", "# Generator Network\n", "# Input: Noise, Output: Image\n", "# Note that batch normalization has different behavior at training and inference time,\n", "# we then use a placeholder to indicates the layer if we are training or not.\n", "class Discriminator(Model):\n", " # Set layers.\n", " def __init__(self):\n", " super(Discriminator, self).__init__()\n", " self.conv1 = layers.Conv2D(64, 5, strides=2, padding='SAME')\n", " self.bn1 = layers.BatchNormalization()\n", " self.conv2 = layers.Conv2D(128, 5, strides=2, padding='SAME')\n", " self.bn2 = layers.BatchNormalization()\n", " self.flatten = layers.Flatten()\n", " self.fc1 = layers.Dense(1024)\n", " self.bn3 = layers.BatchNormalization()\n", " self.fc2 = layers.Dense(2)\n", "\n", " # Set forward pass.\n", " def call(self, x, is_training=False):\n", " x = tf.reshape(x, [-1, 28, 28, 1])\n", " x = self.conv1(x)\n", " x = self.bn1(x, training=is_training)\n", " x = tf.nn.leaky_relu(x)\n", " x = self.conv2(x)\n", " x = self.bn2(x, training=is_training)\n", " x = tf.nn.leaky_relu(x)\n", " x = self.flatten(x)\n", " x = self.fc1(x)\n", " x = self.bn3(x, training=is_training)\n", " x = tf.nn.leaky_relu(x)\n", " return self.fc2(x)\n", "\n", "# Build neural network model.\n", "generator = Generator()\n", "discriminator = Discriminator()" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "# Losses.\n", "def generator_loss(reconstructed_image):\n", " gen_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(\n", " logits=reconstructed_image, labels=tf.ones([batch_size], dtype=tf.int32)))\n", " return gen_loss\n", "\n", "def discriminator_loss(disc_fake, disc_real):\n", " disc_loss_real = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(\n", " logits=disc_real, labels=tf.ones([batch_size], dtype=tf.int32)))\n", " disc_loss_fake = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(\n", " logits=disc_fake, labels=tf.zeros([batch_size], dtype=tf.int32)))\n", " return disc_loss_real + disc_loss_fake\n", "\n", "# Optimizers.\n", "optimizer_gen = tf.optimizers.Adam(learning_rate=lr_generator)#, beta_1=0.5, beta_2=0.999)\n", "optimizer_disc = tf.optimizers.Adam(learning_rate=lr_discriminator)#, beta_1=0.5, beta_2=0.999)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# Optimization process. Inputs: real image and noise.\n", "def run_optimization(real_images):\n", " \n", " # Rescale to [-1, 1], the input range of the discriminator\n", " real_images = real_images * 2. - 1.\n", "\n", " # Generate noise.\n", " noise = np.random.normal(-1., 1., size=[batch_size, noise_dim]).astype(np.float32)\n", " \n", " with tf.GradientTape() as g:\n", " \n", " fake_images = generator(noise, is_training=True)\n", " disc_fake = discriminator(fake_images, is_training=True)\n", " disc_real = discriminator(real_images, is_training=True)\n", "\n", " disc_loss = discriminator_loss(disc_fake, disc_real)\n", " \n", " # Training Variables for each optimizer\n", " gradients_disc = g.gradient(disc_loss, discriminator.trainable_variables)\n", " optimizer_disc.apply_gradients(zip(gradients_disc, discriminator.trainable_variables))\n", " \n", " # Generate noise.\n", " noise = np.random.normal(-1., 1., size=[batch_size, noise_dim]).astype(np.float32)\n", " \n", " with tf.GradientTape() as g:\n", " \n", " fake_images = generator(noise, is_training=True)\n", " disc_fake = discriminator(fake_images, is_training=True)\n", "\n", " gen_loss = generator_loss(disc_fake)\n", " \n", " gradients_gen = g.gradient(gen_loss, generator.trainable_variables)\n", " optimizer_gen.apply_gradients(zip(gradients_gen, generator.trainable_variables))\n", " \n", " return gen_loss, disc_loss" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "initial: gen_loss: 0.694535, disc_loss: 1.403878\n", "step: 500, gen_loss: 2.154825, disc_loss: 0.429040\n", "step: 1000, gen_loss: 2.103464, disc_loss: 0.502638\n", "step: 1500, gen_loss: 2.282369, disc_loss: 0.572065\n", "step: 2000, gen_loss: 2.711531, disc_loss: 0.341791\n", "step: 2500, gen_loss: 2.835576, disc_loss: 0.322356\n", "step: 3000, gen_loss: 2.970127, disc_loss: 0.275483\n", "step: 3500, gen_loss: 2.836321, disc_loss: 0.190680\n", "step: 4000, gen_loss: 2.892855, disc_loss: 0.337681\n", "step: 4500, gen_loss: 2.948962, disc_loss: 0.329322\n", "step: 5000, gen_loss: 3.799061, disc_loss: 0.327288\n", "step: 5500, gen_loss: 4.090328, disc_loss: 0.274685\n", "step: 6000, gen_loss: 4.343777, disc_loss: 0.155223\n", "step: 6500, gen_loss: 3.806556, disc_loss: 0.155855\n", "step: 7000, gen_loss: 4.827947, disc_loss: 0.078372\n", "step: 7500, gen_loss: 3.708949, disc_loss: 0.140979\n", "step: 8000, gen_loss: 5.250406, disc_loss: 0.164736\n", "step: 8500, gen_loss: 5.491106, disc_loss: 0.110080\n", "step: 9000, gen_loss: 4.391072, disc_loss: 0.100240\n", "step: 9500, gen_loss: 5.074200, disc_loss: 0.105567\n", "step: 10000, gen_loss: 6.077592, disc_loss: 0.215981\n", "step: 10500, gen_loss: 4.468120, disc_loss: 0.099412\n", "step: 11000, gen_loss: 5.887744, disc_loss: 0.093534\n", "step: 11500, gen_loss: 5.656942, disc_loss: 0.075079\n", "step: 12000, gen_loss: 4.752551, disc_loss: 0.092505\n", "step: 12500, gen_loss: 5.682284, disc_loss: 0.077406\n", "step: 13000, gen_loss: 5.386811, disc_loss: 0.101259\n", "step: 13500, gen_loss: 4.646158, disc_loss: 0.039680\n", "step: 14000, gen_loss: 5.698145, disc_loss: 0.083461\n", "step: 14500, gen_loss: 4.513465, disc_loss: 0.077946\n", "step: 15000, gen_loss: 5.586041, disc_loss: 0.039668\n", "step: 15500, gen_loss: 6.316034, disc_loss: 0.084526\n", "step: 16000, gen_loss: 7.120735, disc_loss: 0.070267\n", "step: 16500, gen_loss: 5.511638, disc_loss: 0.053082\n", "step: 17000, gen_loss: 8.176781, disc_loss: 0.029212\n", "step: 17500, gen_loss: 6.144678, disc_loss: 0.140460\n", "step: 18000, gen_loss: 5.583275, disc_loss: 0.077311\n", "step: 18500, gen_loss: 5.840647, disc_loss: 0.042103\n", "step: 19000, gen_loss: 7.111396, disc_loss: 0.030435\n", "step: 19500, gen_loss: 4.391251, disc_loss: 0.043656\n", "step: 20000, gen_loss: 6.271616, disc_loss: 0.040568\n" ] } ], "source": [ "# Run training for the given number of steps.\n", "for step, (batch_x, _) in enumerate(train_data.take(training_steps + 1)):\n", " \n", " if step == 0:\n", " # Generate noise.\n", " noise = np.random.normal(-1., 1., size=[batch_size, noise_dim]).astype(np.float32)\n", " gen_loss = generator_loss(discriminator(generator(noise)))\n", " disc_loss = discriminator_loss(discriminator(batch_x), discriminator(generator(noise)))\n", " print(\"initial: gen_loss: %f, disc_loss: %f\" % (gen_loss, disc_loss))\n", " continue\n", " \n", " # Run the optimization.\n", " gen_loss, disc_loss = run_optimization(batch_x)\n", " \n", " if step % display_step == 0:\n", " print(\"step: %i, gen_loss: %f, disc_loss: %f\" % (step, gen_loss, disc_loss))" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "# Visualize predictions.\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Testing\n", "# Generate images from noise, using the generator network.\n", "n = 6\n", "canvas = np.empty((28 * n, 28 * n))\n", "for i in range(n):\n", " # Noise input.\n", " z = np.random.normal(-1., 1., size=[n, noise_dim]).astype(np.float32)\n", " # Generate image from noise.\n", " g = generator(z).numpy()\n", " # Rescale to original [0, 1]\n", " g = (g + 1.) / 2\n", " # Reverse colours for better display\n", " g = -1 * (g - 1)\n", " for j in range(n):\n", " # Draw the generated digits\n", " canvas[i * 28:(i + 1) * 28, j * 28:(j + 1) * 28] = g[j].reshape([28, 28])\n", "\n", "plt.figure(figsize=(n, n))\n", "plt.imshow(canvas, origin=\"upper\", cmap=\"gray\")\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.15" } }, "nbformat": 4, "nbformat_minor": 2 }