瀏覽代碼

adding self-curation using LLM

Kai Wu 11 月之前
父節點
當前提交
274ed14aa0

+ 40 - 0
recipes/finetuning/datasets/chatbot_dataset.py

@@ -0,0 +1,40 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
+
+
+import copy
+import datasets
+from datasets import Dataset, load_dataset, DatasetDict
+import itertools
+
+
+B_INST, E_INST = "[INST]", "[/INST]"
+
+def tokenize_dialog(q_a_pair, tokenizer):
+    prompt_tokens = [tokenizer.encode(f"{tokenizer.bos_token}{B_INST} {(question).strip()} {E_INST}", add_special_tokens=False) for question in q_a_pair["question"]]
+    answer_tokens = [tokenizer.encode(f"{answer.strip()} {tokenizer.eos_token}", add_special_tokens=False) for answer in q_a_pair["answer"]]
+    dialog_tokens = list(itertools.chain.from_iterable(zip(prompt_tokens, answer_tokens)))
+    dialog_tokens = list(itertools.chain.from_iterable(zip(prompt_tokens, answer_tokens)))
+    #Add labels, convert prompt token to -100 in order to ignore in loss function
+    labels_tokens = [len(c)*[-100,] if i % 2 == 0 else c for i,c in enumerate(dialog_tokens)]
+
+    combined_tokens = {
+        "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
+        "labels": list(itertools.chain(*(t for t in labels_tokens))),
+    }
+
+    return dict(combined_tokens, attention_mask=[1]*len(combined_tokens["input_ids"]))
+
+
+def get_custom_dataset(dataset_config, tokenizer, split, split_ratio=0.8):
+    dataset = load_dataset('json', data_files=dataset_config.data_path)
+    dataset = dataset['train'].train_test_split(test_size=1-split_ratio, shuffle=True)
+
+    dataset = dataset[split].map(lambda sample: {
+        "question": sample["question"],
+        "answer": sample["answer"],
+        },
+        batched=True,
+    )
+    dataset = dataset.map(lambda x: tokenize_dialog(x, tokenizer))
+    return dataset

File diff suppressed because it is too large
+ 7 - 1
recipes/use_cases/end2end-recipes/chatbot/data_pipelines/REAME.md


+ 13 - 3
recipes/use_cases/end2end-recipes/chatbot/data_pipelines/config.yaml

@@ -1,10 +1,9 @@
 question_prompt_template: >
   You are a language model skilled in creating quiz questions.
   You will be provided with a document,
-  read it and generate question and answer pairs
-  that are most likely be asked by a use of llama that just want to start,
+  read it and please generate question and answer that are most likely be asked by a user of llama model,
   please make sure you follow those rules,
-  1. Generate only {num_questions} question answer pairs.
+  1. Generate at most {num_questions} question answer pairs, you can generate less questions if you believe there are nothing related to Llama.
   2. Generate in {language}.
   3. The questions can be answered based *solely* on the given passage.
   4. Avoid asking questions with similar meaning.
@@ -23,6 +22,17 @@ question_prompt_template: >
       }}
     ]
 
+eval_prompt_template: >
+  Below is a question and answer pair about Llama language model. Evaluate
+  whether or not this qusestion and answer pair will be helpful for a user of Llama langauge model.
+  Respond with only a single JSON blob with an "explanation" field that is a short (less than 100 word)
+  explanation of your answer and an "answer" field which is YES or NO. Only generate the answer in {language}.
+  Return the result in json format with the template:
+    {{
+      "Reason": "your reason here.",
+      "Answer": "YES or No."
+    }},
+
 data_dir: "./data"
 
 language: "English"

File diff suppressed because it is too large
+ 147 - 0
recipes/use_cases/end2end-recipes/chatbot/data_pipelines/evalset.json


+ 22 - 11
recipes/use_cases/end2end-recipes/chatbot/data_pipelines/generate_question_answers.py

@@ -5,7 +5,7 @@ import argparse
 import asyncio
 import json
 from config import load_config
-from generator_utils import generate_question_batches, parse_qa_to_json
+from generator_utils import generate_question_batches, parse_qa_to_json, generate_data_eval
 from itertools import chain
 import logging
 import aiofiles  # Ensure aiofiles is installed for async file operations
