瀏覽代碼

Refactor labours.py to enable PB loading

Vadim Markovtsev 7 年之前
父節點
當前提交
4d9cb213f4
共有 1 個文件被更改,包括 184 次插入109 次删除
  1. 184 109
      labours.py

+ 184 - 109
labours.py

@@ -29,6 +29,7 @@ def parse_args():
                         help="Path to the output file/directory (empty for display).")
     parser.add_argument("-i", "--input", default="-",
                         help="Path to the input file (- for stdin).")
+    parser.add_argument("-f", "--input-format", default="yaml", choices=["yaml", "pb"])
     parser.add_argument("--text-size", default=12, type=int,
                         help="Size of the labels and legend.")
     parser.add_argument("--backend", help="Matplotlib backend to use.")
@@ -56,29 +57,122 @@ def parse_args():
     return args
 
 
+class Reader(object):
+    def read(self, file):
+        raise NotImplementedError
+
+    def get_name(self):
+        raise NotImplementedError
+
+    def get_header(self):
+        raise NotImplementedError
+
+    def get_project_burndown(self):
+        raise NotImplementedError
+
+    def get_files_burndown(self):
+        raise NotImplementedError
+
+    def get_people_burndown(self):
+        raise NotImplementedError
+
+    def get_ownership_burndown(self):
+        raise NotImplementedError
+
+    def get_people_interaction(self):
+        raise NotImplementedError
+
+    def get_files_coocc(self):
+        raise NotImplementedError
+
+    def get_people_coocc(self):
+        raise NotImplementedError
+
+
+class YamlReader(Reader):
+    def read(self, file):
+        yaml.reader.Reader.NON_PRINTABLE = re.compile(r"(?!x)x")
+        try:
+            loader = yaml.CLoader
+        except AttributeError:
+            print("Warning: failed to import yaml.CLoader, falling back to slow yaml.Loader")
+            loader = yaml.Loader
+        try:
+            if file != "-":
+                with open(file) as fin:
+                    data = yaml.load(fin, Loader=loader)
+            else:
+                data = yaml.load(sys.stdin, Loader=loader)
+        except (UnicodeEncodeError, yaml.reader.ReaderError) as e:
+            print("\nInvalid unicode in the input: %s\nPlease filter it through "
+                  "fix_yaml_unicode.py" % e)
+            sys.exit(1)
+        print("done")
+        self.data = data
+
+    def get_name(self):
+        return next(iter(self.data["project"]))
+
+    def get_header(self):
+        header = self.data["burndown"]
+        return header["begin"], header["end"], header["sampling"], header["granularity"]
+
+    def get_project_burndown(self):
+        name, matrix = next(iter(self.data["project"].items()))
+        return name, self._parse_burndown_matrix(matrix).T
+
+    def get_files_burndown(self):
+        return [(p[0], self._parse_burndown_matrix(p[1]).T) for p in self.data["files"].items()]
+
+    def get_people_burndown(self):
+        return [(p[0], self._parse_burndown_matrix(p[1]).T) for p in self.data["people"].items()]
+
+    def get_ownership_burndown(self):
+        return self.data["people_sequence"], {p[0]: self._parse_burndown_matrix(p[1])
+                                              for p in self.data["people"].items()}
+
+    def get_people_interaction(self):
+        return self.data["people_sequence"], self._parse_burndown_matrix(self.data["people_interaction"])
+
+    def get_files_coocc(self):
+        coocc = self.data["files_coocc"]
+        return coocc["index"], self._parse_coocc_matrix(coocc["matrix"])
+
+    def get_people_coocc(self):
+        coocc = self.data["people_coocc"]
+        return coocc["index"], self._parse_coocc_matrix(coocc["matrix"])
+
+    def _parse_burndown_matrix(self, matrix):
+        return numpy.array([numpy.fromstring(line, dtype=int, sep=" ")
+                            for line in matrix.split("\n")])
+
+    def _parse_coocc_matrix(self, matrix):
+        from scipy.sparse import csr_matrix
+        data = []
+        indices = []
+        indptr = [0]
+        for row in matrix:
+            for k, v in sorted(row.items()):
+                data.append(v)
+                indices.append(k)
+            indptr.append(indptr[-1] + len(row))
+        return csr_matrix((data, indices, indptr), shape=(len(matrix),) * 2)
+
+
+class ProtobufReader(Reader):
+    def read(self, file):
+        pass
+
+
+READERS = {"yaml": YamlReader, "pb": ProtobufReader}
+
+
 def read_input(args):
     sys.stdout.write("Reading the input... ")
     sys.stdout.flush()
