reader.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Utilities for parsing PTB text files."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import collections
  20. import os
  21. import tensorflow as tf
  22. def _read_words(filename):
  23. with tf.gfile.GFile(filename, "r") as f:
  24. return f.read().decode("utf-8").replace("\n", "<eos>").split()
  25. def _build_vocab(filename):
  26. data = _read_words(filename)
  27. counter = collections.Counter(data)
  28. count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
  29. words, _ = list(zip(*count_pairs))
  30. word_to_id = dict(zip(words, range(len(words))))
  31. return word_to_id
  32. def _file_to_word_ids(filename, word_to_id):
  33. data = _read_words(filename)
  34. return [word_to_id[word] for word in data if word in word_to_id]
  35. def ptb_raw_data(data_path=None):
  36. """Load PTB raw data from data directory "data_path".
  37. Reads PTB text files, converts strings to integer ids,
  38. and performs mini-batching of the inputs.
  39. The PTB dataset comes from Tomas Mikolov's webpage:
  40. http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
  41. Args:
  42. data_path: string path to the directory where simple-examples.tgz has
  43. been extracted.
  44. Returns:
  45. tuple (train_data, valid_data, test_data, vocabulary)
  46. where each of the data objects can be passed to PTBIterator.
  47. """
  48. train_path = os.path.join(data_path, "ptb.train.txt")
  49. valid_path = os.path.join(data_path, "ptb.valid.txt")
  50. test_path = os.path.join(data_path, "ptb.test.txt")
  51. word_to_id = _build_vocab(train_path)
  52. train_data = _file_to_word_ids(train_path, word_to_id)
  53. valid_data = _file_to_word_ids(valid_path, word_to_id)
  54. test_data = _file_to_word_ids(test_path, word_to_id)
  55. vocabulary = len(word_to_id)
  56. return train_data, valid_data, test_data, vocabulary
  57. def ptb_producer(raw_data, batch_size, num_steps, name=None):
  58. """Iterate on the raw PTB data.
  59. This chunks up raw_data into batches of examples and returns Tensors that
  60. are drawn from these batches.
  61. Args:
  62. raw_data: one of the raw data outputs from ptb_raw_data.
  63. batch_size: int, the batch size.
  64. num_steps: int, the number of unrolls.
  65. name: the name of this operation (optional).
  66. Returns:
  67. A pair of Tensors, each shaped [batch_size, num_steps]. The second element
  68. of the tuple is the same data time-shifted to the right by one.
  69. Raises:
  70. tf.errors.InvalidArgumentError: if batch_size or num_steps are too high.
  71. """
  72. with tf.name_scope(name, "PTBProducer", [raw_data, batch_size, num_steps]):
  73. raw_data = tf.convert_to_tensor(raw_data, name="raw_data", dtype=tf.int32)
  74. data_len = tf.size(raw_data)
  75. batch_len = data_len // batch_size
  76. data = tf.reshape(raw_data[0 : batch_size * batch_len],
  77. [batch_size, batch_len])
  78. epoch_size = (batch_len - 1) // num_steps
  79. assertion = tf.assert_positive(
  80. epoch_size,
  81. message="epoch_size == 0, decrease batch_size or num_steps")
  82. with tf.control_dependencies([assertion]):
  83. epoch_size = tf.identity(epoch_size, name="epoch_size")
  84. i = tf.train.range_input_producer(epoch_size, shuffle=False).dequeue()
  85. x = tf.strided_slice(data, [0, i * num_steps],
  86. [batch_size, (i + 1) * num_steps])
  87. x.set_shape([batch_size, num_steps])
  88. y = tf.strided_slice(data, [0, i * num_steps + 1],
  89. [batch_size, (i + 1) * num_steps + 1])
  90. y.set_shape([batch_size, num_steps])
  91. return x, y