models_to_frozen_graphs.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  2. # Full license terms provided in LICENSE.md file.
  3. import tensorflow as tf
  4. import sys
  5. sys.path.append("third_party/models/research/")
  6. sys.path.append("third_party/models")
  7. sys.path.append("third_party/")
  8. sys.path.append("third_party/models/research/slim/")
  9. sys.path.append("scripts")
  10. import tensorflow.contrib.slim as tf_slim
  11. import slim.nets as nets
  12. import slim.nets.vgg
  13. from model_meta import NETS, CHECKPOINT_DIR, FROZEN_GRAPHS_DIR
  14. from convert_relu6 import convertRelu6
  15. import os
  16. if __name__ == '__main__':
  17. if not os.path.exists(CHECKPOINT_DIR):
  18. print("%s does not exist. Exiting." % CHECKPOINT_DIR)
  19. exit()
  20. if not os.path.exists(FROZEN_GRAPHS_DIR):
  21. print("%s does not exist. Creating it now." % FROZEN_GRAPHS_DIR)
  22. os.makedirs(FROZEN_GRAPHS_DIR)
  23. for net_name, net_meta in NETS.items():
  24. if 'exclude' in net_meta.keys() and net_meta['exclude'] is True:
  25. continue
  26. print("Converting %s" % net_name)
  27. print(net_meta)
  28. tf.reset_default_graph()
  29. tf_config = tf.ConfigProto()
  30. tf_config.gpu_options.allow_growth = True
  31. with tf.Session(config=tf_config) as tf_sess:
  32. tf_sess = tf.Session(config=tf_config)
  33. tf_input = tf.placeholder(
  34. tf.float32,
  35. (
  36. None,
  37. net_meta['input_height'],
  38. net_meta['input_width'],
  39. net_meta['input_channels']
  40. ),
  41. name=net_meta['input_name']
  42. )
  43. with tf_slim.arg_scope(net_meta['arg_scope']()):
  44. tf_net, tf_end_points = net_meta['model'](
  45. tf_input,
  46. is_training=False,
  47. num_classes=net_meta['num_classes']
  48. )
  49. tf_saver = tf.train.Saver()
  50. tf_saver.restore(
  51. save_path=net_meta['checkpoint_filename'],
  52. sess=tf_sess
  53. )
  54. frozen_graph = tf.graph_util.convert_variables_to_constants(
  55. tf_sess,
  56. tf_sess.graph_def,
  57. output_node_names=net_meta['output_names']
  58. )
  59. frozen_graph = convertRelu6(frozen_graph)
  60. with open(net_meta['frozen_graph_filename'], 'wb') as f:
  61. f.write(frozen_graph.SerializeToString())
  62. f.close()