Parcourir la source

Merge https://github.com/tensorflow/models

Ivan Bogatyy il y a 8 ans
Parent
commit
fe794eabfd

+ 7 - 3
im2txt/README.md

@@ -296,8 +296,12 @@ Your trained *Show and Tell* model can generate captions for any JPEG image! The
 following command line will generate captions for an image from the test set.
 following command line will generate captions for an image from the test set.
 
 
 ```shell
 ```shell
-# Directory containing model checkpoints.
-CHECKPOINT_DIR="${HOME}/im2txt/model/train"
+# Path to checkpoint file or a directory containing checkpoint files. Passing
+# a directory will only work if there is also a file named 'checkpoint' which
+# lists the available checkpoints in the directory. It will not work if you
+# point to a directory with just a copy of a model checkpoint: in that case,
+# you will need to pass the checkpoint path explicitly.
+CHECKPOINT_PATH="${HOME}/im2txt/model/train"
 
 
 # Vocabulary file generated by the preprocessing script.
 # Vocabulary file generated by the preprocessing script.
 VOCAB_FILE="${HOME}/im2txt/data/mscoco/word_counts.txt"
 VOCAB_FILE="${HOME}/im2txt/data/mscoco/word_counts.txt"
@@ -314,7 +318,7 @@ export CUDA_VISIBLE_DEVICES=""
 
 
 # Run inference to generate captions.
 # Run inference to generate captions.
 bazel-bin/im2txt/run_inference \
 bazel-bin/im2txt/run_inference \
-  --checkpoint_path=${CHECKPOINT_DIR} \
+  --checkpoint_path=${CHECKPOINT_PATH} \
   --vocab_file=${VOCAB_FILE} \
   --vocab_file=${VOCAB_FILE} \
   --input_files=${IMAGE_FILE}
   --input_files=${IMAGE_FILE}
 ```
 ```

+ 2 - 0
im2txt/im2txt/run_inference.py

@@ -39,6 +39,8 @@ tf.flags.DEFINE_string("input_files", "",
                        "File pattern or comma-separated list of file patterns "
                        "File pattern or comma-separated list of file patterns "
                        "of image files.")
                        "of image files.")
 
 
+tf.logging.set_verbosity(tf.logging.INFO)
+
 
 
 def main(_):
 def main(_):
   # Build the inference graph.
   # Build the inference graph.

+ 3 - 1
inception/inception/data/process_bounding_boxes.py

@@ -102,7 +102,9 @@ def GetItem(name, root, index=0):
 
 
 
 
 def GetInt(name, root, index=0):
 def GetInt(name, root, index=0):
-  return int(GetItem(name, root, index))
+  # In some XML annotation files, the point values are not integers, but floats.
+  # So we add a float function to avoid ValueError.
+  return int(float(GetItem(name, root, index)))
 
 
 
 
 def FindNumberBoundingBoxes(root):
 def FindNumberBoundingBoxes(root):

+ 5 - 5
inception/inception/slim/README.md

@@ -445,15 +445,15 @@ defined with just the following snippet:
 
 
 ```python
 ```python
 with arg_scope([slim.ops.conv2d, slim.ops.fc], stddev=0.01, weight_decay=0.0005):
 with arg_scope([slim.ops.conv2d, slim.ops.fc], stddev=0.01, weight_decay=0.0005):
-  net = slim.ops.repeat_op(1, inputs, slim.ops.conv2d, 64, [3, 3], scope='conv1')
+  net = slim.ops.repeat_op(2, inputs, slim.ops.conv2d, 64, [3, 3], scope='conv1')
   net = slim.ops.max_pool(net, [2, 2], scope='pool1')
   net = slim.ops.max_pool(net, [2, 2], scope='pool1')
-  net = slim.ops.repeat_op(1, net, slim.ops.conv2d, 128, [3, 3], scope='conv2')
+  net = slim.ops.repeat_op(2, net, slim.ops.conv2d, 128, [3, 3], scope='conv2')
   net = slim.ops.max_pool(net, [2, 2], scope='pool2')
   net = slim.ops.max_pool(net, [2, 2], scope='pool2')
