labours.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577
  1. import argparse
  2. from datetime import datetime, timedelta
  3. import io
  4. import os
  5. import re
  6. import sys
  7. import tempfile
  8. import warnings
  9. try:
  10. from clint.textui import progress
  11. except ImportError:
  12. print("Warning: clint is not installed, no fancy progressbars in the terminal for you.")
  13. progress = None
  14. import numpy
  15. import yaml
  16. if sys.version_info[0] < 3:
  17. # OK, ancients, I will support Python 2, but you owe me a beer
  18. input = raw_input
  19. def parse_args():
  20. parser = argparse.ArgumentParser()
  21. parser.add_argument("-o", "--output", default="",
  22. help="Path to the output file/directory (empty for display).")
  23. parser.add_argument("-i", "--input", default="-",
  24. help="Path to the input file (- for stdin).")
  25. parser.add_argument("--text-size", default=12, type=int,
  26. help="Size of the labels and legend.")
  27. parser.add_argument("--backend", help="Matplotlib backend to use.")
  28. parser.add_argument("--style", choices=["black", "white"], default="black",
  29. help="Plot's general color scheme.")
  30. parser.add_argument("--relative", action="store_true",
  31. help="Occupy 100%% height for every measurement.")
  32. parser.add_argument("--couples-tmp-dir", help="Temporary directory to work with couples.")
  33. parser.add_argument("-m", "--mode",
  34. choices=["project", "file", "person", "matrix", "people", "couples",
  35. "all"],
  36. default="project", help="What to plot.")
  37. parser.add_argument(
  38. "--resample", default="year",
  39. help="The way to resample the time series. Possible values are: "
  40. "\"month\", \"year\", \"no\", \"raw\" and pandas offset aliases ("
  41. "http://pandas.pydata.org/pandas-docs/stable/timeseries.html"
  42. "#offset-aliases).")
  43. args = parser.parse_args()
  44. return args
  45. def read_input(args):
  46. sys.stdout.write("Reading the input... ")
  47. sys.stdout.flush()
  48. yaml.reader.Reader.NON_PRINTABLE = re.compile(r"(?!x)x")
  49. try:
  50. loader = yaml.CLoader
  51. except AttributeError:
  52. print("Warning: failed to import yaml.CLoader, falling back to slow yaml.Loader")
  53. loader = yaml.Loader
  54. try:
  55. if args.input != "-":
  56. with open(args.input) as fin:
  57. data = yaml.load(fin, Loader=loader)
  58. else:
  59. data = yaml.load(sys.stdin, Loader=loader)
  60. except (UnicodeEncodeError, yaml.reader.ReaderError) as e:
  61. print("\nInvalid unicode in the input: %s\nPlease filter it through fix_yaml_unicode.py" %
  62. e)
  63. sys.exit(1)
  64. print("done")
  65. return data["burndown"], data["project"], data.get("files"), data.get("people_sequence"), \
  66. data.get("people"), data.get("people_interaction"), data.get("files_coocc"), \
  67. data.get("people_coocc")
  68. def calculate_average_lifetime(matrix):
  69. lifetimes = numpy.zeros(matrix.shape[1] - 1)
  70. for band in matrix:
  71. start = 0
  72. for i, line in enumerate(band):
  73. if i == 0 or band[i - 1] == 0:
  74. start += 1
  75. continue
  76. lifetimes[i - start] = band[i - 1] - line
  77. lifetimes[i - start] = band[i - 1]
  78. return (lifetimes.dot(numpy.arange(1, matrix.shape[1], 1))
  79. / (lifetimes.sum() * matrix.shape[1]))
  80. def load_main(header, name, matrix, resample):
  81. import pandas
  82. start = header["begin"]
  83. last = header["end"]
  84. granularity = header["granularity"]
  85. sampling = header["sampling"]
  86. start = datetime.fromtimestamp(int(start))
  87. last = datetime.fromtimestamp(int(last))
  88. granularity = int(granularity)
  89. sampling = int(sampling)
  90. matrix = numpy.array([numpy.fromstring(line, dtype=int, sep=" ")
  91. for line in matrix.split("\n")]).T
  92. print(name, "lifetime index:", calculate_average_lifetime(matrix))
  93. finish = start + timedelta(days=matrix.shape[1] * sampling)
  94. if resample not in ("no", "raw"):
  95. # Interpolate the day x day matrix.
  96. # Each day brings equal weight in the granularity.
  97. # Sampling's interpolation is linear.
  98. daily_matrix = numpy.zeros(
  99. (matrix.shape[0] * granularity, matrix.shape[1] * sampling),
  100. dtype=numpy.float32)
  101. epsrange = numpy.arange(0, 1, 1.0 / sampling)
  102. for y in range(matrix.shape[0]):
  103. for x in range(matrix.shape[1]):
  104. previous = matrix[y, x - 1] if x > 0 else 0
  105. value = ((previous + (matrix[y, x] - previous) * epsrange)
  106. / granularity)[numpy.newaxis, :]
  107. if (y + 1) * granularity <= x * sampling:
  108. daily_matrix[y * granularity:(y + 1) * granularity,
  109. x * sampling:(x + 1) * sampling] = value
  110. elif y * granularity <= (x + 1) * sampling:
  111. for suby in range(y * granularity, (y + 1) * granularity):
  112. for subx in range(suby, (x + 1) * sampling):
  113. daily_matrix[suby, subx] = matrix[
  114. y, x] / granularity
  115. daily_matrix[(last - start).days:] = 0
  116. # Resample the bands
  117. aliases = {
  118. "year": "A",
  119. "month": "M"
  120. }
  121. resample = aliases.get(resample, resample)
  122. periods = 0
  123. date_granularity_sampling = [start]
  124. while date_granularity_sampling[-1] < finish:
  125. periods += 1
  126. date_granularity_sampling = pandas.date_range(
  127. start, periods=periods, freq=resample)
  128. date_range_sampling = pandas.date_range(
  129. date_granularity_sampling[0],
  130. periods=(finish - date_granularity_sampling[0]).days,
  131. freq="1D")
  132. # Fill the new square matrix
  133. matrix = numpy.zeros(
  134. (len(date_granularity_sampling), len(date_range_sampling)),
  135. dtype=numpy.float32)
  136. for i, gdt in enumerate(date_granularity_sampling):
  137. istart = (date_granularity_sampling[i - 1] - start).days \
  138. if i > 0 else 0
  139. ifinish = (gdt - start).days
  140. for j, sdt in enumerate(date_range_sampling):
  141. if (sdt - start).days >= istart:
  142. break
  143. matrix[i, j:] = \
  144. daily_matrix[istart:ifinish, (sdt - start).days:].sum(axis=0)
  145. # Hardcode some cases to improve labels" readability
  146. if resample in ("year", "A"):
  147. labels = [dt.year for dt in date_granularity_sampling]
  148. elif resample in ("month", "M"):
  149. labels = [dt.strftime("%Y %B") for dt in date_granularity_sampling]
  150. else:
  151. labels = [dt.date() for dt in date_granularity_sampling]
  152. else:
  153. labels = [
  154. "%s - %s" % ((start + timedelta(days=i * granularity)).date(),
  155. (
  156. start + timedelta(days=(i + 1) * granularity)).date())
  157. for i in range(matrix.shape[0])]
  158. if len(labels) > 18:
  159. warnings.warn("Too many labels - consider resampling.")
  160. resample = "M" # fake resampling type is checked while plotting
  161. date_range_sampling = pandas.date_range(
  162. start + timedelta(days=sampling), periods=matrix.shape[1],
  163. freq="%dD" % sampling)
  164. return name, matrix, date_range_sampling, labels, granularity, sampling, resample
  165. def load_matrix(contents):
  166. matrix = numpy.array([numpy.fromstring(line, dtype=int, sep=" ")
  167. for line in contents.split("\n")])
  168. return matrix
  169. def load_people(header, sequence, contents):
  170. import pandas
  171. start = header["begin"]
  172. last = header["end"]
  173. sampling = header["sampling"]
  174. start = datetime.fromtimestamp(int(start))
  175. last = datetime.fromtimestamp(int(last))
  176. sampling = int(sampling)
  177. people = []
  178. for name in sequence:
  179. people.append(numpy.array([numpy.fromstring(line, dtype=int, sep=" ")
  180. for line in contents[name].split("\n")]).sum(axis=1))
  181. people = numpy.array(people)
  182. date_range_sampling = pandas.date_range(
  183. start + timedelta(days=sampling), periods=people[0].shape[0],
  184. freq="%dD" % sampling)
  185. return sequence, people, date_range_sampling, last
  186. def apply_plot_style(figure, axes, legend, style, text_size):
  187. figure.set_size_inches(12, 9)
  188. for side in ("bottom", "top", "left", "right"):
  189. axes.spines[side].set_color(style)
  190. for axis in (axes.xaxis, axes.yaxis):
  191. axis.label.update(dict(fontsize=text_size, color=style))
  192. for axis in ("x", "y"):
  193. axes.tick_params(axis=axis, colors=style, labelsize=text_size)
  194. if legend is not None:
  195. frame = legend.get_frame()
  196. for setter in (frame.set_facecolor, frame.set_edgecolor):
  197. setter("black" if style == "white" else "white")
  198. for text in legend.get_texts():
  199. text.set_color(style)
  200. def get_plot_path(base, name):
  201. root, ext = os.path.splitext(base)
  202. if not ext:
  203. ext = ".png"
  204. output = os.path.join(root, name + ext)
  205. os.makedirs(os.path.dirname(output), exist_ok=True)
  206. return output
  207. def deploy_plot(title, output, style):
  208. import matplotlib.pyplot as pyplot
  209. if not output:
  210. pyplot.gcf().canvas.set_window_title(title)
  211. pyplot.show()
  212. else:
  213. if title:
  214. pyplot.title(title, color=style)
  215. pyplot.tight_layout()
  216. pyplot.savefig(output, transparent=True)
  217. pyplot.clf()
  218. def plot_burndown(args, target, name, matrix, date_range_sampling, labels, granularity,
  219. sampling, resample):
  220. import matplotlib
  221. if args.backend:
  222. matplotlib.use(args.backend)
  223. import matplotlib.pyplot as pyplot
  224. pyplot.stackplot(date_range_sampling, matrix, labels=labels)
  225. if args.relative:
  226. for i in range(matrix.shape[1]):
  227. matrix[:, i] /= matrix[:, i].sum()
  228. pyplot.ylim(0, 1)
  229. legend_loc = 3
  230. else:
  231. legend_loc = 2
  232. legend = pyplot.legend(loc=legend_loc, fontsize=args.text_size)
  233. pyplot.ylabel("Lines of code")
  234. pyplot.xlabel("Time")
  235. apply_plot_style(pyplot.gcf(), pyplot.gca(), legend, args.style, args.text_size)
  236. pyplot.xlim(date_range_sampling[0], date_range_sampling[-1])
  237. locator = pyplot.gca().xaxis.get_major_locator()
  238. # set the optimal xticks locator
  239. if "M" not in resample:
  240. pyplot.gca().xaxis.set_major_locator(matplotlib.dates.YearLocator())
  241. locs = pyplot.gca().get_xticks().tolist()
  242. if len(locs) >= 16:
  243. pyplot.gca().xaxis.set_major_locator(matplotlib.dates.YearLocator())
  244. locs = pyplot.gca().get_xticks().tolist()
  245. if len(locs) >= 16:
  246. pyplot.gca().xaxis.set_major_locator(locator)
  247. if locs[0] < pyplot.xlim()[0]:
  248. del locs[0]
  249. endindex = -1
  250. if len(locs) >= 2 and \
  251. pyplot.xlim()[1] - locs[-1] > (locs[-1] - locs[-2]) / 2:
  252. locs.append(pyplot.xlim()[1])
  253. endindex = len(locs) - 1
  254. startindex = -1
  255. if len(locs) >= 2 and \
  256. locs[0] - pyplot.xlim()[0] > (locs[1] - locs[0]) / 2:
  257. locs.append(pyplot.xlim()[0])
  258. startindex = len(locs) - 1
  259. pyplot.gca().set_xticks(locs)
  260. # hacking time!
  261. labels = pyplot.gca().get_xticklabels()
  262. if startindex >= 0:
  263. labels[startindex].set_text(date_range_sampling[0].date())
  264. labels[startindex].set_text = lambda _: None
  265. labels[startindex].set_rotation(30)
  266. labels[startindex].set_ha("right")
  267. if endindex >= 0:
  268. labels[endindex].set_text(date_range_sampling[-1].date())
  269. labels[endindex].set_text = lambda _: None
  270. labels[endindex].set_rotation(30)
  271. labels[endindex].set_ha("right")
  272. title = "%s %d x %d (granularity %d, sampling %d)" % \
  273. ((name,) + matrix.shape + (granularity, sampling))
  274. output = args.output
  275. if output:
  276. if args.mode == "project" and target == "project":
  277. output = args.output
  278. else:
  279. if target == "project":
  280. name = "project"
  281. output = get_plot_path(args.output, name)
  282. deploy_plot(title, output, args.style)
  283. def plot_many(args, target, header, parts):
  284. if not args.output:
  285. print("Warning: output not set, showing %d plots." % len(parts))
  286. itercnt = progress.bar(parts.items(), expected_size=len(parts)) \
  287. if progress is not None else parts.items()
  288. stdout = io.StringIO()
  289. for name, matrix in itercnt:
  290. backup = sys.stdout
  291. sys.stdout = stdout
  292. plot_burndown(args, target, *load_main(header, name, matrix, args.resample))
  293. sys.stdout = backup
  294. sys.stdout.write(stdout.getvalue())
  295. def plot_matrix(args, repo, people, matrix):
  296. matrix = matrix.astype(float)
  297. zeros = matrix[:, 0] == 0
  298. matrix[zeros, :] = 1
  299. matrix /= matrix[:, 0][:, None]
  300. matrix = -matrix[:, 1:]
  301. matrix[zeros, :] = 0
  302. import matplotlib
  303. if args.backend:
  304. matplotlib.use(args.backend)
  305. import matplotlib.pyplot as pyplot
  306. s = 4 + matrix.shape[1] * 0.3
  307. fig = pyplot.figure(figsize=(s, s))
  308. ax = fig.add_subplot(111)
  309. ax.xaxis.set_label_position("top")
  310. ax.matshow(matrix, cmap=pyplot.cm.OrRd)
  311. ax.set_xticks(numpy.arange(0, matrix.shape[1]))
  312. ax.set_yticks(numpy.arange(0, matrix.shape[0]))
  313. ax.set_xticklabels(["Unidentified"] + people, rotation=90, ha="center")
  314. ax.set_yticklabels(people, va="center")
  315. ax.set_xticks(numpy.arange(0.5, matrix.shape[1] + 0.5), minor=True)
  316. ax.set_yticks(numpy.arange(0.5, matrix.shape[0] + 0.5), minor=True)
  317. ax.grid(which="minor")
  318. apply_plot_style(fig, ax, None, args.style, args.text_size)
  319. if not args.output:
  320. pos1 = ax.get_position()
  321. pos2 = (pos1.x0 + 0.245, pos1.y0 - 0.1, pos1.width * 0.9, pos1.height * 0.9)
  322. ax.set_position(pos2)
  323. if args.mode == "all":
  324. output = get_plot_path(args.output, "matrix")
  325. else:
  326. output = args.output
  327. title = "%s %d developers overwrite" % (repo, matrix.shape[0])
  328. if args.output:
  329. # FIXME(vmarkovtsev): otherwise the title is screwed in savefig()
  330. title = ""
  331. deploy_plot(title, output, args.style)
  332. def plot_people(args, repo, names, people, date_range, last):
  333. import matplotlib
  334. if args.backend:
  335. matplotlib.use(args.backend)
  336. import matplotlib.pyplot as pyplot
  337. pyplot.stackplot(date_range, people, labels=names)
  338. pyplot.xlim(date_range[0], last)
  339. if args.relative:
  340. for i in range(people.shape[1]):
  341. people[:, i] /= people[:, i].sum()
  342. pyplot.ylim(0, 1)
  343. legend_loc = 3
  344. else:
  345. legend_loc = 2
  346. legend = pyplot.legend(loc=legend_loc, fontsize=args.text_size)
  347. apply_plot_style(pyplot.gcf(), pyplot.gca(), legend, args.style, args.text_size)
  348. if args.mode == "all":
  349. output = get_plot_path(args.output, "people")
  350. else:
  351. output = args.output
  352. deploy_plot("%s code ratio through time" % repo, output, args.style)
  353. def train_embeddings(coocc_tree, tmpdir, shard_size=4096):
  354. from scipy.sparse import csr_matrix
  355. import tensorflow as tf
  356. try:
  357. from . import swivel
  358. except SystemError:
  359. import swivel
  360. index = coocc_tree["index"]
  361. nshards = len(index) // shard_size
  362. if nshards * shard_size < len(index):
  363. nshards += 1
  364. shard_size = len(index) // nshards
  365. nshards = len(index) // shard_size
  366. remainder = len(index) - nshards * shard_size
  367. if remainder > 0:
  368. lengths = numpy.array([len(cd) for cd in coocc_tree["matrix"]])
  369. filtered = sorted(numpy.argsort(lengths)[remainder:])
  370. else:
  371. filtered = list(range(len(index)))
  372. print("Reading the sparse matrix...")
  373. data = []
  374. indices = []
  375. indptr = [0]
  376. for row, cd in enumerate(coocc_tree["matrix"]):
  377. for col, val in sorted(cd.items()):
  378. data.append(val)
  379. indices.append(col)
  380. indptr.append(indptr[-1] + len(cd))
  381. matrix = csr_matrix((data, indices, indptr), shape=(len(index), len(index)))
  382. if len(filtered) < len(index):
  383. matrix = matrix[filtered, :][:, filtered]
  384. meta_index = []
  385. for i, j in enumerate(filtered):
  386. meta_index.append((index[j], matrix[i, i]))
  387. index = [mi[0] for mi in meta_index]
  388. with tempfile.TemporaryDirectory(prefix="hercules_labours_", dir=tmpdir or None) as tmproot:
  389. print("Writing Swivel metadata...")
  390. vocabulary = "\n".join(index)
  391. with open(os.path.join(tmproot, "row_vocab.txt"), "w") as out:
  392. out.write(vocabulary)
  393. with open(os.path.join(tmproot, "col_vocab.txt"), "w") as out:
  394. out.write(vocabulary)
  395. del vocabulary
  396. bool_sums = matrix.indptr[1:] - matrix.indptr[:-1]
  397. bool_sums_str = "\n".join(map(str, bool_sums.tolist()))
  398. with open(os.path.join(tmproot, "row_sums.txt"), "w") as out:
  399. out.write(bool_sums_str)
  400. with open(os.path.join(tmproot, "col_sums.txt"), "w") as out:
  401. out.write(bool_sums_str)
  402. del bool_sums_str
  403. reorder = numpy.argsort(-bool_sums)
  404. print("Writing Swivel shards...")
  405. for row in range(nshards):
  406. for col in range(nshards):
  407. def _int64s(xs):
  408. return tf.train.Feature(
  409. int64_list=tf.train.Int64List(value=list(xs)))
  410. def _floats(xs):
  411. return tf.train.Feature(
  412. float_list=tf.train.FloatList(value=list(xs)))
  413. indices_row = reorder[row::nshards]
  414. indices_col = reorder[col::nshards]
  415. shard = matrix[indices_row][:, indices_col].tocoo()
  416. example = tf.train.Example(features=tf.train.Features(feature={
  417. "global_row": _int64s(indices_row),
  418. "global_col": _int64s(indices_col),
  419. "sparse_local_row": _int64s(shard.row),
  420. "sparse_local_col": _int64s(shard.col),
  421. "sparse_value": _floats(shard.data)}))
  422. with open(os.path.join(tmproot, "shard-%03d-%03d.pb" % (row, col)), "wb") as out:
  423. out.write(example.SerializeToString())
  424. print("Training Swivel model...")
  425. swivel.FLAGS.submatrix_rows = shard_size
  426. swivel.FLAGS.submatrix_cols = shard_size
  427. if len(meta_index) < 10000:
  428. embedding_size = 50
  429. num_epochs = 100
  430. elif len(meta_index) < 100000:
  431. embedding_size = 100
  432. num_epochs = 200
  433. elif len(meta_index) < 500000:
  434. embedding_size = 200
  435. num_epochs = 300
  436. else:
  437. embedding_size = 300
  438. num_epochs = 200
  439. swivel.FLAGS.embedding_size = embedding_size
  440. swivel.FLAGS.input_base_path = tmproot
  441. swivel.FLAGS.output_base_path = tmproot
  442. swivel.FLAGS.loss_multiplier = 1.0 / shard_size
  443. swivel.FLAGS.num_epochs = num_epochs
  444. swivel.main(None)
  445. print("Reading Swivel embeddings...")
  446. embeddings = []
  447. with open(os.path.join(tmproot, "row_embedding.tsv")) as frow:
  448. with open(os.path.join(tmproot, "col_embedding.tsv")) as fcol:
  449. for i, (lrow, lcol) in enumerate(zip(frow, fcol)):
  450. prow, pcol = (l.split("\t", 1) for l in (lrow, lcol))
  451. assert prow[0] == pcol[0]
  452. erow, ecol = \
  453. (numpy.fromstring(p[1], dtype=numpy.float32, sep="\t")
  454. for p in (prow, pcol))
  455. embeddings.append((erow + ecol) / 2)
  456. return meta_index, embeddings
  457. def write_embeddings(name, output, index, embeddings):
  458. print("Writing Tensorflow Projector files...")
  459. if not output:
  460. output = "couples_" + name
  461. metaf = "%s_%s_meta.tsv" % (output, name)
  462. with open(metaf, "w") as fout:
  463. fout.write("name\tcommits\n")
  464. for pair in index:
  465. fout.write("%s\t%s\n" % pair)
  466. print("Wrote", metaf)
  467. dataf = "%s_%s_data.tsv" % (output, name)
  468. with open(dataf, "w") as fout:
  469. for vec in embeddings:
  470. fout.write("\t".join(str(v) for v in vec))
  471. fout.write("\n")
  472. print("Wrote", dataf)
  473. def main():
  474. args = parse_args()
  475. header, main_contents, files_contents, people_sequence, people_contents, people_matrix, \
  476. files_coocc, people_coocc = read_input(args)
  477. name = next(iter(main_contents))
  478. files_warning = "Files stats were not collected. Re-run hercules with -files."
  479. people_warning = "People stats were not collected. Re-run hercules with -people."
  480. if args.mode == "project":
  481. plot_burndown(args, "project",
  482. *load_main(header, name, main_contents[name], args.resample))
  483. elif args.mode == "file":
  484. if not files_contents:
  485. print(files_warning)
  486. return
  487. plot_many(args, "file", header, files_contents)
  488. elif args.mode == "person":
  489. if not people_contents:
  490. print(people_warning)
  491. return
  492. plot_many(args, "person", header, people_contents)
  493. elif args.mode == "matrix":
  494. if not people_contents:
  495. print(people_warning)
  496. return
  497. plot_matrix(args, name, people_sequence, load_matrix(people_matrix))
  498. elif args.mode == "people":
  499. if not people_contents:
  500. print(people_warning)
  501. return
  502. plot_people(args, name, *load_people(header, people_sequence, people_contents))
  503. elif args.mode == "couples":
  504. write_embeddings("files", args.output,
  505. *train_embeddings(files_coocc, args.couples_tmp_dir))
  506. write_embeddings("people", args.output,
  507. *train_embeddings(people_coocc, args.couples_tmp_dir))
  508. elif args.mode == "all":
  509. plot_burndown(args, "project",
  510. *load_main(header, name, main_contents[name], args.resample))
  511. if files_contents:
  512. plot_many(args, "file", header, files_contents)
  513. if people_contents:
  514. plot_many(args, "person", header, people_contents)
  515. plot_matrix(args, name, people_sequence, load_matrix(people_matrix))
  516. plot_people(args, name, *load_people(header, people_sequence, people_contents))
  517. if people_coocc:
  518. assert files_coocc
  519. write_embeddings("files", args.output,
  520. *train_embeddings(files_coocc, args.couples_tmp_dir))
  521. write_embeddings("people", args.output,
  522. *train_embeddings(people_coocc, args.couples_tmp_dir))
  523. if __name__ == "__main__":
  524. sys.exit(main())