Kaynağa Gözat

restructured folders and added eval pipeline

Kai Wu 1 yıl önce
ebeveyn
işleme
9add30acb2

Dosya farkı çok büyük olduğundan ihmal edildi
+ 44 - 4
recipes/use_cases/end2end-recipes/chatbot/data_pipelines/README.md


+ 11 - 87
recipes/use_cases/end2end-recipes/chatbot/data_pipelines/generate_question_answers.py

@@ -1,30 +1,21 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
-
-import argparse
 import asyncio
-import json
-from config import load_config
-from generator_utils import generate_question_batches, parse_qa_to_json, generate_data_eval
-from itertools import chain
 import logging
-import aiofiles  # Ensure aiofiles is installed for async file operations
 from abc import ABC, abstractmethod
 from octoai.client import OctoAI
 from functools import partial
 from openai import OpenAI
-
+import json
+from generator_utils import generate_question_batches, parse_qa_to_json, generate_data_eval
 # Configure logging to include the timestamp, log level, and message
 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
-
-# Manage rate limits with throttling
-rate_limit_threshold = 2000
-allowed_concurrent_requests = int(rate_limit_threshold * 0.75)
-request_limiter = asyncio.Semaphore(allowed_concurrent_requests)
 # Since OctoAI has different naming for llama models, create this mapping to get huggingface offical model name given OctoAI names.
 MODEL_NAME_MAPPING={"meta-llama-3-70b-instruct":"meta-llama/Meta-Llama-3-70B-Instruct",
 "meta-llama-3-8b-instruct":"meta-llama/Meta-Llama-3-8B-Instruct","llama-2-7b-chat":"meta-llama/Llama-2-7b-chat-hf"
 ,"llama-2-70b-chat":"meta-llama/Llama-2-70b-chat-hf"}
+# Manage rate limits with throttling
+rate_limit_threshold = 2000
+allowed_concurrent_requests = int(rate_limit_threshold * 0.75)
+request_limiter = asyncio.Semaphore(allowed_concurrent_requests)
 class ChatService(ABC):
     @abstractmethod
     async def execute_chat_request_async(self, api_context: dict, chat_request, eval=False):
@@ -63,7 +54,10 @@ class VllmChatService(ChatService):
         async with request_limiter:
             try:
                 event_loop = asyncio.get_running_loop()
