segmodel_to_onnx.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import torch
  2. from torch import nn
  3. from torchvision import models
  4. import torchvision.transforms as T
  5. import numpy as np
  6. import cv2
  7. import time
  8. from segcolors import colors
  9. class SegModel(nn.Module):
  10. def __init__(self):
  11. super().__init__()
  12. self.net= models.segmentation.fcn_resnet50(pretrained=True, aux_loss=False).cuda()
  13. self.ppmean=torch.Tensor([0.485, 0.456, 0.406])
  14. self.ppstd=torch.Tensor([0.229, 0.224, 0.225])
  15. self.preprocessor=T.Compose([T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])])
  16. self.cmap=torch.from_numpy(colors[:,::-1].copy())
  17. def forward(self, x):
  18. """x is a pytorch tensor"""
  19. #x=(x-self.ppmean)/self.ppstd #uncomment if you want onnx to include pre-processing
  20. isize=x.shape[-2:]
  21. x=self.net.backbone(x)['out']
  22. x=self.net.classifier(x)
  23. #x=nn.functional.interpolate(x, isize, mode='bilinear') #uncomment if you want onnx to include interpolation
  24. return x
  25. def export_onnx(self, onnxpath):
  26. """onnxpath: string, path of output onnx file"""
  27. x=torch.randn(1,3,360,640).cuda() #360p size
  28. input=['image']
  29. output=['probabilities']
  30. torch.onnx.export(self, x, onnxpath, verbose=False, input_names=input, output_names=output, opset_version=11)
  31. print('Exported to onnx')
  32. def infervideo(self, fname, view=True, savepath=None):
  33. """
  34. fname: path of input video file/camera index
  35. view(bool): whether or not to display results
  36. savepath (string or None): if path specified, output video is saved
  37. """
  38. src=cv2.VideoCapture(fname)
  39. ret,frame=src.read()
  40. if not ret:
  41. print(f'Cannot read input file/camera {fname}')
  42. quit()
  43. self.net.eval()
  44. dst=None
  45. fps=0.0
  46. if savepath is not None:
  47. dst=self.getvideowriter(savepath, src)
  48. with torch.no_grad(): #we just inferring, no need to calculate gradients
  49. while ret:
  50. outf, cfps=self.inferframe(frame, benchmark=True)
  51. if view:
  52. cv2.imshow('segmentation', outf)
  53. k=cv2.waitKey(1)
  54. if k==ord('q'):
  55. break
  56. if dst:
  57. dst.write(outf)
  58. fps=0.9*fps+0.1*cfps
  59. print(fps)
  60. ret,frame=src.read()
  61. src.release()
  62. if dst:
  63. dst.release()
  64. def inferframe(self, frame, benchmark=True):
  65. """
  66. frame: numpy array containing un-pre-processed video frame (dtype is uint8)
  67. benchamrk: bool, whether or not to calculate inference time
  68. """
  69. rgb=frame[...,::-1].copy()
  70. processed=self.preprocessor(rgb)[None]
  71. start, end = 1e6, 0
  72. if benchmark:
  73. start=time.time()
  74. processed=processed.cuda() #transfer to GPU <-- does not use zero copy
  75. inferred= self(processed) #infer
  76. if benchmark:
  77. end=time.time()
  78. inferred=inferred.argmax(dim=1)
  79. overlaid=self.overlay(frame, inferred)
  80. return overlaid, 1.0/(end-start)
  81. def overlay(self, bgr, mask):
  82. """
  83. overlay pixel-wise predictions on input frame
  84. bgr: (numpy array) original video frame read from video/camera
  85. mask: (numpy array) class mask containing one of 21 classes for each pixel
  86. """
  87. colored = self.cmap[mask].to('cpu').numpy()[0,...]
  88. colored=cv2.resize(colored, (bgr.shape[1], bgr.shape[0]), interpolation=cv2.INTER_CUBIC)
  89. oved = cv2.addWeighted(bgr, 0.7, colored, 0.3, 0.0)
  90. return oved
  91. def getvideowriter(self, savepath, srch):
  92. """
  93. Simple utility function for getting video writer
  94. savepath: string, path of output file
  95. src: a cv2.VideoCapture object
  96. """
  97. fps=srch.get(cv2.CAP_PROP_FPS)
  98. width=int(srch.get(cv2.CAP_PROP_FRAME_WIDTH))
  99. height=int(srch.get(cv2.CAP_PROP_FRAME_HEIGHT))
  100. fourcc=int(srch.get(cv2.CAP_PROP_FOURCC))
  101. dst=cv2.VideoWriter(savepath, fourcc, fps, (width, height))
  102. return dst
  103. if __name__=='__main__':
  104. model=SegModel()
  105. model.export_onnx('./segmodel.onnx')
  106. #model.infervideo('../may20/cam_2.mp4') #uncomment to infer on a video or camera