eval.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. '''
  2. Generate prompts for the LLM Needle Haystack.
  3. Source code from:
  4. https://github.com/gkamradt/LLMTest_NeedleInAHaystack/tree/main
  5. https://github.com/THUDM/LongAlign/tree/main/Needle_test
  6. '''
  7. import yaml
  8. import os
  9. import json
  10. import re
  11. import tqdm
  12. import time
  13. import requests
  14. import argparse
  15. def pred_openai(model_name, msg):
  16. tries = 0
  17. while tries < 5:
  18. tries += 1
  19. try:
  20. headers = {
  21. 'Authorization': f"Bearer {api_key}"
  22. }
  23. resp = requests.post("https://api.openai.com/v1/chat/completions", json = {
  24. "model": model_name,
  25. "messages": msg,
  26. "temperature": 0.
  27. }, headers=headers, timeout=120)
  28. if resp.status_code != 200:
  29. raise Exception(resp.text)
  30. resp = resp.json()
  31. break
  32. except KeyboardInterrupt as e:
  33. raise e
  34. except Exception as e:
  35. if "maximum context length" in str(e):
  36. raise e
  37. print("Error Occurs: \"%s\" Retry ..."%(str(e)))
  38. time.sleep(1)
  39. else:
  40. print("Max tries. Failed.")
  41. return
  42. return resp["choices"][0]["message"]["content"]
  43. USER_TEMPLATE = '''[Instruction]\nPlease act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user question displayed below. {criteria}[Ground truth]\n{reference}\nBegin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must rate the response on a scale of 1 to 10 by strictly following this format: "[[rating]]", for example: "Rating: [[5]]".\n\n[Question]\n{input}\n\n[The Start of Assistant\'s Answer]\n{prediction}\n[The End of Assistant\'s Answer]'''
  44. SYSTEM_TEMPLATE = 'You are a helpful assistant.'
  45. CRITERIA = {
  46. "accuracy": """
  47. Score 1: The answer is completely unrelated to the reference.
  48. Score 3: The answer has minor relevance but does not align with the reference.
  49. Score 5: The answer has moderate relevance but contains inaccuracies.
  50. Score 7: The answer aligns with the reference but has minor omissions.
  51. Score 10: The answer is completely accurate and aligns perfectly with the reference.
  52. Only respond with a numberical score
  53. """
  54. }
  55. def get_criteria():
  56. cri = 'For this evaluation, you should primarily consider the following criteria:\n'
  57. for key, value in CRITERIA.items():
  58. cri += f'{key}: {value}\n'
  59. return cri
  60. def get_user_template(input, prediction, reference, criteria):
  61. return USER_TEMPLATE.format(
  62. input=input,
  63. prediction=prediction,
  64. reference=reference,
  65. criteria=criteria
  66. )
  67. if __name__ == '__main__':
  68. with open('utils/needle_test/config-eval.yaml', 'r') as file:
  69. config = yaml.load(file, Loader=yaml.FullLoader)
  70. api_key = os.environ.get("OPENAI_API_KEY") # Enter your openai api key here
  71. if api_key is None:
  72. raise ValueError("Please set the OPENAI_API_KEY environment variable")
  73. parser = argparse.ArgumentParser()
  74. parser.add_argument('--input-path', type=str, default='None')
  75. parser.add_argument('--output-path', type=str, default='None')
  76. args = parser.parse_args()
  77. pred_dir = args.input_path
  78. save_dir = args.output_path
  79. model_name = config['model']['model_name']
  80. model_provider = config['model']['model_provider']
  81. criteria = get_criteria()
  82. reference = config['prompt']['needle']
  83. input = config['prompt']['retrieval_question']
  84. if not os.path.exists(save_dir):
  85. os.makedirs(save_dir)
  86. result_dict = {}
  87. for filename in tqdm.tqdm(os.listdir(pred_dir)):
  88. if not filename.endswith('.txt'):
  89. continue
  90. with open(f'{pred_dir}/{filename}', 'r') as f:
  91. data = f.read().strip()
  92. prediction = data
  93. user_template = get_user_template(input, prediction, reference, criteria)
  94. if model_provider == 'OpenAI':
  95. msg = [{
  96. "role": "system",
  97. "content": SYSTEM_TEMPLATE
  98. }, {
  99. "role": "user",
  100. "content": user_template
  101. }
  102. ]
  103. result = pred_openai(model_name, msg)
  104. else:
  105. raise NotImplementedError(f'Not implemented model provider: {model_provider}')
  106. pattern = r"\[\[(\d+)\]\]"
  107. match = re.search(pattern, result)
  108. score = int(match.group(1)) if match else None
  109. result_dict[filename.replace('.txt', '')] = {
  110. 'prediction': prediction,
  111. 'score': score
  112. }
  113. with open(f'{save_dir}/{model_provider}_{model_name}_eval.json', 'w') as f:
  114. json.dump(result_dict, f, indent=4)