build_mscoco_data.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Converts MSCOCO data to TFRecord file format with SequenceExample protos.
  16. The MSCOCO images are expected to reside in JPEG files located in the following
  17. directory structure:
  18. train_image_dir/COCO_train2014_000000000151.jpg
  19. train_image_dir/COCO_train2014_000000000260.jpg
  20. ...
  21. and
  22. val_image_dir/COCO_val2014_000000000042.jpg
  23. val_image_dir/COCO_val2014_000000000073.jpg
  24. ...
  25. The MSCOCO annotations JSON files are expected to reside in train_captions_file
  26. and val_captions_file respectively.
  27. This script converts the combined MSCOCO data into sharded data files consisting
  28. of 256, 4 and 8 TFRecord files, respectively:
  29. output_dir/train-00000-of-00256
  30. output_dir/train-00001-of-00256
  31. ...
  32. output_dir/train-00255-of-00256
  33. and
  34. output_dir/val-00000-of-00004
  35. ...
  36. output_dir/val-00003-of-00004
  37. and
  38. output_dir/test-00000-of-00008
  39. ...
  40. output_dir/test-00007-of-00008
  41. Each TFRecord file contains ~2300 records. Each record within the TFRecord file
  42. is a serialized SequenceExample proto consisting of precisely one image-caption
  43. pair. Note that each image has multiple captions (usually 5) and therefore each
  44. image is replicated multiple times in the TFRecord files.
  45. The SequenceExample proto contains the following fields:
  46. context:
  47. image/image_id: integer MSCOCO image identifier
  48. image/data: string containing JPEG encoded image in RGB colorspace
  49. feature_lists:
  50. image/caption: list of strings containing the (tokenized) caption words
  51. image/caption_ids: list of integer ids corresponding to the caption words
  52. The captions are tokenized using the NLTK (http://www.nltk.org/) word tokenizer.
  53. The vocabulary of word identifiers is constructed from the sorted list (by
  54. descending frequency) of word tokens in the training set. Only tokens appearing
  55. at least 4 times are considered; all other words get the "unknown" word id.
  56. NOTE: This script will consume around 100GB of disk space because each image
  57. in the MSCOCO dataset is replicated ~5 times (once per caption) in the output.
  58. This is done for two reasons:
  59. 1. In order to better shuffle the training data.
  60. 2. It makes it easier to perform asynchronous preprocessing of each image in
  61. TensorFlow.
  62. Running this script using 16 threads may take around 1 hour on a HP Z420.
  63. """
  64. from __future__ import absolute_import
  65. from __future__ import division
  66. from __future__ import print_function
  67. from collections import Counter
  68. from collections import namedtuple
  69. from datetime import datetime
  70. import json
  71. import os.path
  72. import random
  73. import sys
  74. import threading
  75. import nltk.tokenize
  76. import numpy as np
  77. import tensorflow as tf
  78. tf.flags.DEFINE_string("train_image_dir", "/tmp/train2014/",
  79. "Training image directory.")
  80. tf.flags.DEFINE_string("val_image_dir", "/tmp/val2014",
  81. "Validation image directory.")
  82. tf.flags.DEFINE_string("train_captions_file", "/tmp/captions_train2014.json",
  83. "Training captions JSON file.")
  84. tf.flags.DEFINE_string("val_captions_file", "/tmp/captions_train2014.json",
  85. "Validation captions JSON file.")
  86. tf.flags.DEFINE_string("output_dir", "/tmp/", "Output data directory.")
  87. tf.flags.DEFINE_integer("train_shards", 256,
  88. "Number of shards in training TFRecord files.")
  89. tf.flags.DEFINE_integer("val_shards", 4,
  90. "Number of shards in validation TFRecord files.")
  91. tf.flags.DEFINE_integer("test_shards", 8,
  92. "Number of shards in testing TFRecord files.")
  93. tf.flags.DEFINE_string("start_word", "<S>",
  94. "Special word added to the beginning of each sentence.")
  95. tf.flags.DEFINE_string("end_word", "</S>",
  96. "Special word added to the end of each sentence.")
  97. tf.flags.DEFINE_string("unknown_word", "<UNK>",
  98. "Special word meaning 'unknown'.")
  99. tf.flags.DEFINE_integer("min_word_count", 4,
  100. "The minimum number of occurrences of each word in the "
  101. "training set for inclusion in the vocabulary.")
  102. tf.flags.DEFINE_string("word_counts_output_file", "/tmp/word_counts.txt",
  103. "Output vocabulary file of word counts.")
  104. tf.flags.DEFINE_integer("num_threads", 8,
  105. "Number of threads to preprocess the images.")
  106. FLAGS = tf.flags.FLAGS
  107. ImageMetadata = namedtuple("ImageMetadata",
  108. ["image_id", "filename", "captions"])
  109. class Vocabulary(object):
  110. """Simple vocabulary wrapper."""
  111. def __init__(self, vocab, unk_id):
  112. """Initializes the vocabulary.
  113. Args:
  114. vocab: A dictionary of word to word_id.
  115. unk_id: Id of the special 'unknown' word.
  116. """
  117. self._vocab = vocab
  118. self._unk_id = unk_id
  119. def word_to_id(self, word):
  120. """Returns the integer id of a word string."""
  121. if word in self._vocab:
  122. return self._vocab[word]
  123. else:
  124. return self._unk_id
  125. class ImageDecoder(object):
  126. """Helper class for decoding images in TensorFlow."""
  127. def __init__(self):
  128. # Create a single TensorFlow Session for all image decoding calls.
  129. self._sess = tf.Session()
  130. # TensorFlow ops for JPEG decoding.
  131. self._encoded_jpeg = tf.placeholder(dtype=tf.string)
  132. self._decode_jpeg = tf.image.decode_jpeg(self._encoded_jpeg, channels=3)
  133. def decode_jpeg(self, encoded_jpeg):
  134. image = self._sess.run(self._decode_jpeg,
  135. feed_dict={self._encoded_jpeg: encoded_jpeg})
  136. assert len(image.shape) == 3
  137. assert image.shape[2] == 3
  138. return image
  139. def _int64_feature(value):
  140. """Wrapper for inserting an int64 Feature into a SequenceExample proto."""
  141. return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
  142. def _bytes_feature(value):
  143. """Wrapper for inserting a bytes Feature into a SequenceExample proto."""
  144. return tf.train.Feature(bytes_list=tf.train.BytesList(value=[str(value)]))
  145. def _int64_feature_list(values):
  146. """Wrapper for inserting an int64 FeatureList into a SequenceExample proto."""
  147. return tf.train.FeatureList(feature=[_int64_feature(v) for v in values])
  148. def _bytes_feature_list(values):
  149. """Wrapper for inserting a bytes FeatureList into a SequenceExample proto."""
  150. return tf.train.FeatureList(feature=[_bytes_feature(v) for v in values])
  151. def _to_sequence_example(image, decoder, vocab):
  152. """Builds a SequenceExample proto for an image-caption pair.
  153. Args:
  154. image: An ImageMetadata object.
  155. decoder: An ImageDecoder object.
  156. vocab: A Vocabulary object.
  157. Returns:
  158. A SequenceExample proto.
  159. """
  160. with open(image.filename, "r") as f:
  161. encoded_image = f.read()
  162. try:
  163. decoder.decode_jpeg(encoded_image)
  164. except (tf.errors.InvalidArgumentError, AssertionError):
  165. print("Skipping file with invalid JPEG data: %s" % image.filename)
  166. return
  167. context = tf.train.Features(feature={
  168. "image/image_id": _int64_feature(image.image_id),
  169. "image/data": _bytes_feature(encoded_image),
  170. })
  171. assert len(image.captions) == 1
  172. caption = image.captions[0]
  173. caption_ids = [vocab.word_to_id(word) for word in caption]
  174. feature_lists = tf.train.FeatureLists(feature_list={
  175. "image/caption": _bytes_feature_list(caption),
  176. "image/caption_ids": _int64_feature_list(caption_ids)
  177. })
  178. sequence_example = tf.train.SequenceExample(
  179. context=context, feature_lists=feature_lists)
  180. return sequence_example
  181. def _process_image_files(thread_index, ranges, name, images, decoder, vocab,
  182. num_shards):
  183. """Processes and saves a subset of images as TFRecord files in one thread.
  184. Args:
  185. thread_index: Integer thread identifier within [0, len(ranges)].
  186. ranges: A list of pairs of integers specifying the ranges of the dataset to
  187. process in parallel.
  188. name: Unique identifier specifying the dataset.
  189. images: List of ImageMetadata.
  190. decoder: An ImageDecoder object.
  191. vocab: A Vocabulary object.
  192. num_shards: Integer number of shards for the output files.
  193. """
  194. # Each thread produces N shards where N = num_shards / num_threads. For
  195. # instance, if num_shards = 128, and num_threads = 2, then the first thread
  196. # would produce shards [0, 64).
  197. num_threads = len(ranges)
  198. assert not num_shards % num_threads
  199. num_shards_per_batch = int(num_shards / num_threads)
  200. shard_ranges = np.linspace(ranges[thread_index][0], ranges[thread_index][1],
  201. num_shards_per_batch + 1).astype(int)
  202. num_images_in_thread = ranges[thread_index][1] - ranges[thread_index][0]
  203. counter = 0
  204. for s in xrange(num_shards_per_batch):
  205. # Generate a sharded version of the file name, e.g. 'train-00002-of-00010'
  206. shard = thread_index * num_shards_per_batch + s
  207. output_filename = "%s-%.5d-of-%.5d" % (name, shard, num_shards)
  208. output_file = os.path.join(FLAGS.output_dir, output_filename)
  209. writer = tf.python_io.TFRecordWriter(output_file)
  210. shard_counter = 0
  211. images_in_shard = np.arange(shard_ranges[s], shard_ranges[s + 1], dtype=int)
  212. for i in images_in_shard:
  213. image = images[i]
  214. sequence_example = _to_sequence_example(image, decoder, vocab)
  215. if sequence_example is not None:
  216. writer.write(sequence_example.SerializeToString())
  217. shard_counter += 1
  218. counter += 1
  219. if not counter % 1000:
  220. print("%s [thread %d]: Processed %d of %d items in thread batch." %
  221. (datetime.now(), thread_index, counter, num_images_in_thread))
  222. sys.stdout.flush()
  223. writer.close()
  224. print("%s [thread %d]: Wrote %d image-caption pairs to %s" %
  225. (datetime.now(), thread_index, shard_counter, output_file))
  226. sys.stdout.flush()
  227. shard_counter = 0
  228. print("%s [thread %d]: Wrote %d image-caption pairs to %d shards." %
  229. (datetime.now(), thread_index, counter, num_shards_per_batch))
  230. sys.stdout.flush()
  231. def _process_dataset(name, images, vocab, num_shards):
  232. """Processes a complete data set and saves it as a TFRecord.
  233. Args:
  234. name: Unique identifier specifying the dataset.
  235. images: List of ImageMetadata.
  236. vocab: A Vocabulary object.
  237. num_shards: Integer number of shards for the output files.
  238. """
  239. # Break up each image into a separate entity for each caption.
  240. images = [ImageMetadata(image.image_id, image.filename, [caption])
  241. for image in images for caption in image.captions]
  242. # Shuffle the ordering of images. Make the randomization repeatable.
  243. random.seed(12345)
  244. random.shuffle(images)
  245. # Break the images into num_threads batches. Batch i is defined as
  246. # images[ranges[i][0]:ranges[i][1]].
  247. num_threads = min(num_shards, FLAGS.num_threads)
  248. spacing = np.linspace(0, len(images), num_threads + 1).astype(np.int)
  249. ranges = []
  250. threads = []
  251. for i in xrange(len(spacing) - 1):
  252. ranges.append([spacing[i], spacing[i + 1]])
  253. # Create a mechanism for monitoring when all threads are finished.
  254. coord = tf.train.Coordinator()
  255. # Create a utility for decoding JPEG images to run sanity checks.
  256. decoder = ImageDecoder()
  257. # Launch a thread for each batch.
  258. print("Launching %d threads for spacings: %s" % (num_threads, ranges))
  259. for thread_index in xrange(len(ranges)):
  260. args = (thread_index, ranges, name, images, decoder, vocab, num_shards)
  261. t = threading.Thread(target=_process_image_files, args=args)
  262. t.start()
  263. threads.append(t)
  264. # Wait for all the threads to terminate.
  265. coord.join(threads)
  266. print("%s: Finished processing all %d image-caption pairs in data set '%s'." %
  267. (datetime.now(), len(images), name))
  268. def _create_vocab(captions):
  269. """Creates the vocabulary of word to word_id.
  270. The vocabulary is saved to disk in a text file of word counts. The id of each
  271. word in the file is its corresponding 0-based line number.
  272. Args:
  273. captions: A list of lists of strings.
  274. Returns:
  275. A Vocabulary object.
  276. """
  277. print("Creating vocabulary.")
  278. counter = Counter()
  279. for c in captions:
  280. counter.update(c)
  281. print("Total words:", len(counter))
  282. # Filter uncommon words and sort by descending count.
  283. word_counts = [x for x in counter.items() if x[1] >= FLAGS.min_word_count]
  284. word_counts.sort(key=lambda x: x[1], reverse=True)
  285. print("Words in vocabulary:", len(word_counts))
  286. # Write out the word counts file.
  287. with tf.gfile.FastGFile(FLAGS.word_counts_output_file, "w") as f:
  288. f.write("\n".join(["%s %d" % (w, c) for w, c in word_counts]))
  289. print("Wrote vocabulary file:", FLAGS.word_counts_output_file)
  290. # Create the vocabulary dictionary.
  291. reverse_vocab = [x[0] for x in word_counts]
  292. unk_id = len(reverse_vocab)
  293. vocab_dict = dict([(x, y) for (y, x) in enumerate(reverse_vocab)])
  294. vocab = Vocabulary(vocab_dict, unk_id)
  295. return vocab
  296. def _process_caption(caption):
  297. """Processes a caption string into a list of tonenized words.
  298. Args:
  299. caption: A string caption.
  300. Returns:
  301. A list of strings; the tokenized caption.
  302. """
  303. tokenized_caption = [FLAGS.start_word]
  304. tokenized_caption.extend(nltk.tokenize.word_tokenize(caption.lower()))
  305. tokenized_caption.append(FLAGS.end_word)
  306. return tokenized_caption
  307. def _load_and_process_metadata(captions_file, image_dir):
  308. """Loads image metadata from a JSON file and processes the captions.
  309. Args:
  310. captions_file: JSON file containing caption annotations.
  311. image_dir: Directory containing the image files.
  312. Returns:
  313. A list of ImageMetadata.
  314. """
  315. with tf.gfile.FastGFile(captions_file, "r") as f:
  316. caption_data = json.load(f)
  317. # Extract the filenames.
  318. id_to_filename = [(x["id"], x["file_name"]) for x in caption_data["images"]]
  319. # Extract the captions. Each image_id is associated with multiple captions.
  320. id_to_captions = {}
  321. for annotation in caption_data["annotations"]:
  322. image_id = annotation["image_id"]
  323. caption = annotation["caption"]
  324. id_to_captions.setdefault(image_id, [])
  325. id_to_captions[image_id].append(caption)
  326. assert len(id_to_filename) == len(id_to_captions)
  327. assert set([x[0] for x in id_to_filename]) == set(id_to_captions.keys())
  328. print("Loaded caption metadata for %d images from %s" %
  329. (len(id_to_filename), captions_file))
  330. # Process the captions and combine the data into a list of ImageMetadata.
  331. print("Proccessing captions.")
  332. image_metadata = []
  333. num_captions = 0
  334. for image_id, base_filename in id_to_filename:
  335. filename = os.path.join(image_dir, base_filename)
  336. captions = [_process_caption(c) for c in id_to_captions[image_id]]
  337. image_metadata.append(ImageMetadata(image_id, filename, captions))
  338. num_captions += len(captions)
  339. print("Finished processing %d captions for %d images in %s" %
  340. (num_captions, len(id_to_filename), captions_file))
  341. return image_metadata
  342. def main(unused_argv):
  343. def _is_valid_num_shards(num_shards):
  344. """Returns True if num_shards is compatible with FLAGS.num_threads."""
  345. return num_shards < FLAGS.num_threads or not num_shards % FLAGS.num_threads
  346. assert _is_valid_num_shards(FLAGS.train_shards), (
  347. "Please make the FLAGS.num_threads commensurate with FLAGS.train_shards")
  348. assert _is_valid_num_shards(FLAGS.val_shards), (
  349. "Please make the FLAGS.num_threads commensurate with FLAGS.val_shards")
  350. assert _is_valid_num_shards(FLAGS.test_shards), (
  351. "Please make the FLAGS.num_threads commensurate with FLAGS.test_shards")
  352. if not tf.gfile.IsDirectory(FLAGS.output_dir):
  353. tf.gfile.MakeDirs(FLAGS.output_dir)
  354. # Load image metadata from caption files.
  355. mscoco_train_dataset = _load_and_process_metadata(FLAGS.train_captions_file,
  356. FLAGS.train_image_dir)
  357. mscoco_val_dataset = _load_and_process_metadata(FLAGS.val_captions_file,
  358. FLAGS.val_image_dir)
  359. # Redistribute the MSCOCO data as follows:
  360. # train_dataset = 100% of mscoco_train_dataset + 85% of mscoco_val_dataset.
  361. # val_dataset = 5% of mscoco_val_dataset (for validation during training).
  362. # test_dataset = 10% of mscoco_val_dataset (for final evaluation).
  363. train_cutoff = int(0.85 * len(mscoco_val_dataset))
  364. val_cutoff = int(0.90 * len(mscoco_val_dataset))
  365. train_dataset = mscoco_train_dataset + mscoco_val_dataset[0:train_cutoff]
  366. val_dataset = mscoco_val_dataset[train_cutoff:val_cutoff]
  367. test_dataset = mscoco_val_dataset[val_cutoff:]
  368. # Create vocabulary from the training captions.
  369. train_captions = [c for image in train_dataset for c in image.captions]
  370. vocab = _create_vocab(train_captions)
  371. _process_dataset("train", train_dataset, vocab, FLAGS.train_shards)
  372. _process_dataset("val", val_dataset, vocab, FLAGS.val_shards)
  373. _process_dataset("test", test_dataset, vocab, FLAGS.test_shards)
  374. if __name__ == "__main__":
  375. tf.app.run()