plotting.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import os
  2. def import_pyplot(backend, style):
  3. import matplotlib
  4. if backend:
  5. matplotlib.use(backend)
  6. from matplotlib import pyplot
  7. pyplot.style.use(style)
  8. print("matplotlib: backend is", matplotlib.get_backend())
  9. return matplotlib, pyplot
  10. def apply_plot_style(figure, axes, legend, background, font_size, axes_size):
  11. foreground = "black" if background == "white" else "white"
  12. if axes_size is None:
  13. axes_size = (16, 12)
  14. else:
  15. axes_size = tuple(float(p) for p in axes_size.split(","))
  16. figure.set_size_inches(*axes_size)
  17. for side in ("bottom", "top", "left", "right"):
  18. axes.spines[side].set_color(foreground)
  19. for axis in (axes.xaxis, axes.yaxis):
  20. axis.label.update(dict(fontsize=font_size, color=foreground))
  21. for axis in ("x", "y"):
  22. getattr(axes, axis + "axis").get_offset_text().set_size(font_size)
  23. axes.tick_params(axis=axis, colors=foreground, labelsize=font_size)
  24. try:
  25. axes.ticklabel_format(axis="y", style="sci", scilimits=(0, 3))
  26. except AttributeError:
  27. pass
  28. figure.patch.set_facecolor(background)
  29. axes.set_facecolor(background)
  30. if legend is not None:
  31. frame = legend.get_frame()
  32. for setter in (frame.set_facecolor, frame.set_edgecolor):
  33. setter(background)
  34. for text in legend.get_texts():
  35. text.set_color(foreground)
  36. def get_plot_path(base: str, name: str) -> str:
  37. root, ext = os.path.splitext(base)
  38. if not ext:
  39. ext = ".png"
  40. output = os.path.join(root, name + ext)
  41. os.makedirs(os.path.dirname(output), exist_ok=True)
  42. return output
  43. def deploy_plot(title: str, output: str, background: str, tight: bool = True) -> None:
  44. import matplotlib.pyplot as pyplot
  45. if not output:
  46. pyplot.gcf().canvas.set_window_title(title)
  47. pyplot.show()
  48. else:
  49. if title:
  50. pyplot.title(title, color="black" if background == "white" else "white")
  51. if tight:
  52. try:
  53. pyplot.tight_layout()
  54. except: # noqa: E722
  55. print("Warning: failed to set the tight layout")
  56. print("Writing plot to %s" % output)
  57. pyplot.savefig(output, transparent=True)
  58. pyplot.clf()