-                model_name = MODEL_NAME_MAPPING[api_context['model']]
+                if api_context["model"] in MODEL_NAME_MAPPING:
+                    model_name = MODEL_NAME_MAPPING[api_context['model']]
+                else:
+                    model_name = api_context['model']
                 client = OpenAI(api_key=api_context['api_key'], base_url="http://localhost:"+ str(api_context['endpoint'])+"/v1")
                 api_chat_call = partial(
                     client.chat.completions.create,
@@ -74,6 +68,7 @@ class VllmChatService(ChatService):
                 response = await event_loop.run_in_executor(None, api_chat_call)
                 assistant_response = next((choice.message.content for choice in response.choices if choice.message.role == 'assistant'), "")
                 if eval:
+                    print(assistant_response)
                     assistant_response_json = json.loads(assistant_response)
                 else:
                     assistant_response_json = parse_qa_to_json(assistant_response)
@@ -81,74 +76,3 @@ class VllmChatService(ChatService):
             except Exception as error:
                 logging.error(f"Error during chat request execution: {error}",exc_info=True)
                 return ""
-
-async def main(context):
-    if context["endpoint"]:
-        chat_service = VllmChatService()
-    else:
-        chat_service = OctoAIChatService()
-    try:
-        logging.info("Starting to generate question/answer pairs.")
-        data = await generate_question_batches(chat_service, context)
-        if not data:
-            logging.warning("No data generated. Please check the input context or model configuration.")
-            return
-        flattened_list = list(chain.from_iterable(data))
-        # with open("data.json") as fp:
-        #     flattened_list = json.load(fp)
-        logging.info(f"Successfully generated {len(flattened_list)} question/answer pairs.")
-        # Use asynchronous file operation for writing to the file
-
-        # async with aiofiles.open("data.json", "w") as output_file:
-        #     await output_file.write(json.dumps(flattened_list, indent=4))
-        # logging.info("Data successfully written to 'data.json'. Process completed.")
-        curated_data = await generate_data_eval(chat_service, context,flattened_list)
-        logging.info(f"Only {len(curated_data)} question/answer pairs pass the self-curation")
-        async with aiofiles.open("curated_data.json", "w") as curated_data:
-             await curated_data.write(json.dumps(flattened_list, indent=4))
-        logging.info("Data successfully written to 'curated_data.json'. Process completed.")
-    except Exception as e:
-        logging.error(f"An unexpected error occurred during the process: {e}",exc_info=True)
-
-def parse_arguments():
-    # Define command line arguments for the script
-    parser = argparse.ArgumentParser(
-        description="Generate question/answer pairs from documentation."
-    )
-    parser.add_argument(
-        "-t", "--total_questions",
-        type=int,
-        default=100,
-        help="Specify the total number of question/answer pairs to generate."
-    )
-    parser.add_argument(
-        "-m", "--model",
-        choices=["meta-llama-3-70b-instruct","meta-llama-3-8b-instruct","llama-2-13b-chat", "llama-2-70b-chat"],
-        default="meta-llama-3-70b-instruct",
-        help="Select the model to use for generation."
-    )
-    parser.add_argument(
-        "-c", "--config_path",
-        default="config.yaml",
-        help="Set the configuration file path that has system prompt along with language, dataset path and number of questions."
-    )
-    parser.add_argument(
-        "-v", "--vllm_endpoint",
-        default=None,
-        type=int,
-        help="If a port is specified, then use local vllm endpoint for generating question/answer pairs."
-    )
-    return parser.parse_args()
-
-if __name__ == "__main__":
-    logging.info("Initializing the process and loading configuration...")
-    args = parse_arguments()
-
-    context = load_config(args.config_path)
-    context["total_questions"] = args.total_questions
-    context["model"] = args.model
-    context["endpoint"] = args.vllm_endpoint
-    logging.info(f"Configuration loaded. Generating {args.total_questions} question/answer pairs using model '{args.model}'.")
-    if context["endpoint"]:
-        logging.info(f"Use local vllm service at port: '{args.vllm_endpoint}'.")
-    asyncio.run(main(context))

recipes/use_cases/end2end-recipes/chatbot/data_pipelines/config.py → recipes/use_cases/end2end-recipes/chatbot/pipelines/config.py


recipes/use_cases/end2end-recipes/chatbot/data_pipelines/doc_processor.py → recipes/use_cases/end2end-recipes/chatbot/pipelines/doc_processor.py


+ 118 - 0
recipes/use_cases/end2end-recipes/chatbot/pipelines/eval_chatbot.py

@@ -0,0 +1,118 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
+from chat_utils import OctoAIChatService, VllmChatService
+import logging
+import evaluate
+import argparse
+from config import load_config
+import asyncio
+import json
+from itertools import chain
+
+def compute_rouge_score(generated : str, reference: str):
+    rouge_score = evaluate.load('rouge')
+    return rouge_score.compute(
+        predictions=generated,
+        references=reference,
+        use_stemmer=True,
+        use_aggregator=True
+    )
+def compute_bert_score(generated : str, reference: str):
+    bertscore = evaluate.load("bertscore")
+    return bertscore.compute(
+        predictions=generated,
+        references=reference,
+        lang="en"
+    )
+# This function is used to evaluate the quality of generated QA pairs. Return the original QA pair if the model eval result is YES. Otherwise, return an empty dict.
+async def eval_request(chat_service, api_context: dict, question: str) -> dict:
+    prompt_for_system = api_context['eval_prompt_template'].format(language=api_context["language"])
+    chat_request_payload = [{'role': 'system', 'content': prompt_for_system}, {'role': 'user', 'content': f"Question: {question}"}]
+    # Getting a list of result, in this case, there should be only one result
+    results = await chat_service.execute_chat_request_async(api_context, chat_request_payload,eval=False)
+    # convert the result string to a list
+    results = eval(results)
+    if not results or len(results) > 1:
+        print("results",type(results),len(results),results)
+        return {}
+    result = results[0]
+    if "Answer" not in result:
+        print("Error: eval response does not contain answer")
+        print(question,result)
+        return {}
+    print("result",result)
+    # Send back the model generated answer
+    return result["Answer"]
+
+async def generate_eval_answer(chat_service, api_context: dict, questions: list):
+    eval_tasks = []
+    for batch_index, question in enumerate(questions):
+        try:
+            result = eval_request(chat_service, api_context, question)
+            eval_tasks.append(result)
+        except Exception as e:
+            print(f"Error during data eval request execution: {e}")
+    print(len(eval_tasks),"eval_tasks")
+    eval_results = await asyncio.gather(*eval_tasks)
+
+    return eval_results
+
+async def main(context):
+    if context["endpoint"]:
+        chat_service = VllmChatService()
+    else:
+        chat_service = OctoAIChatService()
+    try:
+        logging.info("Starting to generate answer given the eval set.")
+        with open(context["eval_json"]) as fp:
+            eval_json = json.load(fp)
+        questions,groud_truth = [],[]
+        for index, item in enumerate(eval_json):
+            questions.append(item["question"])
+            groud_truth.append(item["answer"])
+        generated_answers = await generate_eval_answer(chat_service, context,questions)
+        if not generated_answers:
+            logging.warning("No answers generated. Please check the input context or model configuration.")
+            return
+        logging.info(f"Successfully generated {len(generated_answers)} answers.")
+        rouge_score = compute_rouge_score(generated_answers,groud_truth)
+        print("Rouge_score:",rouge_score)
+        bert_score = compute_bert_score(generated_answers,groud_truth)
+        print("Bert_score:",bert_score)
+        logging.info("Eval successfully")
+    except Exception as e:
+        logging.error(f"An unexpected error occurred during the process: {e}",exc_info=True)
+
+def parse_arguments():
+    # Define command line arguments for the script
+    parser = argparse.ArgumentParser(
+        description="Generate question/answer pairs from documentation."
+    )
+    parser.add_argument(
+        "-m", "--model",
+        default="chatbot",
+        help="Select the model to use for evaluation, this maybe a LoRA adapter."
+    )
+    parser.add_argument(
+        "-c", "--config_path",
+        default="eval_config.yaml",
+        help="Set the configuration file path that has system prompt along with language, evalset path."
+    )
+    parser.add_argument(
+        "-v", "--vllm_endpoint",
+        default=None,
+        type=int,
+        help="If a port is specified, then use local vllm endpoint for evaluations."
+    )
+    return parser.parse_args()
+
+if __name__ == "__main__":
+    logging.info("Initializing the process and loading configuration...")
+    args = parse_arguments()
+
+    context = load_config(args.config_path)
+    context["model"] = args.model
+    context["endpoint"] = args.vllm_endpoint
+    if context["endpoint"]:
+        logging.info(f"Use local vllm service at port: '{args.vllm_endpoint}'.")
+    asyncio.run(main(context))

+ 12 - 0
recipes/use_cases/end2end-recipes/chatbot/pipelines/eval_config.yaml

@@ -0,0 +1,12 @@
+eval_prompt_template: >
+  You are a AI assistant that skilled in answering questions related to Llama model.
+  Below is a question from a llama user, please answer it in {language}, make the answer as concise as possible, it should be at most 100 words.
+  Return the result with the template:
+  {{
+      "Question": "The question user asked to you"
+      "Answer": "Your answer to the question"
+  }}
+
+eval_json: "./evalset.json"
+
+language: "English"

recipes/use_cases/end2end-recipes/chatbot/data_pipelines/evalset.json → recipes/use_cases/end2end-recipes/chatbot/pipelines/evalset.json


+ 88 - 0
recipes/use_cases/end2end-recipes/chatbot/pipelines/generate_question_answers.py

@@ -0,0 +1,88 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
+
+import argparse
+import asyncio
+import json
+from config import load_config
+from generator_utils import generate_question_batches, generate_data_eval
+from chat_utils import OctoAIChatService, VllmChatService
+from itertools import chain
+import logging
+import aiofiles  # Ensure aiofiles is installed for async file operations
+
+
+# Configure logging to include the timestamp, log level, and message
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
+
+
+async def main(context):
+    if context["endpoint"]:
+        chat_service = VllmChatService()
+    else:
+        chat_service = OctoAIChatService()
+    try:
+        logging.info("Starting to generate question/answer pairs.")
+        data = await generate_question_batches(chat_service, context)
+        if not data:
+            logging.warning("No data generated. Please check the input context or model configuration.")
+            return
+        flattened_list = list(chain.from_iterable(data))
+        # with open("data.json") as fp:
+        #     flattened_list = json.load(fp)
+        logging.info(f"Successfully generated {len(flattened_list)} question/answer pairs.")
+        # Use asynchronous file operation for writing to the file
+
+        # async with aiofiles.open("data.json", "w") as output_file:
+        #     await output_file.write(json.dumps(flattened_list, indent=4))
+        # logging.info("Data successfully written to 'data.json'. Process completed.")
+        curated_data = await generate_data_eval(chat_service, context,flattened_list)
+        logging.info(f"Only {len(curated_data)} question/answer pairs pass the self-curation")
+        async with aiofiles.open("curated_data.json", "w") as curated_data:
+             await curated_data.write(json.dumps(flattened_list, indent=4))
+        logging.info("Data successfully written to 'curated_data.json'. Process completed.")
+    except Exception as e:
+        logging.error(f"An unexpected error occurred during the process: {e}",exc_info=True)
+
+def parse_arguments():
+    # Define command line arguments for the script
+    parser = argparse.ArgumentParser(
+        description="Generate question/answer pairs from documentation."
+    )
+    parser.add_argument(
+        "-t", "--total_questions",
+        type=int,
+        default=100,
+        help="Specify the total number of question/answer pairs to generate."
+    )
+    parser.add_argument(
+        "-m", "--model",
+        choices=["meta-llama-3-70b-instruct","meta-llama-3-8b-instruct","llama-2-13b-chat", "llama-2-70b-chat"],
+        default="meta-llama-3-70b-instruct",
+        help="Select the model to use for generation."
+    )
+    parser.add_argument(
+        "-c", "--config_path",
+        default="./generation_config.yaml",
+        help="Set the configuration file path that has system prompt along with language, dataset path and number of questions."
+    )
+    parser.add_argument(
+        "-v", "--vllm_endpoint",
+        default=None,
+        type=int,
+        help="If a port is specified, then use local vllm endpoint for generating question/answer pairs."
+    )
+    return parser.parse_args()
+
+if __name__ == "__main__":
+    logging.info("Initializing the process and loading configuration...")
+    args = parse_arguments()
+
+    context = load_config(args.config_path)
+    context["total_questions"] = args.total_questions
+    context["model"] = args.model
+    context["endpoint"] = args.vllm_endpoint
+    logging.info(f"Configuration loaded. Generating {args.total_questions} question/answer pairs using model '{args.model}'.")
+    if context["endpoint"]:
+        logging.info(f"Use local vllm service at port: '{args.vllm_endpoint}'.")
+    asyncio.run(main(context))

+ 1 - 1
recipes/use_cases/end2end-recipes/chatbot/data_pipelines/config.yaml

@@ -22,7 +22,7 @@ question_prompt_template: >
       }}
     ]
 
