utils.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. import dataclasses
  2. from typing import Dict, Optional, Union
  3. from lm_eval.tasks.ifeval import instructions_registry
  4. @dataclasses.dataclass
  5. class InputExample:
  6. key: int
  7. instruction_id_list: list[str]
  8. prompt: str
  9. kwargs: list[Dict[str, Optional[Union[str, int]]]]
  10. @dataclasses.dataclass
  11. class OutputExample:
  12. instruction_id_list: list[str]
  13. prompt: str
  14. response: str
  15. follow_all_instructions: bool
  16. follow_instruction_list: list[bool]
  17. def test_instruction_following_strict(
  18. inp,
  19. response,
  20. ):
  21. """Tests response to see if instructions are followed."""
  22. instruction_list = inp.instruction_id_list
  23. is_following_list = []
  24. for index, instruction_id in enumerate(instruction_list):
  25. instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id]
  26. instruction = instruction_cls(instruction_id)
  27. # Remove None values from kwargs to avoid unexpected keyword argument errors in build_description method.
  28. kwargs = {k: v for k, v in inp.kwargs[index].items() if v}
  29. instruction.build_description(**kwargs)
  30. args = instruction.get_instruction_args()
  31. if args and "prompt" in args:
  32. instruction.build_description(prompt=inp.prompt)
  33. if response.strip() and instruction.check_following(response):
  34. is_following_list.append(True)
  35. else:
  36. is_following_list.append(False)
  37. return OutputExample(
  38. instruction_id_list=inp.instruction_id_list,
  39. prompt=inp.prompt,
  40. response=response,
  41. follow_all_instructions=all(is_following_list),
  42. follow_instruction_list=is_following_list,
  43. )
  44. def test_instruction_following_loose(
  45. inp,
  46. response,
  47. ):
  48. """Tests response for an upper bound for following instructions."""
  49. r = response.split("\n")
  50. response_remove_first = "\n".join(r[1:]).strip()
  51. response_remove_last = "\n".join(r[:-1]).strip()
  52. response_remove_both = "\n".join(r[1:-1]).strip()
  53. revised_response = response.replace("*", "")
  54. revised_response_remove_first = response_remove_first.replace("*", "")
  55. revised_response_remove_last = response_remove_last.replace("*", "")
  56. revised_response_remove_both = response_remove_both.replace("*", "")
  57. all_responses = [
  58. response,
  59. revised_response,
  60. response_remove_first,
  61. response_remove_last,
  62. response_remove_both,
  63. revised_response_remove_first,
  64. revised_response_remove_last,
  65. revised_response_remove_both,
  66. ]
  67. instruction_list = inp.instruction_id_list
  68. is_following_list = []
  69. for index, instruction_id in enumerate(instruction_list):
  70. instruction_cls = instructions_registry.INSTRUCTION_DICT[instruction_id]
  71. instruction = instruction_cls(instruction_id)
  72. # Remove None values from kwargs to avoid unexpected keyword argument errors in build_description method.
  73. kwargs = {k: v for k, v in inp.kwargs[index].items() if v}
  74. instruction.build_description(**kwargs)
  75. args = instruction.get_instruction_args()
  76. if args and "prompt" in args:
  77. instruction.build_description(prompt=inp.prompt)
  78. is_following = False
  79. for r in all_responses:
  80. if r.strip() and instruction.check_following(r):
  81. is_following = True
  82. break
  83. is_following_list.append(is_following)
  84. return OutputExample(
  85. instruction_id_list=inp.instruction_id_list,
  86. prompt=inp.prompt,
  87. response=response,
  88. follow_all_instructions=all(is_following_list),
  89. follow_instruction_list=is_following_list,
  90. )
  91. def process_results(doc, results):
  92. new_kwargs = []
  93. for item in doc["kwargs"]:
  94. if item["nth_paragraph"]:
  95. item["nth_paragraph"] = int(item["nth_paragraph"])
  96. new_kwargs.append(item)
  97. inp = InputExample(
  98. key=doc["key"],
  99. instruction_id_list=doc["instruction_id_list"],
  100. prompt=doc["prompt"],
  101. kwargs=new_kwargs,
  102. )
  103. response = results[0]
  104. out_strict = test_instruction_following_strict(inp, response)
  105. out_loose = test_instruction_following_loose(inp, response)
  106. return {
  107. "prompt_level_strict_acc": out_strict.follow_all_instructions,
  108. "inst_level_strict_acc": out_strict.follow_instruction_list,
  109. "prompt_level_loose_acc": out_loose.follow_all_instructions,
  110. "inst_level_loose_acc": out_loose.follow_instruction_list,
  111. }
  112. def agg_inst_level_acc(items):
  113. flat_items = [item for sublist in items for item in sublist]
  114. inst_level_acc = sum(flat_items) / len(flat_items)
  115. return inst_level_acc