pytorch_model.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import cv2
  2. import onnx
  3. import torch
  4. from albumentations import (Compose,Resize,)
  5. from albumentations.augmentations.transforms import Normalize
  6. from albumentations.pytorch.transforms import ToTensor
  7. from torchvision import models
  8. def preprocess_image(img_path):
  9. # transformations for the input data
  10. transforms = Compose([
  11. Resize(224, 224, interpolation=cv2.INTER_NEAREST),
  12. Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  13. ToTensor(),
  14. ])
  15. # read input image
  16. input_img = cv2.imread(img_path)
  17. # do transformations
  18. input_data = transforms(image=input_img)["image"]
  19. # prepare batch
  20. batch_data = torch.unsqueeze(input_data, 0)
  21. return batch_data
  22. def postprocess(output_data):
  23. # get class names
  24. with open("imagenet_classes.txt") as f:
  25. classes = [line.strip() for line in f.readlines()]
  26. # calculate human-readable value by softmax
  27. confidences = torch.nn.functional.softmax(output_data, dim=1)[0] * 100
  28. # find top predicted classes
  29. _, indices = torch.sort(output_data, descending=True)
  30. i = 0
  31. # print the top classes predicted by the model
  32. while confidences[indices[0][i]] > 0.5:
  33. class_idx = indices[0][i]
  34. print(
  35. "class:",
  36. classes[class_idx],
  37. ", confidence:",
  38. confidences[class_idx].item(),
  39. "%, index:",
  40. class_idx.item(),
  41. )
  42. i += 1
  43. def main():
  44. # load pre-trained model -------------------------------------------------------------------------------------------
  45. model = models.resnet50(pretrained=True)
  46. # preprocessing stage ----------------------------------------------------------------------------------------------
  47. input = preprocess_image("turkish_coffee.jpg").cuda()
  48. # inference stage --------------------------------------------------------------------------------------------------
  49. model.eval()
  50. model.cuda()
  51. output = model(input)
  52. # post-processing stage --------------------------------------------------------------------------------------------
  53. postprocess(output)
  54. # convert to ONNX --------------------------------------------------------------------------------------------------
  55. ONNX_FILE_PATH = "resnet50.onnx"
  56. torch.onnx.export(model, input, ONNX_FILE_PATH, input_names=["input"], output_names=["output"], export_params=True)
  57. onnx_model = onnx.load(ONNX_FILE_PATH)
  58. # check that the model converted fine
  59. onnx.checker.check_model(onnx_model)
  60. print("Model was successfully converted to ONNX format.")
  61. print("It was saved to", ONNX_FILE_PATH)
  62. if __name__ == '__main__':
  63. main()