prep.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. #!/usr/bin/env python
  2. #
  3. # Copyright 2016 Google Inc. All Rights Reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """Prepare a corpus for processing by swivel.
  17. Creates a sharded word co-occurrence matrix from a text file input corpus.
  18. Usage:
  19. prep.py --output_dir <output-dir> --input <text-file>
  20. Options:
  21. --input <filename>
  22. The input text.
  23. --output_dir <directory>
  24. Specifies the output directory where the various Swivel data
  25. files should be placed.
  26. --shard_size <int>
  27. Specifies the shard size; default 4096.
  28. --min_count <int>
  29. Specifies the minimum number of times a word should appear
  30. to be included in the vocabulary; default 5.
  31. --max_vocab <int>
  32. Specifies the maximum vocabulary size; default shard size
  33. times 1024.
  34. --vocab <filename>
  35. Use the specified unigram vocabulary instead of generating
  36. it from the corpus.
  37. --window_size <int>
  38. Specifies the window size for computing co-occurrence stats;
  39. default 10.
  40. --bufsz <int>
  41. The number of co-occurrences that are buffered; default 16M.
  42. """
  43. import itertools
  44. import math
  45. import os
  46. import struct
  47. import sys
  48. import tensorflow as tf
  49. flags = tf.app.flags
  50. flags.DEFINE_string('input', '', 'The input text.')
  51. flags.DEFINE_string('output_dir', '/tmp/swivel_data',
  52. 'Output directory for Swivel data')
  53. flags.DEFINE_integer('shard_size', 4096, 'The size for each shard')
  54. flags.DEFINE_integer('min_count', 5,
  55. 'The minimum number of times a word should occur to be '
  56. 'included in the vocabulary')
  57. flags.DEFINE_integer('max_vocab', 4096 * 64, 'The maximum vocabulary size')
  58. flags.DEFINE_string('vocab', '', 'Vocabulary to use instead of generating one')
  59. flags.DEFINE_integer('window_size', 10, 'The window size')
  60. flags.DEFINE_integer('bufsz', 16 * 1024 * 1024,
  61. 'The number of co-occurrences to buffer')
  62. FLAGS = flags.FLAGS
  63. shard_cooc_fmt = struct.Struct('iif')
  64. def words(line):
  65. """Splits a line of text into tokens."""
  66. return line.strip().split()
  67. def create_vocabulary(lines):
  68. """Reads text lines and generates a vocabulary."""
  69. lines.seek(0, os.SEEK_END)
  70. nbytes = lines.tell()
  71. lines.seek(0, os.SEEK_SET)
  72. vocab = {}
  73. for lineno, line in enumerate(lines, start=1):
  74. for word in words(line):
  75. vocab.setdefault(word, 0)
  76. vocab[word] += 1
  77. if lineno % 100000 == 0:
  78. pos = lines.tell()
  79. sys.stdout.write('\rComputing vocabulary: %0.1f%% (%d/%d)...' % (
  80. 100.0 * pos / nbytes, pos, nbytes))
  81. sys.stdout.flush()
  82. sys.stdout.write('\n')
  83. vocab = [(tok, n) for tok, n in vocab.iteritems() if n >= FLAGS.min_count]
  84. vocab.sort(key=lambda kv: (-kv[1], kv[0]))
  85. num_words = min(len(vocab), FLAGS.max_vocab)
  86. if num_words % FLAGS.shard_size != 0:
  87. num_words -= num_words % FLAGS.shard_size
  88. if not num_words:
  89. raise Exception('empty vocabulary')
  90. print 'vocabulary contains %d tokens' % num_words
  91. vocab = vocab[:num_words]
  92. return [tok for tok, n in vocab]
  93. def write_vocab_and_sums(vocab, sums, vocab_filename, sums_filename):
  94. """Writes vocabulary and marginal sum files."""
  95. with open(os.path.join(FLAGS.output_dir, vocab_filename), 'w') as vocab_out:
  96. with open(os.path.join(FLAGS.output_dir, sums_filename), 'w') as sums_out:
  97. for tok, cnt in itertools.izip(vocab, sums):
  98. print >> vocab_out, tok
  99. print >> sums_out, cnt
  100. def compute_coocs(lines, vocab):
  101. """Compute the co-occurrence statistics from the text.
  102. This generates a temporary file for each shard that contains the intermediate
  103. counts from the shard: these counts must be subsequently sorted and collated.
  104. """
  105. word_to_id = {tok: idx for idx, tok in enumerate(vocab)}
  106. lines.seek(0, os.SEEK_END)
  107. nbytes = lines.tell()
  108. lines.seek(0, os.SEEK_SET)
  109. num_shards = len(vocab) / FLAGS.shard_size
  110. shardfiles = {}
  111. for row in range(num_shards):
  112. for col in range(num_shards):
  113. filename = os.path.join(
  114. FLAGS.output_dir, 'shard-%03d-%03d.tmp' % (row, col))
  115. shardfiles[(row, col)] = open(filename, 'w+')
  116. def flush_coocs():
  117. for (row_id, col_id), cnt in coocs.iteritems():
  118. row_shard = row_id % num_shards
  119. row_off = row_id / num_shards
  120. col_shard = col_id % num_shards
  121. col_off = col_id / num_shards
  122. # Since we only stored (a, b), we emit both (a, b) and (b, a).
  123. shardfiles[(row_shard, col_shard)].write(
  124. shard_cooc_fmt.pack(row_off, col_off, cnt))
  125. shardfiles[(col_shard, row_shard)].write(
  126. shard_cooc_fmt.pack(col_off, row_off, cnt))
  127. coocs = {}
  128. sums = [0.0] * len(vocab)
  129. for lineno, line in enumerate(lines, start=1):
  130. # Computes the word IDs for each word in the sentence. This has the effect
  131. # of "stretching" the window past OOV tokens.
  132. wids = filter(
  133. lambda wid: wid is not None,
  134. (word_to_id.get(w) for w in words(line)))
  135. for pos in xrange(len(wids)):
  136. lid = wids[pos]
  137. window_extent = min(FLAGS.window_size + 1, len(wids) - pos)
  138. for off in xrange(1, window_extent):
  139. rid = wids[pos + off]
  140. pair = (min(lid, rid), max(lid, rid))
  141. count = 1.0 / off
  142. sums[lid] += count
  143. sums[rid] += count
  144. coocs.setdefault(pair, 0.0)
  145. coocs[pair] += count
  146. sums[lid] += 1.0
  147. pair = (lid, lid)
  148. coocs.setdefault(pair, 0.0)
  149. coocs[pair] += 0.5 # Only add 1/2 since we output (a, b) and (b, a)
  150. if lineno % 10000 == 0:
  151. pos = lines.tell()
  152. sys.stdout.write('\rComputing co-occurrences: %0.1f%% (%d/%d)...' % (
  153. 100.0 * pos / nbytes, pos, nbytes))
  154. sys.stdout.flush()
  155. if len(coocs) > FLAGS.bufsz:
  156. flush_coocs()
  157. coocs = {}
  158. flush_coocs()
  159. sys.stdout.write('\n')
  160. return shardfiles, sums
  161. def write_shards(vocab, shardfiles):
  162. """Processes the temporary files to generate the final shard data.
  163. The shard data is stored as a tf.Example protos using a TFRecordWriter. The
  164. temporary files are removed from the filesystem once they've been processed.
  165. """
  166. num_shards = len(vocab) / FLAGS.shard_size
  167. ix = 0
  168. for (row, col), fh in shardfiles.iteritems():
  169. ix += 1
  170. sys.stdout.write('\rwriting shard %d/%d' % (ix, len(shardfiles)))
  171. sys.stdout.flush()
  172. # Read the entire binary co-occurrence and unpack it into an array.
  173. fh.seek(0)
  174. buf = fh.read()
  175. os.unlink(fh.name)
  176. fh.close()
  177. coocs = [
  178. shard_cooc_fmt.unpack_from(buf, off)
  179. for off in range(0, len(buf), shard_cooc_fmt.size)]
  180. # Sort and merge co-occurrences for the same pairs.
  181. coocs.sort()
  182. if coocs:
  183. current_pos = 0
  184. current_row_col = (coocs[current_pos][0], coocs[current_pos][1])
  185. for next_pos in range(1, len(coocs)):
  186. next_row_col = (coocs[next_pos][0], coocs[next_pos][1])
  187. if current_row_col == next_row_col:
  188. coocs[current_pos] = (
  189. coocs[current_pos][0],
  190. coocs[current_pos][1],
  191. coocs[current_pos][2] + coocs[next_pos][2])
  192. else:
  193. current_pos += 1
  194. if current_pos < next_pos:
  195. coocs[current_pos] = coocs[next_pos]
  196. current_row_col = (coocs[current_pos][0], coocs[current_pos][1])
  197. coocs = coocs[:(1 + current_pos)]
  198. # Convert to a TF Example proto.
  199. def _int64s(xs):
  200. return tf.train.Feature(int64_list=tf.train.Int64List(value=list(xs)))
  201. def _floats(xs):
  202. return tf.train.Feature(float_list=tf.train.FloatList(value=list(xs)))
  203. example = tf.train.Example(features=tf.train.Features(feature={
  204. 'global_row': _int64s(
  205. row + num_shards * i for i in range(FLAGS.shard_size)),
  206. 'global_col': _int64s(
  207. col + num_shards * i for i in range(FLAGS.shard_size)),
  208. 'sparse_local_row': _int64s(cooc[0] for cooc in coocs),
  209. 'sparse_local_col': _int64s(cooc[1] for cooc in coocs),
  210. 'sparse_value': _floats(cooc[2] for cooc in coocs),
  211. }))
  212. filename = os.path.join(FLAGS.output_dir, 'shard-%03d-%03d.pb' % (row, col))
  213. with open(filename, 'w') as out:
  214. out.write(example.SerializeToString())
  215. sys.stdout.write('\n')
  216. def main(_):
  217. # Create the output directory, if necessary
  218. if FLAGS.output_dir and not os.path.isdir(FLAGS.output_dir):
  219. os.makedirs(FLAGS.output_dir)
  220. # Read the file onces to create the vocabulary.
  221. if FLAGS.vocab:
  222. with open(FLAGS.vocab, 'r') as lines:
  223. vocab = [line.strip() for line in lines]
  224. else:
  225. with open(FLAGS.input, 'r') as lines:
  226. vocab = create_vocabulary(lines)
  227. # Now read the file again to determine the co-occurrence stats.
  228. with open(FLAGS.input, 'r') as lines:
  229. shardfiles, sums = compute_coocs(lines, vocab)
  230. # Collect individual shards into the shards.recs file.
  231. write_shards(vocab, shardfiles)
  232. # Now write the marginals. They're symmetric for this application.
  233. write_vocab_and_sums(vocab, sums, 'row_vocab.txt', 'row_sums.txt')
  234. write_vocab_and_sums(vocab, sums, 'col_vocab.txt', 'col_sums.txt')
  235. print 'done!'
  236. if __name__ == '__main__':
  237. tf.app.run()