Jelajahi Sumber

Fix the survival analysis: now proper Kaplan-Meier

Signed-off-by: Vadim Markovtsev <vadim@sourced.tech>
Vadim Markovtsev 6 tahun lalu
induk
melakukan
10e10780a5
3 mengubah file dengan 51 tambahan dan 19 penghapusan
  1. 48 18
      python/labours/labours.py
  2. 1 0
      python/requirements.txt
  3. 2 1
      python/setup.py

+ 48 - 18
python/labours/labours.py

@@ -427,21 +427,48 @@ class DevDay(namedtuple("DevDay", ("Commits", "Added", "Removed", "Changed", "La
                       Languages=dict(langs))
 
 
-def calculate_average_lifetime(matrix):
-    lifetimes = numpy.zeros(matrix.shape[1] - 1)
-    for band in matrix:
-        start = 0
-        for i, line in enumerate(band):
-            if i == 0 or band[i - 1] == 0:
-                start += 1
-                continue
-            lifetimes[i - start] = band[i - 1] - line
-        lifetimes[i - start] = band[i - 1]
-    lsum = lifetimes.sum()
-    if lsum != 0:
-        total = lifetimes.dot(numpy.arange(1, matrix.shape[1], 1))
-        return total / (lsum * matrix.shape[1])
-    return numpy.nan
+def fit_kaplan_meier(matrix):
+    from lifelines import KaplanMeierFitter
+
+    T = []
+    W = []
+    indexes = numpy.arange(matrix.shape[0], dtype=int)
+    entries = numpy.zeros(matrix.shape[0], int)
+    dead = set()
+    for i in range(1, matrix.shape[1]):
+        diff = matrix[:, i - 1] - matrix[:, i]
+        entries[diff < 0] = i
+        mask = diff > 0
+        deaths = diff[mask]
+        T.append(numpy.full(len(deaths), i) - entries[indexes[mask]])
+        W.append(deaths)
+        entered = entries > 0
+        entered[0] = True
+        dead = dead.union(set(numpy.where((matrix[:, i] == 0) & entered)[0]))
+    # add the survivors as censored
+    nnzind = entries != 0
+    nnzind[0] = True
+    nnzind[sorted(dead)] = False
+    T.append(numpy.full(nnzind.sum(), matrix.shape[1]) - entries[nnzind])
+    W.append(matrix[nnzind, -1])
+    T = numpy.concatenate(T)
+    E = numpy.ones(len(T), bool)
+    E[-nnzind.sum():] = 0
+    W = numpy.concatenate(W)
+    if T.size == 0:
+        return None
+    kmf = KaplanMeierFitter().fit(T, E, weights=W)
+    return kmf
+
+
+def print_survival_function(kmf, sampling):
+    sf = kmf.survival_function_
+    sf.index = [timedelta(days=d) for d in sf.index * sampling]
+    sf.columns = ["Ratio of survived lines"]
+    try:
+        print(sf[len(sf) // 6::len(sf) // 6].append(sf.tail(1)))
+    except ValueError:
+        pass
 
 
 def interpolate_burndown_matrix(matrix, granularity, sampling):
@@ -578,7 +605,7 @@ def import_pandas():
     return pandas
 
 
-def load_burndown(header, name, matrix, resample):
+def load_burndown(header, name, matrix, resample, report_survival=True):
     pandas = import_pandas()
 
     start, last, sampling, granularity = header
@@ -586,7 +613,10 @@ def load_burndown(header, name, matrix, resample):
     assert granularity > 0
     start = datetime.fromtimestamp(start)
     last = datetime.fromtimestamp(last)
-    print(name, "lifetime index:", calculate_average_lifetime(matrix))
+    if report_survival:
+        kmf = fit_kaplan_meier(matrix)
+        if kmf is not None:
+            print_survival_function(kmf, sampling)
     finish = start + timedelta(days=matrix.shape[1] * sampling)
     if resample not in ("no", "raw"):
         print("resampling to %s, please wait..." % resample)
@@ -610,7 +640,7 @@ def load_burndown(header, name, matrix, resample):
         if date_granularity_sampling[0] > finish:
             if resample == "A":
                 print("too loose resampling - by year, trying by month")
-                return load_burndown(header, name, matrix, "month")
+                return load_burndown(header, name, matrix, "month", report_survival=False)
             else:
                 raise ValueError("Too loose resampling: %s. Try finer." % resample)
         date_range_sampling = pandas.date_range(

+ 1 - 0
python/requirements.txt

@@ -10,3 +10,4 @@ hdbscan==0.8.18
 seriate==1.0.0
 fastdtw==0.3.2
 python-dateutil==2.8.0
+lifelines==0.20.4

+ 2 - 1
python/setup.py

@@ -15,7 +15,7 @@ setup(
     description="Python companion for github.com/src-d/hercules to visualize the results.",
     long_description=long_description,
     long_description_content_type="text/markdown",
-    version="10.0.1",
+    version="10.0.2",
     license="Apache-2.0",
     author="source{d}",
     author_email="machine-learning@sourced.tech",
@@ -36,6 +36,7 @@ setup(
         "seriate>=1.0,<2.0",
         "fastdtw>=0.3.2,<2.0",
         "python-dateutil>=2.6.0,<3.0",
+        "lifelines>=0.20.0,<2.0",
     ],
     package_data={"labours": ["../LICENSE.md", "../README.md", "../requirements.txt"]},
     entry_points={