-  net = slim.ops.repeat_op(2, net, slim.ops.conv2d, 256, [3, 3], scope='conv3')
+  net = slim.ops.repeat_op(3, net, slim.ops.conv2d, 256, [3, 3], scope='conv3')
   net = slim.ops.max_pool(net, [2, 2], scope='pool3')
   net = slim.ops.max_pool(net, [2, 2], scope='pool3')
-  net = slim.ops.repeat_op(2, net, slim.ops.conv2d, 512, [3, 3], scope='conv4')
+  net = slim.ops.repeat_op(3, net, slim.ops.conv2d, 512, [3, 3], scope='conv4')
   net = slim.ops.max_pool(net, [2, 2], scope='pool4')
   net = slim.ops.max_pool(net, [2, 2], scope='pool4')
-  net = slim.ops.repeat_op(2, net, slim.ops.conv2d, 512, [3, 3], scope='conv5')
+  net = slim.ops.repeat_op(3, net, slim.ops.conv2d, 512, [3, 3], scope='conv5')
   net = slim.ops.max_pool(net, [2, 2], scope='pool5')
   net = slim.ops.max_pool(net, [2, 2], scope='pool5')
   net = slim.ops.flatten(net, scope='flatten5')
   net = slim.ops.flatten(net, scope='flatten5')
   net = slim.ops.fc(net, 4096, scope='fc6')
   net = slim.ops.fc(net, 4096, scope='fc6')

+ 3 - 3
slim/README.md

@@ -13,7 +13,7 @@ converting them
 to TensorFlow's native TFRecord format and reading them in using TF-Slim's
 to TensorFlow's native TFRecord format and reading them in using TF-Slim's
 data reading and queueing utilities. You can easily train any model on any of
 data reading and queueing utilities. You can easily train any model on any of
 these datasets, as we demonstrate below. We've also included a
 these datasets, as we demonstrate below. We've also included a
-[jupyter notebook](https://github.com/tensorflow/models/blob/master/slim/slim_walkthough.ipynb),
+[jupyter notebook](https://github.com/tensorflow/models/blob/master/slim/slim_walkthrough.ipynb),
 which provides working examples of how to use TF-Slim for image classification.
 which provides working examples of how to use TF-Slim for image classification.
 
 
 ## Contacts
 ## Contacts
@@ -303,8 +303,8 @@ $ python train_image_classifier.py \
     --dataset_split_name=train \
     --dataset_split_name=train \
     --model_name=inception_v3 \
     --model_name=inception_v3 \
     --checkpoint_path=${CHECKPOINT_PATH} \
     --checkpoint_path=${CHECKPOINT_PATH} \
-    --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits/Logits \
-    --trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits/Logits
+    --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
+    --trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits
 ```
 ```
 
 
 
 

+ 1 - 1
slim/eval_image_classifier.py

@@ -183,7 +183,7 @@ def main(_):
         checkpoint_path=checkpoint_path,
         checkpoint_path=checkpoint_path,
         logdir=FLAGS.eval_dir,
         logdir=FLAGS.eval_dir,
         num_evals=num_batches,
         num_evals=num_batches,
-        eval_op=names_to_updates.values(),
+        eval_op=list(names_to_updates.values()),
         variables_to_restore=variables_to_restore)
         variables_to_restore=variables_to_restore)
 
 
 
 

+ 1 - 1
slim/nets/inception_resnet_v2.py

