Browse Source

fixed session resource leak

John Welsh 5 years ago
parent
commit
91437c8a9c
2 changed files with 54 additions and 59 deletions
  1. 32 37
      scripts/models_to_frozen_graphs.py
  2. 22 22
      scripts/test_tf.py

+ 32 - 37
scripts/models_to_frozen_graphs.py

@@ -38,45 +38,40 @@ if __name__ == '__main__':
         tf.reset_default_graph()
         tf_config = tf.ConfigProto()
         tf_config.gpu_options.allow_growth = True
-        tf_sess = tf.Session(config=tf_config)
-        tf_input = tf.placeholder(
-            tf.float32, 
-            (
-                None, 
-                net_meta['input_height'], 
-                net_meta['input_width'], 
-                net_meta['input_channels']
-            ),
-            name=net_meta['input_name']
-        )
-
-        with tf_slim.arg_scope(net_meta['arg_scope']()):
-            tf_net, tf_end_points = net_meta['model'](
-                tf_input, 
-                is_training=False,
-                num_classes=net_meta['num_classes']
+        with tf.Session(config=tf_config) as tf_sess:
+            tf_sess = tf.Session(config=tf_config)
+            tf_input = tf.placeholder(
+                tf.float32, 
+                (
+                    None, 
+                    net_meta['input_height'], 
+                    net_meta['input_width'], 
+                    net_meta['input_channels']
+                ),
+                name=net_meta['input_name']
             )
 
-        tf_saver = tf.train.Saver()
-        tf_saver.restore(
-            save_path=net_meta['checkpoint_filename'], 
-            sess=tf_sess
-        )
-        frozen_graph = tf.graph_util.convert_variables_to_constants(
-            tf_sess,
-            tf_sess.graph_def,
-            output_node_names=net_meta['output_names']
-        )
+            with tf_slim.arg_scope(net_meta['arg_scope']()):
+                tf_net, tf_end_points = net_meta['model'](
+                    tf_input, 
+                    is_training=False,
+                    num_classes=net_meta['num_classes']
+                )
 
-        frozen_graph = convertRelu6(frozen_graph)
+            tf_saver = tf.train.Saver()
+            tf_saver.restore(
+                save_path=net_meta['checkpoint_filename'], 
+                sess=tf_sess
+            )
+            frozen_graph = tf.graph_util.convert_variables_to_constants(
+                tf_sess,
+                tf_sess.graph_def,
+                output_node_names=net_meta['output_names']
+            )
 
-        with open(net_meta['frozen_graph_filename'], 'wb') as f:
-            f.write(frozen_graph.SerializeToString())
-    
-        f.close()
+            frozen_graph = convertRelu6(frozen_graph)
 
-        del tf_config
-        del tf_sess
-        del tf_input
-        del tf_saver
-        del frozen_graph
+            with open(net_meta['frozen_graph_filename'], 'wb') as f:
+                f.write(frozen_graph.SerializeToString())
+        
+            f.close()

+ 22 - 22
scripts/test_tf.py

@@ -39,29 +39,29 @@ if __name__ == '__main__':
             tf_config.gpu_options.allow_growth = True
             tf_config.allow_soft_placement = True
 
-            tf_sess = tf.Session(config=tf_config, graph=graph)
-            tf_input = tf_sess.graph.get_tensor_by_name(net_meta['input_name'] + ':0')
-            tf_output = tf_sess.graph.get_tensor_by_name(net_meta['output_names'][0] + ':0')
+            with tf.Session(config=tf_config, graph=graph) as tf_sess:
+                tf_input = tf_sess.graph.get_tensor_by_name(net_meta['input_name'] + ':0')
+                tf_output = tf_sess.graph.get_tensor_by_name(net_meta['output_names'][0] + ':0')
 
-            # load and preprocess image
-            image = cv2.imread(TEST_IMAGE_PATH)
-            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
-            image = cv2.resize(image, (net_meta['input_width'], net_meta['input_height']))
-            image = net_meta['preprocess_fn'](image)
+                # load and preprocess image
+                image = cv2.imread(TEST_IMAGE_PATH)
+                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+                image = cv2.resize(image, (net_meta['input_width'], net_meta['input_height']))
+                image = net_meta['preprocess_fn'](image)
 
 
-            # run network
-            times = []
-            for i in range(NUM_RUNS + 1):
-                t0 = time.time()
-                output = tf_sess.run([tf_output], feed_dict={
-                    tf_input: image[None, ...]
-                })[0]
-                t1 = time.time()
-                times.append(1000 * (t1 - t0))
-            avg_time = np.mean(times[1:]) # don't include first run
+                # run network
+                times = []
+                for i in range(NUM_RUNS + 1):
+                    t0 = time.time()
+                    output = tf_sess.run([tf_output], feed_dict={
+                        tf_input: image[None, ...]
+                    })[0]
+                    t1 = time.time()
+                    times.append(1000 * (t1 - t0))
+                avg_time = np.mean(times[1:]) # don't include first run
 
-            # parse output
-            top5 = net_meta['postprocess_fn'](output)
-            print(top5)
-            test_f.write("%s %s\n" % (net_name, avg_time))
+                # parse output
+                top5 = net_meta['postprocess_fn'](output)
+                print(top5)
+                test_f.write("%s %s\n" % (net_name, avg_time))