소스 검색

Align model slim/resnet to slim/inception (set spatial_squeeze and default_image_size). Add a sample script for training resnet_v1_50 on flower set

Stefan Zechner 8 년 전
부모
커밋
7e2435e509
3개의 변경된 파일102개의 추가작업 그리고 4개의 파일을 삭제
  1. 7 2
      slim/nets/resnet_v1.py
  2. 6 2
      slim/nets/resnet_v2.py
  3. 89 0
      slim/scripts/finetune_resnet_v1_50_on_flowers.sh

+ 7 - 2
slim/nets/resnet_v1.py

@@ -119,6 +119,7 @@ def resnet_v1(inputs,
               global_pool=True,
               output_stride=None,
               include_root_block=True,
+              spatial_squeeze=True,
               reuse=None,
               scope=None):
   """Generator for v1 ResNet models.
@@ -158,6 +159,8 @@ def resnet_v1(inputs,
       ratio of input to output spatial resolution.
     include_root_block: If True, include the initial convolution followed by
       max-pooling, if False excludes it.
+    spatial_squeeze: if True, logits is of shape is [B, C], if false logits is
+        of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
     reuse: whether or not the network and its variables should be reused. To be
       able to reuse 'scope' must be given.
     scope: Optional variable_scope.
@@ -197,11 +200,13 @@ def resnet_v1(inputs,
         if num_classes is not None:
           net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
                             normalizer_fn=None, scope='logits')
+        if spatial_squeeze:
+          logits = tf.squeeze(net, [1, 2], name='SpatialSqueeze')
         # Convert end_points_collection into a dictionary of end_points.
         end_points = slim.utils.convert_collection_to_dict(end_points_collection)
         if num_classes is not None:
-          end_points['predictions'] = slim.softmax(net, scope='predictions')
-        return net, end_points
+          end_points['predictions'] = slim.softmax(logits, scope='predictions')
+        return logits, end_points
 resnet_v1.default_image_size = 224
 
 

+ 6 - 2
slim/nets/resnet_v2.py

@@ -117,6 +117,7 @@ def resnet_v2(inputs,
               global_pool=True,
               output_stride=None,
               include_root_block=True,
+              spatial_squeeze=True,
               reuse=None,
               scope=None):
   """Generator for v2 (preactivation) ResNet models.
@@ -157,6 +158,8 @@ def resnet_v2(inputs,
     include_root_block: If True, include the initial convolution followed by
       max-pooling, if False excludes it. If excluded, `inputs` should be the
       results of an activation-less convolution.
+    spatial_squeeze: if True, logits is of shape is [B, C], if false logits is
+        of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
     reuse: whether or not the network and its variables should be reused. To be
       able to reuse 'scope' must be given.
     scope: Optional variable_scope.
@@ -206,11 +209,12 @@ def resnet_v2(inputs,
         if num_classes is not None:
           net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
                             normalizer_fn=None, scope='logits')
-        logits = tf.squeeze(net, [1, 2], name='SpatialSqueeze')
+        if spatial_squeeze:
+          logits = tf.squeeze(net, [1, 2], name='SpatialSqueeze')
         # Convert end_points_collection into a dictionary of end_points.
         end_points = slim.utils.convert_collection_to_dict(end_points_collection)
         if num_classes is not None:
-          end_points['predictions'] = slim.softmax(net, scope='predictions')
+          end_points['predictions'] = slim.softmax(logits, scope='predictions')
         return logits, end_points
 resnet_v2.default_image_size = 224
 

+ 89 - 0
slim/scripts/finetune_resnet_v1_50_on_flowers.sh

@@ -0,0 +1,89 @@
+#!/bin/bash
+#
+# This script performs the following operations:
+# 1. Downloads the Flowers dataset
+# 2. Fine-tunes a ResNetV1-50 model on the Flowers training set.
+# 3. Evaluates the model on the Flowers validation set.
+#
+# Usage:
+# cd slim
+# ./slim/scripts/finetune_resnet_v1_50_on_flowers.sh
+
+# Where the pre-trained ResNetV1-50 checkpoint is saved to.
+PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints
+
+# Where the training (fine-tuned) checkpoint and logs will be saved to.
+TRAIN_DIR=/tmp/flowers-models/resnet_v1_50
+
+# Where the dataset is saved to.
+DATASET_DIR=/tmp/flowers
+
+# Download the pre-trained checkpoint.
+if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then
+  mkdir ${PRETRAINED_CHECKPOINT_DIR}
+fi
+if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt ]; then
+  wget http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz
+  tar -xvf resnet_v1_50_2016_08_28.tar.gz
+  mv resnet_v1_50.ckpt ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt
+  rm resnet_v1_50_2016_08_28.tar.gz
+fi
+
+# Download the dataset
+python download_and_convert_data.py \
+  --dataset_name=flowers \
+  --dataset_dir=${DATASET_DIR}
+
+# Fine-tune only the new layers for 3000 steps.
+python train_image_classifier.py \
+  --train_dir=${TRAIN_DIR} \
+  --dataset_name=flowers \
+  --dataset_split_name=train \
+  --dataset_dir=${DATASET_DIR} \
+  --model_name=resnet_v1_50 \
+  --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt \
+  --checkpoint_exclude_scopes=resnet_v1_50/logits \
+  --trainable_scopes=resnet_v1_50/logits \
+  --max_number_of_steps=3000 \
+  --batch_size=32 \
+  --learning_rate=0.01 \
+  --save_interval_secs=60 \
+  --save_summaries_secs=60 \
+  --log_every_n_steps=100 \
+  --optimizer=rmsprop \
+  --weight_decay=0.00004
+
+# Run evaluation.
+python eval_image_classifier.py \
+  --checkpoint_path=${TRAIN_DIR} \
+  --eval_dir=${TRAIN_DIR} \
+  --dataset_name=flowers \
+  --dataset_split_name=validation \
+  --dataset_dir=${DATASET_DIR} \
+  --model_name=resnet_v1_50
+
+# Fine-tune all the new layers for 1000 steps.
+python train_image_classifier.py \
+  --train_dir=${TRAIN_DIR}/all \
+  --dataset_name=flowers \
+  --dataset_split_name=train \
+  --dataset_dir=${DATASET_DIR} \
+  --checkpoint_path=${TRAIN_DIR} \
+  --model_name=resnet_v1_50 \
+  --max_number_of_steps=1000 \
+  --batch_size=32 \
+  --learning_rate=0.001 \
+  --save_interval_secs=60 \
+  --save_summaries_secs=60 \
+  --log_every_n_steps=100 \
+  --optimizer=rmsprop \
+  --weight_decay=0.00004
+
+# Run evaluation.
+python eval_image_classifier.py \
+  --checkpoint_path=${TRAIN_DIR}/all \
+  --eval_dir=${TRAIN_DIR}/all \
+  --dataset_name=flowers \
+  --dataset_split_name=validation \
+  --dataset_dir=${DATASET_DIR} \
+  --model_name=resnet_v1_50