-eval_prompt_template: >
+curation_prompt_template: >
   Below is a question and answer pair about Llama language model. Evaluate
   whether or not this qusestion and answer pair will be helpful for a user of Llama langauge model.
   Respond with only a single JSON blob with an "explanation" field that is a short (less than 100 word)

+ 5 - 5
recipes/use_cases/end2end-recipes/chatbot/data_pipelines/generator_utils.py

@@ -14,7 +14,6 @@ import json
 # Initialize logging
 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
-
 def read_text_file(file_path):
     try:
         with open(file_path, 'r') as f:
@@ -86,7 +85,7 @@ def clean(s):
             if any(c.isalnum() for c in item):
                 result.append(item)
         return " ".join(result)
-
+# given a response string, return a string that can be saved as json.
 def parse_qa_to_json(response_string):
     split_lines = response_string.split("\n")
     start,end = None,None
@@ -114,7 +113,8 @@ def parse_qa_to_json(response_string):
         question = " ".join(split_lines[start:end]).split('"Question":')[1]
         answer = " ".join(split_lines[end:]).split('"Answer":')[1]
         qa_set.add((clean(question), clean(answer)))
-    qa_list = [{"question": q, "answer":a} for q,a in qa_set]
+    qa_list = [{"Question": q, "Answer":a} for q,a in qa_set]
+
     return json.dumps(qa_list, indent=4)
 
 
