dataset_utils.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Contains utilities for downloading and converting datasets."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import os
  20. import sys
  21. import tarfile
  22. from six.moves import urllib
  23. import tensorflow as tf
  24. LABELS_FILENAME = 'labels.txt'
  25. def int64_feature(values):
  26. """Returns a TF-Feature of int64s.
  27. Args:
  28. values: A scalar or list of values.
  29. Returns:
  30. a TF-Feature.
  31. """
  32. if not isinstance(values, (tuple, list)):
  33. values = [values]
  34. return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
  35. def bytes_feature(values):
  36. """Returns a TF-Feature of bytes.
  37. Args:
  38. values: A string.
  39. Returns:
  40. a TF-Feature.
  41. """
  42. return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
  43. def image_to_tfexample(image_data, image_format, height, width, class_id):
  44. return tf.train.Example(features=tf.train.Features(feature={
  45. 'image/encoded': bytes_feature(image_data),
  46. 'image/format': bytes_feature(image_format),
  47. 'image/class/label': int64_feature(class_id),
  48. 'image/height': int64_feature(height),
  49. 'image/width': int64_feature(width),
  50. }))
  51. def download_and_uncompress_tarball(tarball_url, dataset_dir):
  52. """Downloads the `tarball_url` and uncompresses it locally.
  53. Args:
  54. tarball_url: The URL of a tarball file.
  55. dataset_dir: The directory where the temporary files are stored.
  56. """
  57. filename = tarball_url.split('/')[-1]
  58. filepath = os.path.join(dataset_dir, filename)
  59. def _progress(count, block_size, total_size):
  60. sys.stdout.write('\r>> Downloading %s %.1f%%' % (
  61. filename, float(count * block_size) / float(total_size) * 100.0))
  62. sys.stdout.flush()
  63. filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress)
  64. print()
  65. statinfo = os.stat(filepath)
  66. print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
  67. tarfile.open(filepath, 'r:gz').extractall(dataset_dir)
  68. def write_label_file(labels_to_class_names, dataset_dir,
  69. filename=LABELS_FILENAME):
  70. """Writes a file with the list of class names.
  71. Args:
  72. labels_to_class_names: A map of (integer) labels to class names.
  73. dataset_dir: The directory in which the labels file should be written.
  74. filename: The filename where the class names are written.
  75. """
  76. labels_filename = os.path.join(dataset_dir, filename)
  77. with tf.gfile.Open(labels_filename, 'w') as f:
  78. for label in labels_to_class_names:
  79. class_name = labels_to_class_names[label]
  80. f.write('%d:%s\n' % (label, class_name))
  81. def has_labels(dataset_dir, filename=LABELS_FILENAME):
  82. """Specifies whether or not the dataset directory contains a label map file.
  83. Args:
  84. dataset_dir: The directory in which the labels file is found.
  85. filename: The filename where the class names are written.
  86. Returns:
  87. `True` if the labels file exists and `False` otherwise.
  88. """
  89. return tf.gfile.Exists(os.path.join(dataset_dir, filename))
  90. def read_label_file(dataset_dir, filename=LABELS_FILENAME):
  91. """Reads the labels file and returns a mapping from ID to class name.
  92. Args:
  93. dataset_dir: The directory in which the labels file is found.
  94. filename: The filename where the class names are written.
  95. Returns:
  96. A map from a label (integer) to class name.
  97. """
  98. labels_filename = os.path.join(dataset_dir, filename)
  99. with tf.gfile.Open(labels_filename, 'r') as f:
  100. lines = f.read().decode()
  101. lines = lines.split('\n')
  102. lines = filter(None, lines)
  103. labels_to_class_names = {}
  104. for line in lines:
  105. index = line.index(':')
  106. labels_to_class_names[int(line[:index])] = line[index+1:]
  107. return labels_to_class_names