humaneval.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import typing as t
  2. from bigcode_eval.tasks import humaneval
  3. from bigcode_eval.tasks.custom_metrics.execute import check_correctness
  4. from datasets import load_dataset
  5. from lm_eval.evaluator_utils import eval_logger
  6. import dspy
  7. from .datatypes import TaskDatasets
  8. from .helpers import train_val_test_split
  9. if t.TYPE_CHECKING:
  10. from bigcode_eval.base import Task
  11. def signature(instructions: str = "") -> dspy.Signature:
  12. class HumanEval(dspy.Signature):
  13. __doc__ = instructions
  14. prompt: str = dspy.InputField()
  15. solution: str = dspy.OutputField()
  16. return HumanEval
  17. def metric(gold: dspy.Example, pred: dspy.Example, trace=False) -> bool:
  18. program = gold.prompt + "\n" + pred.solution + "\n" + gold.dspy_test
  19. result = check_correctness(
  20. program,
  21. timeout=30,
  22. task_id=gold.dspy_task_id,
  23. completion_id=None,
  24. )
  25. if result["passed"]:
  26. return True
  27. eval_logger.debug(f"{gold.dspy_task_id}: {result['result']}")
  28. return False
  29. def datasets(
  30. train_size: float = 0.1,
  31. validation_size: float = 0.2,
  32. ) -> TaskDatasets:
  33. dataset = load_dataset("codeparrot/instructhumaneval")
  34. train_docs, validation_docs, test_docs = train_val_test_split(
  35. dataset,
  36. train_size=train_size,
  37. validation_size=validation_size,
  38. )
  39. return TaskDatasets(
  40. trainset=map(_task_doc_example, train_docs),
  41. valset=map(_task_doc_example, validation_docs),
  42. testset=map(_task_doc_example, test_docs),
  43. )
  44. class TaskDoc(t.TypedDict):
  45. task_id: str
  46. prompt: str
  47. canonical_solution: str
  48. test: str
  49. inputs = ["prompt"]
  50. outputs = ["solution"]
  51. def _task_doc_example(doc: TaskDoc) -> dspy.Example:
  52. return dspy.Example(
  53. prompt=doc["prompt"],
  54. solution=doc["canonical_solution"],
  55. # dspy_ keys are hidden
  56. dspy_task_id=doc["task_id"],
  57. dspy_test=doc["test"],
  58. ).with_inputs(*inputs)