glove_to_shards.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. #!/usr/bin/env python
  2. #
  3. # Copyright 2016 Google Inc. All Rights Reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """Converts a Glove binary co-occurrence matrix into Swivel shards.
  17. Usage:
  18. glove_to_shards.py --input <coocs> --vocab <vocab> --output_dir <output_dir>
  19. Options
  20. --input <coocs>
  21. The Glove co-occurrence file.
  22. --vocab <vocab>
  23. Path to the vocabulary text file, one token per line.
  24. --output_dir <directory>
  25. Specifies the touput directory where the various Swivel data
  26. files sohuld be placed.
  27. --shard_size <int>
  28. Specifies the shard size; default 4096.
  29. """
  30. from __future__ import print_function
  31. import itertools
  32. import os
  33. import struct
  34. import sys
  35. import tensorflow as tf
  36. flags = tf.app.flags
  37. flags.DEFINE_string('input', 'coocurrences.bin', 'Vocabulary file')
  38. flags.DEFINE_string('vocab', 'vocab.txt', 'Vocabulary file')
  39. flags.DEFINE_string('output_dir', '/tmp/swivel_data', 'Output directory')
  40. flags.DEFINE_integer('shard_size', 4096, 'Shard size')
  41. FLAGS = tf.app.flags.FLAGS
  42. glove_cooc_fmt = struct.Struct('iid')
  43. shard_cooc_fmt = struct.Struct('if')
  44. def make_shard_files(coocs, nshards, vocab_sz):
  45. """Chops the binary Glove co-occurrence matrix into shards.
  46. This reads the Glove binary co-occurrence file and assigns individual
  47. co-occurrence counts to the appropriate Swivel shard.
  48. Args:
  49. coocs: the co-occurrnece file to read
  50. nshards: the number of shards along one dimension of the square matrix
  51. vocab_sz: the vocabulary size
  52. Returns:
  53. A (shard_table, marginals) tuple. The shard_table maps the row and column
  54. shard ID to a file handle containing the co-occurrences for that shard; the
  55. marginals contain the marginal sums.
  56. """
  57. row_sums = [0] * vocab_sz
  58. col_sums = [0] * vocab_sz
  59. coocs.seek(0, os.SEEK_END)
  60. ncoocs = coocs.tell() / glove_cooc_fmt.size
  61. coocs.seek(0, os.SEEK_SET)
  62. shard_files = {}
  63. for row in range(nshards):
  64. for col in range(nshards):
  65. filename = os.path.join(
  66. FLAGS.output_dir, 'shard-%03d-%03d.bin' % (row, col))
  67. shard_files[(row, col)] = open(filename, 'w+')
  68. for ix in xrange(ncoocs):
  69. if ix % 1000000 == 0:
  70. sys.stdout.write('\rsharding co-occurrences: %0.1f%% (%d/%d)' % (
  71. 100.0 * ix / ncoocs, ix, ncoocs))
  72. sys.stdout.flush()
  73. bits = coocs.read(glove_cooc_fmt.size)
  74. if not bits:
  75. break
  76. # Glove has 1-indexed IDs.
  77. row_id, col_id, cnt = glove_cooc_fmt.unpack(bits)
  78. if row_id > vocab_sz or col_id > vocab_sz:
  79. continue
  80. row_id -= 1
  81. row_shard = row_id % nshards
  82. row_off = row_id / nshards
  83. col_id -= 1
  84. col_shard = col_id % nshards
  85. col_off = col_id / nshards
  86. shard_pos = row_off * FLAGS.shard_size + col_off # row major
  87. shard_files[(row_shard, col_shard)].write(
  88. shard_cooc_fmt.pack(shard_pos, cnt))
  89. # Accumulate marginals.
  90. row_sums[row_id] += cnt
  91. col_sums[col_id] += cnt
  92. sys.stdout.write('\n')
  93. if any(abs(r - c) > 0.1 for r, c in itertools.izip(row_sums, col_sums)):
  94. print('WARNING! Row and column marginals differ; is your matrix symmetric?',
  95. file=sys.stderr)
  96. return (shard_files, row_sums)
  97. def main(_):
  98. with open(FLAGS.vocab, 'r') as lines:
  99. orig_vocab_sz = sum(1 for _ in lines)
  100. shard_sz = FLAGS.shard_size
  101. vocab_sz = orig_vocab_sz - orig_vocab_sz % shard_sz
  102. nshards = vocab_sz / shard_sz
  103. print('vocab size is %d (originally %d), %d %dx%d-element shards' % (
  104. vocab_sz, orig_vocab_sz, nshards * nshards, shard_sz, shard_sz))
  105. # Create the output directory, if necessary
  106. if FLAGS.output_dir and not os.path.isdir(FLAGS.output_dir):
  107. os.makedirs(FLAGS.output_dir)
  108. with open(FLAGS.input, 'r') as coocs:
  109. shard_files, marginals = make_shard_files(coocs, nshards, vocab_sz)
  110. # Now sort the shards and write the TFRecords.
  111. filename = os.path.join(FLAGS.output_dir, 'shards.recs')
  112. with tf.python_io.TFRecordWriter(filename) as writer:
  113. ix = 0
  114. for (row, col), fh in shard_files.iteritems():
  115. ix += 1
  116. sys.stdout.write('\rwriting shard %d/%d' % (ix, len(shard_files)))
  117. sys.stdout.flush()
  118. fh.seek(0)
  119. buf = fh.read()
  120. os.unlink(fh.name)
  121. fh.close()
  122. coocs = [
  123. shard_cooc_fmt.unpack_from(buf, off)
  124. for off in range(0, len(buf), shard_cooc_fmt.size)]
  125. # N.B. we assume that there aren't any duplicates here!
  126. coocs.sort(key=lambda kv: kv[0])
  127. def _int64s(xs):
  128. return tf.train.Feature(int64_list=tf.train.Int64List(value=list(xs)))
  129. def _floats(xs):
  130. return tf.train.Feature(float_list=tf.train.FloatList(value=list(xs)))
  131. example = tf.train.Example(features=tf.train.Features(feature={
  132. 'global_row': _int64s(row + nshards * i for i in range(shard_sz)),
  133. 'global_col': _int64s(col + nshards * i for i in range(shard_sz)),
  134. 'sparse_local_row': _int64s(pos / shard_sz for pos, _ in coocs),
  135. 'sparse_local_col': _int64s(pos % shard_sz for pos, _ in coocs),
  136. 'sparse_value': _floats(cnt for _, cnt in coocs)}))
  137. writer.write(example.SerializeToString())
  138. print('\nwriting marginals...')
  139. with open(os.path.join(FLAGS.output_dir, 'marginals.txt'), 'w') as fh:
  140. for cnt in marginals:
  141. fh.write('%0.1f\n' % cnt)
  142. print('done!')
  143. if __name__ == '__main__':
  144. tf.app.run()