|
@@ -16,7 +16,6 @@ import re
|
|
import string
|
|
import string
|
|
import pandas as pd
|
|
import pandas as pd
|
|
from langchain.retrievers.document_compressors import FlashrankRerank
|
|
from langchain.retrievers.document_compressors import FlashrankRerank
|
|
-from transformers import AutoTokenizer
|
|
|
|
|
|
|
|
|
|
|
|
def generate_answers_model_only(model_name,question_list,api_url="http://localhost:8000/v1",key="EMPTY"):
|
|
def generate_answers_model_only(model_name,question_list,api_url="http://localhost:8000/v1",key="EMPTY"):
|
|
@@ -48,15 +47,7 @@ def build_retriever(api_config,embedding_model_name,retrieved_docs_num=5):
|
|
loader = DirectoryLoader(api_config['data_dir'])
|
|
loader = DirectoryLoader(api_config['data_dir'])
|
|
docs = loader.load()
|
|
docs = loader.load()
|
|
# Split the document into chunks with a specified chunk size
|
|
# Split the document into chunks with a specified chunk size
|
|
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=api_config["chunk_size"],chunk_overlap=int(api_config["chunk_size"] / 10),add_start_index=True,strip_whitespace=True)
|
|
|
|
- # text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
|
|
|
|
- # AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B"),
|
|
|
|
- # chunk_size=api_config["chunk_size"],
|
|
|
|
- # chunk_overlap=int(api_config["chunk_size"] / 10),
|
|
|
|
- # add_start_index=True,
|
|
|
|
- # strip_whitespace=True,
|
|
|
|
- # separators=["\n\n", "\n", ".", " ", ""],
|
|
|
|
- # )
|
|
|
|
|
|
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=api_config["chunk_size"],chunk_overlap=int(api_config["chunk_size"] / 10),separators= ["----------","\n\n", "\n", " ", ""],strip_whitespace=True)
|
|
docs_processed = text_splitter.split_documents(docs)
|
|
docs_processed = text_splitter.split_documents(docs)
|
|
# Remove duplicates
|
|
# Remove duplicates
|
|
unique_texts = {}
|
|
unique_texts = {}
|
|
@@ -65,7 +56,7 @@ def build_retriever(api_config,embedding_model_name,retrieved_docs_num=5):
|
|
if doc.page_content not in unique_texts:
|
|
if doc.page_content not in unique_texts:
|
|
unique_texts[doc.page_content] = True
|
|
unique_texts[doc.page_content] = True
|
|
docs_processed_unique.append(doc)
|
|
docs_processed_unique.append(doc)
|
|
-
|
|
|
|
|
|
+ logging.info(f"Total number of docs_processed used by vectorstore: {len(docs_processed_unique)}")
|
|
# Store the document into a vector store with a specific embedding model
|
|
# Store the document into a vector store with a specific embedding model
|
|
embedding_model = HuggingFaceEmbeddings(
|
|
embedding_model = HuggingFaceEmbeddings(
|
|
model_name=embedding_model_name,
|
|
model_name=embedding_model_name,
|
|
@@ -78,7 +69,7 @@ def build_retriever(api_config,embedding_model_name,retrieved_docs_num=5):
|
|
)
|
|
)
|
|
return retriever
|
|
return retriever
|
|
def generate_answers_with_RAG(model_name, question_list,api_config,retriever,api_url_overwrite=None):
|
|
def generate_answers_with_RAG(model_name, question_list,api_config,retriever,api_url_overwrite=None):
|
|
- api_url = "http://localhost:"+str(api_config['vllm_endpoint'])+"/v1"
|
|
|
|
|
|
+ api_url = api_config['model_endpoint_url']
|
|
if api_url_overwrite:
|
|
if api_url_overwrite:
|
|
api_url = api_url_overwrite
|
|
api_url = api_url_overwrite
|
|
key = api_config['api_key']
|
|
key = api_config['api_key']
|
|
@@ -206,8 +197,8 @@ def score_single(api_config,generated,reference,questions, run_exact_match=True,
|
|
metric["BERTScore_Precision"] = P
|
|
metric["BERTScore_Precision"] = P
|
|
metric["BERTScore_Recall"] = R
|
|
metric["BERTScore_Recall"] = R
|
|
metric["BERTScore_F1"] = F1
|
|
metric["BERTScore_F1"] = F1
|
|
- if api_config["judge_endpoint"] and run_llm_as_judge:
|
|
|
|
- api_url = "http://localhost:"+str(api_config["judge_endpoint"])+"/v1"
|
|
|
|
|
|
+ if api_config["judge_endpoint_url"] and run_llm_as_judge:
|
|
|
|
+ api_url = api_config["judge_endpoint_url"]
|
|
LLM_judge_score,judge_responses = compute_judge_score(questions, generated, reference, api_config,api_url=api_url)
|
|
LLM_judge_score,judge_responses = compute_judge_score(questions, generated, reference, api_config,api_url=api_url)
|
|
metric["LLM_judge_score"] = LLM_judge_score
|
|
metric["LLM_judge_score"] = LLM_judge_score
|
|
metric["LLM_judge_responses"] = judge_responses
|
|
metric["LLM_judge_responses"] = judge_responses
|
|
@@ -220,7 +211,7 @@ def score_single(api_config,generated,reference,questions, run_exact_match=True,
|
|
def main(api_config):
|
|
def main(api_config):
|
|
# Since the eval set is small, we can run the eval without async functions
|
|
# Since the eval set is small, we can run the eval without async functions
|
|
try:
|
|
try:
|
|
- api_url = "http://localhost:"+str(api_config["vllm_endpoint"])+"/v1"
|
|
|
|
|
|
+ api_url = api_config["model_endpoint_url"]
|
|
logging.info("Starting to generate answer given the eval set.")
|
|
logging.info("Starting to generate answer given the eval set.")
|
|
questions,groud_truth = [],[]
|
|
questions,groud_truth = [],[]
|
|
if api_config["eval_file"].endswith(".parquet"):
|
|
if api_config["eval_file"].endswith(".parquet"):
|
|
@@ -234,15 +225,7 @@ def main(api_config):
|
|
for index, item in enumerate(eval_file):
|
|
for index, item in enumerate(eval_file):
|
|
questions.append(item["question"])
|
|
questions.append(item["question"])
|
|
groud_truth.append(item["answer"])
|
|
groud_truth.append(item["answer"])
|
|
- generated_answers = {
|
|
|
|
- "RAFT": [],
|
|
|
|
- "RAFT_RAG": [],
|
|
|
|
- "Baseline": [],
|
|
|
|
- "Baseline_RAG": [],
|
|
|
|
- "70B_RAG": [],
|
|
|
|
- "70B_Base": [],
|
|
|
|
-
|
|
|
|
- }
|
|
|
|
|
|
+ generated_answers = {}
|
|
# build retriver
|
|
# build retriver
|
|
retriever = build_retriever(api_config,"sentence-transformers/multi-qa-mpnet-base-cos-v1",api_config["rag_topk"])
|
|
retriever = build_retriever(api_config,"sentence-transformers/multi-qa-mpnet-base-cos-v1",api_config["rag_topk"])
|
|
# Generate answers for 8B models
|
|
# Generate answers for 8B models
|
|
@@ -251,11 +234,11 @@ def main(api_config):
|
|
generated_answers[model_name+"_RAG"] = generate_answers_with_RAG(model_name, questions,api_config,retriever)
|
|
generated_answers[model_name+"_RAG"] = generate_answers_with_RAG(model_name, questions,api_config,retriever)
|
|
print("Finished generating answers for ", model_name)
|
|
print("Finished generating answers for ", model_name)
|
|
large_model_name = "meta-llama/Meta-Llama-3-70B-Instruct"
|
|
large_model_name = "meta-llama/Meta-Llama-3-70B-Instruct"
|
|
- large_api_url = "http://localhost:"+str(api_config["judge_endpoint"])+"/v1"
|
|
|
|
- generated_answers["70B_Base"] = generate_answers_model_only(large_model_name,questions,large_api_url)
|
|
|
|
- generated_answers["70B_RAG"] = generate_answers_with_RAG(large_model_name, questions,api_config,retriever,large_api_url)
|
|
|
|
|
|
+ large_api_url = api_config["judge_endpoint_url"]
|
|
|
|
+ #generated_answers["70B_Base"] = generate_answers_model_only(large_model_name,questions,large_api_url)
|
|
|
|
+ #generated_answers["70B_RAG"] = generate_answers_with_RAG(large_model_name, questions,api_config,retriever,large_api_url)
|
|
print("Finished generating answers for ", large_model_name)
|
|
print("Finished generating answers for ", large_model_name)
|
|
- logging.info(f"Successfully generated {len(generated_answers[model_name])} answers for all models.")
|
|
|
|
|
|
+ logging.info(f"Successfully generated {len(generated_answers[model_name+'_RAG'])} answers for all models.")
|
|
# for generate answer from each model, compute the score metric
|
|
# for generate answer from each model, compute the score metric
|
|
all_metrics = []
|
|
all_metrics = []
|
|
output_file = api_config["output_log"]+str(datetime.now().strftime("%Y%m%d_%H%M%S"))
|
|
output_file = api_config["output_log"]+str(datetime.now().strftime("%Y%m%d_%H%M%S"))
|
|
@@ -272,7 +255,7 @@ def main(api_config):
|
|
fp.write(f"BERTScore Precision: {metric['BERTScore_Precision']:.4f}, Recall: {metric['BERTScore_Recall']:.4f}, F1: {metric['BERTScore_F1']:.4f} \n")
|
|
fp.write(f"BERTScore Precision: {metric['BERTScore_Precision']:.4f}, Recall: {metric['BERTScore_Recall']:.4f}, F1: {metric['BERTScore_F1']:.4f} \n")
|
|
fp.write(f"Exact_match_percentage: {metric['Exact_match']} \n")
|
|
fp.write(f"Exact_match_percentage: {metric['Exact_match']} \n")
|
|
judge_responses = ["None"] * len(questions)
|
|
judge_responses = ["None"] * len(questions)
|
|
- if api_config["judge_endpoint"]:
|
|
|
|
|
|
+ if api_config["judge_endpoint_url"]:
|
|
fp.write(f"LLM_judge_score: {metric['LLM_judge_score']} \n")
|
|
fp.write(f"LLM_judge_score: {metric['LLM_judge_score']} \n")
|
|
judge_responses = metric["LLM_judge_responses"]
|
|
judge_responses = metric["LLM_judge_responses"]
|
|
all_metrics.append((model_name,metric['LLM_judge_score'],metric["LLM_judge_responses"]))
|
|
all_metrics.append((model_name,metric['LLM_judge_score'],metric["LLM_judge_responses"]))
|
|
@@ -330,16 +313,16 @@ def parse_arguments():
|
|
help="Provide the data folder path to build RAG for evaluation. If not specified, the data_dir in eval_config.yaml will be used."
|
|
help="Provide the data folder path to build RAG for evaluation. If not specified, the data_dir in eval_config.yaml will be used."
|
|
)
|
|
)
|
|
parser.add_argument(
|
|
parser.add_argument(
|
|
- "-v", "--vllm_endpoint",
|
|
|
|
- default=8000,
|
|
|
|
- type=int,
|
|
|
|
- help="If a port is specified, then use local vllm endpoint for eval."
|
|
|
|
|
|
+ "-u", "--model_endpoint_url",
|
|
|
|
+ default="http://localhost:8000/v1",
|
|
|
|
+ type=str,
|
|
|
|
+ help="The raft model endpoint url for eval."
|
|
)
|
|
)
|
|
parser.add_argument(
|
|
parser.add_argument(
|
|
- "-j", "--judge_endpoint",
|
|
|
|
|
|
+ "-j", "--judge_endpoint_url",
|
|
default=None,
|
|
default=None,
|
|
- type=int,
|
|
|
|
- help="If a port is specified, then use local vllm endpoint as judge LLM."
|
|
|
|
|
|
+ type=str,
|
|
|
|
+ help="The large model endpoint url for judge as LLM."
|
|
)
|
|
)
|
|
parser.add_argument(
|
|
parser.add_argument(
|
|
"-o", "--output_log",
|
|
"-o", "--output_log",
|
|
@@ -371,12 +354,12 @@ if __name__ == "__main__":
|
|
logging.info("Initializing the process and loading configuration...")
|
|
logging.info("Initializing the process and loading configuration...")
|
|
args = parse_arguments()
|
|
args = parse_arguments()
|
|
api_config = load_config(args.config_path)
|
|
api_config = load_config(args.config_path)
|
|
- api_config["vllm_endpoint"] = args.vllm_endpoint
|
|
|
|
|
|
+ api_config["model_endpoint_url"] = args.model_endpoint_url
|
|
if args.data_dir:
|
|
if args.data_dir:
|
|
api_config["data_dir"] = args.data_dir
|
|
api_config["data_dir"] = args.data_dir
|
|
- if args.raft_model_name:
|
|
|
|
|
|
+ if args.model_name:
|
|
api_config["model_name"] = args.model_name
|
|
api_config["model_name"] = args.model_name
|
|
- api_config["judge_endpoint"] = args.judge_endpoint
|
|
|
|
|
|
+ api_config["judge_endpoint_url"] = args.judge_endpoint_url
|
|
api_config["output_log"] = args.output_log
|
|
api_config["output_log"] = args.output_log
|
|
api_config["api_key"] = args.api_key
|
|
api_config["api_key"] = args.api_key
|
|
api_config["chunk_size"] = args.chunk_size
|
|
api_config["chunk_size"] = args.chunk_size
|
|
@@ -384,6 +367,6 @@ if __name__ == "__main__":
|
|
api_config["rerank_topk"] = args.rerank_topk
|
|
api_config["rerank_topk"] = args.rerank_topk
|
|
if api_config["rag_topk"] < api_config["rerank_topk"]:
|
|
if api_config["rag_topk"] < api_config["rerank_topk"]:
|
|
logging.error("The rerank_topk should be smaller than rag_topk.")
|
|
logging.error("The rerank_topk should be smaller than rag_topk.")
|
|
- if api_config["judge_endpoint"]:
|
|
|
|
- logging.info(f"Use local vllm service for judge at port: '{args.judge_endpoint}'.")
|
|
|
|
|
|
+ if api_config["judge_endpoint_url"]:
|
|
|
|
+ logging.info(f"The judge model url is: '{args.judge_endpoint_url}'.")
|
|
main(api_config)
|
|
main(api_config)
|