pytrt.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. import argparse
  2. from segcolors import colors
  3. import numpy as np
  4. import tensorrt as trt
  5. import pycuda.driver as cuda
  6. import pycuda.autoinit
  7. import pdb
  8. import os
  9. import cv2
  10. import time
  11. class TRTSegmentor(object):
  12. def __init__(self,
  13. onnxpath,
  14. colors,
  15. insize=(640,360),
  16. maxworkspace=(1<<25),
  17. precision='FP16',
  18. device='GPU',
  19. max_batch_size=1,
  20. calibrator=None,
  21. dla_core=0
  22. ):
  23. self.onnxpath=onnxpath
  24. self.enginepath=onnxpath+f'.{precision}.{device}.{dla_core}.{max_batch_size}.trt'
  25. #filename to be used for saving and reading engines
  26. self.nclasses=21
  27. self.pp_mean=np.array([0.485, 0.456, 0.406]).reshape((1,1,3))
  28. self.pp_stdev=np.array([0.229, 0.224, 0.225]).reshape((1,1,3))
  29. #mean and stdev for pre-processing images, see torchvision documentation
  30. self.colors=colors #colormap for 21 classes of Pascal VOC
  31. self.in_w=insize[0]
  32. self.in_h=insize[1] #width, height of input images
  33. #here we specify very important engine build flags
  34. self.maxworkspace=maxworkspace
  35. self.max_batch_size=max_batch_size
  36. self.precision_str=precision
  37. self.precision={'FP16':0, 'INT8':1, 'FP32': -1}[precision]
  38. #mapping strings to tensorrt precision flags
  39. self.device={'GPU':trt.DeviceType.GPU, 'DLA': trt.DeviceType.DLA}[device]
  40. #mapping strings to tensorrt device types
  41. self.dla_core=dla_core #used only if DLA device is selected
  42. self.calibrator=calibrator #used only for INT8 precision
  43. self.allowGPUFallback=3 #used only if DLA is selected
  44. self.engine, self.logger= self.parse_or_load()
  45. self.context=self.engine.create_execution_context()
  46. self.trt2np_dtype={'FLOAT':np.float32, 'HALF':np.float16, 'INT8':np.int8}
  47. self.dtype = self.trt2np_dtype[self.engine.get_binding_dtype(0).name]
  48. self.allocate_buffers(np.zeros((1,3,self.in_h,self.in_w), dtype=self.dtype))
  49. def allocate_buffers(self, image):
  50. pass
  51. insize=image.shape[-2:]
  52. outsize=[insize[0] >> 3, insize[1] >> 3]
  53. self.output=np.empty((self.nclasses,outsize[0],outsize[1]), dtype=self.dtype)
  54. self.d_input=cuda.mem_alloc(image.nbytes)
  55. self.d_output=cuda.mem_alloc(self.output.nbytes)
  56. self.bindings=[int(self.d_input), int(self.d_output)]
  57. #print(self.bindings)
  58. self.stream=cuda.Stream()
  59. def preprocess(self, img):
  60. img=cv2.resize(img,(self.in_w,self.in_h))
  61. img=img[...,::-1]
  62. img=img.astype(np.float32)/255
  63. img=(img-self.pp_mean)/self.pp_stdev
  64. img=np.transpose(img,(2,0,1))
  65. img=np.ascontiguousarray(img[None,...]).astype(self.dtype)
  66. return img
  67. def infer(self, image, benchmark=False):
  68. """
  69. image: unresized,
  70. """
  71. intensor=self.preprocess(image)
  72. start=time.time()
  73. cuda.memcpy_htod_async(self.d_input, intensor, self.stream)
  74. self.context.execute_async_v2(self.bindings, self.stream.handle, None)
  75. cuda.memcpy_dtoh_async(self.output, self.d_output, self.stream)
  76. self.stream.synchronize()
  77. if benchmark:
  78. duration=(time.time()-start)
  79. return duration
  80. def infer_async(self, intensor):
  81. #intensor should be preprocessed tensor
  82. cuda.memcpy_htod_async(self.d_input, intensor, self.stream)
  83. self.context.execute_async_v2(self.bindings, self.stream.handle, None)
  84. cuda.memcpy_dtoh_async(self.output, self.d_output, self.stream)
  85. def draw(self, img):
  86. shape=(img.shape[1],img.shape[0])
  87. segres=np.transpose(self.output,(1,2,0)).astype(np.float32)
  88. segres=cv2.resize(segres, shape)
  89. mask=segres.argmax(axis=-1)
  90. colored=self.colors[mask]
  91. drawn=cv2.addWeighted(img, 0.5, colored, 0.5, 0.0)
  92. return drawn
  93. def infervideo(self, infile):
  94. src=cv2.VideoCapture(infile)
  95. ret,frame=src.read()
  96. fps=0.0
  97. if not ret:
  98. print('Cannot read file/camera: {}'.format(infile))
  99. while ret:
  100. duration=self.infer(frame, benchmark=True)
  101. drawn=self.draw(frame)
  102. cv2.imshow('segmented', drawn)
  103. k=cv2.waitKey(1)
  104. if k==ord('q'):
  105. break
  106. fps=0.9*fps+0.1/(duration)
  107. print('FPS=:{:.2f}'.format(fps))
  108. ret,frame=src.read()
  109. def parse_or_load(self):
  110. logger= trt.Logger(trt.Logger.INFO)
  111. #we want to show logs of type info and above (warnings, errors)
  112. if os.path.exists(self.enginepath):
  113. logger.log(trt.Logger.INFO, 'Found pre-existing engine file')
  114. with open(self.enginepath, 'rb') as f:
  115. rt=trt.Runtime(logger)
  116. engine=rt.deserialize_cuda_engine(f.read())
  117. return engine, logger
  118. else: #parse and build if no engine found
  119. with trt.Builder(logger) as builder:
  120. builder.max_batch_size=self.max_batch_size
  121. #setting max_batch_size isn't strictly necessary in this case
  122. #since the onnx file already has that info, but its a good practice
  123. network_flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
  124. #since the onnx file was exported with an explicit batch dim,
  125. #we need to tell this to the builder. We do that with EXPLICIT_BATCH flag
  126. with builder.create_network(network_flag) as net:
  127. with trt.OnnxParser(net, logger) as p:
  128. #create onnx parser which will read onnx file and
  129. #populate the network object `net`
  130. with open(self.onnxpath, 'rb') as f:
  131. if not p.parse(f.read()):
  132. for err in range(p.num_errors):
  133. print(p.get_error(err))
  134. else:
  135. logger.log(trt.Logger.INFO, 'Onnx file parsed successfully')
  136. net.get_input(0).dtype=trt.DataType.HALF
  137. net.get_output(0).dtype=trt.DataType.HALF
  138. #we set the inputs and outputs to be float16 type to enable
  139. #maximum fp16 acceleration. Also helps for int8
  140. config=builder.create_builder_config()
  141. #we specify all the important parameters like precision,
  142. #device type, fallback in config object
  143. config.max_workspace_size = self.maxworkspace
  144. if self.precision_str in ['FP16', 'INT8']:
  145. config.flags = ((1<<self.precision)|(1<<self.allowGPUFallback))
  146. config.DLA_core=self.dla_core
  147. # DLA core (0 or 1 for Jetson AGX/NX/Orin) to be used must be
  148. # specified at engine build time. An engine built for DLA0 will
  149. # not work on DLA1. As such, to use two DLA engines simultaneously,
  150. # we must build two different engines.
  151. config.default_device_type=self.device
  152. #if device is set to GPU, DLA_core has no effect
  153. config.profiling_verbosity = trt.ProfilingVerbosity.VERBOSE
  154. #building with verbose profiling helps debug the engine if there are
  155. #errors in inference output. Does not impact throughput.
  156. if self.precision_str=='INT8' and self.calibrator is None:
  157. logger.log(trt.Logger.ERROR, 'Please provide calibrator')
  158. #can't proceed without a calibrator
  159. quit()
  160. elif self.precision_str=='INT8' and self.calibrator is not None:
  161. config.int8_calibrator=self.calibrator
  162. logger.log(trt.Logger.INFO, 'Using INT8 calibrator provided by user')
  163. logger.log(trt.Logger.INFO, 'Checking if network is supported...')
  164. if builder.is_network_supported(net, config):
  165. logger.log(trt.Logger.INFO, 'Network is supported')
  166. #tensorRT engine can be built only if all ops in network are supported.
  167. #If ops are not supported, build will fail. In this case, consider using
  168. #torch-tensorrt integration. We might do a blog post on this in the future.
  169. else:
  170. logger.log(trt.Logger.ERROR, 'Network contains operations that are not supported by TensorRT')
  171. logger.log(trt.Logger.ERROR, 'QUITTING because network is not supported')
  172. quit()
  173. if self.device==trt.DeviceType.DLA:
  174. dla_supported=0
  175. logger.log(trt.Logger.INFO, 'Number of layers in network: {}'.format(net.num_layers))
  176. for idx in range(net.num_layers):
  177. if config.can_run_on_DLA(net.get_layer(idx)):
  178. dla_supported+=1
  179. logger.log(trt.Logger.INFO, f'{dla_supported} of {net.num_layers} layers are supported on DLA')
  180. logger.log(trt.Logger.INFO, 'Building inference engine...')
  181. engine=builder.build_engine(net, config)
  182. #this will take some time
  183. logger.log(trt.Logger.INFO, 'Inference engine built successfully')
  184. with open(self.enginepath, 'wb') as s:
  185. s.write(engine.serialize())
  186. logger.log(trt.Logger.INFO, f'Inference engine saved to {self.enginepath}')
  187. return engine, logger
  188. class Calibrator(trt.IInt8EntropyCalibrator2):
  189. def __init__(self, imgdir, n_samples,input_size=(640,360), batch_size=1, iotype=np.float16):
  190. super().__init__()
  191. self.imgdir=imgdir
  192. self.n_samples=n_samples
  193. self.input_size=input_size
  194. self.batch_size=batch_size
  195. self.iotype=iotype
  196. self.pp_mean=np.array([0.485, 0.456, 0.406]).reshape((1,1,3))
  197. self.pp_stdev=np.array([0.229, 0.224, 0.225]).reshape((1,1,3))
  198. self.cache_path='cache.ich'
  199. self.setup()
  200. self.images_read=0
  201. def setup(self):
  202. all_images=sorted([f for f in os.listdir(self.imgdir) if f.endswith('.jpg')])
  203. assert len(all_images)>=self.n_samples, f'Not enough images available. Requested {self.n_samples} images for calibration but only {len(all_images)} are avialable in {self.imgdir}'
  204. used=all_images[:self.n_samples]
  205. self.images=[os.path.join(self.imgdir,f) for f in used]
  206. nbytes=self.batch_size*3*self.input_size[0]*self.input_size[1]*self.iotype(1).nbytes
  207. self.buffer=cuda.mem_alloc(nbytes)
  208. def preprocess(self, img):
  209. img=cv2.resize(img,self.input_size)
  210. img=img[...,::-1] #bgr2rgb
  211. img=img.astype(np.float32)/255
  212. img=(img-self.pp_mean)/self.pp_stdev #normalize
  213. img=np.transpose(img,(2,0,1)) #HWC to CHW format
  214. img=np.ascontiguousarray(img[None,...]).astype(self.iotype)
  215. #NCHW data of type used by engine input
  216. return img
  217. def get_batch(self, names):
  218. if self.images_read+self.batch_size < self.n_samples:
  219. batch=[]
  220. for idx in range(self.images_read,self.images_read+self.batch_size):
  221. img=cv2.imread(self.images[idx],1)
  222. intensor=self.preprocess(img)
  223. batch.append(intensor)
  224. batch=np.concatenate(batch, axis=0)
  225. cuda.memcpy_htod(self.buffer, batch)
  226. self.images_read+=self.batch_size
  227. return [int(self.buffer)]
  228. else:
  229. return None
  230. def get_batch_size(self):
  231. return self.batch_size
  232. def read_calibration_cache(self):
  233. if os.path.exists(self.cache_path):
  234. with open(self.cache_path, "rb") as f:
  235. return f.read()
  236. def write_calibration_cache(self, cache):
  237. with open(self.cache_path, 'wb') as f:
  238. f.write(cache)
  239. def infervideo_2DLAs(infile, onnxpath, calibrator=None, precision='INT8',display=False):
  240. src=cv2.VideoCapture(infile)
  241. seg1=TRTSegmentor(onnxpath, colors, device='DLA', precision=precision ,calibrator=calibrator, dla_core=0)
  242. seg2=TRTSegmentor(onnxpath, colors, device='DLA', precision=precision ,calibrator=calibrator, dla_core=1)
  243. ret1,frame1=src.read()
  244. ret2,frame2=src.read()
  245. fps=0.0
  246. while ret1 and ret2:
  247. intensor1=seg1.preprocess(frame1)
  248. intensor2=seg2.preprocess(frame2)
  249. start=time.time()
  250. cuda.memcpy_htod_async(seg1.d_input, intensor1, seg1.stream)
  251. cuda.memcpy_htod_async(seg2.d_input, intensor2, seg2.stream)
  252. seg1.context.execute_async_v2(seg1.bindings, seg1.stream.handle, None)
  253. seg2.context.execute_async_v2(seg2.bindings, seg2.stream.handle, None)
  254. cuda.memcpy_dtoh_async(seg1.output, seg1.d_output, seg1.stream)
  255. cuda.memcpy_dtoh_async(seg2.output, seg2.d_output, seg2.stream)
  256. seg1.stream.synchronize()
  257. seg2.stream.synchronize()
  258. end=time.time()
  259. if display:
  260. drawn1=seg1.draw(frame1)
  261. drawn2=seg2.draw(frame2)
  262. cv2.imshow('segmented1', drawn1)
  263. cv2.imshow('segmented2', drawn2)
  264. k=cv2.waitKey(1)
  265. if k==ord('q'):
  266. break
  267. fps=0.9*fps+0.1*(2.0/(end-start))
  268. print('FPS = {:.3f}'.format(fps))
  269. ret1,frame1=src.read()
  270. ret2,frame2=src.read()
  271. if __name__ == '__main__':
  272. parser=argparse.ArgumentParser(description='TensorRT python tutorial')
  273. parser.add_argument('--precision', type=str,
  274. default='fp16', choices=['int8', 'fp16', 'fp32'],
  275. help='precision FP32, FP16 or INT8')
  276. parser.add_argument('--device', type=str,
  277. default='gpu', choices=['gpu', 'dla', 'dla0', 'dla1', '2DLAs'],
  278. help='GPU, DLA or 2DLAs')
  279. parser.add_argument('--infile', type=str, required=True,
  280. help='path of input video file to infer on')
  281. args=parser.parse_args()
  282. calibrator=Calibrator('./val2017/', 5000)
  283. if args.device=='2DLAs':
  284. precision=args.precision.upper()
  285. infervideo_2DLAs(args.infile, './segmodel.onnx', calibrator, precision)
  286. else:
  287. device=args.device.upper()
  288. precision=args.precision.upper()
  289. dla_core=int(device[3:]) if len(device)>3 else 0
  290. device=device[:3]
  291. seg=TRTSegmentor('./segmodel.onnx', colors,
  292. device=device,
  293. precision=precision,
  294. calibrator=calibrator,
  295. dla_core=dla_core)
  296. seg.infervideo(args.infile)
  297. print('Inferred successfully')