Преглед на файлове

Convert tf.GraphKeys.VARIABLES -> tf.GraphKeys.GLOBAL_VARIABLES

Neal Wu преди 8 години
родител
ревизия
211ee00a3b
променени са 3 файла, в които са добавени 5 реда и са изтрити 5 реда
  1. 2 2
      slim/nets/inception_resnet_v2_test.py
  2. 2 2
      slim/nets/inception_v4_test.py
  3. 1 1
      video_prediction/prediction_train.py

+ 2 - 2
slim/nets/inception_resnet_v2_test.py

@@ -65,9 +65,9 @@ class InceptionTest(tf.test.TestCase):
         inception.inception_resnet_v2(inputs, num_classes)
       with tf.variable_scope('on_gpu'), tf.device('/gpu:0'):
         inception.inception_resnet_v2(inputs, num_classes)
-      for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_cpu'):
+      for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='on_cpu'):
         self.assertDeviceEqual(v.device, '/cpu:0')
-      for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_gpu'):
+      for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='on_gpu'):
         self.assertDeviceEqual(v.device, '/gpu:0')
 
   def testHalfSizeImages(self):

+ 2 - 2
slim/nets/inception_v4_test.py

@@ -146,9 +146,9 @@ class InceptionTest(tf.test.TestCase):
       inception.inception_v4(inputs, num_classes)
     with tf.variable_scope('on_gpu'), tf.device('/gpu:0'):
       inception.inception_v4(inputs, num_classes)
-    for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_cpu'):
+    for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='on_cpu'):
       self.assertDeviceEqual(v.device, '/cpu:0')
-    for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_gpu'):
+    for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='on_gpu'):
       self.assertDeviceEqual(v.device, '/gpu:0')
 
   def testHalfSizeImages(self):

+ 1 - 1
video_prediction/prediction_train.py

@@ -196,7 +196,7 @@ def main(unused_argv):
   print 'Constructing saver.'
   # Make saver.
   saver = tf.train.Saver(
-      tf.get_collection(tf.GraphKeys.VARIABLES), max_to_keep=0)
+      tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES), max_to_keep=0)
 
   # Make training session.
   sess = tf.InteractiveSession()