Bläddra i källkod

Changes for TF 1.0 compatibility

Yaroslav Bulatov 8 år sedan
förälder
incheckning
10340bf52f
2 ändrade filer med 16 tillägg och 7 borttagningar
  1. 9 1
      resnet/cifar_input.py
  2. 7 6
      resnet/resnet_model.py

+ 9 - 1
resnet/cifar_input.py

@@ -18,6 +18,14 @@
 
 import tensorflow as tf
 
+# backward compatible concat (arg order changed in head)
+import inspect
+def concat(values, axis):
+    if 'axis' in inspect.signature(tf.concat).parameters.keys():
+        return tf.concat(values=values, axis=axis)
+    else:
+        assert 'concat_dim' in inspect.signature(tf.concat).parameters.keys()
+        return tf.concat(concat_dim=axis, values=values)
 
 def build_input(dataset, data_path, batch_size, mode):
   """Build CIFAR image and labels.
@@ -101,7 +109,7 @@ def build_input(dataset, data_path, batch_size, mode):
   labels = tf.reshape(labels, [batch_size, 1])
   indices = tf.reshape(tf.range(0, batch_size, 1), [batch_size, 1])
   labels = tf.sparse_to_dense(
-      tf.concat(1, [indices, labels]),
+      tf.concat(values=[indices, labels], axis=1),
       [batch_size, num_classes], 1.0, 0.0)
 
   assert len(images.get_shape()) == 4

+ 7 - 6
resnet/resnet_model.py

@@ -24,6 +24,7 @@ from collections import namedtuple
 
 import numpy as np
 import tensorflow as tf
+import six
 
 from tensorflow.python.training import moving_averages
 
@@ -89,21 +90,21 @@ class ResNet(object):
     with tf.variable_scope('unit_1_0'):
       x = res_func(x, filters[0], filters[1], self._stride_arr(strides[0]),
                    activate_before_residual[0])
-    for i in xrange(1, self.hps.num_residual_units):
+    for i in six.moves.range(1, self.hps.num_residual_units):
       with tf.variable_scope('unit_1_%d' % i):
         x = res_func(x, filters[1], filters[1], self._stride_arr(1), False)
 
     with tf.variable_scope('unit_2_0'):
       x = res_func(x, filters[1], filters[2], self._stride_arr(strides[1]),
                    activate_before_residual[1])
-    for i in xrange(1, self.hps.num_residual_units):
+    for i in six.moves.range(1, self.hps.num_residual_units):
       with tf.variable_scope('unit_2_%d' % i):
         x = res_func(x, filters[2], filters[2], self._stride_arr(1), False)
 
     with tf.variable_scope('unit_3_0'):
       x = res_func(x, filters[2], filters[3], self._stride_arr(strides[2]),
                    activate_before_residual[2])
-    for i in xrange(1, self.hps.num_residual_units):
+    for i in six.moves.range(1, self.hps.num_residual_units):
       with tf.variable_scope('unit_3_%d' % i):
         x = res_func(x, filters[3], filters[3], self._stride_arr(1), False)
 
@@ -118,7 +119,7 @@ class ResNet(object):
 
     with tf.variable_scope('costs'):
       xent = tf.nn.softmax_cross_entropy_with_logits(
-          logits, self.labels)
+          logits=logits, labels=self.labels)
       self.cost = tf.reduce_mean(xent, name='xent')
       self.cost += self._decay()
 
@@ -266,7 +267,7 @@ class ResNet(object):
         costs.append(tf.nn.l2_loss(var))
         # tf.histogram_summary(var.op.name, var)
 
-    return tf.mul(self.hps.weight_decay_rate, tf.add_n(costs))
+    return tf.multiply(self.hps.weight_decay_rate, tf.add_n(costs))
 
   def _conv(self, name, x, filter_size, in_filters, out_filters, strides):
     """Convolution."""
@@ -280,7 +281,7 @@ class ResNet(object):
 
   def _relu(self, x, leakiness=0.0):
     """Relu, with optional leaky support."""
-    return tf.select(tf.less(x, 0.0), leakiness * x, x, name='leaky_relu')
+    return tf.where(tf.less(x, 0.0), leakiness * x, x, name='leaky_relu')
 
   def _fully_connected(self, x, out_dim):
     """FullyConnected layer for final output."""