eval_chatbot.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
  3. from chat_utils import OctoAIChatService, VllmChatService
  4. import logging
  5. import evaluate
  6. import argparse
  7. from config import load_config
  8. import asyncio
  9. import json
  10. from itertools import chain
  11. def compute_rouge_score(generated : str, reference: str):
  12. rouge_score = evaluate.load('rouge')
  13. return rouge_score.compute(
  14. predictions=generated,
  15. references=reference,
  16. use_stemmer=True,
  17. use_aggregator=True
  18. )
  19. def compute_bert_score(generated : str, reference: str):
  20. bertscore = evaluate.load("bertscore")
  21. return bertscore.compute(
  22. predictions=generated,
  23. references=reference,
  24. lang="en"
  25. )
  26. # This function is used to evaluate the quality of generated QA pairs. Return the original QA pair if the model eval result is YES. Otherwise, return an empty dict.
  27. async def eval_request(chat_service, api_context: dict, question: str) -> dict:
  28. prompt_for_system = api_context['eval_prompt_template'].format(language=api_context["language"])
  29. chat_request_payload = [{'role': 'system', 'content': prompt_for_system}, {'role': 'user', 'content': f"Question: {question}"}]
  30. # Getting a list of result, in this case, there should be only one result
  31. results = await chat_service.execute_chat_request_async(api_context, chat_request_payload,eval=False)
  32. # convert the result string to a list
  33. results = eval(results)
  34. if not results or len(results) > 1:
  35. print("results",type(results),len(results),results)
  36. return {}
  37. result = results[0]
  38. if "Answer" not in result:
  39. print("Error: eval response does not contain answer")
  40. print(question,result)
  41. return {}
  42. print("result",result)
  43. # Send back the model generated answer
  44. return result["Answer"]
  45. async def generate_eval_answer(chat_service, api_context: dict, questions: list):
  46. eval_tasks = []
  47. for batch_index, question in enumerate(questions):
  48. try:
  49. result = eval_request(chat_service, api_context, question)
  50. eval_tasks.append(result)
  51. except Exception as e:
  52. print(f"Error during data eval request execution: {e}")
  53. print(len(eval_tasks),"eval_tasks")
  54. eval_results = await asyncio.gather(*eval_tasks)
  55. return eval_results
  56. async def main(context):
  57. if context["endpoint"]:
  58. chat_service = VllmChatService()
  59. else:
  60. chat_service = OctoAIChatService()
  61. try:
  62. logging.info("Starting to generate answer given the eval set.")
  63. with open(context["eval_json"]) as fp:
  64. eval_json = json.load(fp)
  65. questions,groud_truth = [],[]
  66. for index, item in enumerate(eval_json):
  67. questions.append(item["question"])
  68. groud_truth.append(item["answer"])
  69. generated_answers = await generate_eval_answer(chat_service, context,questions)
  70. if not generated_answers:
  71. logging.warning("No answers generated. Please check the input context or model configuration.")
  72. return
  73. logging.info(f"Successfully generated {len(generated_answers)} answers.")
  74. rouge_score = compute_rouge_score(generated_answers,groud_truth)
  75. print("Rouge_score:",rouge_score)
  76. bert_score = compute_bert_score(generated_answers,groud_truth)
  77. print("Bert_score:",bert_score)
  78. logging.info("Eval successfully")
  79. except Exception as e:
  80. logging.error(f"An unexpected error occurred during the process: {e}",exc_info=True)
  81. def parse_arguments():
  82. # Define command line arguments for the script
  83. parser = argparse.ArgumentParser(
  84. description="Generate question/answer pairs from documentation."
  85. )
  86. parser.add_argument(
  87. "-m", "--model",
  88. default="chatbot",
  89. help="Select the model to use for evaluation, this maybe a LoRA adapter."
  90. )
  91. parser.add_argument(
  92. "-c", "--config_path",
  93. default="eval_config.yaml",
  94. help="Set the configuration file path that has system prompt along with language, evalset path."
  95. )
  96. parser.add_argument(
  97. "-v", "--vllm_endpoint",
  98. default=None,
  99. type=int,
  100. help="If a port is specified, then use local vllm endpoint for evaluations."
  101. )
  102. return parser.parse_args()
  103. if __name__ == "__main__":
  104. logging.info("Initializing the process and loading configuration...")
  105. args = parse_arguments()
  106. context = load_config(args.config_path)
  107. context["model"] = args.model
  108. context["endpoint"] = args.vllm_endpoint
  109. if context["endpoint"]:
  110. logging.info(f"Use local vllm service at port: '{args.vllm_endpoint}'.")
  111. asyncio.run(main(context))