cgan_fashionmnist_tensorflow.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. import imageio
  2. import glob
  3. import os
  4. import time
  5. import cv2
  6. import tensorflow as tf
  7. from tensorflow.keras import layers
  8. from IPython import display
  9. import matplotlib.pyplot as plt
  10. import numpy as np
  11. from tensorflow.keras import backend as K
  12. from sklearn.manifold import TSNE
  13. import matplotlib.pyplot as plt
  14. from tensorflow import keras
  15. from matplotlib import pyplot
  16. from numpy import asarray
  17. from numpy.random import randn
  18. from numpy.random import randint
  19. from numpy import linspace
  20. from matplotlib import pyplot
  21. (x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
  22. x_train = x_train.reshape(x_train.shape[0], 28, 28, 1).astype('float32')
  23. x_test = x_test.astype('float32')
  24. x_train = (x_train / 127.5) - 1
  25. # Batch and shuffle the data
  26. train_dataset = tf.data.Dataset.from_tensor_slices((x_train,y_train)).\
  27. shuffle(60000).batch(128)
  28. plt.figure(figsize=(10, 10))
  29. for images,_ in train_dataset.take(1):
  30. for i in range(100):
  31. ax = plt.subplot(10, 10, i + 1)
  32. plt.imshow(images[i,:,:,0].numpy().astype("uint8"), cmap='gray')
  33. plt.axis("off")
  34. BATCH_SIZE=128
  35. latent_dim = 100
  36. # label input
  37. con_label = layers.Input(shape=(1,))
  38. # image generator input
  39. latent_vector = layers.Input(shape=(100,))
  40. def label_conditioned_gen(n_classes=10, embedding_dim=100):
  41. # embedding for categorical input
  42. label_embedding = layers.Embedding(n_classes, embedding_dim)(con_label)
  43. # linear multiplication
  44. n_nodes = 7 * 7
  45. label_dense = layers.Dense(n_nodes)(label_embedding)
  46. # reshape to additional channel
  47. label_reshape_layer = layers.Reshape((7, 7, 1))(label_dense)
  48. return label_reshape_layer
  49. def latent_gen(latent_dim=100):
  50. # image generator input
  51. in_lat = layers.Input(shape=(latent_dim,))
  52. n_nodes = 128 * 7 * 7
  53. latent_dense = layers.Dense(n_nodes)(latent_vector)
  54. latent_dense = layers.LeakyReLU(alpha=0.2)(latent_dense)
  55. latent_reshape = layers.Reshape((7, 7, 128))(latent_dense)
  56. return latent_reshape
  57. def con_generator():
  58. latent_vector_output = label_conditioned_gen()
  59. label_output = latent_gen()
  60. # merge image gen and label input
  61. merge = layers.Concatenate()([latent_vector_output, label_output])
  62. # upsample to 14x14
  63. x = layers.Conv2DTranspose(128, (4,4), strides=(2,2), padding='same')(merge)
  64. x = layers.LeakyReLU(alpha=0.2)(x)
  65. # upsample to 28x28
  66. x = layers.Conv2DTranspose(128, (4,4), strides=(2,2), padding='same')(x)
  67. x = layers.LeakyReLU(alpha=0.2)(x)
  68. # output
  69. out_layer = layers.Conv2D(1, (7,7), activation='tanh', padding='same')(x)
  70. # define model
  71. model = tf.keras.Model([latent_vector, con_label], out_layer)
  72. return model
  73. conditional_gen = con_generator()
  74. conditional_gen.summary()
  75. def label_condition_disc(in_shape=(28, 28, 1), n_classes=10, embedding_dim=100):
  76. # label input
  77. con_label = layers.Input(shape=(1,))
  78. # embedding for categorical input
  79. label_embedding = layers.Embedding(n_classes, embedding_dim)(con_label)
  80. # scale up to image dimensions with linear activation
  81. nodes = in_shape[0] * in_shape[1] * in_shape[2]
  82. label_dense = layers.Dense(nodes)(label_embedding)
  83. # reshape to additional channel
  84. label_reshape_layer = layers.Reshape((in_shape[0], in_shape[1], 1))(label_dense)
  85. # image input
  86. return con_label, label_reshape_layer
  87. def image_disc(in_shape=(28,28, 1)):
  88. inp_image = layers.Input(shape=in_shape)
  89. return inp_image
  90. def con_discriminator():
  91. con_label, label_condition_output = label_condition_disc()
  92. inp_image_output = image_disc()
  93. # concat label as a channel
  94. merge = layers.Concatenate()([inp_image_output, label_condition_output])
  95. # downsample
  96. x = layers.Conv2D(128, (3,3), strides=(2,2), padding='same')(merge)
  97. x = layers.LeakyReLU(alpha=0.2)(x)
  98. # downsample
  99. x = layers.Conv2D(128, (3,3), strides=(2,2), padding='same')(x)
  100. x = layers.LeakyReLU(alpha=0.2)(x)
  101. # flatten feature maps
  102. flattened_out = layers.Flatten()(x)
  103. # dropout
  104. dropout = layers.Dropout(0.4)(flattened_out)
  105. # output
  106. dense_out = layers.Dense(1, activation='sigmoid')(dropout)
  107. # define model
  108. model = tf.keras.Model([inp_image_output, con_label], dense_out)
  109. return model
  110. conditional_discriminator = con_discriminator()
  111. conditional_discriminator.summary()
  112. binary_cross_entropy = tf.keras.losses.BinaryCrossentropy()
  113. def generator_loss(label, fake_output):
  114. gen_loss = binary_cross_entropy(label, fake_output)
  115. #print(gen_loss)
  116. return gen_loss
  117. def discriminator_loss(label, output):
  118. disc_loss = binary_cross_entropy(label, output)
  119. #print(total_loss)
  120. return disc_loss
  121. learning_rate = 0.0002
  122. generator_optimizer = tf.keras.optimizers.Adam(lr = 0.0002, beta_1 = 0.5, beta_2 = 0.999 )
  123. discriminator_optimizer = tf.keras.optimizers.Adam(lr = 0.0002, beta_1 = 0.5, beta_2 = 0.999 )
  124. num_examples_to_generate = 25
  125. latent_dim = 100
  126. # We will reuse this seed overtime to visualize progress
  127. seed = tf.random.normal([num_examples_to_generate, latent_dim])
  128. print(seed.dtype)
  129. print(conditional_gen.input)
  130. # Notice the use of `tf.function`
  131. # This annotation causes the function to be "compiled".
  132. @tf.function
  133. def train_step(images,target):
  134. # noise vector sampled from normal distribution
  135. noise = tf.random.normal([target.shape[0], latent_dim])
  136. # Train Discriminator with real labels
  137. with tf.GradientTape() as disc_tape1:
  138. generated_images = conditional_gen([noise,target], training=True)
  139. real_output = conditional_discriminator([images,target], training=True)
  140. real_targets = tf.ones_like(real_output)
  141. disc_loss1 = discriminator_loss(real_targets, real_output)
  142. # gradient calculation for discriminator for real labels
  143. gradients_of_disc1 = disc_tape1.gradient(disc_loss1, conditional_discriminator.trainable_variables)
  144. # parameters optimization for discriminator for real labels
  145. discriminator_optimizer.apply_gradients(zip(gradients_of_disc1,\
  146. conditional_discriminator.trainable_variables))
  147. # Train Discriminator with fake labels
  148. with tf.GradientTape() as disc_tape2:
  149. fake_output = conditional_discriminator([generated_images,target], training=True)
  150. fake_targets = tf.zeros_like(fake_output)
  151. disc_loss2 = discriminator_loss(fake_targets, fake_output)
  152. # gradient calculation for discriminator for fake labels
  153. gradients_of_disc2 = disc_tape2.gradient(disc_loss2, conditional_discriminator.trainable_variables)
  154. # parameters optimization for discriminator for fake labels
  155. discriminator_optimizer.apply_gradients(zip(gradients_of_disc2,\
  156. conditional_discriminator.trainable_variables))
  157. # Train Generator with real labels
  158. with tf.GradientTape() as gen_tape:
  159. generated_images = conditional_gen([noise,target], training=True)
  160. fake_output = conditional_discriminator([generated_images,target], training=True)
  161. real_targets = tf.ones_like(fake_output)
  162. gen_loss = generator_loss(real_targets, fake_output)
  163. # gradient calculation for generator for real labels
  164. gradients_of_gen = gen_tape.gradient(gen_loss, conditional_gen.trainable_variables)
  165. # parameters optimization for generator for real labels
  166. generator_optimizer.apply_gradients(zip(gradients_of_gen,\
  167. conditional_gen.trainable_variables))
  168. def train(dataset, epochs):
  169. for epoch in range(epochs):
  170. start = time.time()
  171. i = 0
  172. D_loss_list, G_loss_list = [], []
  173. for image_batch,target in dataset:
  174. i += 1
  175. train_step(image_batch,target)
  176. print(epoch)
  177. display.clear_output(wait=True)
  178. generate_and_save_images(conditional_gen,
  179. epoch + 1,
  180. seed)
  181. # # Save the model every 15 epochs
  182. # if (epoch + 1) % 15 == 0:
  183. # checkpoint.save(file_prefix = checkpoint_prefix)
  184. conditional_gen.save_weights('fashion/training_weights/gen_'+ str(epoch)+'.h5')
  185. conditional_discriminator.save_weights('fashion/training_weights/disc_'+ str(epoch)+'.h5')
  186. print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
  187. # Generate after the final epoch
  188. display.clear_output(wait=True)
  189. generate_and_save_images(conditional_gen,
  190. epochs,
  191. seed)
  192. def label_gen(n_classes=10):
  193. lab = tf.random.uniform((1,), minval=0, maxval=10, dtype=tf.dtypes.int32, seed=None, name=None)
  194. return tf.repeat(lab, [25], axis=None, name=None)
  195. # Create dictionary of target classes
  196. label_dict = {
  197. 0: 'T-shirt/top',
  198. 1: 'Trouser',
  199. 2: 'Pullover',
  200. 3: 'Dress',
  201. 4: 'Coat',
  202. 5: 'Sandal',
  203. 6: 'Shirt',
  204. 7: 'Sneaker',
  205. 8: 'Bag',
  206. 9: 'Ankle boot',
  207. }
  208. def generate_and_save_images(model, epoch, test_input):
  209. # Notice `training` is set to False.
  210. # This is so all layers run in inference mode (batchnorm).
  211. labels = label_gen()
  212. predictions = model([test_input, labels], training=False)
  213. print(predictions.shape)
  214. fig = plt.figure(figsize=(4,4))
  215. print("Generated Images are Conditioned on Label:", label_dict[np.array(labels)[0]])
  216. for i in range(predictions.shape[0]):
  217. pred = (predictions[i, :, :, 0] + 1) * 127.5
  218. pred = np.array(pred)
  219. plt.subplot(5, 5, i+1)
  220. plt.imshow(pred.astype(np.uint8), cmap='gray')
  221. plt.axis('off')
  222. plt.savefig('fashion/images/image_at_epoch_{:d}.png'.format(epoch))
  223. plt.show()
  224. train(train_dataset, 2)
  225. conditional_gen.load_weights('fashion/training_weights/gen_1.h5')
  226. # example of interpolating between generated faces
  227. fig = plt.figure(figsize=(10,10))
  228. # generate points in latent space as input for the generator
  229. def generate_latent_points(latent_dim, n_samples, n_classes=10):
  230. # generate points in the latent space
  231. x_input = randn(latent_dim * n_samples)
  232. # reshape into a batch of inputs for the network
  233. z_input = x_input.reshape(n_samples, latent_dim)
  234. return z_input
  235. # uniform interpolation between two points in latent space
  236. def interpolate_points(p1, p2, n_steps=10):
  237. # interpolate ratios between the points
  238. ratios = linspace(0, 1, num=n_steps)
  239. # linear interpolate vectors
  240. vectors = list()
  241. for ratio in ratios:
  242. v = (1.0 - ratio) * p1 + ratio * p2
  243. vectors.append(v)
  244. return asarray(vectors)
  245. # load model
  246. pts = generate_latent_points(100, 2)
  247. # interpolate points in latent space
  248. interpolated = interpolate_points(pts[0], pts[1])
  249. # generate images
  250. from matplotlib import gridspec
  251. output = None
  252. for label in range(10):
  253. labels = tf.ones(10) * label
  254. predictions = conditional_gen([interpolated, labels], training=False)
  255. if output is None:
  256. output = predictions
  257. else:
  258. output = np.concatenate((output,predictions))
  259. k = 0
  260. nrow = 10
  261. ncol = 10
  262. fig = plt.figure(figsize=(15,15))
  263. gs = gridspec.GridSpec(nrow, ncol, width_ratios=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
  264. wspace=0.0, hspace=0.0, top=0.95, bottom=0.05, left=0.17, right=0.845)
  265. for i in range(10):
  266. for j in range(10):
  267. pred = (output[k, :, :, :] + 1 ) * 127.5
  268. ax= plt.subplot(gs[i,j])
  269. pred = np.array(pred)
  270. ax.imshow(pred.astype(np.uint8), cmap='gray')
  271. ax.set_xticklabels([])
  272. ax.set_yticklabels([])
  273. ax.axis('off')
  274. k += 1
  275. plt.savefig('result.png', dpi=300)
  276. plt.show()
  277. print(pred.shape)