cnn_fmnist.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. # %%
  4. import argparse
  5. import tensorflow as tf
  6. import horovod.tensorflow.keras as hvd
  7. import sys
  8. import time
  9. def parse_args():
  10. parser = argparse.ArgumentParser()
  11. parser.add_argument("--batch-size", type=int, default=256, help="Batch size")
  12. args = parser.parse_args()
  13. return args
  14. args = parse_args()
  15. global g_args
  16. g_args = args
  17. batch_size = args.batch_size
  18. # Horovod: initialize Horovod.
  19. hvd.init()
  20. # Horovod: pin GPU to be used to process local rank (one GPU per process)
  21. gpus = tf.config.experimental.list_physical_devices('GPU')
  22. for gpu in gpus:
  23. tf.config.experimental.set_memory_growth(gpu, True)
  24. if gpus:
  25. tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU')
  26. (mnist_images, mnist_labels), _ = tf.keras.datasets.mnist.load_data(path='mnist-%d.npz' % hvd.rank())
  27. dataset = tf.data.Dataset.from_tensor_slices(
  28. (tf.cast(mnist_images[..., tf.newaxis] / 255.0, tf.float32),
  29. tf.cast(mnist_labels, tf.int64))
  30. )
  31. dataset = dataset.repeat().shuffle(10000).batch(batch_size)
  32. mnist_model = tf.keras.Sequential([
  33. tf.keras.layers.Conv2D(32, [3, 3], activation='relu'),
  34. tf.keras.layers.Conv2D(64, [3, 3], activation='relu'),
  35. tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
  36. tf.keras.layers.Dropout(0.25),
  37. tf.keras.layers.Flatten(),
  38. tf.keras.layers.Dense(128, activation='relu'),
  39. tf.keras.layers.Dropout(0.5),
  40. tf.keras.layers.Dense(10, activation='softmax')
  41. ])
  42. # Horovod: adjust learning rate based on number of GPUs.
  43. opt = tf.optimizers.Adam(0.001)
  44. # Horovod: add Horovod DistributedOptimizer.
  45. opt = hvd.DistributedOptimizer(opt, backward_passes_per_step=1, average_aggregated_gradients=True)
  46. # Horovod: Specify `experimental_run_tf_function=False` to ensure TensorFlow
  47. # uses hvd.DistributedOptimizer() to compute gradients.
  48. mnist_model.compile(loss=tf.losses.SparseCategoricalCrossentropy(),
  49. optimizer=opt,
  50. metrics=['accuracy'],
  51. experimental_run_tf_function=False)
  52. class PrintLR(tf.keras.callbacks.Callback):
  53. def __init__(self, total_images=0):
  54. self.total_images = total_images
  55. def on_train_begin(self, epoch, logs=None):
  56. global seconds1
  57. seconds1 = time.time()
  58. def on_epoch_begin(self, epoch, logs=None):
  59. self.epoch_start_time = time.time()
  60. def on_epoch_end(self, epoch, logs=None):
  61. if hvd.rank() == 0 :
  62. epoch_time = time.time() - self.epoch_start_time
  63. print('Epoch time : {}'.format(epoch_time))
  64. images_per_sec = round(self.total_images / epoch_time, 2)
  65. print('Images/sec: {}'.format(images_per_sec))
  66. callbacks = [
  67. # Horovod: broadcast initial variable states from rank 0 to all other processes.
  68. # This is necessary to ensure consistent initialization of all workers when
  69. # training is started with random weights or restored from a checkpoint.
  70. hvd.callbacks.BroadcastGlobalVariablesCallback(0),
  71. # Horovod: average metrics among workers at the end of every epoch.
  72. #
  73. # Note: This callback must be in the list before the ReduceLROnPlateau,
  74. # TensorBoard or other metrics-based callbacks.
  75. hvd.callbacks.MetricAverageCallback(),
  76. #Throughput calculator
  77. PrintLR(total_images=len(mnist_labels)),
  78. ]
  79. # Horovod: write logs on worker 0.
  80. verbose = 2 if hvd.rank() == 0 else 0
  81. # Train the model.
  82. # Horovod: adjust number of steps based on number of GPUs.
  83. mnist_model.fit(dataset, steps_per_epoch=len(mnist_labels) // (batch_size*hvd.size()), callbacks=callbacks, epochs=4, verbose=verbose)