train_model.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. # import required modules
  2. from keras.preprocessing.image import ImageDataGenerator
  3. from keras import optimizers
  4. import matplotlib.pyplot as plt
  5. # import created model
  6. from net import Net
  7. # Dimensions of our images
  8. img_width, img_height = 32, 32
  9. # 3 channel image
  10. no_of_channels = 3
  11. # train data Directory
  12. train_data_dir = 'train/'
  13. # test data Directory
  14. validation_data_dir = 'test/'
  15. epochs = 80
  16. batch_size = 32
  17. #initialize model
  18. model = Net.build(width = img_width, height = img_height, depth = no_of_channels)
  19. print('building done')
  20. # Compile model
  21. rms = optimizers.RMSprop(lr=0.001, rho=0.9, epsilon=None, decay=0.0)
  22. print('optimizing done')
  23. model.compile(loss='categorical_crossentropy',
  24. optimizer=rms,
  25. metrics=['accuracy'])
  26. print('compiling')
  27. # this is the augmentation configuration used for training
  28. # horizontal_flip = False, as we need to retain Characters
  29. train_datagen = ImageDataGenerator(
  30. featurewise_center=True,
  31. featurewise_std_normalization=True,
  32. rescale=1. / 255,
  33. shear_range=0.1,
  34. zoom_range=0.1,
  35. rotation_range=5,
  36. width_shift_range=0.05,
  37. height_shift_range=0.05,
  38. horizontal_flip=False)
  39. # this is the augmentation configuration used for testing, only rescaling
  40. test_datagen = ImageDataGenerator(featurewise_center=True, featurewise_std_normalization=True, rescale=1. / 255)
  41. train_generator = train_datagen.flow_from_directory(
  42. train_data_dir,
  43. target_size=(img_width, img_height),
  44. batch_size=batch_size,
  45. class_mode='categorical')
  46. validation_generator = test_datagen.flow_from_directory(
  47. validation_data_dir,
  48. target_size=(img_width, img_height),
  49. batch_size=batch_size,
  50. class_mode='categorical')
  51. # fit the model
  52. history = model.fit_generator(
  53. train_generator,
  54. steps_per_epoch=train_generator.samples / batch_size,
  55. epochs=epochs,
  56. validation_data=validation_generator,
  57. validation_steps=validation_generator.samples / batch_size)
  58. # evaluate on validation dataset
  59. model.evaluate_generator(validation_generator)
  60. # save weights in a file
  61. model.save_weights('trained_weights.h5')
  62. print(history.history)
  63. # Loss Curves
  64. plt.figure(figsize=[8,6])
  65. plt.plot(history.history['loss'],'r',linewidth=3.0)
  66. plt.plot(history.history['val_loss'],'b',linewidth=3.0)
  67. plt.legend(['Training loss', 'Validation Loss'],fontsize=18)
  68. plt.xlabel('Epochs ',fontsize=16)
  69. plt.ylabel('Loss',fontsize=16)
  70. plt.title('Loss Curves',fontsize=16)
  71. # Accuracy Curves
  72. plt.figure(figsize=[8,6])
  73. plt.plot(history.history['acc'],'r',linewidth=3.0)
  74. plt.plot(history.history['val_acc'],'b',linewidth=3.0)
  75. plt.legend(['Training Accuracy', 'Validation Accuracy'],fontsize=18)
  76. plt.xlabel('Epochs ',fontsize=16)
  77. plt.ylabel('Accuracy',fontsize=16)
  78. plt.title('Accuracy Curves',fontsize=16)
  79. plt.show()