|
@@ -1,13 +1,12 @@
|
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
|
# This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
|
|
|
-from chat_utils import OctoAIChatService, VllmChatService
|
|
|
import logging
|
|
|
import evaluate
|
|
|
import argparse
|
|
|
from config import load_config
|
|
|
import json
|
|
|
from itertools import chain
|
|
|
-from langchain_community.llms import VLLMOpenAI
|
|
|
+from langchain_openai import ChatOpenAI
|
|
|
|
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings
|
|
|
from langchain_community.vectorstores import FAISS
|
|
@@ -26,17 +25,17 @@ def generate_answers_model_only(model_name,question_list,api_url="http://localho
|
|
|
# Use langchain to load the documents from data directory
|
|
|
# Load the RAFT model
|
|
|
|
|
|
- llm = VLLMOpenAI(
|
|
|
+ llm = ChatOpenAI(
|
|
|
openai_api_key=key,
|
|
|
openai_api_base=api_url,
|
|
|
model_name=model_name,
|
|
|
temperature=0.0,
|
|
|
- max_tokens=100
|
|
|
+ max_tokens=1000
|
|
|
)
|
|
|
- system_prompt = SystemMessage(content=context['eval_prompt_template'])
|
|
|
- generated_answers = []
|
|
|
- all_tasks = [[system_prompt, HumanMessage(content=question)] for question in question_list]
|
|
|
+
|
|
|
+ all_tasks = [api_config['eval_prompt_template'].format(question=question) for question in question_list]
|
|
|
generated_answers = llm.batch(all_tasks)
|
|
|
+ generated_answers = [ item.content for item in generated_answers]
|
|
|
if len(generated_answers) == 0:
|
|
|
logging.error("No model answers generated. Please check the input context or model configuration in ",model_name)
|
|
|
return []
|
|
@@ -48,7 +47,12 @@ def format_docs_raft(docs):
|
|
|
return context
|
|
|
def format_docs(docs):
|
|
|
return "\n\n".join(doc.page_content for doc in docs)
|
|
|
-def generate_answers_with_RAG(model_name, data_dir,question_list,rag_template,api_url="http://localhost:8000/v1",key="EMPTY"):
|
|
|
+def generate_answers_with_RAG(model_name, question_list,api_config,api_url_overwrite=None):
|
|
|
+ data_dir = api_config['data_dir']
|
|
|
+ api_url = "http://localhost:"+str(api_config['vllm_endpoint'])+"/v1"
|
|
|
+ if api_url_overwrite:
|
|
|
+ api_url = api_url_overwrite
|
|
|
+ key = api_config['api_key']
|
|
|
# Use langchain to load the documents from data directory
|
|
|
loader = DirectoryLoader(data_dir)
|
|
|
docs = loader.load()
|
|
@@ -62,12 +66,12 @@ def generate_answers_with_RAG(model_name, data_dir,question_list,rag_template,ap
|
|
|
search_kwargs={"k": 5}
|
|
|
)
|
|
|
# Load the RAFT model
|
|
|
- llm = VLLMOpenAI(
|
|
|
+ llm = ChatOpenAI(
|
|
|
openai_api_key=key,
|
|
|
openai_api_base=api_url,
|
|
|
model_name=model_name,
|
|
|
temperature=0.0,
|
|
|
- max_tokens=100
|
|
|
+ max_tokens=1000
|
|
|
)
|
|
|
all_tasks = []
|
|
|
for q in question_list:
|
|
@@ -79,9 +83,10 @@ def generate_answers_with_RAG(model_name, data_dir,question_list,rag_template,ap
|
|
|
else:
|
|
|
documents = format_docs_raft(retrieved_docs)
|
|
|
# create a prompt
|
|
|
- text = rag_template.format(context=documents,question=q)
|
|
|
+ text = api_config["RAG_prompt_template"].format(context=documents,question=q)
|
|
|
all_tasks.append(text)
|
|
|
generated_answers = llm.batch(all_tasks)
|
|
|
+ generated_answers = [ item.content for item in generated_answers]
|
|
|
if len(generated_answers) == 0:
|
|
|
logging.error("No RAG answers generated. Please check the input context or model configuration in ",model_name)
|
|
|
return []
|
|
@@ -101,6 +106,7 @@ def clean_text_list(text_list):
|
|
|
index = text.rfind("<ANSWER>")
|
|
|
if index!= -1:
|
|
|
text = text[index:]
|
|
|
+ text = text.replace("</ANSWER>:","")
|
|
|
text = text.replace("begin_quote","")
|
|
|
text = text.replace("end_quote","")
|
|
|
text = text.replace("##","")
|
|
@@ -144,27 +150,24 @@ def compute_bert_score(generated : list, reference: list):
|
|
|
precision = score["precision"]
|
|
|
recall = score["recall"]
|
|
|
return sum(precision)/len(precision), sum(recall)/len(recall), sum(f1)/len(f1)
|
|
|
-def compute_judge_score(questions: list, generated : list, reference: list, context,api_url="http://localhost:8001/v1",key="EMPTY"):
|
|
|
+def compute_judge_score(questions: list, generated : list, reference: list, api_config,api_url="http://localhost:8001/v1",key="EMPTY"):
|
|
|
correct_num = 0
|
|
|
model_name = "meta-llama/Meta-Llama-3-70B-Instruct"
|
|
|
- llm = VLLMOpenAI(
|
|
|
+ llm = ChatOpenAI(
|
|
|
openai_api_key=key,
|
|
|
openai_api_base=api_url,
|
|
|
model_name=model_name,
|
|
|
+ max_tokens=1000,
|
|
|
temperature=0.0)
|
|
|
all_tasks = []
|
|
|
- for q,pred,gold in zip(questions, generated,reference):
|
|
|
- messages = [
|
|
|
- HumanMessage(content=f"Question: {q} \n Teacher's Answer: {gold} \n Student's Answer: {pred} "),
|
|
|
- SystemMessage(content=context['judge_prompt_template'])
|
|
|
- ]
|
|
|
- all_tasks.append(messages)
|
|
|
- response = llm.batch(all_tasks)
|
|
|
- for response in response:
|
|
|
- if "YES" in response:
|
|
|
- correct_num += 1
|
|
|
- return correct_num/len(questions)
|
|
|
-def score_single(context,generated,reference,questions, run_exact_match=True,run_rouge=True, run_bert=True, run_llm_as_judge=False):
|
|
|
+ for question,prediction,gold in zip(questions, generated,reference):
|
|
|
+ message = api_config['judge_prompt_template'].format(question=question,prediction=prediction,gold=gold)
|
|
|
+ all_tasks.append(message)
|
|
|
+ judge_responses = llm.batch(all_tasks)
|
|
|
+ judge_responses = ["YES" in item.content.split("<ANSWER>")[-1] for item in judge_responses]
|
|
|
+ correct_num = sum(judge_responses)
|
|
|
+ return correct_num/len(questions),judge_responses
|
|
|
+def score_single(api_config,generated,reference,questions, run_exact_match=True,run_rouge=True, run_bert=True, run_llm_as_judge=True):
|
|
|
# set metric to default -1, means no metric is computed
|
|
|
metric = {
|
|
|
"Rouge_score": -1,
|
|
@@ -184,22 +187,23 @@ def score_single(context,generated,reference,questions, run_exact_match=True,run
|
|
|
metric["BERTScore_Precision"] = P
|
|
|
metric["BERTScore_Recall"] = R
|
|
|
metric["BERTScore_F1"] = F1
|
|
|
- if context["judge_endpoint"] and run_llm_as_judge:
|
|
|
- api_url = "http://localhost:"+str(context["judge_endpoint"])+"/v1"
|
|
|
- LLM_judge_score = compute_judge_score(questions, generated, reference, context,api_url=api_url)
|
|
|
+ if api_config["judge_endpoint"] and run_llm_as_judge:
|
|
|
+ api_url = "http://localhost:"+str(api_config["judge_endpoint"])+"/v1"
|
|
|
+ LLM_judge_score,judge_responses = compute_judge_score(questions, generated, reference, api_config,api_url=api_url)
|
|
|
metric["LLM_judge_score"] = LLM_judge_score
|
|
|
+ metric["LLM_judge_responses"] = judge_responses
|
|
|
print(f"LLM_judge_score: {LLM_judge_score}")
|
|
|
if run_exact_match:
|
|
|
exact_match = exact_match_score(generated,reference)
|
|
|
print(f"Exact_match_percentage: {exact_match:.4f}")
|
|
|
metric["Exact_match"] = exact_match
|
|
|
return metric
|
|
|
-def main(context):
|
|
|
+def main(api_config):
|
|
|
# Since the eval set is small, we can run the eval without async functions
|
|
|
try:
|
|
|
- api_url = "http://localhost:"+str(context["vllm_endpoint"])+"/v1"
|
|
|
+ api_url = "http://localhost:"+str(api_config["vllm_endpoint"])+"/v1"
|
|
|
logging.info("Starting to generate answer given the eval set.")
|
|
|
- with open(context["eval_json"]) as fp:
|
|
|
+ with open(api_config["eval_json"]) as fp:
|
|
|
eval_json = json.load(fp)
|
|
|
questions,groud_truth = [],[]
|
|
|
for index, item in enumerate(eval_json):
|
|
@@ -210,38 +214,73 @@ def main(context):
|
|
|
"RAFT_RAG": [],
|
|
|
"Baseline": [],
|
|
|
"Baseline_RAG": [],
|
|
|
+ "70B_RAG": [],
|
|
|
+ "70B_Base": [],
|
|
|
+
|
|
|
}
|
|
|
# Generate answers for baseline
|
|
|
- base_model_name = context["base_model_name"]
|
|
|
- #generated_answers["Baseline"] = generate_answers_model_only(base_model_name,questions,api_url)
|
|
|
- generated_answers["Baseline_RAG"] = generate_answers_with_RAG(base_model_name, context["data_dir"],questions,context['RAG_prompt_template'],api_url)
|
|
|
+ base_model_name = api_config["base_model_name"]
|
|
|
+ generated_answers["Baseline"] = generate_answers_model_only(base_model_name,questions,api_url)
|
|
|
+ generated_answers["Baseline_RAG"] = generate_answers_with_RAG(base_model_name, questions,api_config)
|
|
|
# Generate answers for RAFT
|
|
|
- raft_model_name = context["raft_model_name"]
|
|
|
- #generated_answers["RAFT"] = generate_answers_model_only(raft_model_name,questions,api_url)
|
|
|
- generated_answers["RAFT_RAG"] = generate_answers_with_RAG(raft_model_name, context["data_dir"],questions,context['RAG_prompt_template'],api_url)
|
|
|
+ raft_model_name = api_config["raft_model_name"]
|
|
|
+ generated_answers["RAFT"] = generate_answers_model_only(raft_model_name,questions,api_url)
|
|
|
+ generated_answers["RAFT_RAG"] = generate_answers_with_RAG(raft_model_name, questions,api_config)
|
|
|
+
|
|
|
+ large_model_name = "meta-llama/Meta-Llama-3-70B-Instruct"
|
|
|
+ large_api_url = "http://localhost:"+str(api_config["judge_endpoint"])+"/v1"
|
|
|
+ generated_answers["70B_Base"] = generate_answers_model_only(large_model_name,questions,large_api_url)
|
|
|
+ generated_answers["70B_RAG"] = generate_answers_with_RAG(large_model_name, questions,api_config,large_api_url,)
|
|
|
logging.info(f"Successfully generated {len(generated_answers['Baseline_RAG'])} answers for all models.")
|
|
|
# for generate answer from each model, compute the score metric
|
|
|
+ all_metrics = []
|
|
|
for model_name,model_answer in generated_answers.items():
|
|
|
if len(model_answer) != len(groud_truth):
|
|
|
print(f"The length of {model_name} answer is not equal to the length of ground truth.")
|
|
|
continue
|
|
|
- metric = score_single(context,model_answer,groud_truth,questions)
|
|
|
+ metric = score_single(api_config,model_answer,groud_truth,questions)
|
|
|
print(f"The eval result for {model_name} is: {metric}")
|
|
|
- with open(context["output_log"],"a") as fp:
|
|
|
+ with open(api_config["output_log"],"a") as fp:
|
|
|
fp.write(f"Eval_result for {model_name} \n")
|
|
|
fp.write(f"Rouge_score: {metric['Rouge_score']} \n")
|
|
|
fp.write(f"BERTScore Precision: {metric['BERTScore_Precision']:.4f}, Recall: {metric['BERTScore_Recall']:.4f}, F1: {metric['BERTScore_F1']:.4f} \n")
|
|
|
fp.write(f"Exact_match_percentage: {metric['Exact_match']} \n")
|
|
|
- if context["judge_endpoint"]:
|
|
|
+ judge_responses = ["None"] * len(questions)
|
|
|
+ if api_config["judge_endpoint"]:
|
|
|
fp.write(f"LLM_judge_score: {metric['LLM_judge_score']} \n")
|
|
|
+ judge_responses = metric["LLM_judge_responses"]
|
|
|
+ all_metrics.append((model_name,metric['LLM_judge_score'],metric["LLM_judge_responses"]))
|
|
|
fp.write(f"QA details: \n")
|
|
|
- for item in zip(questions,model_answer,groud_truth):
|
|
|
+ for item in zip(questions,model_answer,groud_truth,judge_responses):
|
|
|
fp.write(f"question: {item[0]} \n")
|
|
|
fp.write(f"generated_answers: {item[1]} \n")
|
|
|
fp.write(f"groud_truth: {item[2]} \n")
|
|
|
+ fp.write(f"LLM_judge_response: {item[3]} \n")
|
|
|
fp.write("\n")
|
|
|
fp.write("\n------------------------------------\n")
|
|
|
- logging.info(f"Eval successfully, the eval result is saved to {context['output_log']}.")
|
|
|
+ # Now we want to take a closer look at the questions that are not answered the same by all the models.
|
|
|
+ judge_zip = list(zip(*[item[-1] for item in all_metrics]))
|
|
|
+ with open(api_config["output_log"],"a") as fp:
|
|
|
+ for item in all_metrics:
|
|
|
+ fp.write(f"Model_Name: {item[0]}, LLM_SCORE: {item[1]} \n")
|
|
|
+ for idx,item in enumerate(judge_zip):
|
|
|
+ # if all the responses are "YES" or all the responses are "NO", then we skip this question
|
|
|
+ if sum([r=="YES" for r in item]) == len(item) or sum([r=="YES" for r in item]) == 0:
|
|
|
+ continue
|
|
|
+ else:
|
|
|
+ fp.write(f"Comparing interested question: {questions[idx]} \n")
|
|
|
+ fp.write(f"groud_truth: {groud_truth[idx]} \n")
|
|
|
+ fp.write(f"{item[2]} Baseline_answers: {generated_answers['Baseline'][idx]} \n")
|
|
|
+ fp.write(f"{item[3]} Baseline_RAG_answers: {generated_answers['Baseline_RAG'][idx]} \n")
|
|
|
+ fp.write(f"{item[0]} RAFT_answers: {generated_answers['RAFT'][idx]} \n")
|
|
|
+ fp.write(f"{item[1]} RAFT_RAG_answers: {generated_answers['RAFT_RAG'][idx]} \n")
|
|
|
+ fp.write(f"{item[4]} 70B_Base_answers: {generated_answers['70B_Base'][idx]} \n")
|
|
|
+ fp.write(f"{item[5]} 70B_RAG_answers: {generated_answers['70B_RAG'][idx]} \n")
|
|
|
+ fp.write("-------\n")
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ logging.info(f"Eval successfully, the eval result is saved to {api_config['output_log']}.")
|
|
|
# Saving the eval result to a log file
|
|
|
except Exception as e:
|
|
|
logging.error(f"An unexpected error occurred during the process: {e}",exc_info=True)
|
|
@@ -283,20 +322,26 @@ def parse_arguments():
|
|
|
default="eval_result.log",
|
|
|
help="save the eval result to a log file. Default is eval_result.log"
|
|
|
)
|
|
|
-
|
|
|
+ parser.add_argument(
|
|
|
+ "-k", "--api_key",
|
|
|
+ default="EMPTY",
|
|
|
+ type=str,
|
|
|
+ help="LLM API key for generating question/answer pairs."
|
|
|
+ )
|
|
|
return parser.parse_args()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
logging.info("Initializing the process and loading configuration...")
|
|
|
args = parse_arguments()
|
|
|
- context = load_config(args.config_path)
|
|
|
- context["vllm_endpoint"] = args.vllm_endpoint
|
|
|
+ api_config = load_config(args.config_path)
|
|
|
+ api_config["vllm_endpoint"] = args.vllm_endpoint
|
|
|
if args.data_dir:
|
|
|
- context["data_dir"] = args.data_dir
|
|
|
+ api_config["data_dir"] = args.data_dir
|
|
|
if args.raft_model_name:
|
|
|
- context["raft_model_name"] = args.raft_model_name
|
|
|
- context["judge_endpoint"] = args.judge_endpoint
|
|
|
- context["output_log"] = args.output_log
|
|
|
- if context["judge_endpoint"]:
|
|
|
+ api_config["raft_model_name"] = args.raft_model_name
|
|
|
+ api_config["judge_endpoint"] = args.judge_endpoint
|
|
|
+ api_config["output_log"] = args.output_log
|
|
|
+ api_config["api_key"] = args.api_key
|
|
|
+ if api_config["judge_endpoint"]:
|
|
|
logging.info(f"Use local vllm service for judge at port: '{args.judge_endpoint}'.")
|
|
|
- main(context)
|
|
|
+ main(api_config)
|