| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112 |
- # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """Contains utilities for downloading and converting datasets."""
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import os
- import tensorflow as tf
- LABELS_FILENAME = 'labels.txt'
- def int64_feature(values):
- """Returns a TF-Feature of int64s.
- Args:
- values: A scalar or list of values.
- Returns:
- a TF-Feature.
- """
- if not isinstance(values, (tuple, list)):
- values = [values]
- return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
- def bytes_feature(values):
- """Returns a TF-Feature of bytes.
- Args:
- values: A string.
- Returns:
- a TF-Feature.
- """
- return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
- def image_to_tfexample(image_data, image_format, height, width, class_id):
- return tf.train.Example(features=tf.train.Features(feature={
- 'image/encoded': bytes_feature(image_data),
- 'image/format': bytes_feature(image_format),
- 'image/class/label': int64_feature(class_id),
- 'image/height': int64_feature(height),
- 'image/width': int64_feature(width),
- }))
- def write_label_file(labels_to_class_names, dataset_dir,
- filename=LABELS_FILENAME):
- """Writes a file with the list of class names.
- Args:
- labels_to_class_names: A map of (integer) labels to class names.
- dataset_dir: The directory in which the labels file should be written.
- filename: The filename where the class names are written.
- """
- labels_filename = os.path.join(dataset_dir, filename)
- with tf.gfile.Open(labels_filename, 'w') as f:
- for label in labels_to_class_names:
- class_name = labels_to_class_names[label]
- f.write('%d:%s\n' % (label, class_name))
- def has_labels(dataset_dir, filename=LABELS_FILENAME):
- """Specifies whether or not the dataset directory contains a label map file.
- Args:
- dataset_dir: The directory in which the labels file is found.
- filename: The filename where the class names are written.
- Returns:
- `True` if the labels file exists and `False` otherwise.
- """
- return tf.gfile.Exists(os.path.join(dataset_dir, filename))
- def read_label_file(dataset_dir, filename=LABELS_FILENAME):
- """Reads the labels file and returns a mapping from ID to class name.
- Args:
- dataset_dir: The directory in which the labels file is found.
- filename: The filename where the class names are written.
- Returns:
- A map from a label (integer) to class name.
- """
- labels_filename = os.path.join(dataset_dir, filename)
- with tf.gfile.Open(labels_filename, 'r') as f:
- lines = f.read()
- lines = lines.split('\n')
- lines = filter(None, lines)
- labels_to_class_names = {}
- for line in lines:
- index = line.index(':')
- labels_to_class_names[int(line[:index])] = line[index+1:]
- return labels_to_class_names
|