-    yaml.reader.Reader.NON_PRINTABLE = re.compile(r"(?!x)x")
-    try:
-        loader = yaml.CLoader
-    except AttributeError:
-        print("Warning: failed to import yaml.CLoader, falling back to slow yaml.Loader")
-        loader = yaml.Loader
-    try:
-        if args.input != "-":
-            with open(args.input) as fin:
-                data = yaml.load(fin, Loader=loader)
-        else:
-            data = yaml.load(sys.stdin, Loader=loader)
-    except (UnicodeEncodeError, yaml.reader.ReaderError) as e:
-        print("\nInvalid unicode in the input: %s\nPlease filter it through fix_yaml_unicode.py" %
-              e)
-        sys.exit(1)
-    print("done")
-    return data["burndown"], data["project"], data.get("files"), data.get("people_sequence"), \
-           data.get("people"), data.get("people_interaction"), data.get("files_coocc"), \
-           data.get("people_coocc")
+    reader = READERS[args.input_format]()
+    reader.read(args.input)
+    return reader
 
 
 def calculate_average_lifetime(matrix):
@@ -95,19 +189,12 @@ def calculate_average_lifetime(matrix):
             / (lifetimes.sum() * matrix.shape[1]))
 
 
-def load_main(header, name, matrix, resample):
+def load_burndown(header, name, matrix, resample):
     import pandas
 
-    start = header["begin"]
-    last = header["end"]
-    granularity = header["granularity"]
-    sampling = header["sampling"]
-    start = datetime.fromtimestamp(int(start))
-    last = datetime.fromtimestamp(int(last))
-    granularity = int(granularity)
-    sampling = int(sampling)
-    matrix = numpy.array([numpy.fromstring(line, dtype=int, sep=" ")
-                          for line in matrix.split("\n")]).T
+    start, last, sampling, granularity = header
+    start = datetime.fromtimestamp(start)
+    last = datetime.fromtimestamp(last)
     print(name, "lifetime index:", calculate_average_lifetime(matrix))
     finish = start + timedelta(days=matrix.shape[1] * sampling)
     if resample not in ("no", "raw"):
@@ -184,25 +271,15 @@ def load_main(header, name, matrix, resample):
     return name, matrix, date_range_sampling, labels, granularity, sampling, resample
 
 
-def load_churn_matrix(contents):
-    matrix = numpy.array([numpy.fromstring(line, dtype=int, sep=" ")
-                          for line in contents.split("\n")])
-    return matrix
-
-
 def load_people(header, sequence, contents):
     import pandas
 
-    start = header["begin"]
-    last = header["end"]
-    sampling = header["sampling"]
-    start = datetime.fromtimestamp(int(start))
-    last = datetime.fromtimestamp(int(last))
-    sampling = int(sampling)
+    start, last, sampling, _ = header
+    start = datetime.fromtimestamp(start)
+    last = datetime.fromtimestamp(last)
     people = []
     for name in sequence:
