Browse Source

Add type hints using monkeytype

Signed-off-by: Aarni Koskela <akx@iki.fi>
Aarni Koskela 5 năm trước cách đây
mục cha
commit
eee2342c0d

+ 22 - 12
python/labours/burndown.py

@@ -1,4 +1,5 @@
 from datetime import datetime, timedelta
+from typing import TYPE_CHECKING, List, Tuple
 import warnings
 
 import numpy
@@ -6,8 +7,12 @@ import tqdm
 
 from labours.utils import floor_datetime
 
+if TYPE_CHECKING:
+    from lifelines import KaplanMeierFitter
+    from pandas.core.indexes.datetimes import DatetimeIndex
+
 
-def fit_kaplan_meier(matrix):
+def fit_kaplan_meier(matrix: numpy.ndarray) -> 'KaplanMeierFitter':
     from lifelines import KaplanMeierFitter
 
     T = []
@@ -41,7 +46,7 @@ def fit_kaplan_meier(matrix):
     return kmf
 
 
-def print_survival_function(kmf, sampling):
+def print_survival_function(kmf: 'KaplanMeierFitter', sampling: int) -> None:
     sf = kmf.survival_function_
     sf.index = [timedelta(days=d) for d in sf.index * sampling]
     sf.columns = ["Ratio of survived lines"]
@@ -51,7 +56,12 @@ def print_survival_function(kmf, sampling):
         pass
 
 
-def interpolate_burndown_matrix(matrix, granularity, sampling, progress=False):
+def interpolate_burndown_matrix(
+    matrix: numpy.ndarray,
+    granularity: int,
+    sampling: int,
+    progress: bool = False
+) -> numpy.ndarray:
     daily = numpy.zeros(
         (matrix.shape[0] * granularity, matrix.shape[1] * sampling),
         dtype=numpy.float32)
@@ -186,13 +196,13 @@ def import_pandas():
 
 
 def load_burndown(
-    header,
-    name,
-    matrix,
-    resample,
-    report_survival=True,
-    interpolation_progress=False
-):
+    header: Tuple[int, int, int, int, float],
+    name: str,
+    matrix: numpy.ndarray,
+    resample: str,
+    report_survival: bool = True,
+    interpolation_progress: bool = False
+) -> Tuple[str, numpy.ndarray, 'DatetimeIndex', List[int], int, int, str]:
     pandas = import_pandas()
 
     start, last, sampling, granularity, tick = header
@@ -263,8 +273,8 @@ def load_burndown(
     else:
         labels = [
             "%s - %s" % ((start + timedelta(seconds=i * granularity * tick)).date(),
-                         (
-                         start + timedelta(seconds=(i + 1) * granularity * tick)).date())
+            (
+                start + timedelta(seconds=(i + 1) * granularity * tick)).date())
             for i in range(matrix.shape[0])]
         if len(labels) > 18:
             warnings.warn("Too many labels - consider resampling.")

+ 5 - 3
python/labours/cli.py

@@ -1,8 +1,10 @@
 import argparse
+from argparse import Namespace
 import os
 import subprocess
 import sys
 import time
+from typing import List
 
 import numpy
 
@@ -21,7 +23,7 @@ from labours.modes.shotness import show_shotness_stats
 from labours.readers import read_input
 
 
-def list_matplotlib_styles():
+def list_matplotlib_styles() -> List[str]:
     script = "import sys; from matplotlib import pyplot; " \
              "sys.stdout.write(repr(pyplot.style.available))"
     styles = eval(subprocess.check_output([sys.executable, "-c", script]))
@@ -29,7 +31,7 @@ def list_matplotlib_styles():
     return ["default", "classic"] + styles
 
 
