소스 검색

fix generate_qa function

Kai Wu 1 년 전
부모
커밋
6204d5ae38

BIN
recipes/use_cases/end2end-recipes/chatbot/data_pipelines/._faq-data


+ 1 - 1
recipes/use_cases/end2end-recipes/chatbot/data_pipelines/config.py

@@ -13,6 +13,6 @@ def load_config(config_path: str = "./config.yaml"):
         config["api_key"] = os.environ["OCTOAI_API_TOKEN"]
     except KeyError:
         print("API token did not found, please set the OCTOAI_API_TOKEN environment variable if using OctoAI, otherwise set api_key to default EMPTY")
-        # local Vllm endpoint did not need API key, so set the API key to "EMPTY" if not found
+        # local Vllm endpoint did not need API key, so set the API key to "EMPTY" if OCTOAI_API_TOKEN not found
         config["api_key"] = "EMPTY"
     return config

+ 2 - 2
recipes/use_cases/end2end-recipes/chatbot/data_pipelines/config.yaml

@@ -4,7 +4,7 @@ question_prompt_template: >
   read it and generate question and answer pairs
   that are most likely be asked by a use of llama that just want to start,
   please make sure you follow those rules,
-  1. Generate only {total_questions} question answer pairs.
+  1. Generate only {num_questions} question answer pairs.
   2. Generate in {language}.
   3. The questions can be answered based *solely* on the given passage.
   4. Avoid asking questions with similar meaning.
@@ -27,4 +27,4 @@ data_dir: "./data"
 
 language: "English"
 
-total_questions: 1000
+num_questions: 2

+ 10 - 6
recipes/use_cases/end2end-recipes/chatbot/data_pipelines/generate_question_answers.py

@@ -5,7 +5,7 @@ import argparse
 import asyncio
 import json
 from config import load_config
-from generator_utils import generate_question_batches, parse_qa_to_json, get_model_name
+from generator_utils import generate_question_batches, parse_qa_to_json
 from itertools import chain
 import logging
 import aiofiles  # Ensure aiofiles is installed for async file operations
@@ -21,7 +21,10 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
 rate_limit_threshold = 2000
 allowed_concurrent_requests = int(rate_limit_threshold * 0.75)
 request_limiter = asyncio.Semaphore(allowed_concurrent_requests)
-
+# Since OctoAI has different naming for llama models, create this mapping to get huggingface offical model name given OctoAI names.
+MODEL_NAME_MAPPING={"meta-llama-3-70b-instruct":"meta-llama/Meta-Llama-3-70B-Instruct",
+"meta-llama-3-8b-instruct":"meta-llama/Meta-Llama-3-8B-Instruct","llama-2-7b-chat":"meta-llama/Llama-2-7b-chat-hf"
+,"llama-2-70b-chat":"meta-llama/Llama-2-70b-chat-hf"}
 class ChatService(ABC):
     @abstractmethod
     async def execute_chat_request_async(self, api_context: dict, chat_request):
@@ -57,7 +60,7 @@ class VllmChatService(ChatService):
         async with request_limiter:
             try:
                 event_loop = asyncio.get_running_loop()
