dataset_utils.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Contains utilities for downloading and converting datasets."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import os
  20. import tensorflow as tf
  21. LABELS_FILENAME = 'labels.txt'
  22. def int64_feature(values):
  23. """Returns a TF-Feature of int64s.
  24. Args:
  25. values: A scalar or list of values.
  26. Returns:
  27. a TF-Feature.
  28. """
  29. if not isinstance(values, (tuple, list)):
  30. values = [values]
  31. return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
  32. def bytes_feature(values):
  33. """Returns a TF-Feature of bytes.
  34. Args:
  35. values: A string.
  36. Returns:
  37. a TF-Feature.
  38. """
  39. return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
  40. def image_to_tfexample(image_data, image_format, height, width, class_id):
  41. return tf.train.Example(features=tf.train.Features(feature={
  42. 'image/encoded': bytes_feature(image_data),
  43. 'image/format': bytes_feature(image_format),
  44. 'image/class/label': int64_feature(class_id),
  45. 'image/height': int64_feature(height),
  46. 'image/width': int64_feature(width),
  47. }))
  48. def write_label_file(labels_to_class_names, dataset_dir,
  49. filename=LABELS_FILENAME):
  50. """Writes a file with the list of class names.
  51. Args:
  52. labels_to_class_names: A map of (integer) labels to class names.
  53. dataset_dir: The directory in which the labels file should be written.
  54. filename: The filename where the class names are written.
  55. """
  56. labels_filename = os.path.join(dataset_dir, filename)
  57. with tf.gfile.Open(labels_filename, 'w') as f:
  58. for label in labels_to_class_names:
  59. class_name = labels_to_class_names[label]
  60. f.write('%d:%s\n' % (label, class_name))
  61. def has_labels(dataset_dir, filename=LABELS_FILENAME):
  62. """Specifies whether or not the dataset directory contains a label map file.
  63. Args:
  64. dataset_dir: The directory in which the labels file is found.
  65. filename: The filename where the class names are written.
  66. Returns:
  67. `True` if the labels file exists and `False` otherwise.
  68. """
  69. return tf.gfile.Exists(os.path.join(dataset_dir, filename))
  70. def read_label_file(dataset_dir, filename=LABELS_FILENAME):
  71. """Reads the labels file and returns a mapping from ID to class name.
  72. Args:
  73. dataset_dir: The directory in which the labels file is found.
  74. filename: The filename where the class names are written.
  75. Returns:
  76. A map from a label (integer) to class name.
  77. """
  78. labels_filename = os.path.join(dataset_dir, filename)
  79. with tf.gfile.Open(labels_filename, 'r') as f:
  80. lines = f.read()
  81. lines = lines.split('\n')
  82. lines = filter(None, lines)
  83. labels_to_class_names = {}
  84. for line in lines:
  85. index = line.index(':')
  86. labels_to_class_names[int(line[:index])] = line[index+1:]
  87. return labels_to_class_names