123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137 |
- # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """Contains utilities for downloading and converting datasets."""
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import os
- import sys
- import tarfile
- from six.moves import urllib
- import tensorflow as tf
- LABELS_FILENAME = 'labels.txt'
- def int64_feature(values):
- """Returns a TF-Feature of int64s.
- Args:
- values: A scalar or list of values.
- Returns:
- a TF-Feature.
- """
- if not isinstance(values, (tuple, list)):
- values = [values]
- return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
- def bytes_feature(values):
- """Returns a TF-Feature of bytes.
- Args:
- values: A string.
- Returns:
- a TF-Feature.
- """
- return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
- def image_to_tfexample(image_data, image_format, height, width, class_id):
- return tf.train.Example(features=tf.train.Features(feature={
- 'image/encoded': bytes_feature(image_data),
- 'image/format': bytes_feature(image_format),
- 'image/class/label': int64_feature(class_id),
- 'image/height': int64_feature(height),
- 'image/width': int64_feature(width),
- }))
- def download_and_uncompress_tarball(tarball_url, dataset_dir):
- """Downloads the `tarball_url` and uncompresses it locally.
- Args:
- tarball_url: The URL of a tarball file.
- dataset_dir: The directory where the temporary files are stored.
- """
- filename = tarball_url.split('/')[-1]
- filepath = os.path.join(dataset_dir, filename)
- def _progress(count, block_size, total_size):
- sys.stdout.write('\r>> Downloading %s %.1f%%' % (
- filename, float(count * block_size) / float(total_size) * 100.0))
- sys.stdout.flush()
- filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress)
- print()
- statinfo = os.stat(filepath)
- print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
- tarfile.open(filepath, 'r:gz').extractall(dataset_dir)
- def write_label_file(labels_to_class_names, dataset_dir,
- filename=LABELS_FILENAME):
- """Writes a file with the list of class names.
- Args:
- labels_to_class_names: A map of (integer) labels to class names.
- dataset_dir: The directory in which the labels file should be written.
- filename: The filename where the class names are written.
- """
- labels_filename = os.path.join(dataset_dir, filename)
- with tf.gfile.Open(labels_filename, 'w') as f:
- for label in labels_to_class_names:
- class_name = labels_to_class_names[label]
- f.write('%d:%s\n' % (label, class_name))
- def has_labels(dataset_dir, filename=LABELS_FILENAME):
- """Specifies whether or not the dataset directory contains a label map file.
- Args:
- dataset_dir: The directory in which the labels file is found.
- filename: The filename where the class names are written.
- Returns:
- `True` if the labels file exists and `False` otherwise.
- """
- return tf.gfile.Exists(os.path.join(dataset_dir, filename))
- def read_label_file(dataset_dir, filename=LABELS_FILENAME):
- """Reads the labels file and returns a mapping from ID to class name.
- Args:
- dataset_dir: The directory in which the labels file is found.
- filename: The filename where the class names are written.
- Returns:
- A map from a label (integer) to class name.
- """
- labels_filename = os.path.join(dataset_dir, filename)
- with tf.gfile.Open(labels_filename, 'r') as f:
- lines = f.read().decode()
- lines = lines.split('\n')
- lines = filter(None, lines)
- labels_to_class_names = {}
- for line in lines:
- index = line.index(':')
- labels_to_class_names[int(line[:index])] = line[index+1:]
- return labels_to_class_names
|