cifar10_cnn_bn_100epochs.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. '''
  2. This code was originally written by the Keras team. It has been modified by
  3. Sunita Nayak at BigVision LLC. to include Batch Normalization in the architecture.
  4. Train a simple deep CNN on the CIFAR10 small images dataset using Batch Normalization.
  5. It gets to a maximum of 87% validation accuracy. It gets to 79% in only 7 epochs. Note
  6. that the keras team's maximum accuracy was 79% in 50 epochs. With Batch Normalization,
  7. it exceeds 85% in just 21 epochs, and gets to 87% in 39 epochs.
  8. '''
  9. from __future__ import print_function
  10. import keras
  11. from keras.datasets import cifar10
  12. from keras.preprocessing.image import ImageDataGenerator
  13. from keras.models import Sequential
  14. from keras.layers import Dense, Dropout, Activation, Flatten, BatchNormalization
  15. from keras.layers import Conv2D, MaxPooling2D
  16. import os
  17. import pickle
  18. from numpy.random import seed
  19. seed(7)
  20. batch_size = 32
  21. num_classes = 10
  22. epochs = 100
  23. data_augmentation = True
  24. num_predictions = 20
  25. save_dir = os.path.join(os.getcwd(), 'saved_models_bn_100_s7')
  26. model_name = 'keras_cifar10_trained_model.h5'
  27. # The data, split between train and test sets:
  28. (x_train, y_train), (x_test, y_test) = cifar10.load_data()
  29. print('x_train shape:', x_train.shape)
  30. print(x_train.shape[0], 'train samples')
  31. print(x_test.shape[0], 'test samples')
  32. # Convert class vectors to binary class matrices.
  33. y_train = keras.utils.to_categorical(y_train, num_classes)
  34. y_test = keras.utils.to_categorical(y_test, num_classes)
  35. model = Sequential()
  36. model.add(Conv2D(32, (3, 3), padding='same',
  37. input_shape=x_train.shape[1:]))
  38. model.add(BatchNormalization())
  39. model.add(Activation('relu'))
  40. model.add(Conv2D(32, (3, 3)))
  41. model.add(BatchNormalization())
  42. model.add(Activation('relu'))
  43. model.add(MaxPooling2D(pool_size=(2, 2)))
  44. #model.add(Dropout(0.25))
  45. model.add(Conv2D(64, (3, 3), padding='same'))
  46. model.add(BatchNormalization())
  47. model.add(Activation('relu'))
  48. model.add(Conv2D(64, (3, 3)))
  49. model.add(BatchNormalization())
  50. model.add(Activation('relu'))
  51. model.add(MaxPooling2D(pool_size=(2, 2)))
  52. #model.add(Dropout(0.25))
  53. model.add(Flatten())
  54. model.add(Dense(512))
  55. model.add(BatchNormalization())
  56. model.add(Activation('relu'))
  57. #model.add(Dropout(0.5))
  58. model.add(Dense(num_classes))
  59. model.add(BatchNormalization())
  60. model.add(Activation('softmax'))
  61. # initiate RMSprop optimizer
  62. opt = keras.optimizers.rmsprop(lr=0.001, decay=1e-6)
  63. # Let's train the model using RMSprop
  64. model.compile(loss='categorical_crossentropy',
  65. optimizer=opt,
  66. metrics=['accuracy'])
  67. x_train = x_train.astype('float32')
  68. x_test = x_test.astype('float32')
  69. x_train /= 255
  70. x_test /= 255
  71. if not data_augmentation:
  72. print('Not using data augmentation.')
  73. history = model.fit(x_train, y_train,
  74. batch_size=batch_size,
  75. epochs=epochs,
  76. validation_data=(x_test, y_test),
  77. shuffle=True)
  78. else:
  79. print('Using real-time data augmentation.')
  80. # This will do preprocessing and realtime data augmentation:
  81. datagen = ImageDataGenerator(
  82. featurewise_center=False, # set input mean to 0 over the dataset
  83. samplewise_center=False, # set each sample mean to 0
  84. featurewise_std_normalization=False, # divide inputs by std of the dataset
  85. samplewise_std_normalization=False, # divide each input by its std
  86. zca_whitening=False, # apply ZCA whitening
  87. zca_epsilon=1e-06, # epsilon for ZCA whitening
  88. rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180)
  89. width_shift_range=0.1, # randomly shift images horizontally (fraction of total width)
  90. height_shift_range=0.1, # randomly shift images vertically (fraction of total height)
  91. shear_range=0., # set range for random shear
  92. zoom_range=0., # set range for random zoom
  93. channel_shift_range=0., # set range for random channel shifts
  94. fill_mode='nearest', # set mode for filling points outside the input boundaries
  95. cval=0., # value used for fill_mode = "constant"
  96. horizontal_flip=True, # randomly flip images
  97. vertical_flip=False, # randomly flip images
  98. rescale=None, # set rescaling factor (applied before any other transformation)
  99. preprocessing_function=None, # set function that will be applied on each input
  100. data_format=None, # image data format, either "channels_first" or "channels_last"
  101. validation_split=0.0) # fraction of images reserved for validation (strictly between 0 and 1)
  102. # Compute quantities required for feature-wise normalization
  103. # (std, mean, and principal components if ZCA whitening is applied).
  104. datagen.fit(x_train)
  105. # Fit the model on the batches generated by datagen.flow().
  106. history = model.fit_generator(datagen.flow(x_train, y_train, batch_size=batch_size), epochs=epochs, validation_data=(x_test, y_test), workers=4)
  107. with open('./trainHistoryDictWithBn1', 'wb') as file_pi:
  108. pickle.dump(history.history, file_pi)
  109. # Save model and weights
  110. if not os.path.isdir(save_dir):
  111. os.makedirs(save_dir)
  112. model_path = os.path.join(save_dir, model_name)
  113. model.save(model_path)
  114. print('Saved trained model at %s ' % model_path)
  115. # Score trained model.
  116. scores = model.evaluate(x_test, y_test, verbose=1)
  117. print('Test loss:', scores[0])
  118. print('Test accuracy:', scores[1])