eval_chatbot.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
  3. from chat_utils import OctoAIChatService, VllmChatService
  4. import logging
  5. import evaluate
  6. import argparse
  7. from config import load_config
  8. import asyncio
  9. import json
  10. from itertools import chain
  11. from generator_utils import parse_qa_to_json
  12. def compute_rouge_score(generated : str, reference: str):
  13. rouge_score = evaluate.load('rouge')
  14. return rouge_score.compute(
  15. predictions=generated,
  16. references=reference,
  17. use_stemmer=True,
  18. use_aggregator=True
  19. )
  20. def compute_bert_score(generated : str, reference: str):
  21. bertscore = evaluate.load("bertscore")
  22. return bertscore.compute(
  23. predictions=generated,
  24. references=reference,
  25. lang="en"
  26. )
  27. # This function is used to eval the fine-tuned model, given the question, generate the answer.
  28. async def eval_request(chat_service, api_context: dict, question: str) -> dict:
  29. prompt_for_system = api_context['eval_prompt_template'].format(language=api_context["language"])
  30. chat_request_payload = [{'role': 'system', 'content': prompt_for_system}, {'role': 'user', 'content': f"Question: {question}"}]
  31. # Getting a list of result, in this case, there should be only one result
  32. response_string = await chat_service.execute_chat_request_async(api_context, chat_request_payload)
  33. # convert the result string to a dict that contains Question, Answer
  34. result_list = parse_qa_to_json(response_string)
  35. if not result_list or len(result_list) > 1:
  36. print("Error: eval response should be a list of one result dict")
  37. return {}
  38. result = result_list[0]
  39. if "Answer" not in result:
  40. print("Error: eval response does not contain answer")
  41. return {}
  42. # Send back the model generated answer
  43. return result["Answer"]
  44. async def generate_eval_answer(chat_service, api_context: dict, questions: list):
  45. eval_tasks = []
  46. for batch_index, question in enumerate(questions):
  47. try:
  48. result = eval_request(chat_service, api_context, question)
  49. eval_tasks.append(result)
  50. except Exception as e:
  51. print(f"Error during data eval request execution: {e}")
  52. print(len(eval_tasks),"eval_tasks")
  53. eval_results = await asyncio.gather(*eval_tasks)
  54. return eval_results
  55. async def main(context):
  56. if context["endpoint"]:
  57. chat_service = VllmChatService()
  58. else:
  59. chat_service = OctoAIChatService()
  60. try:
  61. logging.info("Starting to generate answer given the eval set.")
  62. with open(context["eval_json"]) as fp:
  63. eval_json = json.load(fp)
  64. questions,groud_truth = [],[]
  65. for index, item in enumerate(eval_json):
  66. questions.append(item["question"])
  67. groud_truth.append(item["answer"])
  68. generated_answers = await generate_eval_answer(chat_service, context,questions)
  69. if not generated_answers:
  70. logging.warning("No answers generated. Please check the input context or model configuration.")
  71. return
  72. logging.info(f"Successfully generated {len(generated_answers)} answers.")
  73. rouge_score = compute_rouge_score(generated_answers,groud_truth)
  74. print("Rouge_score:",rouge_score)
  75. bert_score = compute_bert_score(generated_answers,groud_truth)
  76. print("Bert_score:",bert_score)
  77. logging.info("Eval successfully")
  78. except Exception as e:
  79. logging.error(f"An unexpected error occurred during the process: {e}",exc_info=True)
  80. def parse_arguments():
  81. # Define command line arguments for the script
  82. parser = argparse.ArgumentParser(
  83. description="Generate question/answer pairs from documentation."
  84. )
  85. parser.add_argument(
  86. "-m", "--model",
  87. default="chatbot",
  88. help="Select the model to use for evaluation, this maybe a LoRA adapter."
  89. )
  90. parser.add_argument(
  91. "-c", "--config_path",
  92. default="eval_config.yaml",
  93. help="Set the configuration file path that has system prompt along with language, evalset path."
  94. )
  95. parser.add_argument(
  96. "-v", "--vllm_endpoint",
  97. default=None,
  98. type=int,
  99. help="If a port is specified, then use local vllm endpoint for evaluations."
  100. )
  101. return parser.parse_args()
  102. if __name__ == "__main__":
  103. logging.info("Initializing the process and loading configuration...")
  104. args = parse_arguments()
  105. context = load_config(args.config_path)
  106. context["model"] = args.model
  107. context["endpoint"] = args.vllm_endpoint
  108. if context["endpoint"]:
  109. logging.info(f"Use local vllm service at port: '{args.vllm_endpoint}'.")
  110. asyncio.run(main(context))