Browse Source

Refactor the labours code

Vadim Markovtsev 7 years ago
parent
commit
ab25b271af
1 changed files with 26 additions and 11 deletions
  1. 26 11
      labours.py

+ 26 - 11
labours.py

@@ -34,13 +34,7 @@ def parse_args():
     return args
     return args
 
 
 
 
-def main():
-    args = parse_args()
-
-    import matplotlib
-    if args.backend:
-        matplotlib.use(args.backend)
-    import matplotlib.pyplot as pyplot
+def load_matrix(args):
     import pandas
     import pandas
 
 
     if args.input != "-":
     if args.input != "-":
@@ -73,12 +67,14 @@ def main():
                          / granularity)[numpy.newaxis, :]
                          / granularity)[numpy.newaxis, :]
                 if (y + 1) * granularity <= x * sampling:
                 if (y + 1) * granularity <= x * sampling:
                     daily_matrix[y * granularity:(y + 1) * granularity,
                     daily_matrix[y * granularity:(y + 1) * granularity,
-                                 x * sampling:(x + 1) * sampling] = value
+                    x * sampling:(x + 1) * sampling] = value
                 elif y * granularity <= (x + 1) * sampling:
                 elif y * granularity <= (x + 1) * sampling:
                     for suby in range(y * granularity, (y + 1) * granularity):
                     for suby in range(y * granularity, (y + 1) * granularity):
                         for subx in range(suby, (x + 1) * sampling):
                         for subx in range(suby, (x + 1) * sampling):
-                            daily_matrix[suby, subx] = matrix[y, x] / granularity
+                            daily_matrix[suby, subx] = matrix[
+                                                           y, x] / granularity
         daily_matrix[(last - start).days:] = 0
         daily_matrix[(last - start).days:] = 0
+        # Resample the time interval
         aliases = {
         aliases = {
             "year": "A",
             "year": "A",
             "month": "M"
             "month": "M"
@@ -90,6 +86,7 @@ def main():
             periods += 1
             periods += 1
             date_range_sampling = pandas.date_range(
             date_range_sampling = pandas.date_range(
                 start, periods=periods, freq=args.resample)
                 start, periods=periods, freq=args.resample)
+        # Fill the new square matrix
         matrix = numpy.zeros((len(date_range_sampling),) * 2,
         matrix = numpy.zeros((len(date_range_sampling),) * 2,
                              dtype=numpy.float32)
                              dtype=numpy.float32)
         for i, gdt in enumerate(date_range_sampling):
         for i, gdt in enumerate(date_range_sampling):
@@ -99,6 +96,7 @@ def main():
                 jfinish = min((date_range_sampling[i + j] - start).days,
                 jfinish = min((date_range_sampling[i + j] - start).days,
                               daily_matrix.shape[1] - 1)
                               daily_matrix.shape[1] - 1)
                 matrix[i, i + j] = daily_matrix[istart:ifinish, jfinish].sum()
                 matrix[i, i + j] = daily_matrix[istart:ifinish, jfinish].sum()
+        # Hardcode some cases to improve labels' readability
         if args.resample in ("year", "A"):
         if args.resample in ("year", "A"):
             labels = [dt.year for dt in date_range_sampling]
             labels = [dt.year for dt in date_range_sampling]
         elif args.resample in ("month", "M"):
         elif args.resample in ("month", "M"):
@@ -108,14 +106,25 @@ def main():
     else:
     else:
         labels = [
         labels = [
             "%s - %s" % ((start + timedelta(days=i * granularity)).date(),
             "%s - %s" % ((start + timedelta(days=i * granularity)).date(),
-                         (start + timedelta(days=(i + 1) * granularity)).date())
+                         (
+                         start + timedelta(days=(i + 1) * granularity)).date())
             for i in range(matrix.shape[0])]
             for i in range(matrix.shape[0])]
         if len(labels) > 18:
         if len(labels) > 18:
             warnings.warn("Too many labels - consider resampling.")
             warnings.warn("Too many labels - consider resampling.")
-        args.resample = "M"
+        args.resample = "M"  # fake resampling type is checked while plotting
         date_range_sampling = pandas.date_range(
         date_range_sampling = pandas.date_range(
             start + timedelta(days=sampling), periods=matrix.shape[1],
             start + timedelta(days=sampling), periods=matrix.shape[1],
             freq="%dD" % sampling)
             freq="%dD" % sampling)
+    return matrix, date_range_sampling, labels, granularity, sampling
+
+
+def plot_matrix(args, matrix, date_range_sampling, labels, granularity,
+                sampling):
+    import matplotlib
+    if args.backend:
+        matplotlib.use(args.backend)
+    import matplotlib.pyplot as pyplot
+
     if args.style == "white":
     if args.style == "white":
         pyplot.gca().spines["bottom"].set_color("white")
         pyplot.gca().spines["bottom"].set_color("white")
         pyplot.gca().spines["top"].set_color("white")
         pyplot.gca().spines["top"].set_color("white")
@@ -190,5 +199,11 @@ def main():
         pyplot.tight_layout()
         pyplot.tight_layout()
         pyplot.savefig(args.output, transparent=True)
         pyplot.savefig(args.output, transparent=True)
 
 
+
+def main():
+    args = parse_args()
+    plot_matrix(args, *load_matrix(args))
+
+
 if __name__ == "__main__":
 if __name__ == "__main__":
     sys.exit(main())
     sys.exit(main())