Просмотр исходного кода

Merge pull request #586 from panyx0718/master

Update cifar input following data change.
Xin Pan 9 лет назад
Родитель
Сommit
c9c4981d0a
3 измененных файлов с 9 добавлено и 6 удалено
  1. 5 3
      resnet/README.md
  2. 2 1
      resnet/cifar_input.py
  3. 2 2
      resnet/resnet_main.py

+ 5 - 3
resnet/README.md

@@ -71,12 +71,14 @@ curl -o cifar-100-binary.tar.gz https://www.cs.toronto.edu/~kriz/cifar-100-binar
 ```shell
 # cd to the your workspace.
 # It contains an empty WORKSPACE file, resnet codes and cifar10 dataset.
+# Note: User can split 5k from train set for eval set.
 ls -R
   .:
   cifar10  resnet  WORKSPACE
 
   ./cifar10:
-  test.bin  train.bin  validation.bin
+  data_batch_1.bin  data_batch_2.bin  data_batch_3.bin  data_batch_4.bin
+  data_batch_5.bin  test_batch.bin
 
   ./resnet:
   BUILD  cifar_input.py  g3doc  README.md  resnet_main.py  resnet_model.py
@@ -85,7 +87,7 @@ ls -R
 bazel build -c opt --config=cuda resnet/...
 
 # Train the model.
-bazel-bin/resnet/resnet_main --train_data_path=cifar10/train.bin \
+bazel-bin/resnet/resnet_main --train_data_path=cifar10/data_batch* \
                              --log_root=/tmp/resnet_model \
                              --train_dir=/tmp/resnet_model/train \
                              --dataset='cifar10' \
@@ -94,7 +96,7 @@ bazel-bin/resnet/resnet_main --train_data_path=cifar10/train.bin \
 # Evaluate the model.
 # Avoid running on the same GPU as the training job at the same time,
 # otherwise, you might run out of memory.
-bazel-bin/resnet/resnet_main --eval_data_path=cifar10/test.bin \
+bazel-bin/resnet/resnet_main --eval_data_path=cifar10/test_batch.bin \
                              --log_root=/tmp/resnet_model \
                              --eval_dir=/tmp/resnet_model/test \
                              --mode=eval \

+ 2 - 1
resnet/cifar_input.py

@@ -49,7 +49,8 @@ def build_input(dataset, data_path, batch_size, mode):
   image_bytes = image_size * image_size * depth
   record_bytes = label_bytes + label_offset + image_bytes
 
-  file_queue = tf.train.string_input_producer([data_path], shuffle=True)
+  data_files = tf.gfile.Glob(data_path)
+  file_queue = tf.train.string_input_producer(data_files, shuffle=True)
   # Read examples from files in the filename queue.
   reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
   _, value = reader.read(file_queue)

+ 2 - 2
resnet/resnet_main.py

@@ -26,8 +26,8 @@ import tensorflow as tf
 FLAGS = tf.app.flags.FLAGS
 tf.app.flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or cifar100.')
 tf.app.flags.DEFINE_string('mode', 'train', 'train or eval.')
-tf.app.flags.DEFINE_string('train_data_path', '', 'Filename for training data.')
-tf.app.flags.DEFINE_string('eval_data_path', '', 'Filename for eval data')
+tf.app.flags.DEFINE_string('train_data_path', '', 'Filepattern for training data.')
+tf.app.flags.DEFINE_string('eval_data_path', '', 'Filepattern for eval data')
 tf.app.flags.DEFINE_integer('image_size', 32, 'Image side length.')
 tf.app.flags.DEFINE_string('train_dir', '',
                            'Directory to keep training outputs.')