raft_eval.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
  3. import logging
  4. import evaluate
  5. import argparse
  6. from config import load_config
  7. import json
  8. from langchain_openai import ChatOpenAI
  9. from langchain_community.embeddings import HuggingFaceEmbeddings
  10. from langchain_community.vectorstores import FAISS
  11. from langchain.text_splitter import RecursiveCharacterTextSplitter
  12. from langchain_community.vectorstores.utils import DistanceStrategy
  13. from datetime import datetime
  14. from langchain_community.document_loaders import DirectoryLoader
  15. import re
  16. import string
  17. import pandas as pd
  18. from langchain.retrievers.document_compressors import FlashrankRerank
  19. from transformers import AutoTokenizer
  20. def generate_answers_model_only(model_name,question_list,api_url="http://localhost:8000/v1",key="EMPTY"):
  21. # Use langchain to load the documents from data directory
  22. # Load the RAFT model
  23. llm = ChatOpenAI(
  24. openai_api_key=key,
  25. openai_api_base=api_url,
  26. model_name=model_name,
  27. temperature=0.0,
  28. max_tokens=1000
  29. )
  30. all_tasks = [api_config['eval_prompt_template'].format(question=question) for question in question_list]
  31. generated_answers = llm.batch(all_tasks)
  32. generated_answers = [ item.content for item in generated_answers]
  33. if len(generated_answers) == 0:
  34. logging.error("No model answers generated. Please check the input context or model configuration in ",model_name)
  35. return []
  36. return clean_text_list(generated_answers)
  37. def format_docs_raft(docs):
  38. context = ""
  39. for doc in docs:
  40. context += "\n<DOCUMENT>" + str(doc.page_content) + "</DOCUMENT>\n"
  41. return context
  42. def build_retriever(api_config,embedding_model_name,retrieved_docs_num=5):
  43. # Use langchain to load the documents from data directory
  44. loader = DirectoryLoader(api_config['data_dir'])
  45. docs = loader.load()
  46. # Split the document into chunks with a specified chunk size
  47. text_splitter = RecursiveCharacterTextSplitter(chunk_size=api_config["chunk_size"],chunk_overlap=int(api_config["chunk_size"] / 10),add_start_index=True,strip_whitespace=True)
  48. # text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
  49. # AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B"),
  50. # chunk_size=api_config["chunk_size"],
  51. # chunk_overlap=int(api_config["chunk_size"] / 10),
  52. # add_start_index=True,
  53. # strip_whitespace=True,
  54. # separators=["\n\n", "\n", ".", " ", ""],
  55. # )
  56. docs_processed = text_splitter.split_documents(docs)
  57. # Remove duplicates
  58. unique_texts = {}
  59. docs_processed_unique = []
  60. for doc in docs_processed:
  61. if doc.page_content not in unique_texts:
  62. unique_texts[doc.page_content] = True
  63. docs_processed_unique.append(doc)
  64. # Store the document into a vector store with a specific embedding model
  65. embedding_model = HuggingFaceEmbeddings(
  66. model_name=embedding_model_name,
  67. model_kwargs={"device": "cuda"},
  68. encode_kwargs={"normalize_embeddings": True}, # Set `True` for cosine similarity
  69. )
  70. vectorstore = FAISS.from_documents(docs_processed_unique, embedding_model, distance_strategy=DistanceStrategy.COSINE)
  71. retriever = vectorstore.as_retriever(
  72. search_kwargs={"k": retrieved_docs_num},
  73. )
  74. return retriever
  75. def generate_answers_with_RAG(model_name, question_list,api_config,retriever,api_url_overwrite=None):
  76. api_url = "http://localhost:"+str(api_config['vllm_endpoint'])+"/v1"
  77. if api_url_overwrite:
  78. api_url = api_url_overwrite
  79. key = api_config['api_key']
  80. rerank_topk = api_config["rerank_topk"]
  81. # Load the RAFT model
  82. llm = ChatOpenAI(
  83. openai_api_key=key,
  84. openai_api_base=api_url,
  85. model_name=model_name,
  86. temperature=0.0,
  87. max_tokens=1000
  88. )
  89. all_tasks = []
  90. for q in question_list:
  91. # retrive the top K documents
  92. retrieved_docs = retriever.invoke(q)
  93. if rerank_topk:
  94. ranker = FlashrankRerank(top_n=rerank_topk)
  95. documents = ranker.compress_documents(retrieved_docs,q)
  96. # format the documents into a string
  97. documents = format_docs_raft(retrieved_docs)
  98. # create a prompt
  99. text = api_config["RAG_prompt_template"].format(context=documents,question=q)
  100. all_tasks.append(text)
  101. generated_answers = llm.batch(all_tasks)
  102. generated_answers = [ item.content for item in generated_answers]
  103. if len(generated_answers) == 0:
  104. logging.error("No RAG answers generated. Please check the input context or model configuration in ",model_name)
  105. return []
  106. return clean_text_list(generated_answers)
  107. def compute_rouge_score(generated : list, reference: list):
  108. rouge_score = evaluate.load('rouge')
  109. return rouge_score.compute(
  110. predictions=generated,
  111. references=reference,
  112. use_stemmer=True,
  113. use_aggregator=True
  114. )
  115. def clean_text_list(text_list):
  116. result = []
  117. for text in text_list:
  118. # for raft model, the answer will started with <ANSWER>
  119. index = text.rfind("<ANSWER>")
  120. if index!= -1:
  121. text = text[index:]
  122. text = text.replace("</ANSWER>:","")
  123. text = text.replace("begin_quote","")
  124. text = text.replace("end_quote","")
  125. text = text.replace("##","")
  126. text = text.strip()
  127. result.append(text)
  128. return result
  129. def normalize_answer(s):
  130. def remove_articles(text):
  131. return re.sub(r'\b(a|an|the)\b', ' ', text)
  132. def white_space_fix(text):
  133. return ' '.join(text.split())
  134. def remove_punc(text):
  135. exclude = set(string.punctuation)
  136. return ''.join(ch for ch in text if ch not in exclude)
  137. def lower(text):
  138. return text.lower()
  139. return white_space_fix(remove_articles(remove_punc(lower(s))))
  140. def exact_match_score(prediction, ground_truth):
  141. """Computes EM score for a single prediction and ground truth answer."""
  142. num_match = 0
  143. assert len(prediction) == len(ground_truth), "Answer length does not match prediction length."
  144. assert(len(ground_truth) > 0)
  145. for idx, (pred,gold) in enumerate(zip(prediction, ground_truth)):
  146. if (normalize_answer(pred) == normalize_answer(gold)):
  147. num_match += 1
  148. return num_match/len(ground_truth)
  149. def compute_bert_score(generated : list, reference: list):
  150. bertscore = evaluate.load("bertscore")
  151. score = bertscore.compute(
  152. predictions=generated,
  153. references=reference,
  154. lang="en"
  155. )
  156. f1 = score["f1"]
  157. precision = score["precision"]
  158. recall = score["recall"]
  159. return sum(precision)/len(precision), sum(recall)/len(recall), sum(f1)/len(f1)
  160. def compute_judge_score(questions: list, generated : list, reference: list, api_config,api_url="http://localhost:8001/v1",key="EMPTY"):
  161. correct_num = 0
  162. model_name = "meta-llama/Meta-Llama-3-70B-Instruct"
  163. llm = ChatOpenAI(
  164. openai_api_key=key,
  165. openai_api_base=api_url,
  166. model_name=model_name,
  167. max_tokens=1000,
  168. temperature=0.0)
  169. all_tasks = []
  170. for question,prediction,gold in zip(questions, generated,reference):
  171. message = api_config['judge_prompt_template'].format(question=question,prediction=prediction,gold=gold)
  172. all_tasks.append(message)
  173. judge_responses = llm.batch(all_tasks)
  174. judge_responses = ["YES" in item.content for item in judge_responses]
  175. correct_num = sum(judge_responses)
  176. return correct_num/len(questions),judge_responses
  177. def score_single(api_config,generated,reference,questions, run_exact_match=True,run_rouge=True, run_bert=False, run_llm_as_judge=True):
  178. # set metric to default -1, means no metric is computed
  179. metric = {
  180. "Rouge_score": -1,
  181. "BERTScore_Precision": -1,
  182. "BERTScore_Recall": -1,
  183. "BERTScore_F1": -1,
  184. "LLM_judge_score": -1,
  185. "Exact_match": -1
  186. }
  187. if run_rouge:
  188. rouge_score = compute_rouge_score(generated,reference)
  189. metric["Rouge_score"] = rouge_score
  190. print("Rouge_score:",rouge_score)
  191. if run_bert:
  192. P, R, F1 = compute_bert_score(generated,reference)
  193. print(f"BERTScore Precision: {P:.4f}, Recall: {R:.4f}, F1: {F1:.4f}")
  194. metric["BERTScore_Precision"] = P
  195. metric["BERTScore_Recall"] = R
  196. metric["BERTScore_F1"] = F1
  197. if api_config["judge_endpoint"] and run_llm_as_judge:
  198. api_url = "http://localhost:"+str(api_config["judge_endpoint"])+"/v1"
  199. LLM_judge_score,judge_responses = compute_judge_score(questions, generated, reference, api_config,api_url=api_url)
  200. metric["LLM_judge_score"] = LLM_judge_score
  201. metric["LLM_judge_responses"] = judge_responses
  202. print(f"LLM_judge_score: {LLM_judge_score}")
  203. if run_exact_match:
  204. exact_match = exact_match_score(generated,reference)
  205. print(f"Exact_match_percentage: {exact_match:.4f}")
  206. metric["Exact_match"] = exact_match
  207. return metric
  208. def main(api_config):
  209. # Since the eval set is small, we can run the eval without async functions
  210. try:
  211. api_url = "http://localhost:"+str(api_config["vllm_endpoint"])+"/v1"
  212. logging.info("Starting to generate answer given the eval set.")
  213. questions,groud_truth = [],[]
  214. if api_config["eval_file"].endswith(".parquet"):
  215. eval_file = pd.read_parquet(api_config["eval_file"],filters=[('source', '=', 'pt_discuss_forum')])
  216. for index, item in eval_file.iterrows():
  217. questions.append(item["question"]+"\nDetails:\n"+item["context"])
  218. groud_truth.append(item["answer"])
  219. else:
  220. with open(api_config["eval_file"]) as fp:
  221. eval_file = json.load(fp)
  222. for index, item in enumerate(eval_file):
  223. questions.append(item["question"])
  224. groud_truth.append(item["answer"])
  225. generated_answers = {
  226. "RAFT": [],
  227. "RAFT_RAG": [],
  228. "Baseline": [],
  229. "Baseline_RAG": [],
  230. "70B_RAG": [],
  231. "70B_Base": [],
  232. }
  233. # build retriver
  234. retriever = build_retriever(api_config,"sentence-transformers/multi-qa-mpnet-base-cos-v1",api_config["rag_topk"])
  235. # Generate answers for 8B models
  236. model_name = api_config["model_name"]
  237. generated_answers[model_name] = generate_answers_model_only(model_name,questions,api_url)
  238. generated_answers[model_name+"_RAG"] = generate_answers_with_RAG(model_name, questions,api_config,retriever)
  239. print("Finished generating answers for ", model_name)
  240. large_model_name = "meta-llama/Meta-Llama-3-70B-Instruct"
  241. large_api_url = "http://localhost:"+str(api_config["judge_endpoint"])+"/v1"
  242. generated_answers["70B_Base"] = generate_answers_model_only(large_model_name,questions,large_api_url)
  243. generated_answers["70B_RAG"] = generate_answers_with_RAG(large_model_name, questions,api_config,retriever,large_api_url)
  244. print("Finished generating answers for ", large_model_name)
  245. logging.info(f"Successfully generated {len(generated_answers[model_name])} answers for all models.")
  246. # for generate answer from each model, compute the score metric
  247. all_metrics = []
  248. output_file = api_config["output_log"]+str(datetime.now().strftime("%Y%m%d_%H%M%S"))
  249. for model_name,model_answer in generated_answers.items():
  250. if len(model_answer) != len(groud_truth):
  251. print(f"The length of {model_name} answer is not equal to the length of ground truth.")
  252. continue
  253. metric = score_single(api_config,model_answer,groud_truth,questions)
  254. print(f"The eval result for {model_name} is: {metric}")
  255. with open(output_file,"a") as fp:
  256. fp.write(f"Eval_result for {model_name} \n")
  257. fp.write(f"Rouge_score: {metric['Rouge_score']} \n")
  258. fp.write(f"BERTScore Precision: {metric['BERTScore_Precision']:.4f}, Recall: {metric['BERTScore_Recall']:.4f}, F1: {metric['BERTScore_F1']:.4f} \n")
  259. fp.write(f"Exact_match_percentage: {metric['Exact_match']} \n")
  260. judge_responses = ["None"] * len(questions)
  261. if api_config["judge_endpoint"]:
  262. fp.write(f"LLM_judge_score: {metric['LLM_judge_score']} \n")
  263. judge_responses = metric["LLM_judge_responses"]
  264. all_metrics.append((model_name,metric['LLM_judge_score'],metric["LLM_judge_responses"]))
  265. fp.write(f"QA details: \n")
  266. for item in zip(questions,model_answer,groud_truth,judge_responses):
  267. fp.write(f"question: {item[0]} \n")
  268. fp.write(f"generated_answers: {item[1]} \n")
  269. fp.write(f"groud_truth: {item[2]} \n")
  270. fp.write(f"LLM_judge_response: {item[3]} \n")
  271. fp.write("\n")
  272. fp.write("\n------------------------------------\n")
  273. # Now we want to take a closer look at the questions that are not answered the same by all the models.
  274. judge_zip = list(zip(*[item[-1] for item in all_metrics]))
  275. model_names = [item[0] for item in all_metrics]
  276. with open(output_file,"a") as fp:
  277. for item in all_metrics:
  278. fp.write(f"Model_Name: {item[0]}, LLM_SCORE: {item[1]} \n")
  279. for idx,item in enumerate(judge_zip):
  280. # if all the responses are "YES", then we skip this question
  281. if sum(item) == len(item):
  282. continue
  283. else:
  284. fp.write(f"Comparing interested question: {questions[idx]} \n")
  285. fp.write(f"groud_truth: {groud_truth[idx]} \n")
  286. for i in range(len(model_names)):
  287. fp.write(f"{item[i]} {model_names[i]}_answers: {generated_answers[model_names[i]][idx]} \n")
  288. fp.write("------------------------\n")
  289. fp.write(json.dumps(all_metrics))
  290. print("Finished evaluating the model.")
  291. logging.info(f"Eval successfully, the eval result is saved to {api_config['output_log']}.")
  292. # Saving the eval result to a log file
  293. except Exception as e:
  294. logging.error(f"An unexpected error occurred during the process: {e}",exc_info=True)
  295. def parse_arguments():
  296. # Define command line arguments for the script
  297. parser = argparse.ArgumentParser(
  298. description="Generate question/answer pairs from documentation."
  299. )
  300. parser.add_argument(
  301. "-m", "--model_name",
  302. default=None,
  303. help="Provide the model_name to use for evaluation. If not specified, the model_path in eval_config.yaml will be used."
  304. )
  305. parser.add_argument(
  306. "-c", "--config_path",
  307. default="raft_eval_config.yaml",
  308. help="Set the configuration file path that has system prompt along with language, evalset path."
  309. )
  310. parser.add_argument(
  311. "-d", "--data_dir",
  312. default=None,
  313. help="Provide the data folder path to build RAG for evaluation. If not specified, the data_dir in eval_config.yaml will be used."
  314. )
  315. parser.add_argument(
  316. "-v", "--vllm_endpoint",
  317. default=8000,
  318. type=int,
  319. help="If a port is specified, then use local vllm endpoint for eval."
  320. )
  321. parser.add_argument(
  322. "-j", "--judge_endpoint",
  323. default=None,
  324. type=int,
  325. help="If a port is specified, then use local vllm endpoint as judge LLM."
  326. )
  327. parser.add_argument(
  328. "-o", "--output_log",
  329. default="./eval_result",
  330. help="save the eval result to a log file. Default is eval_result[timestamp].log"
  331. )
  332. parser.add_argument(
  333. "-k", "--api_key",
  334. default="EMPTY",
  335. type=str,
  336. help="LLM API key for generating question/answer pairs."
  337. )
  338. parser.add_argument(
  339. "-r", "--rag_topk",
  340. default=5,
  341. type=int,
  342. help="set the number of top k documents the RAG needs to retrive."
  343. )
  344. parser.add_argument(
  345. "--rerank_topk",
  346. default=0,
  347. type=int,
  348. help="set the number of top k documents the reranker needs to retrive."
  349. )
  350. parser.add_argument("--chunk_size", type=int, default=1000, help="The character size of each chunk used in RAG")
  351. return parser.parse_args()
  352. if __name__ == "__main__":
  353. logging.info("Initializing the process and loading configuration...")
  354. args = parse_arguments()
  355. api_config = load_config(args.config_path)
  356. api_config["vllm_endpoint"] = args.vllm_endpoint
  357. if args.data_dir:
  358. api_config["data_dir"] = args.data_dir
  359. if args.raft_model_name:
  360. api_config["model_name"] = args.model_name
  361. api_config["judge_endpoint"] = args.judge_endpoint
  362. api_config["output_log"] = args.output_log
  363. api_config["api_key"] = args.api_key
  364. api_config["chunk_size"] = args.chunk_size
  365. api_config["rag_topk"] = args.rag_topk
  366. api_config["rerank_topk"] = args.rerank_topk
  367. if api_config["rag_topk"] < api_config["rerank_topk"]:
  368. logging.error("The rerank_topk should be smaller than rag_topk.")
  369. if api_config["judge_endpoint"]:
  370. logging.info(f"Use local vllm service for judge at port: '{args.judge_endpoint}'.")
  371. main(api_config)