convert_plan.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  2. # Full license terms provided in LICENSE.md file.
  3. import os
  4. import subprocess
  5. import uff
  6. import pdb
  7. import sys
  8. UFF_TO_PLAN_EXE_PATH = 'build/src/uff_to_plan'
  9. TMP_UFF_FILENAME = 'data/tmp.uff'
  10. def frozenToPlan(frozen_graph_filename, plan_filename, input_name, input_height,
  11. input_width, output_name, max_batch_size, max_workspace_size, data_type):
  12. # generate uff from frozen graph
  13. uff_model = uff.from_tensorflow_frozen_model(
  14. frozen_file=frozen_graph_filename,
  15. output_nodes=[output_name],
  16. output_filename=TMP_UFF_FILENAME,
  17. text=False,
  18. )
  19. # convert frozen graph to engine (plan)
  20. args = [
  21. TMP_UFF_FILENAME,
  22. plan_filename,
  23. input_name,
  24. str(input_height),
  25. str(input_width),
  26. output_name,
  27. str(max_batch_size),
  28. str(max_workspace_size),
  29. data_type # float / half
  30. ]
  31. subprocess.call([UFF_TO_PLAN_EXE_PATH] + args)
  32. # cleanup tmp file
  33. os.remove(TMP_UFF_FILENAME)
  34. if __name__ == '__main__':
  35. if not os.path.exists('data/plans'):
  36. os.makedirs('data/plans')
  37. if len(sys.argv) is not 10:
  38. print("usage: python convert_plan.py <frozen_graph_path> <output_plan_path> <input_name> <input_height>"
  39. " <input_width> <output_name> <max_batch_size> <max_workspace_size> <data_type>")
  40. exit()
  41. frozen_graph_filename = sys.argv[1]
  42. plan_filename = sys.argv[2]
  43. input_name = sys.argv[3]
  44. input_height = sys.argv[4]
  45. input_width = sys.argv[5]
  46. output_name = sys.argv[6]
  47. max_batch_size = sys.argv[7]
  48. max_workspace_size = sys.argv[8]
  49. data_type = sys.argv[9]
  50. frozenToPlan(frozen_graph_filename,
  51. plan_filename,
  52. input_name,
  53. input_height,
  54. input_width,
  55. output_name,
  56. max_batch_size,
  57. max_workspace_size,
  58. data_type
  59. )