@@ -171,7 +171,7 @@ def inception_resnet_v2(inputs, num_classes=1001, is_training=True,
         end_points['Mixed_5b'] = net
         end_points['Mixed_5b'] = net
         net = slim.repeat(net, 10, block35, scale=0.17)
         net = slim.repeat(net, 10, block35, scale=0.17)
 
 
-        # 17 x 17 x 1024
+        # 17 x 17 x 1088
         with tf.variable_scope('Mixed_6a'):
         with tf.variable_scope('Mixed_6a'):
           with tf.variable_scope('Branch_0'):
           with tf.variable_scope('Branch_0'):
             tower_conv = slim.conv2d(net, 384, 3, stride=2, padding='VALID',
             tower_conv = slim.conv2d(net, 384, 3, stride=2, padding='VALID',

+ 2 - 2
slim/nets/inception_resnet_v2_test.py

@@ -65,9 +65,9 @@ class InceptionTest(tf.test.TestCase):
         inception.inception_resnet_v2(inputs, num_classes)
         inception.inception_resnet_v2(inputs, num_classes)
       with tf.variable_scope('on_gpu'), tf.device('/gpu:0'):
       with tf.variable_scope('on_gpu'), tf.device('/gpu:0'):
         inception.inception_resnet_v2(inputs, num_classes)
         inception.inception_resnet_v2(inputs, num_classes)
-      for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_cpu'):
+      for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='on_cpu'):
         self.assertDeviceEqual(v.device, '/cpu:0')
         self.assertDeviceEqual(v.device, '/cpu:0')
-      for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_gpu'):
+      for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='on_gpu'):
         self.assertDeviceEqual(v.device, '/gpu:0')
         self.assertDeviceEqual(v.device, '/gpu:0')
 
 
   def testHalfSizeImages(self):
   def testHalfSizeImages(self):

+ 1 - 1
slim/nets/inception_v1.py

@@ -270,7 +270,7 @@ def inception_v1(inputs,
     is_training: whether is training or not.
     is_training: whether is training or not.
     dropout_keep_prob: the percentage of activation values that are retained.
     dropout_keep_prob: the percentage of activation values that are retained.
     prediction_fn: a function to get predictions out of logits.
     prediction_fn: a function to get predictions out of logits.
-    spatial_squeeze: if True, logits is of shape is [B, C], if false logits is
+    spatial_squeeze: if True, logits is of shape [B, C], if false logits is
         of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
         of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
     reuse: whether or not the network and its variables should be reused. To be
     reuse: whether or not the network and its variables should be reused. To be
       able to reuse 'scope' must be given.
       able to reuse 'scope' must be given.

+ 1 - 1
slim/nets/inception_v2.py

@@ -443,7 +443,7 @@ def inception_v2(inputs,
       usage will be to set this value in (0, 1) to reduce the number of
       usage will be to set this value in (0, 1) to reduce the number of
       parameters or computation cost of the model.
       parameters or computation cost of the model.
     prediction_fn: a function to get predictions out of logits.
     prediction_fn: a function to get predictions out of logits.
-    spatial_squeeze: if True, logits is of shape is [B, C], if false logits is
+    spatial_squeeze: if True, logits is of shape [B, C], if false logits is
         of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
         of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
     reuse: whether or not the network and its variables should be reused. To be
     reuse: whether or not the network and its variables should be reused. To be
       able to reuse 'scope' must be given.
       able to reuse 'scope' must be given.

+ 1 - 1
slim/nets/inception_v3.py

@@ -453,7 +453,7 @@ def inception_v3(inputs,
       usage will be to set this value in (0, 1) to reduce the number of
       usage will be to set this value in (0, 1) to reduce the number of
       parameters or computation cost of the model.
       parameters or computation cost of the model.
     prediction_fn: a function to get predictions out of logits.
     prediction_fn: a function to get predictions out of logits.
-    spatial_squeeze: if True, logits is of shape is [B, C], if false logits is
+    spatial_squeeze: if True, logits is of shape [B, C], if false logits is
         of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
         of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
     reuse: whether or not the network and its variables should be reused. To be
     reuse: whether or not the network and its variables should be reused. To be
       able to reuse 'scope' must be given.
       able to reuse 'scope' must be given.

+ 2 - 2
slim/nets/inception_v4_test.py

@@ -146,9 +146,9 @@ class InceptionTest(tf.test.TestCase):
       inception.inception_v4(inputs, num_classes)
       inception.inception_v4(inputs, num_classes)
     with tf.variable_scope('on_gpu'), tf.device('/gpu:0'):
     with tf.variable_scope('on_gpu'), tf.device('/gpu:0'):
       inception.inception_v4(inputs, num_classes)
       inception.inception_v4(inputs, num_classes)
-    for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_cpu'):
+    for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='on_cpu'):
       self.assertDeviceEqual(v.device, '/cpu:0')
       self.assertDeviceEqual(v.device, '/cpu:0')
-    for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_gpu'):
+    for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='on_gpu'):
       self.assertDeviceEqual(v.device, '/gpu:0')
       self.assertDeviceEqual(v.device, '/gpu:0')
 
 
   def testHalfSizeImages(self):
   def testHalfSizeImages(self):

+ 11 - 2
slim/nets/resnet_v1.py

@@ -119,6 +119,7 @@ def resnet_v1(inputs,
               global_pool=True,
               global_pool=True,
               output_stride=None,
               output_stride=None,
               include_root_block=True,
               include_root_block=True,
+              spatial_squeeze=True,
               reuse=None,
               reuse=None,
               scope=None):
               scope=None):
   """Generator for v1 ResNet models.
   """Generator for v1 ResNet models.
@@ -158,6 +159,8 @@ def resnet_v1(inputs,
       ratio of input to output spatial resolution.
       ratio of input to output spatial resolution.
     include_root_block: If True, include the initial convolution followed by
     include_root_block: If True, include the initial convolution followed by
       max-pooling, if False excludes it.
       max-pooling, if False excludes it.
+    spatial_squeeze: if True, logits is of shape [B, C], if false logits is
+        of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
     reuse: whether or not the network and its variables should be reused. To be
     reuse: whether or not the network and its variables should be reused. To be
       able to reuse 'scope' must be given.
       able to reuse 'scope' must be given.
     scope: Optional variable_scope.
     scope: Optional variable_scope.
@@ -197,11 +200,13 @@ def resnet_v1(inputs,
         if num_classes is not None:
         if num_classes is not None:
           net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
           net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
                             normalizer_fn=None, scope='logits')
                             normalizer_fn=None, scope='logits')
+        if spatial_squeeze:
+          logits = tf.squeeze(net, [1, 2], name='SpatialSqueeze')
         # Convert end_points_collection into a dictionary of end_points.
         # Convert end_points_collection into a dictionary of end_points.
         end_points = slim.utils.convert_collection_to_dict(end_points_collection)
         end_points = slim.utils.convert_collection_to_dict(end_points_collection)
         if num_classes is not None:
         if num_classes is not None:
-          end_points['predictions'] = slim.softmax(net, scope='predictions')
-        return net, end_points
+          end_points['predictions'] = slim.softmax(logits, scope='predictions')
+        return logits, end_points
 resnet_v1.default_image_size = 224
 resnet_v1.default_image_size = 224
 
 
 
 
@@ -226,6 +231,7 @@ def resnet_v1_50(inputs,
   return resnet_v1(inputs, blocks, num_classes, is_training,
   return resnet_v1(inputs, blocks, num_classes, is_training,
                    global_pool=global_pool, output_stride=output_stride,
                    global_pool=global_pool, output_stride=output_stride,
                    include_root_block=True, reuse=reuse, scope=scope)
                    include_root_block=True, reuse=reuse, scope=scope)
+resnet_v1_50.default_image_size = resnet_v1.default_image_size
 
 
 
 
 def resnet_v1_101(inputs,
 def resnet_v1_101(inputs,
@@ -249,6 +255,7 @@ def resnet_v1_101(inputs,
   return resnet_v1(inputs, blocks, num_classes, is_training,
   return resnet_v1(inputs, blocks, num_classes, is_training,
                    global_pool=global_pool, output_stride=output_stride,
                    global_pool=global_pool, output_stride=output_stride,
                    include_root_block=True, reuse=reuse, scope=scope)
                    include_root_block=True, reuse=reuse, scope=scope)
+resnet_v1_101.default_image_size = resnet_v1.default_image_size
 
 
 
 
 def resnet_v1_152(inputs,
 def resnet_v1_152(inputs,
@@ -271,6 +278,7 @@ def resnet_v1_152(inputs,
   return resnet_v1(inputs, blocks, num_classes, is_training,
   return resnet_v1(inputs, blocks, num_classes, is_training,
                    global_pool=global_pool, output_stride=output_stride,
                    global_pool=global_pool, output_stride=output_stride,
                    include_root_block=True, reuse=reuse, scope=scope)
                    include_root_block=True, reuse=reuse, scope=scope)
+resnet_v1_152.default_image_size = resnet_v1.default_image_size
 
 
 
 
 def resnet_v1_200(inputs,
 def resnet_v1_200(inputs,
@@ -293,3 +301,4 @@ def resnet_v1_200(inputs,
   return resnet_v1(inputs, blocks, num_classes, is_training,
   return resnet_v1(inputs, blocks, num_classes, is_training,
                    global_pool=global_pool, output_stride=output_stride,
                    global_pool=global_pool, output_stride=output_stride,
                    include_root_block=True, reuse=reuse, scope=scope)
                    include_root_block=True, reuse=reuse, scope=scope)
+resnet_v1_200.default_image_size = resnet_v1.default_image_size

+ 12 - 6
slim/nets/resnet_v2.py

@@ -117,6 +117,7 @@ def resnet_v2(inputs,
               global_pool=True,
               global_pool=True,
               output_stride=None,
               output_stride=None,
               include_root_block=True,
               include_root_block=True,
+              spatial_squeeze=True,
               reuse=None,
               reuse=None,
               scope=None):
               scope=None):
   """Generator for v2 (preactivation) ResNet models.
   """Generator for v2 (preactivation) ResNet models.
@@ -157,6 +158,8 @@ def resnet_v2(inputs,
     include_root_block: If True, include the initial convolution followed by
     include_root_block: If True, include the initial convolution followed by
       max-pooling, if False excludes it. If excluded, `inputs` should be the
       max-pooling, if False excludes it. If excluded, `inputs` should be the
       results of an activation-less convolution.
       results of an activation-less convolution.
+    spatial_squeeze: if True, logits is of shape [B, C], if false logits is
+        of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
     reuse: whether or not the network and its variables should be reused. To be
     reuse: whether or not the network and its variables should be reused. To be
       able to reuse 'scope' must be given.
       able to reuse 'scope' must be given.
     scope: Optional variable_scope.
     scope: Optional variable_scope.
@@ -206,12 +209,14 @@ def resnet_v2(inputs,
         if num_classes is not None:
         if num_classes is not None:
           net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
           net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
                             normalizer_fn=None, scope='logits')
                             normalizer_fn=None, scope='logits')
-        logits = tf.squeeze(net, [1, 2], name='SpatialSqueeze')
+        if spatial_squeeze:
+          logits = tf.squeeze(net, [1, 2], name='SpatialSqueeze')
         # Convert end_points_collection into a dictionary of end_points.
         # Convert end_points_collection into a dictionary of end_points.
         end_points = slim.utils.convert_collection_to_dict(end_points_collection)
         end_points = slim.utils.convert_collection_to_dict(end_points_collection)
         if num_classes is not None:
         if num_classes is not None:
-          end_points['predictions'] = slim.softmax(net, scope='predictions')
+          end_points['predictions'] = slim.softmax(logits, scope='predictions')
         return logits, end_points
         return logits, end_points
+resnet_v2.default_image_size = 224
 
 
 
 
 def resnet_v2_50(inputs,
 def resnet_v2_50(inputs,
@@ -234,7 +239,8 @@ def resnet_v2_50(inputs,
   return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
   return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
                    global_pool=global_pool, output_stride=output_stride,
                    global_pool=global_pool, output_stride=output_stride,
                    include_root_block=True, reuse=reuse, scope=scope)
                    include_root_block=True, reuse=reuse, scope=scope)
-resnet_v2_50.default_image_size = 224
+resnet_v2_50.default_image_size = resnet_v2.default_image_size
+
 
 
 def resnet_v2_101(inputs,
 def resnet_v2_101(inputs,
                   num_classes=None,
                   num_classes=None,
@@ -256,7 +262,7 @@ def resnet_v2_101(inputs,
   return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
   return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
                    global_pool=global_pool, output_stride=output_stride,
                    global_pool=global_pool, output_stride=output_stride,
                    include_root_block=True, reuse=reuse, scope=scope)
                    include_root_block=True, reuse=reuse, scope=scope)
-resnet_v2_101.default_image_size = 224
+resnet_v2_101.default_image_size = resnet_v2.default_image_size
 
 
 
 
 def resnet_v2_152(inputs,
 def resnet_v2_152(inputs,
@@ -279,7 +285,7 @@ def resnet_v2_152(inputs,
   return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
   return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
                    global_pool=global_pool, output_stride=output_stride,
                    global_pool=global_pool, output_stride=output_stride,
                    include_root_block=True, reuse=reuse, scope=scope)
                    include_root_block=True, reuse=reuse, scope=scope)
-resnet_v2_152.default_image_size = 224
+resnet_v2_152.default_image_size = resnet_v2.default_image_size
 
 
 
 
 def resnet_v2_200(inputs,
 def resnet_v2_200(inputs,
@@ -302,4 +308,4 @@ def resnet_v2_200(inputs,
   return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
   return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
                    global_pool=global_pool, output_stride=output_stride,
                    global_pool=global_pool, output_stride=output_stride,
                    include_root_block=True, reuse=reuse, scope=scope)
                    include_root_block=True, reuse=reuse, scope=scope)
-resnet_v2_200.default_image_size = 224
+resnet_v2_200.default_image_size = resnet_v2.default_image_size

+ 89 - 0
slim/scripts/finetune_resnet_v1_50_on_flowers.sh

@@ -0,0 +1,89 @@
+#!/bin/bash
+#
+# This script performs the following operations:
+# 1. Downloads the Flowers dataset
+# 2. Fine-tunes a ResNetV1-50 model on the Flowers training set.
+# 3. Evaluates the model on the Flowers validation set.
+#
+# Usage:
+# cd slim
+# ./slim/scripts/finetune_resnet_v1_50_on_flowers.sh
+
+# Where the pre-trained ResNetV1-50 checkpoint is saved to.
+PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints
+
+# Where the training (fine-tuned) checkpoint and logs will be saved to.
+TRAIN_DIR=/tmp/flowers-models/resnet_v1_50
+
+# Where the dataset is saved to.
+DATASET_DIR=/tmp/flowers
+
+# Download the pre-trained checkpoint.
+if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then
+  mkdir ${PRETRAINED_CHECKPOINT_DIR}
+fi
+if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt ]; then
+  wget http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz
+  tar -xvf resnet_v1_50_2016_08_28.tar.gz
+  mv resnet_v1_50.ckpt ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt
+  rm resnet_v1_50_2016_08_28.tar.gz
+fi
+
+# Download the dataset
+python download_and_convert_data.py \
+  --dataset_name=flowers \
+  --dataset_dir=${DATASET_DIR}
+
+# Fine-tune only the new layers for 3000 steps.
+python train_image_classifier.py \
+  --train_dir=${TRAIN_DIR} \
+  --dataset_name=flowers \
+  --dataset_split_name=train \
+  --dataset_dir=${DATASET_DIR} \
+  --model_name=resnet_v1_50 \
+  --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt \
+  --checkpoint_exclude_scopes=resnet_v1_50/logits \
+  --trainable_scopes=resnet_v1_50/logits \
+  --max_number_of_steps=3000 \
+  --batch_size=32 \
+  --learning_rate=0.01 \
+  --save_interval_secs=60 \
+  --save_summaries_secs=60 \
+  --log_every_n_steps=100 \
+  --optimizer=rmsprop \
+  --weight_decay=0.00004
+
+# Run evaluation.
+python eval_image_classifier.py \
+  --checkpoint_path=${TRAIN_DIR} \
+  --eval_dir=${TRAIN_DIR} \
+  --dataset_name=flowers \
+  --dataset_split_name=validation \
+  --dataset_dir=${DATASET_DIR} \
+  --model_name=resnet_v1_50
+
+# Fine-tune all the new layers for 1000 steps.
+python train_image_classifier.py \
+  --train_dir=${TRAIN_DIR}/all \
+  --dataset_name=flowers \
+  --dataset_split_name=train \
+  --dataset_dir=${DATASET_DIR} \
+  --checkpoint_path=${TRAIN_DIR} \
+  --model_name=resnet_v1_50 \
+  --max_number_of_steps=1000 \
+  --batch_size=32 \
+  --learning_rate=0.001 \
+  --save_interval_secs=60 \
+  --save_summaries_secs=60 \
+  --log_every_n_steps=100 \
+  --optimizer=rmsprop \
+  --weight_decay=0.00004
+
+# Run evaluation.
+python eval_image_classifier.py \
+  --checkpoint_path=${TRAIN_DIR}/all \
+  --eval_dir=${TRAIN_DIR}/all \
+  --dataset_name=flowers \
+  --dataset_split_name=validation \
+  --dataset_dir=${DATASET_DIR} \
+  --model_name=resnet_v1_50

+ 4 - 4
slim/slim_walkthough.ipynb

@@ -232,7 +232,7 @@
    },
    },
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
-    "# The following snippet trains the regression model using a sum_of_squares loss.\n",
+    "# The following snippet trains the regression model using a mean_squared_error loss.\n",
     "ckpt_dir = '/tmp/regression_model/'\n",
     "ckpt_dir = '/tmp/regression_model/'\n",
     "\n",
     "\n",
     "with tf.Graph().as_default():\n",
     "with tf.Graph().as_default():\n",
@@ -244,7 +244,7 @@
     "    predictions, nodes = regression_model(inputs, is_training=True)\n",
     "    predictions, nodes = regression_model(inputs, is_training=True)\n",
     "\n",
     "\n",
     "    # Add the loss function to the graph.\n",
     "    # Add the loss function to the graph.\n",
-    "    loss = slim.losses.sum_of_squares(predictions, targets)\n",
+    "    loss = tf.losses.mean_squared_error(labels=targets, predictions=predictions)\n",
     "    \n",
     "    \n",
     "    # The total loss is the uers's loss plus any regularization losses.\n",
     "    # The total loss is the uers's loss plus any regularization losses.\n",
     "    total_loss = slim.losses.get_total_loss()\n",
     "    total_loss = slim.losses.get_total_loss()\n",
@@ -289,12 +289,12 @@
     "    predictions, end_points = regression_model(inputs, is_training=True)\n",
     "    predictions, end_points = regression_model(inputs, is_training=True)\n",
     "\n",
     "\n",
     "    # Add multiple loss nodes.\n",
     "    # Add multiple loss nodes.\n",
-    "    sum_of_squares_loss = slim.losses.sum_of_squares(predictions, targets)\n",
+    "    mean_squared_error_loss = tf.losses.mean_squared_error(labels=targets, predictions=predictions)\n",
     "    absolute_difference_loss = slim.losses.absolute_difference(predictions, targets)\n",
     "    absolute_difference_loss = slim.losses.absolute_difference(predictions, targets)\n",
     "\n",
     "\n",
     "    # The following two ways to compute the total loss are equivalent\n",
     "    # The following two ways to compute the total loss are equivalent\n",
     "    regularization_loss = tf.add_n(slim.losses.get_regularization_losses())\n",
     "    regularization_loss = tf.add_n(slim.losses.get_regularization_losses())\n",
-    "    total_loss1 = sum_of_squares_loss + absolute_difference_loss + regularization_loss\n",
+    "    total_loss1 = mean_squared_error_loss + absolute_difference_loss + regularization_loss\n",
     "\n",
     "\n",
     "    # Regularization Loss is included in the total loss by default.\n",
     "    # Regularization Loss is included in the total loss by default.\n",
     "    # This is good for training, but not for testing.\n",
     "    # This is good for training, but not for testing.\n",

+ 1 - 1
tutorials/embedding/README.md

@@ -28,7 +28,7 @@ g++ -std=c++11 -shared word2vec_ops.cc word2vec_kernels.cc -o word2vec_ops.so -f
 
 
 On Mac, add `-undefined dynamic_lookup` to the g++ command.
 On Mac, add `-undefined dynamic_lookup` to the g++ command.
 
 
-(For an explanation of what this is doing, see the tutorial on [Adding a New Op to TensorFlow](https://www.tensorflow.org/how_tos/adding_an_op/#building_the_op_library). The flag `-D_GLIBCXX_USE_CXX11_ABI=0` is included to support newer versions of g++.)
+(For an explanation of what this is doing, see the tutorial on [Adding a New Op to TensorFlow](https://www.tensorflow.org/how_tos/adding_an_op/#building_the_op_library). The flag `-D_GLIBCXX_USE_CXX11_ABI=0` is included to support newer versions of gcc. However, if you compiled TensorFlow from source using gcc 5 or later, you may need to exclude the flag.)
 Then run using:
 Then run using:
 
 
 ```shell
 ```shell

+ 1 - 1
tutorials/embedding/word2vec.py

@@ -210,7 +210,7 @@ class Word2Vec(object):
         tf.zeros([opts.vocab_size, opts.emb_dim]),
         tf.zeros([opts.vocab_size, opts.emb_dim]),
         name="sm_w_t")
         name="sm_w_t")
 
 
-    # Softmax bias: [emb_dim].
+    # Softmax bias: [vocab_size].
     sm_b = tf.Variable(tf.zeros([opts.vocab_size]), name="sm_b")
     sm_b = tf.Variable(tf.zeros([opts.vocab_size]), name="sm_b")
 
 
     # Global step: scalar, i.e., shape [].
     # Global step: scalar, i.e., shape [].

+ 11 - 7
tutorials/image/cifar10/cifar10_train.py

@@ -52,6 +52,8 @@ tf.app.flags.DEFINE_integer('max_steps', 1000000,
                             """Number of batches to run.""")
                             """Number of batches to run.""")
 tf.app.flags.DEFINE_boolean('log_device_placement', False,
 tf.app.flags.DEFINE_boolean('log_device_placement', False,
                             """Whether to log device placement.""")
                             """Whether to log device placement.""")
+tf.app.flags.DEFINE_integer('log_frequency', 10,
+                            """How often to log results to the console.""")
 
 
 
 
 def train():
 def train():
@@ -78,19 +80,21 @@ def train():
 
 
       def begin(self):
       def begin(self):
         self._step = -1
         self._step = -1
+        self._start_time = time.time()
 
 
       def before_run(self, run_context):
       def before_run(self, run_context):
         self._step += 1
         self._step += 1
-        self._start_time = time.time()
         return tf.train.SessionRunArgs(loss)  # Asks for loss value.
         return tf.train.SessionRunArgs(loss)  # Asks for loss value.
 
 
       def after_run(self, run_context, run_values):
       def after_run(self, run_context, run_values):
-        duration = time.time() - self._start_time
-        loss_value = run_values.results
-        if self._step % 10 == 0:
-          num_examples_per_step = FLAGS.batch_size
-          examples_per_sec = num_examples_per_step / duration
-          sec_per_batch = float(duration)
+        if self._step % FLAGS.log_frequency == 0:
+          current_time = time.time()
+          duration = current_time - self._start_time
+          self._start_time = current_time
+
+          loss_value = run_values.results
+          examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
+          sec_per_batch = float(duration / FLAGS.log_frequency)
 
 
           format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
           format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                         'sec/batch)')
                         'sec/batch)')

+ 1 - 1
tutorials/rnn/translate/seq2seq_model.py

@@ -100,7 +100,7 @@ class Seq2SeqModel(object):
       b = tf.get_variable("proj_b", [self.target_vocab_size], dtype=dtype)
       b = tf.get_variable("proj_b", [self.target_vocab_size], dtype=dtype)
       output_projection = (w, b)
       output_projection = (w, b)
 
 
-      def sampled_loss(labels, inputs):
+      def sampled_loss(inputs, labels):
         labels = tf.reshape(labels, [-1, 1])
         labels = tf.reshape(labels, [-1, 1])
         # We need to compute the sampled_softmax_loss using 32bit floats to
         # We need to compute the sampled_softmax_loss using 32bit floats to
         # avoid numerical instabilities.
         # avoid numerical instabilities.

+ 1 - 1
video_prediction/prediction_train.py

@@ -196,7 +196,7 @@ def main(unused_argv):
   print 'Constructing saver.'
   print 'Constructing saver.'
   # Make saver.
   # Make saver.
   saver = tf.train.Saver(
   saver = tf.train.Saver(
-      tf.get_collection(tf.GraphKeys.VARIABLES), max_to_keep=0)
+      tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES), max_to_keep=0)
 
 
   # Make training session.
   # Make training session.
   sess = tf.InteractiveSession()
   sess = tf.InteractiveSession()