embeddings.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. import os
  2. import shutil
  3. import sys
  4. import tempfile
  5. from typing import List, Tuple
  6. import numpy
  7. from scipy.sparse.csr import csr_matrix
  8. from labours.cors_web_server import web_server
  9. IDEAL_SHARD_SIZE = 4096
  10. def train_embeddings(
  11. index: List[str],
  12. matrix: csr_matrix,
  13. tmpdir: None,
  14. shard_size: int = IDEAL_SHARD_SIZE,
  15. ) -> Tuple[List[Tuple[str, numpy.int64]], List[numpy.ndarray]]:
  16. import tensorflow as tf
  17. from labours._vendor import swivel
  18. assert matrix.shape[0] == matrix.shape[1]
  19. assert len(index) <= matrix.shape[0]
  20. outlier_threshold = numpy.percentile(matrix.data, 99)
  21. matrix.data[matrix.data > outlier_threshold] = outlier_threshold
  22. nshards = len(index) // shard_size
  23. if nshards * shard_size < len(index):
  24. nshards += 1
  25. shard_size = len(index) // nshards
  26. nshards = len(index) // shard_size
  27. remainder = len(index) - nshards * shard_size
  28. if remainder > 0:
  29. lengths = matrix.indptr[1:] - matrix.indptr[:-1]
  30. filtered = sorted(numpy.argsort(lengths)[remainder:])
  31. else:
  32. filtered = list(range(len(index)))
  33. if len(filtered) < matrix.shape[0]:
  34. print("Truncating the sparse matrix...")
  35. matrix = matrix[filtered, :][:, filtered]
  36. meta_index = []
  37. for i, j in enumerate(filtered):
  38. meta_index.append((index[j], matrix[i, i]))
  39. index = [mi[0] for mi in meta_index]
  40. with tempfile.TemporaryDirectory(
  41. prefix="hercules_labours_", dir=tmpdir or None
  42. ) as tmproot:
  43. print("Writing Swivel metadata...")
  44. vocabulary = "\n".join(index)
  45. with open(os.path.join(tmproot, "row_vocab.txt"), "w") as out:
  46. out.write(vocabulary)
  47. with open(os.path.join(tmproot, "col_vocab.txt"), "w") as out:
  48. out.write(vocabulary)
  49. del vocabulary
  50. bool_sums = matrix.indptr[1:] - matrix.indptr[:-1]
  51. bool_sums_str = "\n".join(map(str, bool_sums.tolist()))
  52. with open(os.path.join(tmproot, "row_sums.txt"), "w") as out:
  53. out.write(bool_sums_str)
  54. with open(os.path.join(tmproot, "col_sums.txt"), "w") as out:
  55. out.write(bool_sums_str)
  56. del bool_sums_str
  57. reorder = numpy.argsort(-bool_sums)
  58. print("Writing Swivel shards...")
  59. for row in range(nshards):
  60. for col in range(nshards):
  61. def _int64s(xs):
  62. return tf.train.Feature(
  63. int64_list=tf.train.Int64List(value=list(xs))
  64. )
  65. def _floats(xs):
  66. return tf.train.Feature(
  67. float_list=tf.train.FloatList(value=list(xs))
  68. )
  69. indices_row = reorder[row::nshards]
  70. indices_col = reorder[col::nshards]
  71. shard = matrix[indices_row][:, indices_col].tocoo()
  72. example = tf.train.Example(
  73. features=tf.train.Features(
  74. feature={
  75. "global_row": _int64s(indices_row),
  76. "global_col": _int64s(indices_col),
  77. "sparse_local_row": _int64s(shard.row),
  78. "sparse_local_col": _int64s(shard.col),
  79. "sparse_value": _floats(shard.data),
  80. }
  81. )
  82. )
  83. with open(
  84. os.path.join(tmproot, "shard-%03d-%03d.pb" % (row, col)), "wb"
  85. ) as out:
  86. out.write(example.SerializeToString())
  87. print("Training Swivel model...")
  88. swivel.FLAGS.submatrix_rows = shard_size
  89. swivel.FLAGS.submatrix_cols = shard_size
  90. if len(meta_index) <= IDEAL_SHARD_SIZE / 16:
  91. embedding_size = 50
  92. num_epochs = 100000
  93. elif len(meta_index) <= IDEAL_SHARD_SIZE:
  94. embedding_size = 50
  95. num_epochs = 50000
  96. elif len(meta_index) <= IDEAL_SHARD_SIZE * 2:
  97. embedding_size = 60
  98. num_epochs = 10000
  99. elif len(meta_index) <= IDEAL_SHARD_SIZE * 4:
  100. embedding_size = 70
  101. num_epochs = 8000
  102. elif len(meta_index) <= IDEAL_SHARD_SIZE * 10:
  103. embedding_size = 80
  104. num_epochs = 5000
  105. elif len(meta_index) <= IDEAL_SHARD_SIZE * 25:
  106. embedding_size = 100
  107. num_epochs = 1000
  108. elif len(meta_index) <= IDEAL_SHARD_SIZE * 100:
  109. embedding_size = 200
  110. num_epochs = 600
  111. else:
  112. embedding_size = 300
  113. num_epochs = 300
  114. if os.getenv("CI"):
  115. # Travis, AppVeyor etc. during the integration tests
  116. num_epochs /= 10
  117. swivel.FLAGS.embedding_size = embedding_size
  118. swivel.FLAGS.input_base_path = tmproot
  119. swivel.FLAGS.output_base_path = tmproot
  120. swivel.FLAGS.loss_multiplier = 1.0 / shard_size
  121. swivel.FLAGS.num_epochs = num_epochs
  122. # Tensorflow 1.5 parses sys.argv unconditionally *applause*
  123. argv_backup = sys.argv[1:]
  124. del sys.argv[1:]
  125. swivel.main(None)
  126. sys.argv.extend(argv_backup)
  127. print("Reading Swivel embeddings...")
  128. embeddings = []
  129. with open(os.path.join(tmproot, "row_embedding.tsv")) as frow:
  130. with open(os.path.join(tmproot, "col_embedding.tsv")) as fcol:
  131. for i, (lrow, lcol) in enumerate(zip(frow, fcol)):
  132. prow, pcol = (l.split("\t", 1) for l in (lrow, lcol))
  133. assert prow[0] == pcol[0]
  134. erow, ecol = (
  135. numpy.fromstring(p[1], dtype=numpy.float32, sep="\t")
  136. for p in (prow, pcol)
  137. )
  138. embeddings.append((erow + ecol) / 2)
  139. return meta_index, embeddings
  140. def write_embeddings(
  141. name: str,
  142. output: str,
  143. run_server: bool,
  144. index: List[Tuple[str, numpy.int64]],
  145. embeddings: List[numpy.ndarray],
  146. ) -> None:
  147. print("Writing Tensorflow Projector files...")
  148. if not output:
  149. output = "couples"
  150. if output.endswith(".json"):
  151. output = os.path.join(output[:-5], "couples")
  152. run_server = False
  153. metaf = "%s_%s_meta.tsv" % (output, name)
  154. with open(metaf, "w") as fout:
  155. fout.write("name\tcommits\n")
  156. for pair in index:
  157. fout.write("%s\t%s\n" % pair)
  158. print("Wrote", metaf)
  159. dataf = "%s_%s_data.tsv" % (output, name)
  160. with open(dataf, "w") as fout:
  161. for vec in embeddings:
  162. fout.write("\t".join(str(v) for v in vec))
  163. fout.write("\n")
  164. print("Wrote", dataf)
  165. jsonf = "%s_%s.json" % (output, name)
  166. with open(jsonf, "w") as fout:
  167. fout.write(
  168. """{
  169. "embeddings": [
  170. {
  171. "tensorName": "%s %s coupling",
  172. "tensorShape": [%s, %s],
  173. "tensorPath": "http://0.0.0.0:8000/%s",
  174. "metadataPath": "http://0.0.0.0:8000/%s"
  175. }
  176. ]
  177. }
  178. """
  179. % (output, name, len(embeddings), len(embeddings[0]), dataf, metaf)
  180. )
  181. print("Wrote %s" % jsonf)
  182. if run_server and not web_server.running:
  183. web_server.start()
  184. url = "http://projector.tensorflow.org/?config=http://0.0.0.0:8000/" + jsonf
  185. print(url)
  186. if run_server:
  187. if shutil.which("xdg-open") is not None:
  188. os.system("xdg-open " + url)
  189. else:
  190. browser = os.getenv("BROWSER", "")
  191. if browser:
  192. os.system(browser + " " + url)
  193. else:
  194. print("\t" + url)