Browse Source

adding LLM_as_judge feature

Kai Wu 1 year ago
parent
commit
9092139aca

+ 1 - 4
recipes/use_cases/end2end-recipes/chatbot/README.md

@@ -185,7 +185,7 @@ Model tends to ignore providing the bigger picture in the questions, for example
 
 #### Data Insights
 
-We generated a dataset of almost 800 Q&A pairs from some of the open source documents about Llama models, including getting started guide from Llama website, its FAQ, Llama 3, Purple Llama, Code Llama papers and Llama-Recipes documentations.
+We generated a dataset of almost 3600 Q&A pairs from some of the open source documents about Llama models, including getting started guide from Llama website, its FAQ, Llama 3, Purple Llama, Code Llama papers and Llama-Recipes documentations.
 
 We have run some fine-tuning experiments with single GPU using quantization with different LORA configs (all linear layer versus query and key projections only) and different number of epochs. Although train and eval loss shows decrease specially with using all linear layers in LORA configs and training with 6 epochs, still the result is far from acceptable in real tests.
 
@@ -205,6 +205,3 @@ Below are some examples of real test on the fine-tuned model with very poor resu
   <img src=./poor-test-1.png alt="Poor Test Results example 1" width="48%" style="margin-right: 2%;"/>
   <img src=./poor-test-2.png alt="Poor Test Results example 1" width="48%"/>
 </p>
-
-
-Next, we are looking into augmenting our datasets. One way to do so, is to use our Llama 70B model to read our question answer pairs and come up with two paraphrase versions of each pair to augment our data.

File diff suppressed because it is too large
+ 26 - 6
recipes/use_cases/end2end-recipes/chatbot/pipelines/README.md


+ 52 - 7
recipes/use_cases/end2end-recipes/chatbot/pipelines/eval_chatbot.py

@@ -8,7 +8,7 @@ from config import load_config
 import asyncio
 import json
 from itertools import chain
-from generator_utils import parse_qa_to_json
+from generator_utils import parse_qa_to_json, generate_LLM_eval
 
 def compute_rouge_score(generated : str, reference: str):
     rouge_score = evaluate.load('rouge')
@@ -20,11 +20,15 @@ def compute_rouge_score(generated : str, reference: str):
     )
 def compute_bert_score(generated : str, reference: str):
     bertscore = evaluate.load("bertscore")
-    return bertscore.compute(
+    score = bertscore.compute(
         predictions=generated,
         references=reference,
         lang="en"
     )
+    f1 = score["f1"]
+    precision = score["precision"]
+    recall = score["recall"]
+    return sum(precision)/len(precision), sum(recall)/len(recall), sum(f1)/len(f1)
 # This function is used to eval the fine-tuned model, given the question, generate the answer.
 async def eval_request(chat_service, api_context: dict, question: str) -> dict:
     prompt_for_system = api_context['eval_prompt_template'].format(language=api_context["language"])
@@ -75,11 +79,38 @@ async def main(context):
             logging.warning("No answers generated. Please check the input context or model configuration.")
             return
         logging.info(f"Successfully generated {len(generated_answers)} answers.")
+        judge_list = []
+        for index, item in enumerate(generated_answers):
+            judge_list.append({"Question":questions[index],"Ground_truth":groud_truth[index],"Generated_answer":generated_answers[index]})
+        if context["judge_endpoint"]:
+            # make a copy of the context then change the VLLM endpoint to judge_endpoint
+            context_copy = dict(context)
+            context_copy["endpoint"] = context["judge_endpoint"]
+            context_copy["model"] = "meta-llama/Meta-Llama-3-70B-Instruct"
+            judge_results = await generate_LLM_eval(chat_service, context_copy, judge_list)
+            correct_num = 0
+            for result in judge_results:
+                correct_num += result["Result"] == "YES"
+            LLM_judge_score = correct_num/len(judge_results)
+            print(f"The accuracy of the model is {LLM_judge_score}")
         rouge_score = compute_rouge_score(generated_answers,groud_truth)
         print("Rouge_score:",rouge_score)
-        bert_score = compute_bert_score(generated_answers,groud_truth)
-        print("Bert_score:",bert_score)
-        logging.info("Eval successfully")
+        P, R, F1 = compute_bert_score(generated_answers,groud_truth)
+        print(f"BERTScore Precision: {P:.4f}, Recall: {R:.4f}, F1: {F1:.4f}")
+        # Saving the eval result to a log file
+        with open(context["output_log"],"a") as fp:
+            fp.write(f"Eval_result for {context['model']} \n")
+            fp.write(f"Rouge_score: {rouge_score} \n")
+            fp.write(f"BERTScore Precision: {P:.4f}, Recall: {R:.4f}, F1: {F1:.4f} \n")
+            if context["judge_endpoint"]:
+                fp.write(f"LLM_judge_score: {LLM_judge_score} \n")
+            fp.write(f"QA details: \n")
+            for item in judge_list:
+                fp.write(f"question: {item['Question']} \n")
+                fp.write(f"generated_answers: {item['Generated_answer']} \n")
+                fp.write(f"groud_truth: {item['Ground_truth']} \n")
+                fp.write("\n")
+        logging.info(f"Eval successfully, the eval result is saved to {context['output_log']}.")
     except Exception as e:
         logging.error(f"An unexpected error occurred during the process: {e}",exc_info=True)
 
@@ -104,15 +135,29 @@ def parse_arguments():
         type=int,
         help="If a port is specified, then use local vllm endpoint for evaluations."
     )
