Преглед изворни кода

Update resnet model API + README

Neal Wu пре 8 година
родитељ
комит
0b9c1d302c
3 измењених фајлова са 7 додато и 4 уклоњено
  1. 3 0
      resnet/README.md
  2. 1 1
      resnet/cifar_input.py
  3. 3 3
      resnet/resnet_model.py

+ 3 - 0
resnet/README.md

@@ -93,6 +93,9 @@ bazel-bin/resnet/resnet_main --train_data_path=cifar10/data_batch* \
                              --dataset='cifar10' \
                              --num_gpus=1
 
+# Note that training takes about 5 hours on a TITAN X GPU, but the training script will not produce any output. In the meantime you can check on progress using tensorboard:
+tensorboard --logdir=/tmp/resnet_model
+
 # Evaluate the model.
 # Avoid running on the same GPU as the training job at the same time,
 # otherwise, you might run out of memory.

+ 1 - 1
resnet/cifar_input.py

@@ -100,7 +100,7 @@ def build_input(dataset, data_path, batch_size, mode):
   labels = tf.reshape(labels, [batch_size, 1])
   indices = tf.reshape(tf.range(0, batch_size, 1), [batch_size, 1])
   labels = tf.sparse_to_dense(
-      tf.concat_v2(values=[indices, labels], axis=1),
+      tf.concat(values=[indices, labels], axis=1),
       [batch_size, num_classes], 1.0, 0.0)
 
   assert len(images.get_shape()) == 4

+ 3 - 3
resnet/resnet_model.py

@@ -183,8 +183,8 @@ class ResNet(object):
             'moving_variance', params_shape, tf.float32,
             initializer=tf.constant_initializer(1.0, tf.float32),
             trainable=False)
-        tf.histogram_summary(mean.op.name, mean)
-        tf.histogram_summary(variance.op.name, variance)
+        tf.summary.histogram(mean.op.name, mean)
+        tf.summary.histogram(variance.op.name, variance)
       # elipson used to be 1e-5. Maybe 0.001 solves NaN problem in deeper net.
       y = tf.nn.batch_normalization(
           x, mean, variance, beta, gamma, 0.001)
@@ -265,7 +265,7 @@ class ResNet(object):
     for var in tf.trainable_variables():
       if var.op.name.find(r'DW') > 0:
         costs.append(tf.nn.l2_loss(var))
-        # tf.histogram_summary(var.op.name, var)
+        # tf.summary.histogram(var.op.name, var)
 
     return tf.multiply(self.hps.weight_decay_rate, tf.add_n(costs))