plot.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. # ## Plotting HumanEval Results
  2. # This script will calculate pass@k and fail@k for our experiment and plot them.
  3. #
  4. # Run it with:
  5. # modal run plot
  6. import io
  7. import json
  8. from pathlib import Path
  9. from typing import List, Union
  10. import itertools
  11. import modal
  12. try:
  13. volume = modal.Volume.lookup("humaneval", create_if_missing=False)
  14. except modal.exception.NotFoundError:
  15. raise Exception("Generate results first with modal run generate --data-dir test --no-dry-run --n 1000 --subsample 100")
  16. image = modal.Image.debian_slim(python_version="3.11").pip_install(
  17. "numpy==1.26.4",
  18. "pandas==2.2.3",
  19. "matplotlib==3.9.2",
  20. "seaborn==0.13.2",
  21. )
  22. app = modal.App("many-llamas-human-eval", image=image)
  23. DATA_DIR = Path("/mnt/humaneval")
  24. with image.imports():
  25. import numpy as np
  26. import pandas as pd
  27. import matplotlib.pyplot as plt
  28. import seaborn as sns
  29. @app.function(volumes={DATA_DIR: volume})
  30. def render_plots():
  31. run_dirs = list(sorted((DATA_DIR / "test").glob("run-*")))
  32. for run_dir in reversed(run_dirs):
  33. if len(list(run_dir.iterdir())) < 150:
  34. print(f"skipping incomplete run {run_dir}")
  35. else:
  36. break
  37. all_result_paths = list(run_dir.glob("*.jsonl_results.jsonl"))
  38. data = []
  39. for path in all_result_paths:
  40. data += [json.loads(line) for line in path.read_text(encoding='utf-8').splitlines()]
  41. for element in data:
  42. del element["completion"]
  43. df = pd.DataFrame.from_records(data)
  44. gb = df.groupby("task_id")
  45. passes = gb["passed"].sum()
  46. def estimate_pass_at_k(
  47. num_samples: Union[int, List[int], np.ndarray],
  48. num_correct: Union[List[int], np.ndarray],
  49. k: int
  50. ) -> np.ndarray:
  51. """
  52. Estimates pass@k of each problem and returns them in an array.
  53. """
  54. def estimator(n: int, c: int, k: int) -> float:
  55. """
  56. Calculates 1 - comb(n - c, k) / comb(n, k).
  57. """
  58. if n - c < k:
  59. return 1.0
  60. return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
  61. if isinstance(num_samples, int):
  62. num_samples_it = itertools.repeat(num_samples, len(num_correct))
  63. else:
  64. assert len(num_samples) == len(num_correct)
  65. num_samples_it = iter(num_samples)
  66. return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)])
  67. pass_at_ks = {}
  68. for k in [1, 10, 100, 1000]:
  69. pass_at_ks[k] = estimate_pass_at_k(1000, passes, k)
  70. pass_at_k = {k: np.mean(v) for k, v in pass_at_ks.items()}
  71. plot_df = pd.DataFrame(
  72. {"k": pass_at_k.keys(),
  73. "pass@k": pass_at_k.values()}
  74. )
  75. plot_df["fail@k"] = 1 - plot_df["pass@k"]
  76. sns.set_theme(style='dark')
  77. plt.style.use("dark_background")
  78. plt.rcParams['font.sans-serif'] = ["Inter", "Arial", "DejaVu Sans", "Liberation Sans", "Bitstream Vera Sans", "sans-serif"]
  79. sns.despine()
  80. sns.set_context("talk", rc={"lines.linewidth": 2.5})
  81. gpt4o_benchmark = 0.902
  82. # First plot
  83. plt.figure(figsize=(10, 6))
  84. fg = sns.lineplot(
  85. x="k",
  86. y="pass@k",
  87. data=plot_df,
  88. color="#7FEE64",
  89. linewidth=6,
  90. alpha=0.9,
  91. label="Llama 3.2 3B Instruct pass@k"
  92. )
  93. initial_lim = fg.axes.get_xlim()
  94. fg.axes.hlines(
  95. gpt4o_benchmark, *initial_lim,
  96. linestyle="--",
  97. alpha=0.6,
  98. zorder=-1,
  99. label="GPT-4o fail@1"
  100. )
  101. fg.axes.set_xlim(*initial_lim)
  102. fg.axes.set_ylabel("")
  103. fg.axes.set_ylim(0, 1)
  104. plt.tight_layout(pad=1.2)
  105. plt.legend()
  106. # Save the first plot as bytes
  107. img_buffer = io.BytesIO()
  108. plt.savefig(img_buffer, format='jpeg')
  109. plot_1_img_bytes = img_buffer.getvalue()
  110. plt.close()
  111. # Second plot
  112. plt.figure(figsize=(10, 6))
  113. fg = sns.lineplot(
  114. x="k",
  115. y="fail@k",
  116. data=plot_df,
  117. color="#7FEE64",
  118. linewidth=6,
  119. alpha=0.9,
  120. label="Llama 3.2 3B Instruct fail@k"
  121. )
  122. initial_lim = fg.axes.get_xlim()
  123. fg.axes.hlines(
  124. 1 - gpt4o_benchmark, *initial_lim,
  125. linestyle="--",
  126. alpha=0.6,
  127. zorder=-1,
  128. label="GPT-4o fail@1"
  129. )
  130. fg.axes.set_xlim(*initial_lim)
  131. fg.axes.set_ylabel("")
  132. fg.axes.set_yscale("log")
  133. fg.axes.set_xscale("log")
  134. fg.axes.set_xlim(0.5, 2000)
  135. fg.axes.set_ylim(1e-2, 1e0)
  136. plt.tight_layout(pad=1.2)
  137. plt.legend()
  138. # Save the second plot as bytes
  139. img_buffer = io.BytesIO()
  140. plt.savefig(img_buffer, format='jpeg')
  141. plot_2_img_bytes = img_buffer.getvalue()
  142. plt.close()
  143. return [plot_1_img_bytes, plot_2_img_bytes]
  144. @app.local_entrypoint()
  145. def main():
  146. plots = render_plots.remote()
  147. assert len(plots) == 2
  148. with open ("/tmp/plot-pass-k.jpeg", "wb") as f:
  149. f.write(plots[0])
  150. with open ("/tmp/plot-fail-k.jpeg", "wb") as f:
  151. f.write(plots[1])
  152. print("Plots saved to:")
  153. print(" /tmp/plot-pass-k.jpeg")
  154. print(" /tmp/plot-fail-k.jpeg")