+    parser.add_argument(
+        "-j", "--judge_endpoint",
+        default=None,
+        type=int,
+        help="If a port is specified, then use local vllm endpoint as judge LLM."
+    )
+    parser.add_argument(
+        "-o", "--output_log",
+        default="eval_result.log",
+        help="save the eval result to a log file. Default is eval_result.log"
+    )
     return parser.parse_args()
 
 if __name__ == "__main__":
     logging.info("Initializing the process and loading configuration...")
     args = parse_arguments()
-
     context = load_config(args.config_path)
     context["model"] = args.model
     context["endpoint"] = args.vllm_endpoint
+    context["judge_endpoint"] = args.judge_endpoint
+    context["output_log"] = args.output_log
     if context["endpoint"]:
-        logging.info(f"Use local vllm service at port: '{args.vllm_endpoint}'.")
+        logging.info(f"Use local vllm service for eval at port: '{args.vllm_endpoint}'.")
+    if context["judge_endpoint"]:
+        logging.info(f"Use local vllm service for judge at port: '{args.judge_endpoint}'.")
     asyncio.run(main(context))

+ 9 - 0
recipes/use_cases/end2end-recipes/chatbot/pipelines/eval_config.yaml

@@ -8,6 +8,15 @@ eval_prompt_template: >
       "Answer": "Your answer to the question"
   }}
   ]
+judge_prompt_template: >
+  You are provided with a question, a teacher answer and a student answer. Given that question, you need to score the how good the student answer is compare to
+  the teacher's answer. If the student's answer is correct based on the teacher's answer, then return YES. If the answer is not faithful, then return NO
+  and explain which part of the answer if not faithful in the Reason section.
+  Return the result in json format with the template:
+    {{
+      "Reason": "your reason here.",
+      "Result": "YES or NO."
+    }}
 eval_json: "./evalset.json"
 
 language: "English"

+ 12 - 1
recipes/use_cases/end2end-recipes/chatbot/pipelines/evalset.json

@@ -1,4 +1,16 @@
 [
+    {
+        "question":"What is the difference on the tokenization techniques that Meta Llama 3 uses compare Llama 2?",
+        "answer": "Llama 2 uses SentencePiece for tokenization, whereas Llama 3 has transitioned to OpenAI’s Tiktoken. Llama 3 also introduces a ChatFormat class, special tokens, including those for end-of-turn markers and other features to enhance support for chat-based interactions and dialogue processing."
+    },
+    {
+        "question":"How many tokens were used in Llama 3 pretrain?",
+        "answer": "Llama 3 is pretrained on over 15T tokens that were all collected from publicly available sources."
+    },
+{
+    "question": "what are the goals for Llama 3",
+    "answer":  "With Llama 3, we set out to build the best open models that are on par with the best proprietary models available today. We wanted to address developer feedback to increase the overall helpfulness of Llama 3 and are doing so while continuing to play a leading role on responsible use and deployment of LLMs. We are embracing the open source ethos of releasing early and often to enable the community to get access to these models while they are still in development."
+},
 {
 "question": "What if I want to access Llama models but I’m not sure if my use is permitted under the Llama 2 Community License?",
 "answer": "On a limited case by case basis, we will consider bespoke licensing requests from individual entities. Please contact llamamodels@meta.com to provide more details about your request."
@@ -144,4 +156,3 @@
 "answer": "No, such companies are not prohibited when their usage of Llama is not related to the operation of critical infrastructure. Llama, however, may not be used in the operation of critical infrastructure by any company, regardless of government certifications."
 }
 ]
-

+ 29 - 2
recipes/use_cases/end2end-recipes/chatbot/pipelines/generator_utils.py

@@ -219,9 +219,9 @@ async def generate_question_batches(chat_service, api_context: dict):
         final_result.extend(parsed_json)
     return final_result
 
-async def generate_data_curation(chat_service, api_context: dict, generated_questions: list):
+async def generate_data_curation(chat_service, api_context: dict, evaluation_list: list):
     eval_tasks = []
-    for batch_index, batch_content in enumerate(generated_questions):
+    for batch_index, batch_content in enumerate(evaluation_list):
         try:
             result = data_curation_request(chat_service, api_context, batch_content)
             eval_tasks.append(result)
