make_predictions.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import cv2 # for reading and writing or showing image
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. from keras.models import load_model
  5. import keras
  6. from keras.preprocessing import image
  7. from keras.models import load_model
  8. from net import Net
  9. import sys
  10. def load_image(img_path, show=False):
  11. '''
  12. Function: Convert image to tensor
  13. Input: image_path (eg. /home/user/filename.jpg)
  14. (Note prefer having absolute path)
  15. show (default = False), set if you want to visualize the image
  16. Return: tensor format of image
  17. '''
  18. # load image using image module
  19. # convert to (32, 32) - if not already
  20. img = image.load_img(img_path, target_size=(32, 32)) # Path of test image
  21. # show the image if show=True
  22. if show:
  23. plt.imshow(img)
  24. plt.axis('off')
  25. # converting image to a tensor
  26. img_tensor = image.img_to_array(img) # (height, width, channels)
  27. img_tensor = np.expand_dims(img_tensor, axis=0)
  28. img_tensor /= 255.
  29. # return converted image
  30. return img_tensor
  31. def predict(weights_path, image_path):
  32. '''
  33. Function: loads a trained model and predicts the class of given image
  34. Input: weights_path (.h5 file, prefer adding absolute path)
  35. image_path (image to predict, prefer adding absolute path)
  36. Returns: none
  37. '''
  38. model = Net.build(32, 32, 3, weights_path)
  39. image = load_image(image_path, show=True) # load image, rescale to 0 to 1
  40. class_ = model.predict(image) # predict the output, returns 36 length array
  41. print("Detected: ", class_[0]) # print what is predicted
  42. output_indice = -1 # set it initially to -1
  43. # get class index having maximum predicted score
  44. for i in range(36):
  45. if(i == 0):
  46. max = class_[0][i]
  47. output_indice = 0
  48. else:
  49. if(class_[0][i] > max):
  50. max = class_[0][i]
  51. output_indice = i
  52. # append 26 characters (A to Z) to list characters
  53. characters = []
  54. for i in range(65, 65+26):
  55. characters.append(chr(i))
  56. # if output indice > 9 (means characters)
  57. if(output_indice > 9):
  58. final_result = characters[(output_indice - 9) - 1]
  59. print("Predicted: ", final_result)
  60. print("value: ", max) # print predicted score
  61. # else it's a digit, print directly
  62. else:
  63. print("Predicted: ", output_indice)
  64. print("value: ", max) # print it's predicted score
  65. if(len(sys.argv) < 2):
  66. print("Enter test image path as an argument")
  67. sys.exit(0)
  68. test_image = sys.argv[1]
  69. predict("trained_weights.h5", test_image) # Specify weights file and Test image