@@ -127,8 +127,8 @@ async def prepare_and_send_request(chat_service, api_context: dict, document_con
     return json.loads(await chat_service.execute_chat_request_async(api_context, chat_request_payload,eval=False))
 # This function is used to evaluate the quality of generated QA pairs. Return the original QA pair if the model eval result is YES. Otherwise, return an empty dict.
 async def data_eval_request(chat_service, api_context: dict, document_content: dict) -> dict:
-    prompt_for_system = api_context['eval_prompt_template'].format(language=api_context["language"])
-    chat_request_payload = [{'role': 'system', 'content': prompt_for_system}, {'role': 'user', 'content': f"Question: {document_content['question']}, Answer: {document_content['answer']}"}]
+    prompt_for_system = api_context['curation_prompt_template'].format(language=api_context["language"])
+    chat_request_payload = [{'role': 'system', 'content': prompt_for_system}, {'role': 'user', 'content': f"Question: {document_content['Question']}, Answer: {document_content['Answer']}"}]
     result = await chat_service.execute_chat_request_async(api_context, chat_request_payload,eval=True)
     if not result:
         return {}

+ 3 - 0
requirements.txt

@@ -23,3 +23,6 @@ octoai
 python-magic
 PyPDF2
 aiofiles
+evaluate
+rouge_score
+bert_score