瀏覽代碼

Improve xaxis ticks for yearly sampling

Vadim Markovtsev 8 年之前
父節點
當前提交
d41e7d1e93
共有 1 個文件被更改,包括 43 次插入13 次删除
  1. 43 13
      labours.py

+ 43 - 13
labours.py

@@ -55,10 +55,12 @@ def main():
             "year": "A",
             "year": "A",
             "month": "M"
             "month": "M"
         }
         }
+        args.resample = aliases.get(args.resample, args.resample)
         daily_matrix = numpy.zeros(
         daily_matrix = numpy.zeros(
             (matrix.shape[0] * granularity, matrix.shape[1]),
             (matrix.shape[0] * granularity, matrix.shape[1]),
             dtype=numpy.float32)
             dtype=numpy.float32)
-        for i in range(1, matrix.shape[0]):
+        daily_start = 1 if "M" in args.resample else 0
+        for i in range(daily_start, matrix.shape[0]):
             daily_matrix[i * granularity:(i + 1) * granularity] = \
             daily_matrix[i * granularity:(i + 1) * granularity] = \
                 matrix[i] / granularity
                 matrix[i] / granularity
         date_range_granularity = pandas.date_range(
         date_range_granularity = pandas.date_range(
@@ -67,13 +69,15 @@ def main():
             dr: pandas.Series(row, index=date_range_sampling)
             dr: pandas.Series(row, index=date_range_sampling)
             for dr, row in zip(date_range_granularity, daily_matrix)
             for dr, row in zip(date_range_granularity, daily_matrix)
         }).T
         }).T
-        df = df.resample(aliases.get(args.resample, args.resample)).sum()
-        row0 = matrix[0]
+        df = df.resample(args.resample).sum()
+        if "M" in args.resample:
+            row0 = matrix[0]
         matrix = df.as_matrix()
         matrix = df.as_matrix()
-        matrix[0] = row0
-        for i in range(1, matrix.shape[0]):
-            matrix[i, i] += matrix[i, :i].sum()
-            matrix[i, :i] = 0
+        if "M" in args.resample:
+            matrix[0] = row0
+            for i in range(1, matrix.shape[0]):
+                matrix[i, i] += matrix[i, :i].sum()
+                matrix[i, :i] = 0
         if args.resample in ("year", "A"):
         if args.resample in ("year", "A"):
             labels = [dt.year for dt in df.index]
             labels = [dt.year for dt in df.index]
         elif args.resample in ("month", "M"):
         elif args.resample in ("month", "M"):
@@ -87,6 +91,7 @@ def main():
             for i in range(matrix.shape[0])]
             for i in range(matrix.shape[0])]
         if len(labels) > 18:
         if len(labels) > 18:
             warnings.warn("Too many labels - consider resampling.")
             warnings.warn("Too many labels - consider resampling.")
+        args.resample = "M"
     if args.style == "white":
     if args.style == "white":
         pyplot.gca().spines["bottom"].set_color("white")
         pyplot.gca().spines["bottom"].set_color("white")
         pyplot.gca().spines["top"].set_color("white")
         pyplot.gca().spines["top"].set_color("white")
@@ -115,16 +120,41 @@ def main():
     pyplot.tick_params(labelsize=args.text_size)
     pyplot.tick_params(labelsize=args.text_size)
     pyplot.xlim(date_range_sampling[0], date_range_sampling[-1])
     pyplot.xlim(date_range_sampling[0], date_range_sampling[-1])
     pyplot.gcf().set_size_inches(12, 9)
     pyplot.gcf().set_size_inches(12, 9)
-    # add border ticks
+    locator = pyplot.gca().xaxis.get_major_locator()
+    # set the optimal xticks locator
+    if "M" not in args.resample:
+        pyplot.gca().xaxis.set_major_locator(matplotlib.dates.YearLocator())
     locs = pyplot.gca().get_xticks().tolist()
     locs = pyplot.gca().get_xticks().tolist()
-    locs.extend(pyplot.xlim())
+    if len(locs) >= 16:
+        pyplot.gca().xaxis.set_major_locator(matplotlib.dates.YearLocator())
+        locs = pyplot.gca().get_xticks().tolist()
+        if len(locs) >= 16:
+            pyplot.gca().xaxis.set_major_locator(locator)
+    if locs[0] < pyplot.xlim()[0]:
+        del locs[0]
+    endindex = -1
+    if pyplot.xlim()[1] - locs[-1] >= (locs[-1] - locs[-2]) / 2:
+        locs.append(pyplot.xlim()[1])
+        endindex = len(locs) - 1
+    startindex = -1
+    if locs[0] - pyplot.xlim()[0] >= (locs[1] - locs[0]) / 2:
+        locs.append(pyplot.xlim()[0])
+        startindex = len(locs) - 1
     pyplot.gca().set_xticks(locs)
     pyplot.gca().set_xticks(locs)
     # hacking time!
     # hacking time!
     labels = pyplot.gca().get_xticklabels()
     labels = pyplot.gca().get_xticklabels()
-    labels[-2].set_text(date_range_sampling[0].date())
-    labels[-2].set_text = lambda _: None
-    labels[-1].set_text(date_range_sampling[-1].date())
-    labels[-1].set_text = lambda _: None
+    if startindex >= 0:
+        if "M" in args.resample:
+            labels[startindex].set_text(date_range_sampling[0].date())
+            labels[startindex].set_text = lambda _: None
+        labels[startindex].set_rotation(30)
+        labels[startindex].set_ha("right")
+    if endindex >= 0:
+        if "M" in args.resample:
+            labels[endindex].set_text(date_range_sampling[-1].date())
+            labels[endindex].set_text = lambda _: None
+        labels[endindex].set_rotation(30)
+        labels[endindex].set_ha("right")
     if not args.output:
     if not args.output:
         pyplot.gcf().canvas.set_window_title(
         pyplot.gcf().canvas.set_window_title(
             "Hercules %d x %d (granularity %d, sampling %d)" %
             "Hercules %d x %d (granularity %d, sampling %d)" %