|
@@ -8,7 +8,7 @@ from config import load_config
|
|
|
import asyncio
|
|
|
import json
|
|
|
from itertools import chain
|
|
|
-from generator_utils import parse_qa_to_json
|
|
|
+from generator_utils import parse_qa_to_json, generate_LLM_eval
|
|
|
|
|
|
def compute_rouge_score(generated : str, reference: str):
|
|
|
rouge_score = evaluate.load('rouge')
|
|
@@ -20,11 +20,15 @@ def compute_rouge_score(generated : str, reference: str):
|
|
|
)
|
|
|
def compute_bert_score(generated : str, reference: str):
|
|
|
bertscore = evaluate.load("bertscore")
|
|
|
- return bertscore.compute(
|
|
|
+ score = bertscore.compute(
|
|
|
predictions=generated,
|
|
|
references=reference,
|
|
|
lang="en"
|
|
|
)
|
|
|
+ f1 = score["f1"]
|
|
|
+ precision = score["precision"]
|
|
|
+ recall = score["recall"]
|
|
|
+ return sum(precision)/len(precision), sum(recall)/len(recall), sum(f1)/len(f1)
|
|
|
# This function is used to eval the fine-tuned model, given the question, generate the answer.
|
|
|
async def eval_request(chat_service, api_context: dict, question: str) -> dict:
|
|
|
prompt_for_system = api_context['eval_prompt_template'].format(language=api_context["language"])
|
|
@@ -75,11 +79,38 @@ async def main(context):
|
|
|
logging.warning("No answers generated. Please check the input context or model configuration.")
|
|
|
return
|
|
|
logging.info(f"Successfully generated {len(generated_answers)} answers.")
|
|
|
+ judge_list = []
|
|
|
+ for index, item in enumerate(generated_answers):
|
|
|
+ judge_list.append({"Question":questions[index],"Ground_truth":groud_truth[index],"Generated_answer":generated_answers[index]})
|
|
|
+ if context["judge_endpoint"]:
|
|
|
+ # make a copy of the context then change the VLLM endpoint to judge_endpoint
|
|
|
+ context_copy = dict(context)
|
|
|
+ context_copy["endpoint"] = context["judge_endpoint"]
|
|
|
+ context_copy["model"] = "meta-llama/Meta-Llama-3-70B-Instruct"
|
|
|
+ judge_results = await generate_LLM_eval(chat_service, context_copy, judge_list)
|
|
|
+ correct_num = 0
|
|
|
+ for result in judge_results:
|
|
|
+ correct_num += result["Result"] == "YES"
|
|
|
+ LLM_judge_score = correct_num/len(judge_results)
|
|
|
+ print(f"The accuracy of the model is {LLM_judge_score}")
|
|
|
rouge_score = compute_rouge_score(generated_answers,groud_truth)
|
|
|
print("Rouge_score:",rouge_score)
|
|
|
- bert_score = compute_bert_score(generated_answers,groud_truth)
|
|
|
- print("Bert_score:",bert_score)
|
|
|
- logging.info("Eval successfully")
|
|
|
+ P, R, F1 = compute_bert_score(generated_answers,groud_truth)
|
|
|
+ print(f"BERTScore Precision: {P:.4f}, Recall: {R:.4f}, F1: {F1:.4f}")
|
|
|
+ # Saving the eval result to a log file
|
|
|
+ with open(context["output_log"],"a") as fp:
|
|
|
+ fp.write(f"Eval_result for {context['model']} \n")
|
|
|
+ fp.write(f"Rouge_score: {rouge_score} \n")
|
|
|
+ fp.write(f"BERTScore Precision: {P:.4f}, Recall: {R:.4f}, F1: {F1:.4f} \n")
|
|
|
+ if context["judge_endpoint"]:
|
|
|
+ fp.write(f"LLM_judge_score: {LLM_judge_score} \n")
|
|
|
+ fp.write(f"QA details: \n")
|
|
|
+ for item in judge_list:
|
|
|
+ fp.write(f"question: {item['Question']} \n")
|
|
|
+ fp.write(f"generated_answers: {item['Generated_answer']} \n")
|
|
|
+ fp.write(f"groud_truth: {item['Ground_truth']} \n")
|
|
|
+ fp.write("\n")
|
|
|
+ logging.info(f"Eval successfully, the eval result is saved to {context['output_log']}.")
|
|
|
except Exception as e:
|
|
|
logging.error(f"An unexpected error occurred during the process: {e}",exc_info=True)
|
|
|
|
|
@@ -104,15 +135,29 @@ def parse_arguments():
|
|
|
type=int,
|
|
|
help="If a port is specified, then use local vllm endpoint for evaluations."
|
|
|
)
|
|
|
+ parser.add_argument(
|
|
|
+ "-j", "--judge_endpoint",
|
|
|
+ default=None,
|
|
|
+ type=int,
|
|
|
+ help="If a port is specified, then use local vllm endpoint as judge LLM."
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ "-o", "--output_log",
|
|
|
+ default="eval_result.log",
|
|
|
+ help="save the eval result to a log file. Default is eval_result.log"
|
|
|
+ )
|
|
|
return parser.parse_args()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
logging.info("Initializing the process and loading configuration...")
|
|
|
args = parse_arguments()
|
|
|
-
|
|
|
context = load_config(args.config_path)
|
|
|
context["model"] = args.model
|
|
|
context["endpoint"] = args.vllm_endpoint
|
|
|
+ context["judge_endpoint"] = args.judge_endpoint
|
|
|
+ context["output_log"] = args.output_log
|
|
|
if context["endpoint"]:
|
|
|
- logging.info(f"Use local vllm service at port: '{args.vllm_endpoint}'.")
|
|
|
+ logging.info(f"Use local vllm service for eval at port: '{args.vllm_endpoint}'.")
|
|
|
+ if context["judge_endpoint"]:
|
|
|
+ logging.info(f"Use local vllm service for judge at port: '{args.judge_endpoint}'.")
|
|
|
asyncio.run(main(context))
|