Jelajahi Sumber

Merge pull request #959 from yaroslavvb/resnet-tf10-compatibility

Changes to Resnet for TF 1.0 compatibility
Xin Pan 8 tahun lalu
induk
melakukan
ea364a9ef5
3 mengubah file dengan 10 tambahan dan 9 penghapusan
  1. 1 2
      resnet/cifar_input.py
  2. 2 1
      resnet/resnet_main.py
  3. 7 6
      resnet/resnet_model.py

+ 1 - 2
resnet/cifar_input.py

@@ -18,7 +18,6 @@
 
 import tensorflow as tf
 
-
 def build_input(dataset, data_path, batch_size, mode):
   """Build CIFAR image and labels.
 
@@ -101,7 +100,7 @@ def build_input(dataset, data_path, batch_size, mode):
   labels = tf.reshape(labels, [batch_size, 1])
   indices = tf.reshape(tf.range(0, batch_size, 1), [batch_size, 1])
   labels = tf.sparse_to_dense(
-      tf.concat(1, [indices, labels]),
+      tf.concat_v2(values=[indices, labels], axis=1),
       [batch_size, num_classes], 1.0, 0.0)
 
   assert len(images.get_shape()) == 4

+ 2 - 1
resnet/resnet_main.py

@@ -16,6 +16,7 @@
 """ResNet Train/Eval module.
 """
 import time
+import six
 import sys
 
 import cifar_input
@@ -140,7 +141,7 @@ def evaluate(hps):
     saver.restore(sess, ckpt_state.model_checkpoint_path)
 
     total_prediction, correct_prediction = 0, 0
-    for _ in xrange(FLAGS.eval_batch_count):
+    for _ in six.moves.range(FLAGS.eval_batch_count):
       (summaries, loss, predictions, truth, train_step) = sess.run(
           [model.summaries, model.cost, model.predictions,
            model.labels, model.global_step])

+ 7 - 6
resnet/resnet_model.py

@@ -24,6 +24,7 @@ from collections import namedtuple
 
 import numpy as np
 import tensorflow as tf
+import six
 
 from tensorflow.python.training import moving_averages
 
@@ -89,21 +90,21 @@ class ResNet(object):
     with tf.variable_scope('unit_1_0'):
       x = res_func(x, filters[0], filters[1], self._stride_arr(strides[0]),
                    activate_before_residual[0])
-    for i in xrange(1, self.hps.num_residual_units):
+    for i in six.moves.range(1, self.hps.num_residual_units):
       with tf.variable_scope('unit_1_%d' % i):
         x = res_func(x, filters[1], filters[1], self._stride_arr(1), False)
 
     with tf.variable_scope('unit_2_0'):
       x = res_func(x, filters[1], filters[2], self._stride_arr(strides[1]),
                    activate_before_residual[1])
-    for i in xrange(1, self.hps.num_residual_units):
+    for i in six.moves.range(1, self.hps.num_residual_units):
       with tf.variable_scope('unit_2_%d' % i):
         x = res_func(x, filters[2], filters[2], self._stride_arr(1), False)
 
     with tf.variable_scope('unit_3_0'):
       x = res_func(x, filters[2], filters[3], self._stride_arr(strides[2]),
                    activate_before_residual[2])
-    for i in xrange(1, self.hps.num_residual_units):
+    for i in six.moves.range(1, self.hps.num_residual_units):
       with tf.variable_scope('unit_3_%d' % i):
         x = res_func(x, filters[3], filters[3], self._stride_arr(1), False)
 
@@ -118,7 +119,7 @@ class ResNet(object):
 
     with tf.variable_scope('costs'):
       xent = tf.nn.softmax_cross_entropy_with_logits(
-          logits, self.labels)
+          logits=logits, labels=self.labels)
       self.cost = tf.reduce_mean(xent, name='xent')
       self.cost += self._decay()
 
@@ -266,7 +267,7 @@ class ResNet(object):
         costs.append(tf.nn.l2_loss(var))
         # tf.histogram_summary(var.op.name, var)
 
-    return tf.mul(self.hps.weight_decay_rate, tf.add_n(costs))
+    return tf.multiply(self.hps.weight_decay_rate, tf.add_n(costs))
 
   def _conv(self, name, x, filter_size, in_filters, out_filters, strides):
     """Convolution."""
@@ -280,7 +281,7 @@ class ResNet(object):
 
   def _relu(self, x, leakiness=0.0):
     """Relu, with optional leaky support."""
-    return tf.select(tf.less(x, 0.0), leakiness * x, x, name='leaky_relu')
+    return tf.where(tf.less(x, 0.0), leakiness * x, x, name='leaky_relu')
 
   def _fully_connected(self, x, out_dim):
     """FullyConnected layer for final output."""