prep.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. #!/usr/bin/env python
  2. #
  3. # Copyright 2016 Google Inc. All Rights Reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """Prepare a corpus for processing by swivel.
  17. Creates a sharded word co-occurrence matrix from a text file input corpus.
  18. Usage:
  19. prep.py --output_dir <output-dir> --input <text-file>
  20. Options:
  21. --input <filename>
  22. The input text.
  23. --output_dir <directory>
  24. Specifies the output directory where the various Swivel data
  25. files should be placed.
  26. --shard_size <int>
  27. Specifies the shard size; default 4096.
  28. --min_count <int>
  29. Specifies the minimum number of times a word should appear
  30. to be included in the vocabulary; default 5.
  31. --max_vocab <int>
  32. Specifies the maximum vocabulary size; default shard size
  33. times 1024.
  34. --vocab <filename>
  35. Use the specified unigram vocabulary instead of generating
  36. it from the corpus.
  37. --window_size <int>
  38. Specifies the window size for computing co-occurrence stats;
  39. default 10.
  40. --bufsz <int>
  41. The number of co-occurrences that are buffered; default 16M.
  42. """
  43. import itertools
  44. import math
  45. import os
  46. import struct
  47. import sys
  48. import tensorflow as tf
  49. flags = tf.app.flags
  50. flags.DEFINE_string('input', '', 'The input text.')
  51. flags.DEFINE_string('output_dir', '/tmp/swivel_data',
  52. 'Output directory for Swivel data')
  53. flags.DEFINE_integer('shard_size', 4096, 'The size for each shard')
  54. flags.DEFINE_integer('min_count', 5,
  55. 'The minimum number of times a word should occur to be '
  56. 'included in the vocabulary')
  57. flags.DEFINE_integer('max_vocab', 4096 * 64, 'The maximum vocabulary size')
  58. flags.DEFINE_string('vocab', '', 'Vocabulary to use instead of generating one')
  59. flags.DEFINE_integer('window_size', 10, 'The window size')
  60. flags.DEFINE_integer('bufsz', 16 * 1024 * 1024,
  61. 'The number of co-occurrences to buffer')
  62. FLAGS = flags.FLAGS
  63. shard_cooc_fmt = struct.Struct('iif')
  64. def words(line):
  65. """Splits a line of text into tokens."""
  66. return line.strip().split()
  67. def create_vocabulary(lines):
  68. """Reads text lines and generates a vocabulary."""
  69. lines.seek(0, os.SEEK_END)
  70. nbytes = lines.tell()
  71. lines.seek(0, os.SEEK_SET)
  72. vocab = {}
  73. for lineno, line in enumerate(lines, start=1):
  74. for word in words(line):
  75. vocab.setdefault(word, 0)
  76. vocab[word] += 1
  77. if lineno % 100000 == 0:
  78. pos = lines.tell()
  79. sys.stdout.write('\rComputing vocabulary: %0.1f%% (%d/%d)...' % (
  80. 100.0 * pos / nbytes, pos, nbytes))
  81. sys.stdout.flush()
  82. sys.stdout.write('\n')
  83. vocab = [(tok, n) for tok, n in vocab.iteritems() if n >= FLAGS.min_count]
  84. vocab.sort(key=lambda kv: (-kv[1], kv[0]))
  85. num_words = max(len(vocab), FLAGS.shard_size)
  86. num_words = min(len(vocab), FLAGS.max_vocab)
  87. if num_words % FLAGS.shard_size != 0:
  88. num_words -= num_words % FLAGS.shard_size
  89. if not num_words:
  90. raise Exception('empty vocabulary')
  91. print 'vocabulary contains %d tokens' % num_words
  92. vocab = vocab[:num_words]
  93. return [tok for tok, n in vocab]
  94. def write_vocab_and_sums(vocab, sums, vocab_filename, sums_filename):
  95. """Writes vocabulary and marginal sum files."""
  96. with open(os.path.join(FLAGS.output_dir, vocab_filename), 'w') as vocab_out:
  97. with open(os.path.join(FLAGS.output_dir, sums_filename), 'w') as sums_out:
  98. for tok, cnt in itertools.izip(vocab, sums):
  99. print >> vocab_out, tok
  100. print >> sums_out, cnt
  101. def compute_coocs(lines, vocab):
  102. """Compute the co-occurrence statistics from the text.
  103. This generates a temporary file for each shard that contains the intermediate
  104. counts from the shard: these counts must be subsequently sorted and collated.
  105. """
  106. word_to_id = {tok: idx for idx, tok in enumerate(vocab)}
  107. lines.seek(0, os.SEEK_END)
  108. nbytes = lines.tell()
  109. lines.seek(0, os.SEEK_SET)
  110. num_shards = len(vocab) / FLAGS.shard_size
  111. shardfiles = {}
  112. for row in range(num_shards):
  113. for col in range(num_shards):
  114. filename = os.path.join(
  115. FLAGS.output_dir, 'shard-%03d-%03d.tmp' % (row, col))
  116. shardfiles[(row, col)] = open(filename, 'w+')
  117. def flush_coocs():
  118. for (row_id, col_id), cnt in coocs.iteritems():
  119. row_shard = row_id % num_shards
  120. row_off = row_id / num_shards
  121. col_shard = col_id % num_shards
  122. col_off = col_id / num_shards
  123. # Since we only stored (a, b), we emit both (a, b) and (b, a).
  124. shardfiles[(row_shard, col_shard)].write(
  125. shard_cooc_fmt.pack(row_off, col_off, cnt))
  126. shardfiles[(col_shard, row_shard)].write(
  127. shard_cooc_fmt.pack(col_off, row_off, cnt))
  128. coocs = {}
  129. sums = [0.0] * len(vocab)
  130. for lineno, line in enumerate(lines, start=1):
  131. # Computes the word IDs for each word in the sentence. This has the effect
  132. # of "stretching" the window past OOV tokens.
  133. wids = filter(
  134. lambda wid: wid is not None,
  135. (word_to_id.get(w) for w in words(line)))
  136. for pos in xrange(len(wids)):
  137. lid = wids[pos]
  138. window_extent = min(FLAGS.window_size + 1, len(wids) - pos)
  139. for off in xrange(1, window_extent):
  140. rid = wids[pos + off]
  141. pair = (min(lid, rid), max(lid, rid))
  142. count = 1.0 / off
  143. sums[lid] += count
  144. sums[rid] += count
  145. coocs.setdefault(pair, 0.0)
  146. coocs[pair] += count
  147. sums[lid] += 1.0
  148. pair = (lid, lid)
  149. coocs.setdefault(pair, 0.0)
  150. coocs[pair] += 0.5 # Only add 1/2 since we output (a, b) and (b, a)
  151. if lineno % 10000 == 0:
  152. pos = lines.tell()
  153. sys.stdout.write('\rComputing co-occurrences: %0.1f%% (%d/%d)...' % (
  154. 100.0 * pos / nbytes, pos, nbytes))
  155. sys.stdout.flush()
  156. if len(coocs) > FLAGS.bufsz:
  157. flush_coocs()
  158. coocs = {}
  159. flush_coocs()
  160. sys.stdout.write('\n')
  161. return shardfiles, sums
  162. def write_shards(vocab, shardfiles):
  163. """Processes the temporary files to generate the final shard data.
  164. The shard data is stored as a tf.Example protos using a TFRecordWriter. The
  165. temporary files are removed from the filesystem once they've been processed.
  166. """
  167. num_shards = len(vocab) / FLAGS.shard_size
  168. ix = 0
  169. for (row, col), fh in shardfiles.iteritems():
  170. ix += 1
  171. sys.stdout.write('\rwriting shard %d/%d' % (ix, len(shardfiles)))
  172. sys.stdout.flush()
  173. # Read the entire binary co-occurrence and unpack it into an array.
  174. fh.seek(0)
  175. buf = fh.read()
  176. os.unlink(fh.name)
  177. fh.close()
  178. coocs = [
  179. shard_cooc_fmt.unpack_from(buf, off)
  180. for off in range(0, len(buf), shard_cooc_fmt.size)]
  181. # Sort and merge co-occurrences for the same pairs.
  182. coocs.sort()
  183. if coocs:
  184. current_pos = 0
  185. current_row_col = (coocs[current_pos][0], coocs[current_pos][1])
  186. for next_pos in range(1, len(coocs)):
  187. next_row_col = (coocs[next_pos][0], coocs[next_pos][1])
  188. if current_row_col == next_row_col:
  189. coocs[current_pos] = (
  190. coocs[current_pos][0],
  191. coocs[current_pos][1],
  192. coocs[current_pos][2] + coocs[next_pos][2])
  193. else:
  194. current_pos += 1
  195. if current_pos < next_pos:
  196. coocs[current_pos] = coocs[next_pos]
  197. current_row_col = (coocs[current_pos][0], coocs[current_pos][1])
  198. coocs = coocs[:(1 + current_pos)]
  199. # Convert to a TF Example proto.
  200. def _int64s(xs):
  201. return tf.train.Feature(int64_list=tf.train.Int64List(value=list(xs)))
  202. def _floats(xs):
  203. return tf.train.Feature(float_list=tf.train.FloatList(value=list(xs)))
  204. example = tf.train.Example(features=tf.train.Features(feature={
  205. 'global_row': _int64s(
  206. row + num_shards * i for i in range(FLAGS.shard_size)),
  207. 'global_col': _int64s(
  208. col + num_shards * i for i in range(FLAGS.shard_size)),
  209. 'sparse_local_row': _int64s(cooc[0] for cooc in coocs),
  210. 'sparse_local_col': _int64s(cooc[1] for cooc in coocs),
  211. 'sparse_value': _floats(cooc[2] for cooc in coocs),
  212. }))
  213. filename = os.path.join(FLAGS.output_dir, 'shard-%03d-%03d.pb' % (row, col))
  214. with open(filename, 'w') as out:
  215. out.write(example.SerializeToString())
  216. sys.stdout.write('\n')
  217. def main(_):
  218. # Create the output directory, if necessary
  219. if FLAGS.output_dir and not os.path.isdir(FLAGS.output_dir):
  220. os.makedirs(FLAGS.output_dir)
  221. # Read the file onces to create the vocabulary.
  222. if FLAGS.vocab:
  223. with open(FLAGS.vocab, 'r') as lines:
  224. vocab = [line.strip() for line in lines]
  225. else:
  226. with open(FLAGS.input, 'r') as lines:
  227. vocab = create_vocabulary(lines)
  228. # Now read the file again to determine the co-occurrence stats.
  229. with open(FLAGS.input, 'r') as lines:
  230. shardfiles, sums = compute_coocs(lines, vocab)
  231. # Collect individual shards into the shards.recs file.
  232. write_shards(vocab, shardfiles)
  233. # Now write the marginals. They're symmetric for this application.
  234. write_vocab_and_sums(vocab, sums, 'row_vocab.txt', 'row_sums.txt')
  235. write_vocab_and_sums(vocab, sums, 'col_vocab.txt', 'col_sums.txt')
  236. print 'done!'
  237. if __name__ == '__main__':
  238. tf.app.run()