Bladeren bron

another xrange change + change to concat_v2

Yaroslav Bulatov 8 jaren geleden
bovenliggende
commit
4ec3452d95
2 gewijzigde bestanden met toevoegingen van 3 en 11 verwijderingen
  1. 1 10
      resnet/cifar_input.py
  2. 2 1
      resnet/resnet_main.py

+ 1 - 10
resnet/cifar_input.py

@@ -18,15 +18,6 @@
 
 import tensorflow as tf
 
-# backward compatible concat (arg order changed in head)
-import inspect
-def concat(values, axis):
-    if 'axis' in inspect.signature(tf.concat).parameters.keys():
-        return tf.concat(values=values, axis=axis)
-    else:
-        assert 'concat_dim' in inspect.signature(tf.concat).parameters.keys()
-        return tf.concat(concat_dim=axis, values=values)
-
 def build_input(dataset, data_path, batch_size, mode):
   """Build CIFAR image and labels.
 
@@ -109,7 +100,7 @@ def build_input(dataset, data_path, batch_size, mode):
   labels = tf.reshape(labels, [batch_size, 1])
   indices = tf.reshape(tf.range(0, batch_size, 1), [batch_size, 1])
   labels = tf.sparse_to_dense(
-      tf.concat(values=[indices, labels], axis=1),
+      tf.concat_v2(values=[indices, labels], axis=1),
       [batch_size, num_classes], 1.0, 0.0)
 
   assert len(images.get_shape()) == 4

+ 2 - 1
resnet/resnet_main.py

@@ -16,6 +16,7 @@
 """ResNet Train/Eval module.
 """
 import time
+import six
 import sys
 
 import cifar_input
@@ -140,7 +141,7 @@ def evaluate(hps):
     saver.restore(sess, ckpt_state.model_checkpoint_path)
 
     total_prediction, correct_prediction = 0, 0
-    for _ in xrange(FLAGS.eval_batch_count):
+    for _ in six.moves.range(FLAGS.eval_batch_count):
       (summaries, loss, predictions, truth, train_step) = sess.run(
           [model.summaries, model.cost, model.predictions,
            model.labels, model.global_step])