cifar_base.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. # %%
  4. import argparse
  5. import tensorflow as tf
  6. from tensorflow.keras.datasets import cifar10
  7. from tensorflow.keras.layers import Input, Add, Dense, Activation, ZeroPadding2D, BatchNormalization, Flatten, Conv2D, AveragePooling2D, MaxPooling2D, GlobalMaxPooling2D
  8. from tensorflow.keras.models import Model, load_model
  9. from tensorflow.keras.preprocessing import image
  10. from tensorflow.keras.applications.imagenet_utils import preprocess_input
  11. from tensorflow.keras import backend as K
  12. from tensorflow.keras.initializers import glorot_uniform
  13. import horovod.tensorflow.keras as hvd
  14. import sys
  15. import time
  16. tf.random.set_seed(1330)
  17. def parse_args():
  18. parser = argparse.ArgumentParser()
  19. parser.add_argument("--batch-size", type=int, default=256, help="Batch size")
  20. args = parser.parse_args()
  21. return args
  22. args = parse_args()
  23. global g_args
  24. g_args = args
  25. batch_size = args.batch_size
  26. # Horovod: initialize Horovod.
  27. hvd.init()
  28. # Horovod: pin GPU to be used to process local rank (one GPU per process)
  29. gpus = tf.config.experimental.list_physical_devices('GPU')
  30. for gpu in gpus:
  31. tf.config.experimental.set_memory_growth(gpu, True)
  32. if gpus:
  33. tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU')
  34. (images, labels), _ = tf.keras.datasets.cifar10.load_data()
  35. dataset = tf.data.Dataset.from_tensor_slices(
  36. (tf.cast(images[...] / 255.0, tf.float32),
  37. tf.cast(labels, tf.int64))
  38. )
  39. dataset = dataset.repeat().shuffle(10000).batch(batch_size)
  40. def convolutional_block(X, f, filters, stage, block, s=2):
  41. # Defining name basis
  42. conv_name_base = 'res' + str(stage) + block + '_branch'
  43. bn_name_base = 'bn' + str(stage) + block + '_branch'
  44. # Retrieve Filters
  45. F1, F2, F3 = filters
  46. # Save the input value
  47. X_shortcut = X
  48. ##### MAIN PATH #####
  49. # First component of main path
  50. X = Conv2D(filters=F1, kernel_size=(1, 1), strides=(s, s), padding='valid', name=conv_name_base + '2a', kernel_initializer=glorot_uniform(seed=0))(X)
  51. #X = BatchNormalization(axis=3, name=bn_name_base + '2a')(X)
  52. X = Activation('relu')(X)
  53. # Second component of main path
  54. X = Conv2D(filters=F2, kernel_size=(f, f), strides=(1, 1), padding='same', name=conv_name_base + '2b', kernel_initializer=glorot_uniform(seed=0))(X)
  55. #X = BatchNormalization(axis=3, name=bn_name_base + '2b')(X)
  56. X = Activation('relu')(X)
  57. # Third component of main path
  58. X = Conv2D(filters=F3, kernel_size=(1, 1), strides=(1, 1), padding='valid', name=conv_name_base + '2c', kernel_initializer=glorot_uniform(seed=0))(X)
  59. #X = BatchNormalization(axis=3, name=bn_name_base + '2c')(X)
  60. ##### SHORTCUT PATH ####
  61. X_shortcut = Conv2D(filters=F3, kernel_size=(1, 1), strides=(s, s), padding='valid', name=conv_name_base + '1', kernel_initializer=glorot_uniform(seed=0))(X_shortcut)
  62. #X_shortcut = BatchNormalization(axis=3, name=bn_name_base + '1')(X_shortcut)
  63. # Final step: Add shortcut value to main path, and pass it through a RELU activation
  64. X = Add()([X, X_shortcut])
  65. X = Activation('relu')(X)
  66. return X
  67. def ResNet(input_shape = (28, 28, 1), classes = 10):
  68. # Define the input as a tensor with shape input_shape
  69. X_input = Input(shape=input_shape)
  70. # Zero-Padding
  71. X = ZeroPadding2D((3, 3))(X_input)
  72. # Stage 1
  73. X = Conv2D(64, (7, 7), strides = (2, 2), name = 'conv1', kernel_initializer = glorot_uniform(seed=0))(X)
  74. #X = BatchNormalization(axis = 3, name = 'bn_conv1')(X)
  75. X = Activation('relu')(X)
  76. X = MaxPooling2D((3, 3), strides=(2, 2))(X)
  77. # Stage 2
  78. X = convolutional_block(X, f = 3, filters = [64, 64, 256], stage = 2, block='a', s = 1)
  79. # Stage 3
  80. X = convolutional_block(X, f=3, filters=[128, 128, 512], stage=3, block='a', s=2)
  81. # AVGPOOL
  82. X = AveragePooling2D(pool_size=(2,2), padding='same')(X)
  83. # Output layer
  84. X = Flatten()(X)
  85. X = Dense(classes, activation='softmax', name='fc' + str(classes), kernel_initializer = glorot_uniform(seed=0))(X)
  86. # Create model
  87. model = Model(inputs = X_input, outputs = X, name='ResNet')
  88. return model
  89. model = ResNet(input_shape = (32, 32, 3), classes = 10)
  90. # %%
  91. # Horovod: adjust learning rate based on number of GPUs.
  92. opt = tf.optimizers.Adam()
  93. # Horovod: add Horovod DistributedOptimizer.
  94. opt = hvd.DistributedOptimizer(
  95. opt, backward_passes_per_step=1, average_aggregated_gradients=True)
  96. # Horovod: Specify `experimental_run_tf_function=False` to ensure TensorFlow
  97. # uses hvd.DistributedOptimizer() to compute gradients.
  98. model.compile(loss=tf.losses.SparseCategoricalCrossentropy(),
  99. optimizer=opt,
  100. metrics=['accuracy'],
  101. experimental_run_tf_function=False)
  102. class PrintLR(tf.keras.callbacks.Callback):
  103. def __init__(self, total_images=0):
  104. self.total_images = total_images
  105. def on_epoch_begin(self, epoch, logs=None):
  106. self.epoch_start_time = time.time()
  107. def on_epoch_end(self, epoch, logs=None):
  108. if hvd.rank() == 0 :
  109. epoch_time = time.time() - self.epoch_start_time
  110. print('Epoch time : {}'.format(epoch_time))
  111. images_per_sec = round(self.total_images / epoch_time, 2)
  112. print('Images/sec: {}'.format(images_per_sec))
  113. callbacks = [
  114. # Horovod: broadcast initial variable states from rank 0 to all other processes.
  115. # This is necessary to ensure consistent initialization of all workers when
  116. # training is started with random weights or restored from a checkpoint.
  117. hvd.callbacks.BroadcastGlobalVariablesCallback(0),
  118. # Horovod: average metrics among workers at the end of every epoch.
  119. #
  120. # Note: This callback must be in the list before the ReduceLROnPlateau,
  121. # TensorBoard or other metrics-based callbacks.
  122. hvd.callbacks.MetricAverageCallback(),
  123. PrintLR(total_images=len(labels))
  124. ]
  125. # model.summary()
  126. # Horovod: write logs on worker 0.
  127. verbose = 1 if hvd.rank() == 0 else 0
  128. # Train the model.
  129. # Horovod: adjust number of steps based on number of GPUs.
  130. model.fit(dataset, steps_per_epoch=len(labels) // (batch_size*hvd.size()), callbacks=callbacks, epochs=12, verbose=verbose)