-        people.append(numpy.array([numpy.fromstring(line, dtype=int, sep=" ")
-                                   for line in contents[name].split("\n")]).sum(axis=1))
+        people.append(contents[name].sum(axis=1))
     people = numpy.array(people)
     date_range_sampling = pandas.date_range(
         start + timedelta(days=sampling), periods=people[0].shape[0],
@@ -327,13 +404,13 @@ def plot_burndown(args, target, name, matrix, date_range_sampling, labels, granu
 def plot_many(args, target, header, parts):
     if not args.output:
         print("Warning: output not set, showing %d plots." % len(parts))
-    itercnt = progress.bar(parts.items(), expected_size=len(parts)) \
-        if progress is not None else parts.items()
+    itercnt = progress.bar(parts, expected_size=len(parts)) \
+        if progress is not None else parts
     stdout = io.StringIO()
     for name, matrix in itercnt:
         backup = sys.stdout
         sys.stdout = stdout
-        plot_burndown(args, target, *load_main(header, name, matrix, args.resample))
+        plot_burndown(args, target, *load_burndown(header, name, matrix, args.resample))
         sys.stdout = backup
     sys.stdout.write(stdout.getvalue())
 
@@ -420,15 +497,15 @@ def plot_people(args, repo, names, people, date_range, last):
     deploy_plot("%s code ownership through time" % repo, output, args.style)
 
 
-def train_embeddings(coocc_tree, tmpdir, shard_size=4096):
-    from scipy.sparse import csr_matrix
+def train_embeddings(index, matrix, tmpdir, shard_size=4096):
     try:
         from . import swivel
     except (SystemError, ImportError):
         import swivel
     import tensorflow as tf
 
-    index = coocc_tree["index"]
+    assert matrix.shape[0] == matrix.shape[1]
+    assert len(index) <= matrix.shape[0]
     nshards = len(index) // shard_size
     if nshards * shard_size < len(index):
         nshards += 1
@@ -436,23 +513,12 @@ def train_embeddings(coocc_tree, tmpdir, shard_size=4096):
         nshards = len(index) // shard_size
     remainder = len(index) - nshards * shard_size
     if remainder > 0:
-        lengths = numpy.array([len(cd) for cd in coocc_tree["matrix"]])
+        lengths = matrix.indptr[1:] - matrix.indptr[:-1]
         filtered = sorted(numpy.argsort(lengths)[remainder:])
     else:
         filtered = list(range(len(index)))
-    print("Reading the sparse matrix...")
-    data = []
-    indices = []
-    indptr = [0]
-    for row, cd in enumerate(coocc_tree["matrix"]):
-        if row >= len(index):
-            break
-        for col, val in sorted(cd.items()):
-            data.append(val)
-            indices.append(col)
-        indptr.append(indptr[-1] + len(cd))
-    matrix = csr_matrix((data, indices, indptr), shape=(len(index), len(index)))
-    if len(filtered) < len(index):
+    if len(filtered) < matrix.shape[0]:
+        print("Truncating the sparse matrix...")
         matrix = matrix[filtered, :][:, filtered]
     meta_index = []
     for i, j in enumerate(filtered):
@@ -617,62 +683,71 @@ def write_embeddings(name, output, run_server, index, embeddings):
 
 def main():
     args = parse_args()
-    header, main_contents, files_contents, people_sequence, people_contents, people_matrix, \
-        files_coocc, people_coocc = read_input(args)
-    name = next(iter(main_contents))
+    reader = read_input(args)
+    header = reader.get_header()
+    name = reader.get_name()
 
     files_warning = "Files stats were not collected. Re-run hercules with -files."
     people_warning = "People stats were not collected. Re-run hercules with -people."
     couples_warning = "Coupling stats were not collected. Re-run hercules with -couples."
 
-    if args.mode == "project":
+    def project_burndown():
         plot_burndown(args, "project",
-                      *load_main(header, name, main_contents[name], args.resample))
-    elif args.mode == "file":
-        if not files_contents:
+                      *load_burndown(header, *reader.get_project_burndown(), args.resample))
+
+    def files_burndown():
+        try:
+            plot_many(args, "file", header, reader.get_files_burndown())
+        except KeyError:
             print(files_warning)
-            return
-        plot_many(args, "file", header, files_contents)
-    elif args.mode == "person":
-        if not people_contents:
+
+    def people_burndown():
+        try:
+            plot_many(args, "person", header, reader.get_people_burndown())
+        except KeyError:
             print(people_warning)
-            return
-        plot_many(args, "person", header, people_contents)
-    elif args.mode == "churn_matrix":
-        if not people_contents:
+
+    def churn_matrix():
+        try:
+            plot_churn_matrix(args, name, *reader.get_people_interaction())
+        except KeyError:
             print(people_warning)
-            return
-        plot_churn_matrix(args, name, people_sequence, load_churn_matrix(people_matrix))
-    elif args.mode == "people":
-        if not people_contents:
+
+    def ownership_burndown():
+        try:
+            plot_people(args, name, *load_people(header, *reader.get_ownership_burndown()))
+        except KeyError:
             print(people_warning)
-            return
-        plot_people(args, name, *load_people(header, people_sequence, people_contents))
-    elif args.mode == "couples":
-        if not files_coocc or not people_coocc:
-            print(couples_warning)
-            return
-        write_embeddings("files", args.output, not args.disable_projector,
-                         *train_embeddings(files_coocc, args.couples_tmp_dir))
-        write_embeddings("people", args.output, not args.disable_projector,
-                         *train_embeddings(people_coocc, args.couples_tmp_dir))
-    elif args.mode == "all":
-        plot_burndown(args, "project",
-                      *load_main(header, name, main_contents[name], args.resample))
-        if files_contents:
-            plot_many(args, "file", header, files_contents)
-        if people_contents:
-            plot_many(args, "person", header, people_contents)
-            plot_churn_matrix(args, name, people_sequence, load_churn_matrix(people_matrix))
-            plot_people(args, name, *load_people(header, people_sequence, people_contents))
-        if people_coocc:
-            if not files_coocc or not people_coocc:
-                print(couples_warning)
-                return
+
+    def couples():
+        try:
             write_embeddings("files", args.output, not args.disable_projector,
-                             *train_embeddings(files_coocc, args.couples_tmp_dir))
+                             *train_embeddings(*reader.get_files_coocc(), args.couples_tmp_dir))
             write_embeddings("people", args.output, not args.disable_projector,
-                             *train_embeddings(people_coocc, args.couples_tmp_dir))
+                             *train_embeddings(*reader.get_people_coocc(), args.couples_tmp_dir))
+        except KeyError:
+            print(couples_warning)
+
+    if args.mode == "project":
+        project_burndown()
+    elif args.mode == "file":
+        files_burndown()
+    elif args.mode == "person":
+        people_burndown()
+    elif args.mode == "churn_matrix":
+        churn_matrix()
+    elif args.mode == "people":
+        ownership_burndown()
+    elif args.mode == "couples":
+        couples()
+    elif args.mode == "all":
+        project_burndown()
+        files_burndown()
+        people_burndown()
+        churn_matrix()
+        ownership_burndown()
+        couples()
+
     if web_server.running:
         print("Sleeping for 60 seconds, safe to Ctrl-C")
         try: