Mobilenetv2ToOnnx.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import argparse
  2. import os
  3. import cv2
  4. import numpy as np
  5. import onnx
  6. import onnxruntime
  7. import torch
  8. from albumentations import (
  9. CenterCrop,
  10. Compose,
  11. Normalize,
  12. Resize,
  13. )
  14. from torchvision import models
  15. def compare_pytorch_onnx(
  16. original_model_preds, onnx_model_path, input_image,
  17. ):
  18. # get onnx result
  19. session = onnxruntime.InferenceSession(onnx_model_path)
  20. input_name = session.get_inputs()[0].name
  21. onnx_result = session.run([], {input_name: input_image})
  22. onnx_result = np.squeeze(onnx_result, axis=0)
  23. print("Checking PyTorch model and converted ONNX model outputs ... ")
  24. for test_onnx_result, gold_result in zip(onnx_result, original_model_preds):
  25. np.testing.assert_almost_equal(
  26. gold_result, test_onnx_result, decimal=3,
  27. )
  28. print("PyTorch and ONNX output values are equal! \n")
  29. def get_onnx_model(
  30. original_model, input_image, model_path="models", model_name="pytorch_mobilenet",
  31. ):
  32. # create model root dir
  33. os.makedirs(model_path, exist_ok=True)
  34. model_name = os.path.join(model_path, model_name + ".onnx")
  35. torch.onnx.export(
  36. original_model, torch.Tensor(input_image), model_name, verbose=True,
  37. )
  38. print("ONNX model was successfully generated: {} \n".format(model_name))
  39. return model_name
  40. def get_preprocessed_image(image_name):
  41. # read image
  42. original_image = cv2.imread(image_name)
  43. # convert original image to RGB format
  44. image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
  45. # transform input image:
  46. # 1. resize the image
  47. # 2. crop the image
  48. # 3. normalize: subtract mean and divide by standard deviation
  49. transform = Compose(
  50. [
  51. Resize(height=256, width=256),
  52. CenterCrop(224, 224),
  53. Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  54. ],
  55. )
  56. image = transform(image=image)["image"]
  57. # change the order of channels
  58. image = image.transpose(2, 0, 1)
  59. return np.expand_dims(image, axis=0)
  60. def get_predicted_class(pytorch_preds):
  61. # read ImageNet class id to name mapping
  62. with open("imagenet_classes.txt") as f:
  63. labels = [line.strip() for line in f.readlines()]
  64. # find the class with the maximum score
  65. pytorch_class_idx = np.argmax(pytorch_preds, axis=1)
  66. predicted_pytorch_label = labels[pytorch_class_idx[0]]
  67. # print top predicted class
  68. print("Predicted class by PyTorch model: ", predicted_pytorch_label)
  69. def get_execution_arguments():
  70. parser = argparse.ArgumentParser()
  71. parser.add_argument(
  72. "--input_image",
  73. type=str,
  74. help="Define the full input image path, including its name",
  75. default="test_img_cup.jpg",
  76. )
  77. return parser.parse_args()
  78. if __name__ == "__main__":
  79. # get the test case parameters
  80. args = get_execution_arguments()
  81. # read and process the input image
  82. image = get_preprocessed_image(image_name=args.input_image)
  83. # obtain original model
  84. pytorch_model = models.mobilenet_v2(pretrained=True)
  85. # provide inference of the original PyTorch model
  86. pytorch_model.eval()
  87. pytorch_predictions = pytorch_model(torch.Tensor(image)).detach().numpy()
  88. # obtain OpenCV generated ONNX model
  89. onnx_model_path = get_onnx_model(original_model=pytorch_model, input_image=image)
  90. # check if conversion succeeded
  91. onnx_model = onnx.load(onnx_model_path)
  92. onnx.checker.check_model(onnx_model)
  93. # check onnx model output
  94. compare_pytorch_onnx(
  95. pytorch_predictions, onnx_model_path, image,
  96. )
  97. # decode classification results
  98. get_predicted_class(pytorch_preds=pytorch_predictions)