models_to_frozen_graphs.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  2. # Full license terms provided in LICENSE.md file.
  3. import tensorflow as tf
  4. import sys
  5. sys.path.append("third_party/models/research/")
  6. sys.path.append("third_party/models")
  7. sys.path.append("third_party/")
  8. sys.path.append("third_party/models/research/slim/")
  9. sys.path.append("scripts")
  10. import tensorflow.contrib.slim as tf_slim
  11. import slim.nets as nets
  12. import slim.nets.vgg
  13. from model_meta import NETS, CHECKPOINT_DIR, FROZEN_GRAPHS_DIR
  14. from convert_relu6 import convertRelu6
  15. import os
  16. if __name__ == '__main__':
  17. if not os.path.exists(CHECKPOINT_DIR):
  18. print("%s does not exist. Exiting." % CHECKPOINT_DIR)
  19. exit()
  20. if not os.path.exists(FROZEN_GRAPHS_DIR):
  21. print("%s does not exist. Creating it now." % FROZEN_GRAPHS_DIR)
  22. os.makedirs(FROZEN_GRAPHS_DIR)
  23. for net_name, net_meta in NETS.items():
  24. if 'exclude' in net_meta.keys() and net_meta['exclude'] is True:
  25. continue
  26. print("Converting %s" % net_name)
  27. print(net_meta)
  28. tf.reset_default_graph()
  29. tf_config = tf.ConfigProto()
  30. tf_config.gpu_options.allow_growth = True
  31. tf_sess = tf.Session(config=tf_config)
  32. tf_input = tf.placeholder(
  33. tf.float32,
  34. (
  35. None,
  36. net_meta['input_height'],
  37. net_meta['input_width'],
  38. net_meta['input_channels']
  39. ),
  40. name=net_meta['input_name']
  41. )
  42. with tf_slim.arg_scope(net_meta['arg_scope']()):
  43. tf_net, tf_end_points = net_meta['model'](
  44. tf_input,
  45. is_training=False,
  46. num_classes=net_meta['num_classes']
  47. )
  48. tf_saver = tf.train.Saver()
  49. tf_saver.restore(
  50. save_path=net_meta['checkpoint_filename'],
  51. sess=tf_sess
  52. )
  53. frozen_graph = tf.graph_util.convert_variables_to_constants(
  54. tf_sess,
  55. tf_sess.graph_def,
  56. output_node_names=net_meta['output_names']
  57. )
  58. frozen_graph = convertRelu6(frozen_graph)
  59. with open(net_meta['frozen_graph_filename'], 'wb') as f:
  60. f.write(frozen_graph.SerializeToString())
  61. f.close()
  62. del tf_config
  63. del tf_sess
  64. del tf_input
  65. del tf_saver
  66. del frozen_graph