@@ -235,3 +235,30 @@ async def generate_data_curation(chat_service, api_context: dict, generated_ques
         if item:
             curated_data.append(item)
     return curated_data
+
+# This function is used to evaluate the quality of generated QA pairs. Return the original QA pair if the model eval result is YES. Otherwise, return an empty dict.
+async def LLM_judge_request(chat_service, api_context: dict, document_content: dict) -> dict:
+    prompt_for_system = api_context['judge_prompt_template'].format(language=api_context["language"])
+    chat_request_payload = [{'role': 'system', 'content': prompt_for_system}, {'role': 'user', 'content': f"Question: {document_content['Question']} \n Teacher's Answer: {document_content['Ground_truth']}\n Student's Answer: {document_content['Generated_answer']} "}]
+    result = await chat_service.execute_chat_request_async(api_context, chat_request_payload)
+    if not result:
+        return {}
+    # no parsing needed, just return the loads the result as a dict
+    result = json.loads(result)
+    if "Result" not in result:
+        print("Error: eval response does not contain answer")
+        print(document_content,result)
+        return {}
+    return result
+
+async def generate_LLM_eval(chat_service, api_context: dict, judge_list: list):
+    eval_tasks = []
+    for batch_index, batch_content in enumerate(judge_list):
+        try:
+            result = LLM_judge_request(chat_service, api_context, batch_content)
+            eval_tasks.append(result)
+        except Exception as e:
+            print(f"Error during data eval request execution: {e}")
+
+    judge_results = await asyncio.gather(*eval_tasks)
+    return judge_results

+ 1 - 1
src/llama_recipes/configs/peft.py

@@ -8,7 +8,7 @@ from typing import List
 class lora_config:
      r: int=8
      lora_alpha: int=32
-     target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj"])
+     target_modules: List[str] = field(default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj","gate_proj", "up_proj", "down_proj"])
      bias= "none"
      task_type: str= "CAUSAL_LM"
      lora_dropout: float=0.05

+ 1 - 0
src/llama_recipes/configs/training.py

@@ -31,6 +31,7 @@ class train_config:
     dataset = "samsum_dataset"
     peft_method: str = "lora" # None, llama_adapter (Caution: llama_adapter is currently not supported with FSDP)
     use_peft: bool=False
+    from_peft_checkpoint: str="" # if not empty and use_peft=True, will load the peft checkpoint and resume the fine-tuning on that checkpoint
     output_dir: str = "PATH/to/save/PEFT/model"
     freeze_layers: bool = False
     num_freeze_layers: int = 1

+ 13 - 8
src/llama_recipes/finetuning.py

@@ -8,7 +8,7 @@ import fire
 import random
 import torch
 import torch.optim as optim
-from peft import get_peft_model, prepare_model_for_kbit_training
+from peft import get_peft_model, prepare_model_for_kbit_training, PeftModel
 from torch.distributed.fsdp import (
     FullyShardedDataParallel as FSDP,
     ShardingStrategy
@@ -151,11 +151,17 @@ def main(**kwargs):
         model.to(torch.bfloat16)
 
     if train_config.use_peft:
-        peft_config = generate_peft_config(train_config, kwargs)
-        model = get_peft_model(model, peft_config)
-        model.print_trainable_parameters()
+        # Load the pre-trained peft model checkpoint and setup its configuration
+        if train_config.from_peft_checkpoint:
+            model = PeftModel.from_pretrained(model, train_config.from_peft_checkpoint, is_trainable=True)
+            peft_config = model.peft_config()
+        # Generate the peft config and start fine-tuning from original model
+        else:
+            peft_config = generate_peft_config(train_config, kwargs)
+            model = get_peft_model(model, peft_config)
         if wandb_run:
             wandb_run.config.update(peft_config)
+        model.print_trainable_parameters()
 
 
     hsdp_device_mesh = None
@@ -166,8 +172,7 @@ def main(**kwargs):
     #setting up FSDP if enable_fsdp is enabled
     if train_config.enable_fsdp:
         if not train_config.use_peft and train_config.freeze_layers:
-
-            freeze_transformer_layers(train_config.num_freeze_layers)
+            freeze_transformer_layers(model, train_config.num_freeze_layers)
 
         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
         my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
@@ -188,7 +193,7 @@ def main(**kwargs):
             device_id=device_id,
             limit_all_gathers=True,
             sync_module_states=train_config.low_cpu_fsdp,
-            param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
+            param_init_fn=(lambda module: module.to_empty(device=torch.device("cuda"), recurse=False))
             if train_config.low_cpu_fsdp and rank != 0 else None,
         )
         if fsdp_config.fsdp_activation_checkpointing:
@@ -217,7 +222,7 @@ def main(**kwargs):
         split="test",
     )
     if not train_config.enable_fsdp or rank == 0:
-            print(f"--> Validation Set Length = {len(dataset_val)}")
+        print(f"--> Validation Set Length = {len(dataset_val)}")
 
     if train_config.batching_strategy == "packing":
         dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)

+ 2 - 0
src/llama_recipes/utils/train_utils.py

@@ -103,6 +103,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
     val_loss =[]
 
     if train_config.save_metrics:
+        if not os.path.exists(train_config.output_dir):
+            os.makedirs(train_config.output_dir, exist_ok=True)
         metrics_filename = f"{train_config.output_dir}/metrics_data_{local_rank}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json"
         train_step_perplexity = []
         train_step_loss = []