download_and_convert_data.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. r"""Downloads and converts a particular dataset.
  16. Usage:
  17. ```shell
  18. $ python download_and_convert_data.py \
  19. --dataset_name=mnist \
  20. --dataset_dir=/tmp/mnist
  21. $ python download_and_convert_data.py \
  22. --dataset_name=cifar10 \
  23. --dataset_dir=/tmp/cifar10
  24. $ python download_and_convert_data.py \
  25. --dataset_name=flowers \
  26. --dataset_dir=/tmp/flowers
  27. ```
  28. """
  29. from __future__ import absolute_import
  30. from __future__ import division
  31. from __future__ import print_function
  32. import tensorflow as tf
  33. from datasets import download_and_convert_cifar10
  34. from datasets import download_and_convert_flowers
  35. from datasets import download_and_convert_mnist
  36. FLAGS = tf.app.flags.FLAGS
  37. tf.app.flags.DEFINE_string(
  38. 'dataset_name',
  39. None,
  40. 'The name of the dataset to convert, one of "cifar10", "flowers", "mnist".')
  41. tf.app.flags.DEFINE_string(
  42. 'dataset_dir',
  43. None,
  44. 'The directory where the output TFRecords and temporary files are saved.')
  45. def main(_):
  46. if not FLAGS.dataset_name:
  47. raise ValueError('You must supply the dataset name with --dataset_name')
  48. if not FLAGS.dataset_dir:
  49. raise ValueError('You must supply the dataset directory with --dataset_dir')
  50. if FLAGS.dataset_name == 'cifar10':
  51. download_and_convert_cifar10.run(FLAGS.dataset_dir)
  52. elif FLAGS.dataset_name == 'flowers':
  53. download_and_convert_flowers.run(FLAGS.dataset_dir)
  54. elif FLAGS.dataset_name == 'mnist':
  55. download_and_convert_mnist.run(FLAGS.dataset_dir)
  56. else:
  57. raise ValueError(
  58. 'dataset_name [%s] was not recognized.' % FLAGS.dataset_dir)
  59. if __name__ == '__main__':
  60. tf.app.run()