浏览代码

Merge pull request #129 from vmarkovtsev/master

Improve the plot styling
Vadim Markovtsev 6 年之前
父节点
当前提交
7af1d6671e
共有 3 个文件被更改,包括 52 次插入36 次删除
  1. 2 1
      README.md
  2. 49 34
      labours.py
  3. 1 1
      requirements.txt

+ 2 - 1
README.md

@@ -317,7 +317,8 @@ These options affects all plots:
 python3 labours.py [--style=white|black] [--backend=] [--size=Y,X]
 python3 labours.py [--style=white|black] [--backend=] [--size=Y,X]
 ```
 ```
 
 
-`--style` changes the background to be either white ("black" foreground) or black ("white" foreground).
+`--style` sets the general style of the plot (see `labours.py --help`).
+`--background` changes the plot background to be either white or black.
 `--backend` chooses the Matplotlib backend.
 `--backend` chooses the Matplotlib backend.
 `--size` sets the size of the figure in inches. The default is `12,9`.
 `--size` sets the size of the figure in inches. The default is `12,9`.
 
 

+ 49 - 34
labours.py

@@ -5,6 +5,7 @@ import json
 import os
 import os
 import re
 import re
 import shutil
 import shutil
+import subprocess
 import sys
 import sys
 import tempfile
 import tempfile
 import threading
 import threading
@@ -34,6 +35,14 @@ PB_MESSAGES = {
 }
 }
 
 
 
 