@@ -27,14 +27,14 @@ MODEL_NAME_MAPPING={"meta-llama-3-70b-instruct":"meta-llama/Meta-Llama-3-70B-Ins
 ,"llama-2-70b-chat":"meta-llama/Llama-2-70b-chat-hf"}
 class ChatService(ABC):
     @abstractmethod
-    async def execute_chat_request_async(self, api_context: dict, chat_request):
+    async def execute_chat_request_async(self, api_context: dict, chat_request, eval=False):
         pass
 
 # Please implement your own chat service class here.
 # The class should inherit from the ChatService class and implement the execute_chat_request_async method.
 # The following are two example chat service classes that you can use as a reference.
 class OctoAIChatService(ChatService):
-    async def execute_chat_request_async(self, api_context: dict, chat_request):
+    async def execute_chat_request_async(self, api_context: dict, chat_request, eval=False):
         async with request_limiter:
             try:
                 event_loop = asyncio.get_running_loop()
@@ -47,7 +47,10 @@ class OctoAIChatService(ChatService):
                 )
                 response = await event_loop.run_in_executor(None, api_chat_call)
                 assistant_response = next((choice.message.content for choice in response.choices if choice.message.role == 'assistant'), "")
-                assistant_response_json = parse_qa_to_json(assistant_response)
+                if eval:
+                    assistant_response_json = json.loads(assistant_response)
+                else:
+                    assistant_response_json = parse_qa_to_json(assistant_response)
 
                 return assistant_response_json
             except Exception as error:
@@ -56,7 +59,7 @@ class OctoAIChatService(ChatService):
 # Use the local vllm openai compatible server for generating question/answer pairs to make API call syntax consistent
 # please read for more detail:https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html.
 class VllmChatService(ChatService):
-    async def execute_chat_request_async(self, api_context: dict, chat_request):
+    async def execute_chat_request_async(self, api_context: dict, chat_request, eval=False):
         async with request_limiter:
             try:
                 event_loop = asyncio.get_running_loop()
@@ -70,9 +73,10 @@ class VllmChatService(ChatService):
                 )
                 response = await event_loop.run_in_executor(None, api_chat_call)
                 assistant_response = next((choice.message.content for choice in response.choices if choice.message.role == 'assistant'), "")
-                assistant_response_json = parse_qa_to_json(assistant_response)
-                if len(assistant_response_json)==0:
-                    logging.error("No question/answer pairs generated. Please check the input context or model configuration.")
+                if eval:
+                    assistant_response_json = json.loads(assistant_response)
+                else:
+                    assistant_response_json = parse_qa_to_json(assistant_response)
                 return assistant_response_json
             except Exception as error:
                 logging.error(f"Error during chat request execution: {error}",exc_info=True)
@@ -90,12 +94,19 @@ async def main(context):
             logging.warning("No data generated. Please check the input context or model configuration.")
             return
         flattened_list = list(chain.from_iterable(data))
+        # with open("data.json") as fp:
+        #     flattened_list = json.load(fp)
         logging.info(f"Successfully generated {len(flattened_list)} question/answer pairs.")
         # Use asynchronous file operation for writing to the file
-        async with aiofiles.open("data.json", "w") as output_file:
-            await output_file.write(json.dumps(flattened_list, indent=4))
-        logging.info("Data successfully written to 'data.json'. Process completed.")
 
+        # async with aiofiles.open("data.json", "w") as output_file:
+        #     await output_file.write(json.dumps(flattened_list, indent=4))
+        # logging.info("Data successfully written to 'data.json'. Process completed.")
+        curated_data = await generate_data_eval(chat_service, context,flattened_list)
+        logging.info(f"Only {len(curated_data)} question/answer pairs pass the self-curation")
+        async with aiofiles.open("curated_data.json", "w") as curated_data:
+             await curated_data.write(json.dumps(flattened_list, indent=4))
+        logging.info("Data successfully written to 'curated_data.json'. Process completed.")
     except Exception as e:
         logging.error(f"An unexpected error occurred during the process: {e}",exc_info=True)
 

+ 37 - 2
recipes/use_cases/end2end-recipes/chatbot/data_pipelines/generator_utils.py

