embeddings.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. import os
  2. import shutil
  3. import sys
  4. import tempfile
  5. import numpy
  6. from labours.cors_web_server import web_server
  7. IDEAL_SHARD_SIZE = 4096
  8. def train_embeddings(index, matrix, tmpdir, shard_size=IDEAL_SHARD_SIZE):
  9. import tensorflow as tf
  10. from labours._vendor import swivel
  11. assert matrix.shape[0] == matrix.shape[1]
  12. assert len(index) <= matrix.shape[0]
  13. outlier_threshold = numpy.percentile(matrix.data, 99)
  14. matrix.data[matrix.data > outlier_threshold] = outlier_threshold
  15. nshards = len(index) // shard_size
  16. if nshards * shard_size < len(index):
  17. nshards += 1
  18. shard_size = len(index) // nshards
  19. nshards = len(index) // shard_size
  20. remainder = len(index) - nshards * shard_size
  21. if remainder > 0:
  22. lengths = matrix.indptr[1:] - matrix.indptr[:-1]
  23. filtered = sorted(numpy.argsort(lengths)[remainder:])
  24. else:
  25. filtered = list(range(len(index)))
  26. if len(filtered) < matrix.shape[0]:
  27. print("Truncating the sparse matrix...")
  28. matrix = matrix[filtered, :][:, filtered]
  29. meta_index = []
  30. for i, j in enumerate(filtered):
  31. meta_index.append((index[j], matrix[i, i]))
  32. index = [mi[0] for mi in meta_index]
  33. with tempfile.TemporaryDirectory(prefix="hercules_labours_", dir=tmpdir or None) as tmproot:
  34. print("Writing Swivel metadata...")
  35. vocabulary = "\n".join(index)
  36. with open(os.path.join(tmproot, "row_vocab.txt"), "w") as out:
  37. out.write(vocabulary)
  38. with open(os.path.join(tmproot, "col_vocab.txt"), "w") as out:
  39. out.write(vocabulary)
  40. del vocabulary
  41. bool_sums = matrix.indptr[1:] - matrix.indptr[:-1]
  42. bool_sums_str = "\n".join(map(str, bool_sums.tolist()))
  43. with open(os.path.join(tmproot, "row_sums.txt"), "w") as out:
  44. out.write(bool_sums_str)
  45. with open(os.path.join(tmproot, "col_sums.txt"), "w") as out:
  46. out.write(bool_sums_str)
  47. del bool_sums_str
  48. reorder = numpy.argsort(-bool_sums)
  49. print("Writing Swivel shards...")
  50. for row in range(nshards):
  51. for col in range(nshards):
  52. def _int64s(xs):
  53. return tf.train.Feature(
  54. int64_list=tf.train.Int64List(value=list(xs)))
  55. def _floats(xs):
  56. return tf.train.Feature(
  57. float_list=tf.train.FloatList(value=list(xs)))
  58. indices_row = reorder[row::nshards]
  59. indices_col = reorder[col::nshards]
  60. shard = matrix[indices_row][:, indices_col].tocoo()
  61. example = tf.train.Example(features=tf.train.Features(feature={
  62. "global_row": _int64s(indices_row),
  63. "global_col": _int64s(indices_col),
  64. "sparse_local_row": _int64s(shard.row),
  65. "sparse_local_col": _int64s(shard.col),
  66. "sparse_value": _floats(shard.data)}))
  67. with open(os.path.join(tmproot, "shard-%03d-%03d.pb" % (row, col)), "wb") as out:
  68. out.write(example.SerializeToString())
  69. print("Training Swivel model...")
  70. swivel.FLAGS.submatrix_rows = shard_size
  71. swivel.FLAGS.submatrix_cols = shard_size
  72. if len(meta_index) <= IDEAL_SHARD_SIZE / 16:
  73. embedding_size = 50
  74. num_epochs = 100000
  75. elif len(meta_index) <= IDEAL_SHARD_SIZE:
  76. embedding_size = 50
  77. num_epochs = 50000
  78. elif len(meta_index) <= IDEAL_SHARD_SIZE * 2:
  79. embedding_size = 60
  80. num_epochs = 10000
  81. elif len(meta_index) <= IDEAL_SHARD_SIZE * 4:
  82. embedding_size = 70
  83. num_epochs = 8000
  84. elif len(meta_index) <= IDEAL_SHARD_SIZE * 10:
  85. embedding_size = 80
  86. num_epochs = 5000
  87. elif len(meta_index) <= IDEAL_SHARD_SIZE * 25:
  88. embedding_size = 100
  89. num_epochs = 1000
  90. elif len(meta_index) <= IDEAL_SHARD_SIZE * 100:
  91. embedding_size = 200
  92. num_epochs = 600
  93. else:
  94. embedding_size = 300
  95. num_epochs = 300
  96. if os.getenv("CI"):
  97. # Travis, AppVeyor etc. during the integration tests
  98. num_epochs /= 10
  99. swivel.FLAGS.embedding_size = embedding_size
  100. swivel.FLAGS.input_base_path = tmproot
  101. swivel.FLAGS.output_base_path = tmproot
  102. swivel.FLAGS.loss_multiplier = 1.0 / shard_size
  103. swivel.FLAGS.num_epochs = num_epochs
  104. # Tensorflow 1.5 parses sys.argv unconditionally *applause*
  105. argv_backup = sys.argv[1:]
  106. del sys.argv[1:]
  107. swivel.main(None)
  108. sys.argv.extend(argv_backup)
  109. print("Reading Swivel embeddings...")
  110. embeddings = []
  111. with open(os.path.join(tmproot, "row_embedding.tsv")) as frow:
  112. with open(os.path.join(tmproot, "col_embedding.tsv")) as fcol:
  113. for i, (lrow, lcol) in enumerate(zip(frow, fcol)):
  114. prow, pcol = (l.split("\t", 1) for l in (lrow, lcol))
  115. assert prow[0] == pcol[0]
  116. erow, ecol = \
  117. (numpy.fromstring(p[1], dtype=numpy.float32, sep="\t")
  118. for p in (prow, pcol))
  119. embeddings.append((erow + ecol) / 2)
  120. return meta_index, embeddings
  121. def write_embeddings(name, output, run_server, index, embeddings):
  122. print("Writing Tensorflow Projector files...")
  123. if not output:
  124. output = "couples"
  125. if output.endswith(".json"):
  126. output = os.path.join(output[:-5], "couples")
  127. run_server = False
  128. metaf = "%s_%s_meta.tsv" % (output, name)
  129. with open(metaf, "w") as fout:
  130. fout.write("name\tcommits\n")
  131. for pair in index:
  132. fout.write("%s\t%s\n" % pair)
  133. print("Wrote", metaf)
  134. dataf = "%s_%s_data.tsv" % (output, name)
  135. with open(dataf, "w") as fout:
  136. for vec in embeddings:
  137. fout.write("\t".join(str(v) for v in vec))
  138. fout.write("\n")
  139. print("Wrote", dataf)
  140. jsonf = "%s_%s.json" % (output, name)
  141. with open(jsonf, "w") as fout:
  142. fout.write("""{
  143. "embeddings": [
  144. {
  145. "tensorName": "%s %s coupling",
  146. "tensorShape": [%s, %s],
  147. "tensorPath": "http://0.0.0.0:8000/%s",
  148. "metadataPath": "http://0.0.0.0:8000/%s"
  149. }
  150. ]
  151. }
  152. """ % (output, name, len(embeddings), len(embeddings[0]), dataf, metaf))
  153. print("Wrote %s" % jsonf)
  154. if run_server and not web_server.running:
  155. web_server.start()
  156. url = "http://projector.tensorflow.org/?config=http://0.0.0.0:8000/" + jsonf
  157. print(url)
  158. if run_server:
  159. if shutil.which("xdg-open") is not None:
  160. os.system("xdg-open " + url)
  161. else:
  162. browser = os.getenv("BROWSER", "")
  163. if browser:
  164. os.system(browser + " " + url)
  165. else:
  166. print("\t" + url)