-def parse_args():
+def parse_args() -> Namespace:
     parser = argparse.ArgumentParser()
     parser.add_argument("-o", "--output", default="",
                         help="Path to the output file/directory (empty for display). "
@@ -82,7 +84,7 @@ def parse_args():
     return args
 
 
-def main():
+def main() -> None:
     args = parse_args()
     reader = read_input(args)
     header = reader.get_header()

+ 4 - 4
python/labours/cors_web_server.py

@@ -2,7 +2,7 @@ import threading
 
 
 class CORSWebServer(object):
-    def __init__(self):
+    def __init__(self) -> None:
         self.thread = threading.Thread(target=self.serve)
         self.server = None
 
@@ -23,16 +23,16 @@ class CORSWebServer(object):
 
         test(CORSRequestHandler, ClojureServer)
 
-    def start(self):
+    def start(self) -> None:
         self.thread.start()
 
-    def stop(self):
+    def stop(self) -> None:
         if self.running:
             self.server.shutdown()
             self.thread.join()
 
     @property
-    def running(self):
+    def running(self) -> bool:
         return self.server is not None
 
 

+ 15 - 2
python/labours/embeddings.py

@@ -2,15 +2,22 @@ import os
 import shutil
 import sys
 import tempfile
+from typing import List, Tuple
 
 import numpy
+from scipy.sparse.csr import csr_matrix
 
 from labours.cors_web_server import web_server
 
 IDEAL_SHARD_SIZE = 4096
 
 
-def train_embeddings(index, matrix, tmpdir, shard_size=IDEAL_SHARD_SIZE):
+def train_embeddings(
+    index: List[str],
+    matrix: csr_matrix,
+    tmpdir: None,
+    shard_size: int = IDEAL_SHARD_SIZE
+) -> Tuple[List[Tuple[str, numpy.int64]], List[numpy.ndarray]]:
     import tensorflow as tf
     from labours._vendor import swivel
 
@@ -131,7 +138,13 @@ def train_embeddings(index, matrix, tmpdir, shard_size=IDEAL_SHARD_SIZE):
     return meta_index, embeddings
 
 
-def write_embeddings(name, output, run_server, index, embeddings):
+def write_embeddings(
+    name: str,
+    output: str,
+    run_server: bool,
+    index: List[Tuple[str, numpy.int64]],
+    embeddings: List[numpy.ndarray]
+) -> None:
     print("Writing Tensorflow Projector files...")
     if not output:
         output = "couples"

+ 18 - 3
python/labours/modes/burndown.py

@@ -1,17 +1,32 @@
+from argparse import Namespace
 import contextlib
 import io
 import json
 import sys
+from typing import List, TYPE_CHECKING
 
+import numpy
 import tqdm
 
 from labours.burndown import load_burndown
 from labours.plotting import apply_plot_style, deploy_plot, get_plot_path, import_pyplot
 from labours.utils import default_json, parse_date
 
+if TYPE_CHECKING:
+    from pandas.core.indexes.datetimes import DatetimeIndex
 
-def plot_burndown(args, target, name, matrix, date_range_sampling, labels, granularity,
-                  sampling, resample):
+
+def plot_burndown(
+    args: Namespace,
+    target: str,
+    name: str,
+    matrix: numpy.ndarray,
+    date_range_sampling: 'DatetimeIndex',
+    labels: List[int],
+    granularity: int,
+    sampling: int,
+    resample: str
+) -> None:
     if args.output and args.output.endswith(".json"):
         data = locals().copy()
         del data["args"]
@@ -89,7 +104,7 @@ def plot_burndown(args, target, name, matrix, date_range_sampling, labels, granu
     deploy_plot(title, output, args.background)
 
 
-def plot_many_burndown(args, target, header, parts):
+def plot_many_burndown(args: Namespace, target: str, header, parts):
     if not args.output:
         print("Warning: output not set, showing %d plots." % len(parts))
     stdout = io.StringIO()

+ 26 - 4
python/labours/modes/devs.py

@@ -1,6 +1,8 @@
+from argparse import Namespace
 from collections import defaultdict
 from datetime import datetime, timedelta
 import sys
+from typing import Dict, List, Set, Tuple
 
 import numpy
 import tqdm
@@ -10,7 +12,15 @@ from labours.plotting import apply_plot_style, deploy_plot, get_plot_path, impor
 from labours.utils import _format_number
 
 
-def show_devs(args, name, start_date, end_date, people, days, max_people=50):
+def show_devs(
+    args: Namespace,
+    name: str,
+    start_date: int,
+    end_date: int,
+    people: List[str],
+    days: Dict[int, Dict[int, DevDay]],
+    max_people: int = 50
+) -> None:
     from scipy.signal import convolve, slepian
 
     if len(people) > max_people:
@@ -111,7 +121,11 @@ def show_devs(args, name, start_date, end_date, people, days, max_people=50):
     deploy_plot(title, output, args.background)
 
 
-def order_commits(chosen_people, days, people):
+def order_commits(
+    chosen_people: Set[str],
+    days: Dict[int, Dict[int, DevDay]],
+    people: List[str]
+) -> Tuple[numpy.ndarray, defaultdict, defaultdict, List[int]]:
     from seriate import seriate
     try:
         from fastdtw import fastdtw
@@ -162,7 +176,7 @@ def order_commits(chosen_people, days, people):
     return dists, devseries, devstats, route
 
 
-def hdbscan_cluster_routed_series(dists, route):
+def hdbscan_cluster_routed_series(dists: numpy.ndarray, route: List[int]) -> numpy.ndarray:
     try:
         from hdbscan import HDBSCAN
     except ImportError as e:
@@ -175,7 +189,15 @@ def hdbscan_cluster_routed_series(dists, route):
     return clusters
 
 
-def show_devs_efforts(args, name, start_date, end_date, people, days, max_people):
+def show_devs_efforts(
+    args: Namespace,
+    name: str,
+    start_date: int,
+    end_date: int,
+    people: List[str],
+    days: Dict[int, Dict[int, DevDay]],
+    max_people: int
+) -> None:
     from scipy.signal import convolve, slepian
 
     start_date = datetime.fromtimestamp(start_date)

+ 9 - 2
python/labours/modes/devs_parallel.py

@@ -1,14 +1,21 @@
 from collections import defaultdict
 import sys
+from typing import Any, Dict, List, Tuple
 
 import numpy
+from scipy.sparse.csr import csr_matrix
 
 from labours.modes.devs import hdbscan_cluster_routed_series, order_commits
-from labours.objects import ParallelDevData
 from labours.plotting import deploy_plot, import_pyplot
+from labours.objects import DevDay, ParallelDevData
 
 
-def load_devs_parallel(ownership, couples, devs, max_people):
+def load_devs_parallel(
+    ownership: Tuple[List[Any], Dict[Any, Any]],
+    couples: Tuple[List[str], csr_matrix],
+    devs: Tuple[List[str], Dict[int, Dict[int, DevDay]]],
+    max_people: int
+):
     from seriate import seriate
     try:
         from hdbscan import HDBSCAN

+ 12 - 1
python/labours/modes/languages.py

@@ -1,9 +1,20 @@
+from argparse import Namespace
 from collections import defaultdict
+from typing import Dict, List
 
 import numpy
 
+from labours.objects import DevDay
 
-def show_languages(args, name, start_date, end_date, people, days):
+
+def show_languages(
+    args: Namespace,
+    name: str,
+    start_date: int,
+    end_date: int,
+    people: List[str],
+    days: Dict[int, Dict[int, DevDay]],
+) -> None:
     devlangs = defaultdict(lambda: defaultdict(lambda: numpy.zeros(3, dtype=int)))
     for day, devs in days.items():
         for dev, stats in devs.items():

+ 10 - 1
python/labours/modes/old_vs_new.py

@@ -1,5 +1,7 @@
+from argparse import Namespace
 from datetime import datetime, timedelta
 from itertools import chain
+from typing import Dict, List
 
 import numpy
 
@@ -7,7 +9,14 @@ from labours.objects import DevDay
 from labours.plotting import deploy_plot, get_plot_path, import_pyplot
 
 
-def show_old_vs_new(args, name, start_date, end_date, people, days):
+def show_old_vs_new(
+    args: Namespace,
+    name: str,
+    start_date: int,
+    end_date: int,
+    people: List[str],
+    days: Dict[int, Dict[int, DevDay]]
+) -> None:
     from scipy.signal import convolve, slepian
 
     start_date = datetime.fromtimestamp(start_date)

+ 2 - 1
python/labours/modes/ownership.py

@@ -1,5 +1,6 @@
 from datetime import datetime, timedelta
 import json
+from typing import Any, Dict, List, Tuple
 
 import numpy
 
@@ -8,7 +9,7 @@ from labours.plotting import apply_plot_style, deploy_plot, get_plot_path, impor
 from labours.utils import default_json, floor_datetime, parse_date
 
 
-def load_ownership(header, sequence, contents, max_people, order_by_time):
+def load_ownership(header: Tuple[int, int, int, int, float], sequence: List[Any], contents: Dict[Any, Any], max_people: int, order_by_time: bool):
     pandas = import_pandas()
 
     start, last, sampling, _, tick = header

+ 1 - 1
python/labours/objects.py

@@ -2,7 +2,7 @@ from collections import defaultdict, namedtuple
 
 
 class DevDay(namedtuple("DevDay", ("Commits", "Added", "Removed", "Changed", "Languages"))):
-    def add(self, dd):
+    def add(self, dd: 'DevDay') -> 'DevDay':
         langs = defaultdict(lambda: [0] * 3)
         for key, val in self.Languages.items():
             for i in range(3):

+ 2 - 2
python/labours/plotting.py

@@ -39,7 +39,7 @@ def apply_plot_style(figure, axes, legend, background, font_size, axes_size):
             text.set_color(foreground)
 
 
-def get_plot_path(base, name):
+def get_plot_path(base: str, name: str) -> str:
     root, ext = os.path.splitext(base)
     if not ext:
         ext = ".png"
@@ -48,7 +48,7 @@ def get_plot_path(base, name):
     return output
 
 
-def deploy_plot(title, output, background, tight=True):
+def deploy_plot(title: str, output: str, background: str, tight: bool = True) -> None:
     import matplotlib.pyplot as pyplot
 
     if not output:

+ 17 - 12
python/labours/readers.py

@@ -1,12 +1,17 @@
+from argparse import Namespace
 from importlib import import_module
 import re
 import sys
+from typing import TYPE_CHECKING, Any, Dict, List, Tuple
 
 import numpy
 import yaml
 
 from labours.objects import DevDay
 
+if TYPE_CHECKING:
+    from scipy.sparse.csr import csr_matrix
+
 
 class Reader(object):
     def read(self, file):
@@ -56,7 +61,7 @@ class Reader(object):
 
 
 class YamlReader(Reader):
-    def read(self, file):
+    def read(self, file: str):
         yaml.reader.Reader.NON_PRINTABLE = re.compile(r"(?!x)x")
         try:
             loader = yaml.CLoader
@@ -182,7 +187,7 @@ class YamlReader(Reader):
 
 
 class ProtobufReader(Reader):
-    def read(self, file):
+    def read(self, file: str) -> None:
         try:
             from labours.pb_pb2 import AnalysisResults
         except ImportError as e:
@@ -212,27 +217,27 @@ class ProtobufReader(Reader):
     def get_run_times(self):
         return {key: val for key, val in self.data.header.run_time_per_item.items()}
 
-    def get_name(self):
+    def get_name(self) -> str:
         return self.data.header.repository
 
-    def get_header(self):
+    def get_header(self) -> Tuple[int, int]:
         header = self.data.header
         return header.begin_unix_time, header.end_unix_time
 
-    def get_burndown_parameters(self):
+    def get_burndown_parameters(self) -> Tuple[int, int, float]:
         burndown = self.contents["Burndown"]
         return burndown.sampling, burndown.granularity, burndown.tick_size / 1000000000
 
-    def get_project_burndown(self):
+    def get_project_burndown(self) -> Tuple[str, numpy.ndarray]:
         return self._parse_burndown_matrix(self.contents["Burndown"].project)
 
     def get_files_burndown(self):
         return [self._parse_burndown_matrix(i) for i in self.contents["Burndown"].files]
 
-    def get_people_burndown(self):
+    def get_people_burndown(self) -> List[Any]:
         return [self._parse_burndown_matrix(i) for i in self.contents["Burndown"].people]
 
-    def get_ownership_burndown(self):
+    def get_ownership_burndown(self) -> Tuple[List[Any], Dict[Any, Any]]:
         people = self.get_people_burndown()
         return [p[0] for p in people], {p[0]: p[1].T for p in people}
 
@@ -241,11 +246,11 @@ class ProtobufReader(Reader):
         return [i.name for i in burndown.people], \
             self._parse_sparse_matrix(burndown.people_interaction).toarray()
 
-    def get_files_coocc(self):
+    def get_files_coocc(self) -> Tuple[List[str], 'csr_matrix']:
         node = self.contents["Couples"].file_couples
         return list(node.index), self._parse_sparse_matrix(node.matrix)
 
-    def get_people_coocc(self):
+    def get_people_coocc(self) -> Tuple[List[str], 'csr_matrix']:
         node = self.contents["Couples"].people_couples
         return list(node.index), self._parse_sparse_matrix(node.matrix)
 
@@ -279,7 +284,7 @@ class ProtobufReader(Reader):
             raise KeyError
         return byday
 
-    def get_devs(self):
+    def get_devs(self) -> Tuple[List[str], Dict[int, Dict[int, DevDay]]]:
         people = list(self.contents["Devs"].dev_index)
         days = {d: {dev: DevDay(stats.commits, stats.stats.added, stats.stats.removed,
                                 stats.stats.changed, {k: [v.added, v.removed, v.changed]
@@ -310,7 +315,7 @@ PB_MESSAGES = {
 }
 
 
-def read_input(args):
+def read_input(args: Namespace) -> ProtobufReader:
     sys.stdout.write("Reading the input... ")
     sys.stdout.flush()
     if args.input != "-":

+ 8 - 3
python/labours/utils.py

@@ -1,9 +1,14 @@
 from datetime import datetime
+from numbers import Number
+from typing import TYPE_CHECKING
 
 import numpy
 
+if TYPE_CHECKING:
+    from pandas import Timestamp
 
-def floor_datetime(dt, duration):
+
+def floor_datetime(dt: datetime, duration: float) -> datetime:
     return datetime.fromtimestamp(dt.timestamp() - dt.timestamp() % duration)
 
 
@@ -15,14 +20,14 @@ def default_json(x):
     return x
 
 
-def parse_date(text, default):
+def parse_date(text: None, default: 'Timestamp') -> 'Timestamp':
     if not text:
         return default
     from dateutil.parser import parse
     return parse(text)
 
 
-def _format_number(n):
+def _format_number(n: Number) -> str:
     if n == 0:
         return "0"
     power = int(numpy.log10(abs(n)))

+ 3 - 1
python/setup.cfg

@@ -1,6 +1,8 @@
 [flake8]
 exclude = labours/pb_pb2.py
-ignore = D,B007
+ignore = D,B007,
+         # Spurious "unused import" / "redefinition" errors:
+         F401,F811
 import-order-style = appnexus
 inline-quotes = "
 max-line-length = 99