cifar_lamb.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. # %%
  4. import argparse
  5. import tensorflow as tf
  6. from tensorflow.keras.datasets import cifar10
  7. from tensorflow.keras.layers import Input, Add, Dense, Activation, ZeroPadding2D, BatchNormalization, Flatten, Conv2D, AveragePooling2D, MaxPooling2D, GlobalMaxPooling2D
  8. from tensorflow.keras.models import Model, load_model
  9. from tensorflow.keras.preprocessing import image
  10. from tensorflow.keras.applications.imagenet_utils import preprocess_input
  11. from tensorflow.keras import backend as K
  12. from tensorflow.keras.initializers import glorot_uniform
  13. import horovod.tensorflow.keras as hvd
  14. import sys
  15. import time
  16. def parse_args():
  17. parser = argparse.ArgumentParser()
  18. parser.add_argument("--batch-size", type=int, default=256, help="Batch size")
  19. args = parser.parse_args()
  20. return args
  21. args = parse_args()
  22. global g_args
  23. g_args = args
  24. batch_size = args.batch_size
  25. # Horovod: initialize Horovod.
  26. hvd.init()
  27. # Horovod: pin GPU to be used to process local rank (one GPU per process)
  28. gpus = tf.config.experimental.list_physical_devices('GPU')
  29. for gpu in gpus:
  30. tf.config.experimental.set_memory_growth(gpu, True)
  31. if gpus:
  32. tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU')
  33. (images, labels), _ = tf.keras.datasets.cifar10.load_data()
  34. dataset = tf.data.Dataset.from_tensor_slices(
  35. (tf.cast(images[...] / 255.0, tf.float32),
  36. tf.cast(labels, tf.int64))
  37. )
  38. dataset = dataset.repeat().shuffle(10000).batch(batch_size)
  39. def convolutional_block(X, f, filters, stage, block, s=2):
  40. # Defining name basis
  41. conv_name_base = 'res' + str(stage) + block + '_branch'
  42. bn_name_base = 'bn' + str(stage) + block + '_branch'
  43. # Retrieve Filters
  44. F1, F2, F3 = filters
  45. # Save the input value
  46. X_shortcut = X
  47. ##### MAIN PATH #####
  48. # First component of main path
  49. X = Conv2D(filters=F1, kernel_size=(1, 1), strides=(s, s), padding='valid', name=conv_name_base + '2a', kernel_initializer=glorot_uniform(seed=0))(X)
  50. X = BatchNormalization(axis=3, name=bn_name_base + '2a')(X)
  51. X = Activation('relu')(X)
  52. # Second component of main path
  53. X = Conv2D(filters=F2, kernel_size=(f, f), strides=(1, 1), padding='same', name=conv_name_base + '2b', kernel_initializer=glorot_uniform(seed=0))(X)
  54. X = BatchNormalization(axis=3, name=bn_name_base + '2b')(X)
  55. X = Activation('relu')(X)
  56. # Third component of main path
  57. X = Conv2D(filters=F3, kernel_size=(1, 1), strides=(1, 1), padding='valid', name=conv_name_base + '2c', kernel_initializer=glorot_uniform(seed=0))(X)
  58. X = BatchNormalization(axis=3, name=bn_name_base + '2c')(X)
  59. ##### SHORTCUT PATH ####
  60. X_shortcut = Conv2D(filters=F3, kernel_size=(1, 1), strides=(s, s), padding='valid', name=conv_name_base + '1', kernel_initializer=glorot_uniform(seed=0))(X_shortcut)
  61. X_shortcut = BatchNormalization(axis=3, name=bn_name_base + '1')(X_shortcut)
  62. # Final step: Add shortcut value to main path, and pass it through a RELU activation
  63. X = Add()([X, X_shortcut])
  64. X = Activation('relu')(X)
  65. return X
  66. def ResNet(input_shape = (28, 28, 1), classes = 10):
  67. # Define the input as a tensor with shape input_shape
  68. X_input = Input(shape=input_shape)
  69. # Zero-Padding
  70. X = ZeroPadding2D((3, 3))(X_input)
  71. # Stage 1
  72. X = Conv2D(64, (7, 7), strides = (2, 2), name = 'conv1', kernel_initializer = glorot_uniform(seed=0))(X)
  73. X = BatchNormalization(axis = 3, name = 'bn_conv1')(X)
  74. X = Activation('relu')(X)
  75. X = MaxPooling2D((3, 3), strides=(2, 2))(X)
  76. # Stage 2
  77. X = convolutional_block(X, f = 3, filters = [64, 64, 256], stage = 2, block='a', s = 1)
  78. # Stage 3
  79. X = convolutional_block(X, f=3, filters=[128, 128, 512], stage=3, block='a', s=2)
  80. # AVGPOOL
  81. X = AveragePooling2D(pool_size=(2,2), padding='same')(X)
  82. # Output layer
  83. X = Flatten()(X)
  84. X = Dense(classes, activation='softmax', name='fc' + str(classes), kernel_initializer = glorot_uniform(seed=0))(X)
  85. # Create model
  86. model = Model(inputs = X_input, outputs = X, name='ResNet')
  87. return model
  88. model = ResNet(input_shape = (32, 32, 3), classes = 10)
  89. # %%
  90. # Horovod: adjust learning rate based on number of GPUs.
  91. scaled_lr = 0.001 * hvd.size()
  92. # opt = tf.optimizers.Adam(scaled_lr)
  93. from tensorflow_addons.optimizers import LAMB
  94. # Replace the Adam optimizer with NovoGrad:
  95. opt = LAMB(learning_rate=scaled_lr)
  96. # Horovod: add Horovod DistributedOptimizer.
  97. opt = hvd.DistributedOptimizer(
  98. opt, backward_passes_per_step=1, average_aggregated_gradients=True)
  99. # Horovod: Specify `experimental_run_tf_function=False` to ensure TensorFlow
  100. # uses hvd.DistributedOptimizer() to compute gradients.
  101. model.compile(loss=tf.losses.SparseCategoricalCrossentropy(),
  102. optimizer=opt,
  103. metrics=['accuracy'],
  104. experimental_run_tf_function=False)
  105. class PrintLR(tf.keras.callbacks.Callback):
  106. def __init__(self, total_images=0):
  107. self.total_images = total_images
  108. def on_epoch_begin(self, epoch, logs=None):
  109. self.epoch_start_time = time.time()
  110. def on_epoch_end(self, epoch, logs=None):
  111. if hvd.rank() == 0 :
  112. epoch_time = time.time() - self.epoch_start_time
  113. print('Epoch time : {}'.format(epoch_time))
  114. images_per_sec = round(self.total_images / epoch_time, 2)
  115. print('Images/sec: {}'.format(images_per_sec))
  116. callbacks = [
  117. # Horovod: broadcast initial variable states from rank 0 to all other processes.
  118. # This is necessary to ensure consistent initialization of all workers when
  119. # training is started with random weights or restored from a checkpoint.
  120. hvd.callbacks.BroadcastGlobalVariablesCallback(0),
  121. # Horovod: average metrics among workers at the end of every epoch.
  122. #
  123. # Note: This callback must be in the list before the ReduceLROnPlateau,
  124. # TensorBoard or other metrics-based callbacks.
  125. hvd.callbacks.MetricAverageCallback(),
  126. PrintLR(total_images=len(labels)),
  127. hvd.callbacks.LearningRateWarmupCallback(initial_lr=scaled_lr, warmup_epochs=3, verbose=1),
  128. ]
  129. # model.summary()
  130. # Horovod: write logs on worker 0.
  131. verbose = 1 if hvd.rank() == 0 else 0
  132. # Train the model.
  133. # Horovod: adjust number of steps based on number of GPUs.
  134. model.fit(dataset, steps_per_epoch=len(labels) // (batch_size*hvd.size()), callbacks=callbacks, epochs=12, verbose=verbose)