cifar10_cnn_100epochs.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. '''Train a simple deep CNN on the CIFAR10 small images dataset.
  2. It gets to 75% validation accuracy in 25 epochs, and 79% after 50 epochs.
  3. (it's still underfitting at that point, though).
  4. https://github.com/keras-team/keras/blob/master/examples/cifar10_cnn.py
  5. '''
  6. from __future__ import print_function
  7. import keras
  8. from keras.datasets import cifar10
  9. from keras.preprocessing.image import ImageDataGenerator
  10. from keras.models import Sequential
  11. from keras.layers import Dense, Dropout, Activation, Flatten
  12. from keras.layers import Conv2D, MaxPooling2D
  13. import os
  14. import pickle
  15. from numpy.random import seed
  16. seed(7)
  17. batch_size = 32
  18. num_classes = 10
  19. epochs = 100
  20. data_augmentation = True
  21. num_predictions = 20
  22. save_dir = os.path.join(os.getcwd(), 'saved_models_noBn_100_s7')
  23. model_name = 'keras_cifar10_trained_model.h5'
  24. # The data, split between train and test sets:
  25. (x_train, y_train), (x_test, y_test) = cifar10.load_data()
  26. print('x_train shape:', x_train.shape)
  27. print(x_train.shape[0], 'train samples')
  28. print(x_test.shape[0], 'test samples')
  29. # Convert class vectors to binary class matrices.
  30. y_train = keras.utils.to_categorical(y_train, num_classes)
  31. y_test = keras.utils.to_categorical(y_test, num_classes)
  32. model = Sequential()
  33. model.add(Conv2D(32, (3, 3), padding='same',
  34. input_shape=x_train.shape[1:]))
  35. model.add(Activation('relu'))
  36. model.add(Conv2D(32, (3, 3)))
  37. model.add(Activation('relu'))
  38. model.add(MaxPooling2D(pool_size=(2, 2)))
  39. model.add(Dropout(0.25))
  40. model.add(Conv2D(64, (3, 3), padding='same'))
  41. model.add(Activation('relu'))
  42. model.add(Conv2D(64, (3, 3)))
  43. model.add(Activation('relu'))
  44. model.add(MaxPooling2D(pool_size=(2, 2)))
  45. model.add(Dropout(0.25))
  46. model.add(Flatten())
  47. model.add(Dense(512))
  48. model.add(Activation('relu'))
  49. model.add(Dropout(0.5))
  50. model.add(Dense(num_classes))
  51. model.add(Activation('softmax'))
  52. # initiate RMSprop optimizer
  53. opt = keras.optimizers.rmsprop(lr=0.0001, decay=1e-6)
  54. # Let's train the model using RMSprop
  55. model.compile(loss='categorical_crossentropy',
  56. optimizer=opt,
  57. metrics=['accuracy'])
  58. x_train = x_train.astype('float32')
  59. x_test = x_test.astype('float32')
  60. x_train /= 255
  61. x_test /= 255
  62. if not data_augmentation:
  63. print('Not using data augmentation.')
  64. history = model.fit(x_train, y_train,
  65. batch_size=batch_size,
  66. epochs=epochs,
  67. validation_data=(x_test, y_test),
  68. shuffle=True)
  69. else:
  70. print('Using real-time data augmentation.')
  71. # This will do preprocessing and realtime data augmentation:
  72. datagen = ImageDataGenerator(
  73. featurewise_center=False, # set input mean to 0 over the dataset
  74. samplewise_center=False, # set each sample mean to 0
  75. featurewise_std_normalization=False, # divide inputs by std of the dataset
  76. samplewise_std_normalization=False, # divide each input by its std
  77. zca_whitening=False, # apply ZCA whitening
  78. zca_epsilon=1e-06, # epsilon for ZCA whitening
  79. rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180)
  80. width_shift_range=0.1, # randomly shift images horizontally (fraction of total width)
  81. height_shift_range=0.1, # randomly shift images vertically (fraction of total height)
  82. shear_range=0., # set range for random shear
  83. zoom_range=0., # set range for random zoom
  84. channel_shift_range=0., # set range for random channel shifts
  85. fill_mode='nearest', # set mode for filling points outside the input boundaries
  86. cval=0., # value used for fill_mode = "constant"
  87. horizontal_flip=True, # randomly flip images
  88. vertical_flip=False, # randomly flip images
  89. rescale=None, # set rescaling factor (applied before any other transformation)
  90. preprocessing_function=None, # set function that will be applied on each input
  91. data_format=None, # image data format, either "channels_first" or "channels_last"
  92. validation_split=0.0) # fraction of images reserved for validation (strictly between 0 and 1)
  93. # Compute quantities required for feature-wise normalization
  94. # (std, mean, and principal components if ZCA whitening is applied).
  95. datagen.fit(x_train)
  96. # Fit the model on the batches generated by datagen.flow().
  97. history = model.fit_generator(datagen.flow(x_train, y_train, batch_size=batch_size), epochs=epochs, validation_data=(x_test, y_test), workers=4)
  98. with open('./trainHistoryDictNoBn50', 'wb') as file_pi:
  99. pickle.dump(history.history, file_pi)
  100. # Save model and weights
  101. if not os.path.isdir(save_dir):
  102. os.makedirs(save_dir)
  103. model_path = os.path.join(save_dir, model_name)
  104. model.save(model_path)
  105. print('Saved trained model at %s ' % model_path)
  106. # Score trained model.
  107. scores = model.evaluate(x_test, y_test, verbose=1)
  108. print('Test loss:', scores[0])
  109. print('Test accuracy:', scores[1])