|
@@ -38,45 +38,40 @@ if __name__ == '__main__':
|
|
|
tf.reset_default_graph()
|
|
|
tf_config = tf.ConfigProto()
|
|
|
tf_config.gpu_options.allow_growth = True
|
|
|
- tf_sess = tf.Session(config=tf_config)
|
|
|
- tf_input = tf.placeholder(
|
|
|
- tf.float32,
|
|
|
- (
|
|
|
- None,
|
|
|
- net_meta['input_height'],
|
|
|
- net_meta['input_width'],
|
|
|
- net_meta['input_channels']
|
|
|
- ),
|
|
|
- name=net_meta['input_name']
|
|
|
- )
|
|
|
-
|
|
|
- with tf_slim.arg_scope(net_meta['arg_scope']()):
|
|
|
- tf_net, tf_end_points = net_meta['model'](
|
|
|
- tf_input,
|
|
|
- is_training=False,
|
|
|
- num_classes=net_meta['num_classes']
|
|
|
+ with tf.Session(config=tf_config) as tf_sess:
|
|
|
+ tf_sess = tf.Session(config=tf_config)
|
|
|
+ tf_input = tf.placeholder(
|
|
|
+ tf.float32,
|
|
|
+ (
|
|
|
+ None,
|
|
|
+ net_meta['input_height'],
|
|
|
+ net_meta['input_width'],
|
|
|
+ net_meta['input_channels']
|
|
|
+ ),
|
|
|
+ name=net_meta['input_name']
|
|
|
)
|
|
|
|
|
|
- tf_saver = tf.train.Saver()
|
|
|
- tf_saver.restore(
|
|
|
- save_path=net_meta['checkpoint_filename'],
|
|
|
- sess=tf_sess
|
|
|
- )
|
|
|
- frozen_graph = tf.graph_util.convert_variables_to_constants(
|
|
|
- tf_sess,
|
|
|
- tf_sess.graph_def,
|
|
|
- output_node_names=net_meta['output_names']
|
|
|
- )
|
|
|
+ with tf_slim.arg_scope(net_meta['arg_scope']()):
|
|
|
+ tf_net, tf_end_points = net_meta['model'](
|
|
|
+ tf_input,
|
|
|
+ is_training=False,
|
|
|
+ num_classes=net_meta['num_classes']
|
|
|
+ )
|
|
|
|
|
|
- frozen_graph = convertRelu6(frozen_graph)
|
|
|
+ tf_saver = tf.train.Saver()
|
|
|
+ tf_saver.restore(
|
|
|
+ save_path=net_meta['checkpoint_filename'],
|
|
|
+ sess=tf_sess
|
|
|
+ )
|
|
|
+ frozen_graph = tf.graph_util.convert_variables_to_constants(
|
|
|
+ tf_sess,
|
|
|
+ tf_sess.graph_def,
|
|
|
+ output_node_names=net_meta['output_names']
|
|
|
+ )
|
|
|
|
|
|
- with open(net_meta['frozen_graph_filename'], 'wb') as f:
|
|
|
- f.write(frozen_graph.SerializeToString())
|
|
|
-
|
|
|
- f.close()
|
|
|
+ frozen_graph = convertRelu6(frozen_graph)
|
|
|
|
|
|
- del tf_config
|
|
|
- del tf_sess
|
|
|
- del tf_input
|
|
|
- del tf_saver
|
|
|
- del frozen_graph
|
|
|
+ with open(net_meta['frozen_graph_filename'], 'wb') as f:
|
|
|
+ f.write(frozen_graph.SerializeToString())
|
|
|
+
|
|
|
+ f.close()
|