model_meta.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  2. # Full license terms provided in LICENSE.md file.
  3. import numpy as np
  4. import sys
  5. sys.path.append("third_party/models/research/")
  6. sys.path.append("third_party/models")
  7. sys.path.append("third_party/")
  8. sys.path.append("third_party/models/research/slim/")
  9. import tensorflow.contrib.slim as tf_slim
  10. import slim.nets as nets
  11. import slim.nets.vgg
  12. import slim.nets.inception
  13. import slim.nets.resnet_v1
  14. import slim.nets.resnet_v2
  15. import slim.nets.mobilenet_v1
  16. def create_label_map(label_file='data/imagenet_labels_1001.txt'):
  17. label_map = {}
  18. with open(label_file, 'r') as f:
  19. labels = f.readlines()
  20. for i, label in enumerate(labels):
  21. label_map[i] = label
  22. return label_map
  23. IMAGNET2012_LABEL_MAP = create_label_map()
  24. def preprocess_vgg(image):
  25. return np.array(image, dtype=np.float32) - np.array([123.68, 116.78, 103.94])
  26. def postprocess_vgg(output):
  27. output = output.flatten()
  28. predictions_top5 = np.argsort(output)[::-1][0:5]
  29. labels_top5 = [IMAGNET2012_LABEL_MAP[p + 1] for p in predictions_top5]
  30. return labels_top5
  31. def preprocess_inception(image):
  32. return 2.0 * (np.array(image, dtype=np.float32) / 255.0 - 0.5)
  33. def postprocess_inception(output):
  34. output = output.flatten()
  35. predictions_top5 = np.argsort(output)[::-1][0:5]
  36. labels_top5 = [IMAGNET2012_LABEL_MAP[p] for p in predictions_top5]
  37. return labels_top5
  38. def mobilenet_v1_1p0_224(*args, **kwargs):
  39. kwargs['depth_multiplier'] = 1.0
  40. return nets.mobilenet_v1.mobilenet_v1(*args, **kwargs)
  41. def mobilenet_v1_0p5_160(*args, **kwargs):
  42. kwargs['depth_multiplier'] = 0.5
  43. return nets.mobilenet_v1.mobilenet_v1(*args, **kwargs)
  44. def mobilenet_v1_0p25_128(*args, **kwargs):
  45. kwargs['depth_multiplier'] = 0.25
  46. return nets.mobilenet_v1.mobilenet_v1(*args, **kwargs)
  47. CHECKPOINT_DIR = 'data/checkpoints/'
  48. FROZEN_GRAPHS_DIR = 'data/frozen_graphs/'
  49. # UFF_DIR = 'data/uff/'
  50. PLAN_DIR = 'data/plans/'
  51. NETS = {
  52. 'vgg_16': {
  53. 'model': nets.vgg.vgg_16,
  54. 'arg_scope': nets.vgg.vgg_arg_scope,
  55. 'num_classes': 1000,
  56. 'input_name': 'input',
  57. 'output_names': ['vgg_16/fc8/BiasAdd'],
  58. 'input_width': 224,
  59. 'input_height': 224,
  60. 'input_channels': 3,
  61. 'preprocess_fn': preprocess_vgg,
  62. 'postprocess_fn': postprocess_vgg,
  63. 'checkpoint_filename': CHECKPOINT_DIR + 'vgg_16.ckpt',
  64. 'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'vgg_16.pb',
  65. 'trt_convert_status': "works",
  66. 'plan_filename': PLAN_DIR + 'vgg_16.plan'
  67. },
  68. 'vgg_19': {
  69. 'model': nets.vgg.vgg_19,
  70. 'arg_scope': nets.vgg.vgg_arg_scope,
  71. 'num_classes': 1000,
  72. 'input_name': 'input',
  73. 'output_names': ['vgg_19/fc8/BiasAdd'],
  74. 'input_width': 224,
  75. 'input_height': 224,
  76. 'input_channels': 3,
  77. 'preprocess_fn': preprocess_vgg,
  78. 'postprocess_fn': postprocess_vgg,
  79. 'checkpoint_filename': CHECKPOINT_DIR + 'vgg_19.ckpt',
  80. 'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'vgg_19.pb',
  81. 'trt_convert_status': "works",
  82. 'plan_filename': PLAN_DIR + 'vgg_19.plan',
  83. 'exclude': True
  84. },
  85. 'inception_v1': {
  86. 'model': nets.inception.inception_v1,
  87. 'arg_scope': nets.inception.inception_v1_arg_scope,
  88. 'num_classes': 1001,
  89. 'input_name': 'input',
  90. 'input_width': 224,
  91. 'input_height': 224,
  92. 'input_channels': 3,
  93. 'output_names': ['InceptionV1/Logits/SpatialSqueeze'],
  94. 'checkpoint_filename': CHECKPOINT_DIR + 'inception_v1.ckpt',
  95. 'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'inception_v1.pb',
  96. 'preprocess_fn': preprocess_inception,
  97. 'postprocess_fn': postprocess_inception,
  98. 'trt_convert_status': "works",
  99. 'plan_filename': PLAN_DIR + 'inception_v1.plan'
  100. },
  101. 'inception_v2': {
  102. 'model': nets.inception.inception_v2,
  103. 'arg_scope': nets.inception.inception_v2_arg_scope,
  104. 'num_classes': 1001,
  105. 'input_name': 'input',
  106. 'input_width': 224,
  107. 'input_height': 224,
  108. 'input_channels': 3,
  109. 'output_names': ['InceptionV2/Logits/SpatialSqueeze'],
  110. 'checkpoint_filename': CHECKPOINT_DIR + 'inception_v2.ckpt',
  111. 'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'inception_v2.pb',
  112. 'preprocess_fn': preprocess_inception,
  113. 'postprocess_fn': postprocess_inception,
  114. 'trt_convert_status': "bad results",
  115. 'plan_filename': PLAN_DIR + 'inception_v2.plan'
  116. },
  117. 'inception_v3': {
  118. 'model': nets.inception.inception_v3,
  119. 'arg_scope': nets.inception.inception_v3_arg_scope,
  120. 'num_classes': 1001,
  121. 'input_name': 'input',
  122. 'input_width': 299,
  123. 'input_height': 299,
  124. 'input_channels': 3,
  125. 'output_names': ['InceptionV3/Logits/SpatialSqueeze'],
  126. 'checkpoint_filename': CHECKPOINT_DIR + 'inception_v3.ckpt',
  127. 'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'inception_v3.pb',
  128. 'preprocess_fn': preprocess_inception,
  129. 'postprocess_fn': postprocess_inception,
  130. 'trt_convert_status': "works",
  131. 'plan_filename': PLAN_DIR + 'inception_v3.plan'
  132. },
  133. 'inception_v4': {
  134. 'model': nets.inception.inception_v4,
  135. 'arg_scope': nets.inception.inception_v4_arg_scope,
  136. 'num_classes': 1001,
  137. 'input_name': 'input',
  138. 'input_width': 299,
  139. 'input_height': 299,
  140. 'input_channels': 3,
  141. 'output_names': ['InceptionV4/Logits/Logits/BiasAdd'],
  142. 'checkpoint_filename': CHECKPOINT_DIR + 'inception_v4.ckpt',
  143. 'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'inception_v4.pb',
  144. 'preprocess_fn': preprocess_inception,
  145. 'postprocess_fn': postprocess_inception,
  146. 'trt_convert_status': "works",
  147. 'plan_filename': PLAN_DIR + 'inception_v4.plan'
  148. },
  149. 'inception_resnet_v2': {
  150. 'model': nets.inception.inception_resnet_v2,
  151. 'arg_scope': nets.inception.inception_resnet_v2_arg_scope,
  152. 'num_classes': 1001,
  153. 'input_name': 'input',
  154. 'input_width': 299,
  155. 'input_height': 299,
  156. 'input_channels': 3,
  157. 'output_names': ['InceptionResnetV2/Logits/Logits/BiasAdd'],
  158. 'checkpoint_filename': CHECKPOINT_DIR + 'inception_resnet_v2_2016_08_30.ckpt',
  159. 'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'inception_resnet_v2.pb',
  160. 'preprocess_fn': preprocess_inception,
  161. 'postprocess_fn': postprocess_inception,
  162. 'trt_convert_status': "works",
  163. 'plan_filename': PLAN_DIR + 'inception_resnet_v2.plan'
  164. },
  165. 'resnet_v1_50': {
  166. 'model': nets.resnet_v1.resnet_v1_50,
  167. 'arg_scope': nets.resnet_v1.resnet_arg_scope,
  168. 'num_classes': 1000,
  169. 'input_name': 'input',
  170. 'input_width': 224,
  171. 'input_height': 224,
  172. 'input_channels': 3,
  173. 'output_names': ['resnet_v1_50/SpatialSqueeze'],
  174. 'checkpoint_filename': CHECKPOINT_DIR + 'resnet_v1_50.ckpt',
  175. 'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'resnet_v1_50.pb',
  176. 'preprocess_fn': preprocess_vgg,
  177. 'postprocess_fn': postprocess_vgg,
  178. 'plan_filename': PLAN_DIR + 'resnet_v1_50.plan'
  179. },
  180. 'resnet_v1_101': {
  181. 'model': nets.resnet_v1.resnet_v1_101,
  182. 'arg_scope': nets.resnet_v1.resnet_arg_scope,
  183. 'num_classes': 1000,
  184. 'input_name': 'input',
  185. 'input_width': 224,
  186. 'input_height': 224,
  187. 'input_channels': 3,
  188. 'output_names': ['resnet_v1_101/SpatialSqueeze'],
  189. 'checkpoint_filename': CHECKPOINT_DIR + 'resnet_v1_101.ckpt',
  190. 'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'resnet_v1_101.pb',
  191. 'preprocess_fn': preprocess_vgg,
  192. 'postprocess_fn': postprocess_vgg,
  193. 'plan_filename': PLAN_DIR + 'resnet_v1_101.plan'
  194. },
  195. 'resnet_v1_152': {
  196. 'model': nets.resnet_v1.resnet_v1_152,
  197. 'arg_scope': nets.resnet_v1.resnet_arg_scope,
  198. 'num_classes': 1000,
  199. 'input_name': 'input',
  200. 'input_width': 224,
  201. 'input_height': 224,
  202. 'input_channels': 3,
  203. 'output_names': ['resnet_v1_152/SpatialSqueeze'],
  204. 'checkpoint_filename': CHECKPOINT_DIR + 'resnet_v1_152.ckpt',
  205. 'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'resnet_v1_152.pb',
  206. 'preprocess_fn': preprocess_vgg,
  207. 'postprocess_fn': postprocess_vgg,
  208. 'plan_filename': PLAN_DIR + 'resnet_v1_152.plan'
  209. },
  210. 'resnet_v2_50': {
  211. 'model': nets.resnet_v2.resnet_v2_50,
  212. 'arg_scope': nets.resnet_v2.resnet_arg_scope,
  213. 'num_classes': 1001,
  214. 'input_name': 'input',
  215. 'input_width': 299,
  216. 'input_height': 299,
  217. 'input_channels': 3,
  218. 'output_names': ['resnet_v2_50/SpatialSqueeze'],
  219. 'checkpoint_filename': CHECKPOINT_DIR + 'resnet_v2_50.ckpt',
  220. 'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'resnet_v2_50.pb',
  221. 'preprocess_fn': preprocess_inception,
  222. 'postprocess_fn': postprocess_inception,
  223. 'plan_filename': PLAN_DIR + 'resnet_v2_50.plan'
  224. },
  225. 'resnet_v2_101': {
  226. 'model': nets.resnet_v2.resnet_v2_101,
  227. 'arg_scope': nets.resnet_v2.resnet_arg_scope,
  228. 'num_classes': 1001,
  229. 'input_name': 'input',
  230. 'input_width': 299,
  231. 'input_height': 299,
  232. 'input_channels': 3,
  233. 'output_names': ['resnet_v2_101/SpatialSqueeze'],
  234. 'checkpoint_filename': CHECKPOINT_DIR + 'resnet_v2_101.ckpt',
  235. 'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'resnet_v2_101.pb',
  236. 'preprocess_fn': preprocess_inception,
  237. 'postprocess_fn': postprocess_inception,
  238. 'plan_filename': PLAN_DIR + 'resnet_v2_101.plan'
  239. },
  240. 'resnet_v2_152': {
  241. 'model': nets.resnet_v2.resnet_v2_152,
  242. 'arg_scope': nets.resnet_v2.resnet_arg_scope,
  243. 'num_classes': 1001,
  244. 'input_name': 'input',
  245. 'input_width': 299,
  246. 'input_height': 299,
  247. 'input_channels': 3,
  248. 'output_names': ['resnet_v2_152/SpatialSqueeze'],
  249. 'checkpoint_filename': CHECKPOINT_DIR + 'resnet_v2_152.ckpt',
  250. 'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'resnet_v2_152.pb',
  251. 'preprocess_fn': preprocess_inception,
  252. 'postprocess_fn': postprocess_inception,
  253. 'plan_filename': PLAN_DIR + 'resnet_v2_152.plan'
  254. },
  255. #'resnet_v2_200': {
  256. #},
  257. 'mobilenet_v1_1p0_224': {
  258. 'model': mobilenet_v1_1p0_224,
  259. 'arg_scope': nets.mobilenet_v1.mobilenet_v1_arg_scope,
  260. 'num_classes': 1001,
  261. 'input_name': 'input',
  262. 'input_width': 224,
  263. 'input_height': 224,
  264. 'input_channels': 3,
  265. 'output_names': ['MobilenetV1/Logits/SpatialSqueeze'],
  266. 'checkpoint_filename': CHECKPOINT_DIR +
  267. 'mobilenet_v1_1.0_224.ckpt',
  268. 'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'mobilenet_v1_1p0_224.pb',
  269. 'plan_filename': PLAN_DIR + 'mobilenet_v1_1p0_224.plan',
  270. 'preprocess_fn': preprocess_inception,
  271. 'postprocess_fn': postprocess_inception,
  272. },
  273. 'mobilenet_v1_0p5_160': {
  274. 'model': mobilenet_v1_0p5_160,
  275. 'arg_scope': nets.mobilenet_v1.mobilenet_v1_arg_scope,
  276. 'num_classes': 1001,
  277. 'input_name': 'input',
  278. 'input_width': 160,
  279. 'input_height': 160,
  280. 'input_channels': 3,
  281. 'output_names': ['MobilenetV1/Logits/SpatialSqueeze'],
  282. 'checkpoint_filename': CHECKPOINT_DIR +
  283. 'mobilenet_v1_0.50_160.ckpt',
  284. 'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'mobilenet_v1_0p5_160.pb',
  285. 'plan_filename': PLAN_DIR + 'mobilenet_v1_0p5_160.plan',
  286. 'preprocess_fn': preprocess_inception,
  287. 'postprocess_fn': postprocess_inception,
  288. },
  289. 'mobilenet_v1_0p25_128': {
  290. 'model': mobilenet_v1_0p25_128,
  291. 'arg_scope': nets.mobilenet_v1.mobilenet_v1_arg_scope,
  292. 'num_classes': 1001,
  293. 'input_name': 'input',
  294. 'input_width': 128,
  295. 'input_height': 128,
  296. 'input_channels': 3,
  297. 'output_names': ['MobilenetV1/Logits/SpatialSqueeze'],
  298. 'checkpoint_filename': CHECKPOINT_DIR +
  299. 'mobilenet_v1_0.25_128.ckpt',
  300. 'frozen_graph_filename': FROZEN_GRAPHS_DIR + 'mobilenet_v1_0p25_128.pb',
  301. 'plan_filename': PLAN_DIR + 'mobilenet_v1_0p25_128.plan',
  302. 'preprocess_fn': preprocess_inception,
  303. 'postprocess_fn': postprocess_inception,
  304. },
  305. }