eval_chatbot.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
  3. from chat_utils import OctoAIChatService, VllmChatService
  4. import logging
  5. import evaluate
  6. import argparse
  7. from config import load_config
  8. import asyncio
  9. import json
  10. from itertools import chain
  11. from generator_utils import parse_qa_to_json, generate_LLM_eval
  12. def compute_rouge_score(generated : str, reference: str):
  13. rouge_score = evaluate.load('rouge')
  14. return rouge_score.compute(
  15. predictions=generated,
  16. references=reference,
  17. use_stemmer=True,
  18. use_aggregator=True
  19. )
  20. def compute_bert_score(generated : str, reference: str):
  21. bertscore = evaluate.load("bertscore")
  22. score = bertscore.compute(
  23. predictions=generated,
  24. references=reference,
  25. lang="en"
  26. )
  27. f1 = score["f1"]
  28. precision = score["precision"]
  29. recall = score["recall"]
  30. return sum(precision)/len(precision), sum(recall)/len(recall), sum(f1)/len(f1)
  31. # This function is used to eval the fine-tuned model, given the question, generate the answer.
  32. async def eval_request(chat_service, api_context: dict, question: str) -> dict:
  33. prompt_for_system = api_context['eval_prompt_template'].format(language=api_context["language"])
  34. chat_request_payload = [{'role': 'system', 'content': prompt_for_system}, {'role': 'user', 'content': f"Question: {question}"}]
  35. # Getting a list of result, in this case, there should be only one result
  36. response_string = await chat_service.execute_chat_request_async(api_context, chat_request_payload)
  37. # convert the result string to a dict that contains Question, Answer
  38. result_list = parse_qa_to_json(response_string)
  39. if not result_list or len(result_list) > 1:
  40. print("Error: eval response should be a list of one result dict")
  41. return {}
  42. result = result_list[0]
  43. if "Answer" not in result:
  44. print("Error: eval response does not contain answer")
  45. return {}
  46. # Send back the model generated answer
  47. return result["Answer"]
  48. async def generate_eval_answer(chat_service, api_context: dict, questions: list):
  49. eval_tasks = []
  50. for batch_index, question in enumerate(questions):
  51. try:
  52. result = eval_request(chat_service, api_context, question)
  53. eval_tasks.append(result)
  54. except Exception as e:
  55. print(f"Error during data eval request execution: {e}")
  56. print(len(eval_tasks),"eval_tasks")
  57. eval_results = await asyncio.gather(*eval_tasks)
  58. return eval_results
  59. async def main(context):
  60. if context["endpoint"]:
  61. chat_service = VllmChatService()
  62. else:
  63. chat_service = OctoAIChatService()
  64. try:
  65. logging.info("Starting to generate answer given the eval set.")
  66. with open(context["eval_json"]) as fp:
  67. eval_json = json.load(fp)
  68. questions,groud_truth = [],[]
  69. for index, item in enumerate(eval_json):
  70. questions.append(item["question"])
  71. groud_truth.append(item["answer"])
  72. generated_answers = await generate_eval_answer(chat_service, context,questions)
  73. if not generated_answers:
  74. logging.warning("No answers generated. Please check the input context or model configuration.")
  75. return
  76. logging.info(f"Successfully generated {len(generated_answers)} answers.")
  77. judge_list = []
  78. for index, item in enumerate(generated_answers):
  79. judge_list.append({"Question":questions[index],"Ground_truth":groud_truth[index],"Generated_answer":generated_answers[index]})
  80. if context["judge_endpoint"]:
  81. # make a copy of the context then change the VLLM endpoint to judge_endpoint
  82. context_copy = dict(context)
  83. context_copy["endpoint"] = context["judge_endpoint"]
  84. context_copy["model"] = "meta-llama/Meta-Llama-3-70B-Instruct"
  85. judge_results = await generate_LLM_eval(chat_service, context_copy, judge_list)
  86. correct_num = 0
  87. for result in judge_results:
  88. correct_num += result["Result"] == "YES"
  89. LLM_judge_score = correct_num/len(judge_results)
  90. print(f"The accuracy of the model is {LLM_judge_score}")
  91. rouge_score = compute_rouge_score(generated_answers,groud_truth)
  92. print("Rouge_score:",rouge_score)
  93. P, R, F1 = compute_bert_score(generated_answers,groud_truth)
  94. print(f"BERTScore Precision: {P:.4f}, Recall: {R:.4f}, F1: {F1:.4f}")
  95. # Saving the eval result to a log file
  96. with open(context["output_log"],"a") as fp:
  97. fp.write(f"Eval_result for {context['model']} \n")
  98. fp.write(f"Rouge_score: {rouge_score} \n")
  99. fp.write(f"BERTScore Precision: {P:.4f}, Recall: {R:.4f}, F1: {F1:.4f} \n")
  100. if context["judge_endpoint"]:
  101. fp.write(f"LLM_judge_score: {LLM_judge_score} \n")
  102. fp.write(f"QA details: \n")
  103. for item in judge_list:
  104. fp.write(f"question: {item['Question']} \n")
  105. fp.write(f"generated_answers: {item['Generated_answer']} \n")
  106. fp.write(f"groud_truth: {item['Ground_truth']} \n")
  107. fp.write("\n")
  108. logging.info(f"Eval successfully, the eval result is saved to {context['output_log']}.")
  109. except Exception as e:
  110. logging.error(f"An unexpected error occurred during the process: {e}",exc_info=True)
  111. def parse_arguments():
  112. # Define command line arguments for the script
  113. parser = argparse.ArgumentParser(
  114. description="Generate question/answer pairs from documentation."
  115. )
  116. parser.add_argument(
  117. "-m", "--model",
  118. default="chatbot",
  119. help="Select the model to use for evaluation, this maybe a LoRA adapter."
  120. )
  121. parser.add_argument(
  122. "-c", "--config_path",
  123. default="eval_config.yaml",
  124. help="Set the configuration file path that has system prompt along with language, evalset path."
  125. )
  126. parser.add_argument(
  127. "-v", "--vllm_endpoint",
  128. default=None,
  129. type=int,
  130. help="If a port is specified, then use local vllm endpoint for evaluations."
  131. )
  132. parser.add_argument(
  133. "-j", "--judge_endpoint",
  134. default=None,
  135. type=int,
  136. help="If a port is specified, then use local vllm endpoint as judge LLM."
  137. )
  138. parser.add_argument(
  139. "-o", "--output_log",
  140. default="eval_result.log",
  141. help="save the eval result to a log file. Default is eval_result.log"
  142. )
  143. return parser.parse_args()
  144. if __name__ == "__main__":
  145. logging.info("Initializing the process and loading configuration...")
  146. args = parse_arguments()
  147. context = load_config(args.config_path)
  148. context["model"] = args.model
  149. context["endpoint"] = args.vllm_endpoint
  150. context["judge_endpoint"] = args.judge_endpoint
  151. context["output_log"] = args.output_log
  152. if context["endpoint"]:
  153. logging.info(f"Use local vllm service for eval at port: '{args.vllm_endpoint}'.")
  154. if context["judge_endpoint"]:
  155. logging.info(f"Use local vllm service for judge at port: '{args.judge_endpoint}'.")
  156. asyncio.run(main(context))