plot.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. import io
  2. import json
  3. from pathlib import Path
  4. from typing import List, Union
  5. import itertools
  6. import modal
  7. try:
  8. volume = modal.Volume.lookup("humaneval", create_if_missing=False)
  9. except modal.exception.NotFoundError:
  10. raise Exception("Generate results first with modal run generate --data-dir test --no-dry-run --n 1000 --subsample 100")
  11. image = modal.Image.debian_slim(python_version="3.11").pip_install(
  12. "numpy==1.26.4",
  13. "pandas==2.2.3",
  14. "matplotlib==3.9.2",
  15. "seaborn==0.13.2",
  16. )
  17. app = modal.App("many-llamas-human-eval", image=image)
  18. DATA_DIR = Path("/mnt/humaneval")
  19. with image.imports():
  20. import numpy as np
  21. import pandas as pd
  22. import matplotlib.pyplot as plt
  23. import seaborn as sns
  24. @app.function(volumes={DATA_DIR: volume})
  25. def render_plots():
  26. run_dirs = list(sorted((DATA_DIR / "test").glob("run-*")))
  27. for run_dir in reversed(run_dirs):
  28. if len(list(run_dir.iterdir())) < 150:
  29. print(f"skipping incomplete run {run_dir}")
  30. else:
  31. break
  32. all_result_paths = list(run_dir.glob("*.jsonl_results.jsonl"))
  33. data = []
  34. for path in all_result_paths:
  35. data += [json.loads(line) for line in path.read_text(encoding='utf-8').splitlines()]
  36. for element in data:
  37. del element["completion"]
  38. df = pd.DataFrame.from_records(data)
  39. gb = df.groupby("task_id")
  40. passes = gb["passed"].sum()
  41. def estimate_pass_at_k(
  42. num_samples: Union[int, List[int], np.ndarray],
  43. num_correct: Union[List[int], np.ndarray],
  44. k: int
  45. ) -> np.ndarray:
  46. """
  47. Estimates pass@k of each problem and returns them in an array.
  48. """
  49. def estimator(n: int, c: int, k: int) -> float:
  50. """
  51. Calculates 1 - comb(n - c, k) / comb(n, k).
  52. """
  53. if n - c < k:
  54. return 1.0
  55. return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
  56. if isinstance(num_samples, int):
  57. num_samples_it = itertools.repeat(num_samples, len(num_correct))
  58. else:
  59. assert len(num_samples) == len(num_correct)
  60. num_samples_it = iter(num_samples)
  61. return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)])
  62. pass_at_ks = {}
  63. for k in [1, 10, 100, 1000]:
  64. pass_at_ks[k] = estimate_pass_at_k(1000, passes, k)
  65. pass_at_k = {k: np.mean(v) for k, v in pass_at_ks.items()}
  66. plot_df = pd.DataFrame(
  67. {"k": pass_at_k.keys(),
  68. "pass@k": pass_at_k.values()}
  69. )
  70. plot_df["fail@k"] = 1 - plot_df["pass@k"]
  71. sns.set_theme(style='dark')
  72. plt.style.use("dark_background")
  73. plt.rcParams['font.sans-serif'] = ["Inter", "Arial", "DejaVu Sans", "Liberation Sans", "Bitstream Vera Sans", "sans-serif"]
  74. sns.despine()
  75. sns.set_context("talk", rc={"lines.linewidth": 2.5})
  76. gpt4o_benchmark = 0.902
  77. # First plot
  78. plt.figure(figsize=(10, 6))
  79. fg = sns.lineplot(
  80. x="k",
  81. y="pass@k",
  82. data=plot_df,
  83. color="#7FEE64",
  84. linewidth=6,
  85. alpha=0.9,
  86. label="Llama 3.2 3B Instruct pass@k"
  87. )
  88. initial_lim = fg.axes.get_xlim()
  89. fg.axes.hlines(
  90. gpt4o_benchmark, *initial_lim,
  91. linestyle="--",
  92. alpha=0.6,
  93. zorder=-1,
  94. label="GPT-4o fail@1"
  95. )
  96. fg.axes.set_xlim(*initial_lim)
  97. fg.axes.set_ylabel("")
  98. fg.axes.set_ylim(0, 1)
  99. plt.tight_layout(pad=1.2)
  100. plt.legend()
  101. # Save the first plot as bytes
  102. img_buffer = io.BytesIO()
  103. plt.savefig(img_buffer, format='jpeg')
  104. plot_1_img_bytes = img_buffer.getvalue()
  105. plt.close()
  106. # Second plot
  107. plt.figure(figsize=(10, 6))
  108. fg = sns.lineplot(
  109. x="k",
  110. y="fail@k",
  111. data=plot_df,
  112. color="#7FEE64",
  113. linewidth=6,
  114. alpha=0.9,
  115. label="Llama 3.2 3B Instruct fail@k"
  116. )
  117. initial_lim = fg.axes.get_xlim()
  118. fg.axes.hlines(
  119. 1 - gpt4o_benchmark, *initial_lim,
  120. linestyle="--",
  121. alpha=0.6,
  122. zorder=-1,
  123. label="GPT-4o fail@1"
  124. )
  125. fg.axes.set_xlim(*initial_lim)
  126. fg.axes.set_ylabel("")
  127. fg.axes.set_yscale("log")
  128. fg.axes.set_xscale("log")
  129. fg.axes.set_xlim(0.5, 2000)
  130. fg.axes.set_ylim(1e-2, 1e0)
  131. plt.tight_layout(pad=1.2)
  132. plt.legend()
  133. # Save the second plot as bytes
  134. img_buffer = io.BytesIO()
  135. plt.savefig(img_buffer, format='jpeg')
  136. plot_2_img_bytes = img_buffer.getvalue()
  137. plt.close()
  138. return [plot_1_img_bytes, plot_2_img_bytes]
  139. @app.local_entrypoint()
  140. def main():
  141. plots = render_plots.remote()
  142. assert len(plots) == 2
  143. with open ("/tmp/plot-pass-k.jpeg", "wb") as f:
  144. f.write(plots[0])
  145. with open ("/tmp/plot-fail-k.jpeg", "wb") as f:
  146. f.write(plots[1])
  147. print("Plots saved to:")
  148. print(" /tmp/plot-pass-k.jpeg")
  149. print(" /tmp/plot-fail-k.jpeg")