@@ -121,10 +121,28 @@ def parse_qa_to_json(response_string):
 async def prepare_and_send_request(chat_service, api_context: dict, document_content: str, num_questions: int) -> dict:
     prompt_for_system = api_context['question_prompt_template'].format(num_questions=num_questions, language=api_context["language"])
     chat_request_payload = [{'role': 'system', 'content': prompt_for_system}, {'role': 'user', 'content': document_content}]
-    result = await chat_service.execute_chat_request_async(api_context, chat_request_payload)
+    result = await chat_service.execute_chat_request_async(api_context, chat_request_payload,eval=False)
     if not result:
         return {}
-    return json.loads(await chat_service.execute_chat_request_async(api_context, chat_request_payload))
+    return json.loads(await chat_service.execute_chat_request_async(api_context, chat_request_payload,eval=False))
+# This function is used to evaluate the quality of generated QA pairs. Return the original QA pair if the model eval result is YES. Otherwise, return an empty dict.
+async def data_eval_request(chat_service, api_context: dict, document_content: dict) -> dict:
+    prompt_for_system = api_context['eval_prompt_template'].format(language=api_context["language"])
+    chat_request_payload = [{'role': 'system', 'content': prompt_for_system}, {'role': 'user', 'content': f"Question: {document_content['question']}, Answer: {document_content['answer']}"}]
+    result = await chat_service.execute_chat_request_async(api_context, chat_request_payload,eval=True)
+    if not result:
+        return {}
+    if "Answer" not in result:
+        print("Error: eval response does not contain answer")
+        print(document_content,result)
+        return {}
+    # Send back the original QA pair is the model eval result is YES
+    if result["Answer"] == "YES":
+        return document_content
+    else:
+        print(document_content,result)
+    return {}
+
 
 async def generate_question_batches(chat_service, api_context: dict):
     document_text = read_file_content(api_context)
@@ -158,3 +176,20 @@ async def generate_question_batches(chat_service, api_context: dict):
     question_generation_results = await asyncio.gather(*generation_tasks)
 
     return question_generation_results
+
+async def generate_data_eval(chat_service, api_context: dict, generated_questions: list):
+    eval_tasks = []
+    for batch_index, batch_content in enumerate(generated_questions):
+        try:
+            result = data_eval_request(chat_service, api_context, batch_content)
+            eval_tasks.append(result)
+        except Exception as e:
+            print(f"Error during data eval request execution: {e}")
+
+    eval_results = await asyncio.gather(*eval_tasks)
+    curated_data = []
+    for item in eval_results:
+        # if the item is not empty, add it to the curated data list
+        if item:
+            curated_data.append(item)
+    return curated_data

+ 10 - 9
src/llama_recipes/configs/datasets.py

@@ -3,32 +3,33 @@
 
 from dataclasses import dataclass
 
-    
+
 @dataclass
 class samsum_dataset:
     dataset: str =  "samsum_dataset"
     train_split: str = "train"
     test_split: str = "validation"
-    
-    
+
+
 @dataclass
 class grammar_dataset:
     dataset: str = "grammar_dataset"
-    train_split: str = "src/llama_recipes/datasets/grammar_dataset/gtrain_10k.csv" 
+    train_split: str = "src/llama_recipes/datasets/grammar_dataset/gtrain_10k.csv"
     test_split: str = "src/llama_recipes/datasets/grammar_dataset/grammar_validation.csv"
 
-    
+
 @dataclass
 class alpaca_dataset:
     dataset: str = "alpaca_dataset"
     train_split: str = "train"
     test_split: str = "val"
     data_path: str = "src/llama_recipes/datasets/alpaca_data.json"
-    
-    
+
+
 @dataclass
 class custom_dataset:
     dataset: str = "custom_dataset"
-    file: str = "examples/custom_dataset.py"
+    file: str = "recipes/finetuning/datasets/custom_dataset.py"
     train_split: str = "train"
-    test_split: str = "validation"
+    test_split: str = "validation"
+    data_path: str = ""

+ 1 - 1
src/llama_recipes/finetuning.py

@@ -134,7 +134,7 @@ def main(**kwargs):
     tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
     tokenizer.pad_token_id = tokenizer.eos_token_id
 
-    # If there is a mismatch between tokenizer vocab size and embedding matrix, 
+    # If there is a mismatch between tokenizer vocab size and embedding matrix,
     # throw a warning and then expand the embedding matrix
     if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
         print("WARNING: Resizing the embedding matrix to match the tokenizer vocab size.")