embeddings.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. import os
  2. import shutil
  3. import sys
  4. import tempfile
  5. from typing import List, Tuple
  6. import numpy
  7. from scipy.sparse.csr import csr_matrix
  8. from labours.cors_web_server import web_server
  9. IDEAL_SHARD_SIZE = 4096
  10. def train_embeddings(
  11. index: List[str],
  12. matrix: csr_matrix,
  13. tmpdir: None,
  14. shard_size: int = IDEAL_SHARD_SIZE
  15. ) -> Tuple[List[Tuple[str, numpy.int64]], List[numpy.ndarray]]:
  16. import tensorflow as tf
  17. from labours._vendor import swivel
  18. assert matrix.shape[0] == matrix.shape[1]
  19. assert len(index) <= matrix.shape[0]
  20. outlier_threshold = numpy.percentile(matrix.data, 99)
  21. matrix.data[matrix.data > outlier_threshold] = outlier_threshold
  22. nshards = len(index) // shard_size
  23. if nshards * shard_size < len(index):
  24. nshards += 1
  25. shard_size = len(index) // nshards
  26. nshards = len(index) // shard_size
  27. remainder = len(index) - nshards * shard_size
  28. if remainder > 0:
  29. lengths = matrix.indptr[1:] - matrix.indptr[:-1]
  30. filtered = sorted(numpy.argsort(lengths)[remainder:])
  31. else:
  32. filtered = list(range(len(index)))
  33. if len(filtered) < matrix.shape[0]:
  34. print("Truncating the sparse matrix...")
  35. matrix = matrix[filtered, :][:, filtered]
  36. meta_index = []
  37. for i, j in enumerate(filtered):
  38. meta_index.append((index[j], matrix[i, i]))
  39. index = [mi[0] for mi in meta_index]
  40. with tempfile.TemporaryDirectory(prefix="hercules_labours_", dir=tmpdir or None) as tmproot:
  41. print("Writing Swivel metadata...")
  42. vocabulary = "\n".join(index)
  43. with open(os.path.join(tmproot, "row_vocab.txt"), "w") as out:
  44. out.write(vocabulary)
  45. with open(os.path.join(tmproot, "col_vocab.txt"), "w") as out:
  46. out.write(vocabulary)
  47. del vocabulary
  48. bool_sums = matrix.indptr[1:] - matrix.indptr[:-1]
  49. bool_sums_str = "\n".join(map(str, bool_sums.tolist()))
  50. with open(os.path.join(tmproot, "row_sums.txt"), "w") as out:
  51. out.write(bool_sums_str)
  52. with open(os.path.join(tmproot, "col_sums.txt"), "w") as out:
  53. out.write(bool_sums_str)
  54. del bool_sums_str
  55. reorder = numpy.argsort(-bool_sums)
  56. print("Writing Swivel shards...")
  57. for row in range(nshards):
  58. for col in range(nshards):
  59. def _int64s(xs):
  60. return tf.train.Feature(
  61. int64_list=tf.train.Int64List(value=list(xs)))
  62. def _floats(xs):
  63. return tf.train.Feature(
  64. float_list=tf.train.FloatList(value=list(xs)))
  65. indices_row = reorder[row::nshards]
  66. indices_col = reorder[col::nshards]
  67. shard = matrix[indices_row][:, indices_col].tocoo()
  68. example = tf.train.Example(features=tf.train.Features(feature={
  69. "global_row": _int64s(indices_row),
  70. "global_col": _int64s(indices_col),
  71. "sparse_local_row": _int64s(shard.row),
  72. "sparse_local_col": _int64s(shard.col),
  73. "sparse_value": _floats(shard.data)}))
  74. with open(os.path.join(tmproot, "shard-%03d-%03d.pb" % (row, col)), "wb") as out:
  75. out.write(example.SerializeToString())
  76. print("Training Swivel model...")
  77. swivel.FLAGS.submatrix_rows = shard_size
  78. swivel.FLAGS.submatrix_cols = shard_size
  79. if len(meta_index) <= IDEAL_SHARD_SIZE / 16:
  80. embedding_size = 50
  81. num_epochs = 100000
  82. elif len(meta_index) <= IDEAL_SHARD_SIZE:
  83. embedding_size = 50
  84. num_epochs = 50000
  85. elif len(meta_index) <= IDEAL_SHARD_SIZE * 2:
  86. embedding_size = 60
  87. num_epochs = 10000
  88. elif len(meta_index) <= IDEAL_SHARD_SIZE * 4:
  89. embedding_size = 70
  90. num_epochs = 8000
  91. elif len(meta_index) <= IDEAL_SHARD_SIZE * 10:
  92. embedding_size = 80
  93. num_epochs = 5000
  94. elif len(meta_index) <= IDEAL_SHARD_SIZE * 25:
  95. embedding_size = 100
  96. num_epochs = 1000
  97. elif len(meta_index) <= IDEAL_SHARD_SIZE * 100:
  98. embedding_size = 200
  99. num_epochs = 600
  100. else:
  101. embedding_size = 300
  102. num_epochs = 300
  103. if os.getenv("CI"):
  104. # Travis, AppVeyor etc. during the integration tests
  105. num_epochs /= 10
  106. swivel.FLAGS.embedding_size = embedding_size
  107. swivel.FLAGS.input_base_path = tmproot
  108. swivel.FLAGS.output_base_path = tmproot
  109. swivel.FLAGS.loss_multiplier = 1.0 / shard_size
  110. swivel.FLAGS.num_epochs = num_epochs
  111. # Tensorflow 1.5 parses sys.argv unconditionally *applause*
  112. argv_backup = sys.argv[1:]
  113. del sys.argv[1:]
  114. swivel.main(None)
  115. sys.argv.extend(argv_backup)
  116. print("Reading Swivel embeddings...")
  117. embeddings = []
  118. with open(os.path.join(tmproot, "row_embedding.tsv")) as frow:
  119. with open(os.path.join(tmproot, "col_embedding.tsv")) as fcol:
  120. for i, (lrow, lcol) in enumerate(zip(frow, fcol)):
  121. prow, pcol = (l.split("\t", 1) for l in (lrow, lcol))
  122. assert prow[0] == pcol[0]
  123. erow, ecol = \
  124. (numpy.fromstring(p[1], dtype=numpy.float32, sep="\t")
  125. for p in (prow, pcol))
  126. embeddings.append((erow + ecol) / 2)
  127. return meta_index, embeddings
  128. def write_embeddings(
  129. name: str,
  130. output: str,
  131. run_server: bool,
  132. index: List[Tuple[str, numpy.int64]],
  133. embeddings: List[numpy.ndarray]
  134. ) -> None:
  135. print("Writing Tensorflow Projector files...")
  136. if not output:
  137. output = "couples"
  138. if output.endswith(".json"):
  139. output = os.path.join(output[:-5], "couples")
  140. run_server = False
  141. metaf = "%s_%s_meta.tsv" % (output, name)
  142. with open(metaf, "w") as fout:
  143. fout.write("name\tcommits\n")
  144. for pair in index:
  145. fout.write("%s\t%s\n" % pair)
  146. print("Wrote", metaf)
  147. dataf = "%s_%s_data.tsv" % (output, name)
  148. with open(dataf, "w") as fout:
  149. for vec in embeddings:
  150. fout.write("\t".join(str(v) for v in vec))
  151. fout.write("\n")
  152. print("Wrote", dataf)
  153. jsonf = "%s_%s.json" % (output, name)
  154. with open(jsonf, "w") as fout:
  155. fout.write("""{
  156. "embeddings": [
  157. {
  158. "tensorName": "%s %s coupling",
  159. "tensorShape": [%s, %s],
  160. "tensorPath": "http://0.0.0.0:8000/%s",
  161. "metadataPath": "http://0.0.0.0:8000/%s"
  162. }
  163. ]
  164. }
  165. """ % (output, name, len(embeddings), len(embeddings[0]), dataf, metaf))
  166. print("Wrote %s" % jsonf)
  167. if run_server and not web_server.running:
  168. web_server.start()
  169. url = "http://projector.tensorflow.org/?config=http://0.0.0.0:8000/" + jsonf
  170. print(url)
  171. if run_server:
  172. if shutil.which("xdg-open") is not None:
  173. os.system("xdg-open " + url)
  174. else:
  175. browser = os.getenv("BROWSER", "")
  176. if browser:
  177. os.system(browser + " " + url)
  178. else:
  179. print("\t" + url)