-                model_name = get_model_name(api_context['model'])
+                model_name = MODEL_NAME_MAPPING[api_context['model']]
                 client = OpenAI(api_key=api_context['api_key'], base_url="http://localhost:"+ str(api_context['endpoint'])+"/v1")
                 api_chat_call = partial(
                     client.chat.completions.create,
@@ -68,7 +71,8 @@ class VllmChatService(ChatService):
                 response = await event_loop.run_in_executor(None, api_chat_call)
                 assistant_response = next((choice.message.content for choice in response.choices if choice.message.role == 'assistant'), "")
                 assistant_response_json = parse_qa_to_json(assistant_response)
-                assert(len(assistant_response_json)!=0)
+                if len(assistant_response_json)==0:
+                    logging.error("No question/answer pairs generated. Please check the input context or model configuration.")
                 return assistant_response_json
             except Exception as error:
                 logging.error(f"Error during chat request execution: {error}",exc_info=True)
@@ -103,8 +107,8 @@ def parse_arguments():
     parser.add_argument(
         "-t", "--total_questions",
         type=int,
-        default=10,
-        help="Specify the number of question/answer pairs to generate."
+        default=100,
+        help="Specify the total number of question/answer pairs to generate."
     )
     parser.add_argument(
         "-m", "--model",

+ 19 - 22
recipes/use_cases/end2end-recipes/chatbot/data_pipelines/generator_utils.py

@@ -14,16 +14,7 @@ import json
 # Initialize logging
 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
-# Since OctoAI has different naming for llama models, get the huggingface offical model name using OctoAI names.
-def get_model_name(model):
-    if model == "meta-llama-3-70b-instruct":
-        return "meta-llama/Meta-Llama-3-70B-Instruct"
-    elif model == "meta-llama-3-8b-instruct":
-        return "meta-llama/Meta-Llama-3-8B-Instruct"
-    elif model == "llama-2-7b-chat":
-        return "meta-llama/Llama-2-7b-chat-hf"
-    else:
-        return "meta-llama/Llama-2-70b-chat-hf"
+
 def read_text_file(file_path):
     try:
         with open(file_path, 'r') as f:
@@ -88,8 +79,13 @@ def read_file_content(context):
     if len(text) == 0:
         logging.error(f"Error reading files, text is empty")
     return ' '.join(file_strings)
-
-
+# clean the text by removing all parts that did not contain any alphanumeric characters
+def clean(s):
+        result = []
+        for item in s.split('"'):
+            if any(c.isalnum() for c in item):
+                result.append(item)
+        return " ".join(result)
 
 def parse_qa_to_json(response_string):
     split_lines = response_string.split("\n")
@@ -109,21 +105,21 @@ def parse_qa_to_json(response_string):
                 end = i
             # found Question means we have reached the end of the question, so add it to qa_list
             elif '"Question":' in line:
-                question = " ".join(" ".join(split_lines[start:end]).split('"Question":')[1].split('"')[1:-1])
-                answer = " ".join(" ".join(split_lines[end:i]).split('"Answer":')[1].split('"')[1:-1])
+                question = " ".join(split_lines[start:end]).split('"Question":')[1]
+                answer = " ".join(split_lines[end:i]).split('"Answer":')[1]
                 start,end = i,None
-                qa_set.add((question, answer))
+                qa_set.add((clean(question), clean(answer)))
         # adding last question back to qa_list
-        if start and end:
-            question = " ".join(" ".join(split_lines[start:end]).split('"Question":')[1].split('"')[1:-1])
-            answer = " ".join(" ".join(split_lines[end:i]).split('"Answer":')[1].split('"')[1:-1])
-            qa_set.add((question, answer))
+    if start and end:
+        question = " ".join(split_lines[start:end]).split('"Question":')[1]
+        answer = " ".join(split_lines[end:]).split('"Answer":')[1]
+        qa_set.add((clean(question), clean(answer)))
     qa_list = [{"question": q, "answer":a} for q,a in qa_set]
     return json.dumps(qa_list, indent=4)
 
 
-async def prepare_and_send_request(chat_service, api_context: dict, document_content: str, total_questions: int) -> dict:
-    prompt_for_system = api_context['question_prompt_template'].format(total_questions=total_questions, language=api_context["language"])
+async def prepare_and_send_request(chat_service, api_context: dict, document_content: str, num_questions: int) -> dict:
+    prompt_for_system = api_context['question_prompt_template'].format(num_questions=num_questions, language=api_context["language"])
     chat_request_payload = [{'role': 'system', 'content': prompt_for_system}, {'role': 'user', 'content': document_content}]
     result = await chat_service.execute_chat_request_async(api_context, chat_request_payload)
     if not result:
@@ -142,7 +138,8 @@ async def generate_question_batches(chat_service, api_context: dict):
 
     total_questions = api_context["total_questions"]
     batches_count = len(document_batches)
-    base_questions_per_batch = total_questions // batches_count
+    # each batch should have at least 1 question
+    base_questions_per_batch = max(total_questions // batches_count,1)
     extra_questions = total_questions % batches_count
 
     print(f"Questions per batch: {base_questions_per_batch} (+1 for the first {extra_questions} batches), Total questions: {total_questions}, Batches: {batches_count}")