convert_plan.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  2. # Full license terms provided in LICENSE.md file.
  3. import os
  4. import subprocess
  5. import uff
  6. import pdb
  7. import sys
  8. UFF_TO_PLAN_EXE_PATH = 'build/src/uff_to_plan'
  9. TMP_UFF_FILENAME = 'data/tmp.uff'
  10. def frozenToPlan(frozen_graph_filename, plan_filename, input_name, input_height,
  11. input_width, output_name, max_batch_size, max_workspace_size, data_type):
  12. # generate uff from frozen graph
  13. uff_model = uff.from_tensorflow_frozen_model(
  14. frozen_file=frozen_graph_filename,
  15. output_nodes=[output_name],
  16. output_filename=TMP_UFF_FILENAME,
  17. text=False
  18. )
  19. # convert frozen graph to engine (plan)
  20. args = [
  21. TMP_UFF_FILENAME,
  22. plan_filename,
  23. input_name,
  24. str(input_height),
  25. str(input_width),
  26. output_name,
  27. str(max_batch_size),
  28. str(max_workspace_size),
  29. data_type # float / half
  30. ]
  31. subprocess.call([UFF_TO_PLAN_EXE_PATH] + args)
  32. # cleanup tmp file
  33. os.remove(TMP_UFF_FILENAME)
  34. if __name__ == '__main__':
  35. if len(sys.argv) is not 10:
  36. print("usage: python convert_plan.py <frozen_graph_path> <output_plan_path> <input_name> <input_height>"
  37. " <input_width> <output_name> <max_batch_size> <max_workspace_size> <data_type>")
  38. exit()
  39. frozen_graph_filename = sys.argv[1]
  40. plan_filename = sys.argv[2]
  41. input_name = sys.argv[3]
  42. input_height = sys.argv[4]
  43. input_width = sys.argv[5]
  44. output_name = sys.argv[6]
  45. max_batch_size = sys.argv[7]
  46. max_workspace_size = sys.argv[8]
  47. data_type = sys.argv[9]
  48. frozenToPlan(frozen_graph_filename,
  49. plan_filename,
  50. input_name,
  51. input_height,
  52. input_width,
  53. output_name,
  54. max_batch_size,
  55. max_workspace_size,
  56. data_type
  57. )