+def list_matplotlib_styles():
+    script = "import sys; from matplotlib import pyplot; " \
+             "sys.stdout.write(repr(pyplot.style.available))"
+    styles = eval(subprocess.check_output([sys.executable, "-c", script]))
+    styles.remove("classic")
+    return ["default", "classic"] + styles
+
+
 def parse_args():
 def parse_args():
     parser = argparse.ArgumentParser()
     parser = argparse.ArgumentParser()
     parser.add_argument("-o", "--output", default="",
     parser.add_argument("-o", "--output", default="",
@@ -43,10 +52,12 @@ def parse_args():
     parser.add_argument("-i", "--input", default="-",
     parser.add_argument("-i", "--input", default="-",
                         help="Path to the input file (- for stdin).")
                         help="Path to the input file (- for stdin).")
     parser.add_argument("-f", "--input-format", default="auto", choices=["yaml", "pb", "auto"])
     parser.add_argument("-f", "--input-format", default="auto", choices=["yaml", "pb", "auto"])
-    parser.add_argument("--text-size", default=12, type=int,
+    parser.add_argument("--font-size", default=12, type=int,
                         help="Size of the labels and legend.")
                         help="Size of the labels and legend.")
+    parser.add_argument("--style", default="ggplot", choices=list_matplotlib_styles(),
+                        help="Plot style to use.")
     parser.add_argument("--backend", help="Matplotlib backend to use.")
     parser.add_argument("--backend", help="Matplotlib backend to use.")
-    parser.add_argument("--style", choices=["black", "white"], default="black",
+    parser.add_argument("--background", choices=["black", "white"], default="white",
                         help="Plot's general color scheme.")
                         help="Plot's general color scheme.")
     parser.add_argument("--size", help="Axes' size in inches, for example \"12,9\"")
     parser.add_argument("--size", help="Axes' size in inches, for example \"12,9\"")
     parser.add_argument("--relative", action="store_true",
     parser.add_argument("--relative", action="store_true",
@@ -609,29 +620,41 @@ def load_churn_matrix(people, matrix, max_people):
     return people, matrix
     return people, matrix
 
 
 
 
-def apply_plot_style(figure, axes, legend, style, text_size, axes_size):
+def import_pyplot(backend, style):
+    import matplotlib
+    if backend:
+        matplotlib.use(backend)
+    from matplotlib import pyplot
+    pyplot.style.use(style)
+    return matplotlib, pyplot
+
+
+def apply_plot_style(figure, axes, legend, background, font_size, axes_size):
+    foreground = "black" if background == "white" else "white"
     if axes_size is None:
     if axes_size is None:
         axes_size = (12, 9)
         axes_size = (12, 9)
     else:
     else:
         axes_size = tuple(float(p) for p in axes_size.split(","))
         axes_size = tuple(float(p) for p in axes_size.split(","))
     figure.set_size_inches(*axes_size)
     figure.set_size_inches(*axes_size)
     for side in ("bottom", "top", "left", "right"):
     for side in ("bottom", "top", "left", "right"):
-        axes.spines[side].set_color(style)
+        axes.spines[side].set_color(foreground)
     for axis in (axes.xaxis, axes.yaxis):
     for axis in (axes.xaxis, axes.yaxis):
-        axis.label.update(dict(fontsize=text_size, color=style))
+        axis.label.update(dict(fontsize=font_size, color=foreground))
     for axis in ("x", "y"):
     for axis in ("x", "y"):
-        getattr(axes, axis + "axis").get_offset_text().set_size(text_size)
-        axes.tick_params(axis=axis, colors=style, labelsize=text_size)
+        getattr(axes, axis + "axis").get_offset_text().set_size(font_size)
+        axes.tick_params(axis=axis, colors=foreground, labelsize=font_size)
     try:
     try:
         axes.ticklabel_format(axis="y", style="sci", scilimits=(0, 3))
         axes.ticklabel_format(axis="y", style="sci", scilimits=(0, 3))
     except AttributeError:
     except AttributeError:
         pass
         pass
+    figure.patch.set_facecolor(background)
+    axes.set_facecolor(background)
     if legend is not None:
     if legend is not None:
         frame = legend.get_frame()
         frame = legend.get_frame()
         for setter in (frame.set_facecolor, frame.set_edgecolor):
         for setter in (frame.set_facecolor, frame.set_edgecolor):
-            setter("black" if style == "white" else "white")
+            setter(background)
         for text in legend.get_texts():
         for text in legend.get_texts():
-            text.set_color(style)
+            text.set_color(foreground)
 
 
 
 
 def get_plot_path(base, name):
 def get_plot_path(base, name):
@@ -643,7 +666,7 @@ def get_plot_path(base, name):
     return output
     return output
 
 
 
 
-def deploy_plot(title, output, style):
+def deploy_plot(title, output, background):
     import matplotlib.pyplot as pyplot
     import matplotlib.pyplot as pyplot
 
 
     if not output:
     if not output:
@@ -651,7 +674,7 @@ def deploy_plot(title, output, style):
         pyplot.show()
         pyplot.show()
     else:
     else:
         if title:
         if title:
-            pyplot.title(title, color=style)
+            pyplot.title(title, color="black" if background == "white" else "white")
         try:
         try:
             pyplot.tight_layout()
             pyplot.tight_layout()
         except:  # noqa: E722
         except:  # noqa: E722
@@ -684,10 +707,7 @@ def plot_burndown(args, target, name, matrix, date_range_sampling, labels, granu
             json.dump(data, fout, sort_keys=True, default=default_json)
             json.dump(data, fout, sort_keys=True, default=default_json)
         return
         return
 
 
-    import matplotlib
-    if args.backend:
-        matplotlib.use(args.backend)
-    import matplotlib.pyplot as pyplot
+    matplotlib, pyplot = import_pyplot(args.backend, args.style)
 
 
     pyplot.stackplot(date_range_sampling, matrix, labels=labels)
     pyplot.stackplot(date_range_sampling, matrix, labels=labels)
     if args.relative:
     if args.relative:
@@ -697,10 +717,12 @@ def plot_burndown(args, target, name, matrix, date_range_sampling, labels, granu
         legend_loc = 3
         legend_loc = 3
     else:
     else:
         legend_loc = 2
         legend_loc = 2
-    legend = pyplot.legend(loc=legend_loc, fontsize=args.text_size)
+    pyplot.style.use("ggplot")
+    legend = pyplot.legend(loc=legend_loc, fontsize=args.font_size)
     pyplot.ylabel("Lines of code")
     pyplot.ylabel("Lines of code")
     pyplot.xlabel("Time")
     pyplot.xlabel("Time")
-    apply_plot_style(pyplot.gcf(), pyplot.gca(), legend, args.style, args.text_size, args.size)
+    apply_plot_style(pyplot.gcf(), pyplot.gca(), legend, args.background,
+                     args.font_size, args.size)
     pyplot.xlim(date_range_sampling[0], date_range_sampling[-1])
     pyplot.xlim(date_range_sampling[0], date_range_sampling[-1])
     locator = pyplot.gca().xaxis.get_major_locator()
     locator = pyplot.gca().xaxis.get_major_locator()
     # set the optimal xticks locator
     # set the optimal xticks locator
@@ -775,10 +797,7 @@ def plot_churn_matrix(args, repo, people, matrix):
             json.dump(data, fout, sort_keys=True, default=default_json)
             json.dump(data, fout, sort_keys=True, default=default_json)
         return
         return
 
 
-    import matplotlib
-    if args.backend:
-        matplotlib.use(args.backend)
-    import matplotlib.pyplot as pyplot
+    matplotlib, pyplot = import_pyplot(args.backend, args.style)
 
 
     s = 4 + matrix.shape[1] * 0.3
     s = 4 + matrix.shape[1] * 0.3
     fig = pyplot.figure(figsize=(s, s))
     fig = pyplot.figure(figsize=(s, s))
@@ -793,7 +812,7 @@ def plot_churn_matrix(args, repo, people, matrix):
                        va="bottom", rotation_mode="anchor")
                        va="bottom", rotation_mode="anchor")
     ax.set_yticks(numpy.arange(0.5, matrix.shape[0] + 0.5), minor=True)
     ax.set_yticks(numpy.arange(0.5, matrix.shape[0] + 0.5), minor=True)
     ax.grid(which="minor")
     ax.grid(which="minor")
-    apply_plot_style(fig, ax, None, args.style, args.text_size, args.size)
+    apply_plot_style(fig, ax, None, args.background, args.font_size, args.size)
     if not args.output:
     if not args.output:
         pos1 = ax.get_position()
         pos1 = ax.get_position()
         pos2 = (pos1.x0 + 0.15, pos1.y0 - 0.1, pos1.width * 0.9, pos1.height * 0.9)
         pos2 = (pos1.x0 + 0.15, pos1.y0 - 0.1, pos1.width * 0.9, pos1.height * 0.9)
@@ -822,10 +841,7 @@ def plot_ownership(args, repo, names, people, date_range, last):
             json.dump(data, fout, sort_keys=True, default=default_json)
             json.dump(data, fout, sort_keys=True, default=default_json)
         return
         return
 
 
-    import matplotlib
-    if args.backend:
-        matplotlib.use(args.backend)
-    import matplotlib.pyplot as pyplot
+    matplotlib, pyplot = import_pyplot(args.backend, args.style)
 
 
     pyplot.stackplot(date_range, people, labels=names)
     pyplot.stackplot(date_range, people, labels=names)
     pyplot.xlim(date_range[0], last)
     pyplot.xlim(date_range[0], last)
@@ -836,8 +852,9 @@ def plot_ownership(args, repo, names, people, date_range, last):
         legend_loc = 3
         legend_loc = 3
     else:
     else:
         legend_loc = 2
         legend_loc = 2
-    legend = pyplot.legend(loc=legend_loc, fontsize=args.text_size)
-    apply_plot_style(pyplot.gcf(), pyplot.gca(), legend, args.style, args.text_size, args.size)
+    legend = pyplot.legend(loc=legend_loc, fontsize=args.font_size)
+    apply_plot_style(pyplot.gcf(), pyplot.gca(), legend, args.background,
+                     args.font_size, args.size)
     if args.mode == "all":
     if args.mode == "all":
         output = get_plot_path(args.output, "people")
         output = get_plot_path(args.output, "people")
     else:
     else:
@@ -1070,10 +1087,7 @@ def show_shotness_stats(data):
 
 
 
 
 def show_sentiment_stats(args, name, resample, start, data):
 def show_sentiment_stats(args, name, resample, start, data):
-    import matplotlib
-    if args.backend:
-        matplotlib.use(args.backend)
-    import matplotlib.pyplot as pyplot
+    matplotlib, pyplot = import_pyplot(args.backend, args.style)
 
 
     start = datetime.fromtimestamp(start)
     start = datetime.fromtimestamp(start)
     data = sorted(data.items())
     data = sorted(data.items())
@@ -1092,10 +1106,11 @@ def show_sentiment_stats(args, name, resample, start, data):
             yneg.append(y)
             yneg.append(y)
     pyplot.bar(xpos, ypos, color="g", label="Positive")
     pyplot.bar(xpos, ypos, color="g", label="Positive")
     pyplot.bar(xneg, yneg, color="r", label="Negative")
     pyplot.bar(xneg, yneg, color="r", label="Negative")
-    legend = pyplot.legend(loc=1, fontsize=args.text_size)
+    legend = pyplot.legend(loc=1, fontsize=args.font_size)
     pyplot.ylabel("Lines of code")
     pyplot.ylabel("Lines of code")
     pyplot.xlabel("Time")
     pyplot.xlabel("Time")
-    apply_plot_style(pyplot.gcf(), pyplot.gca(), legend, args.style, args.text_size, args.size)
+    apply_plot_style(pyplot.gcf(), pyplot.gca(), legend, args.background,
+                     args.font_size, args.size)
     pyplot.xlim(xdates[0], xdates[-1])
     pyplot.xlim(xdates[0], xdates[-1])
     locator = pyplot.gca().xaxis.get_major_locator()
     locator = pyplot.gca().xaxis.get_major_locator()
     # set the optimal xticks locator
     # set the optimal xticks locator

+ 1 - 1
requirements.txt

@@ -1,5 +1,5 @@
 clint>=0.5.1,<1.0
 clint>=0.5.1,<1.0
-matplotlib>=2.0,<3.0
+matplotlib>=2.0,<4.0
 numpy>=1.13.1,<2.0
 numpy>=1.13.1,<2.0
 pandas>=0.20.0,<1.0
 pandas>=0.20.0,<1.0
 PyYAML>=3.12,<4.0
 PyYAML>=3.12,<4.0