Bläddra i källkod

adding data geenration pipe

Hamid Shojanazeri 1 år sedan
förälder
incheckning
d10ef3202b

+ 47 - 0
recipes/use_cases/end2end-recipes/chatbot/data_pipelines/doc_processor.py

@@ -0,0 +1,47 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+# Assuming result_average_token is a constant, use UPPER_CASE for its name to follow Python conventions
+AVERAGE_TOKENS_PER_RESULT = 100
+
+def get_token_limit_for_model(model: str) -> int:
+    """Returns the token limit for a given model."""
+    if model == "llama-2-70b-chat-fp16" or model == "llama-2-13b-chat-turbo":
+        return 4096
+    
+
+def calculate_num_tokens_for_message(encoded_text) -> int:
+    """Calculates the number of tokens used by a message."""
+    
+    # Added 3 to account for priming with assistant's reply, as per original comment
+    return len(encoded_text) + 3
+
+
+def split_text_into_chunks(context: dict, text: str, tokenizer) -> list[str]:
+    """Splits a long text into substrings based on token length constraints, adjusted for question generation."""
+    # Adjusted approach to calculate max tokens available for text chunks
+    encoded_text = tokenizer(text, return_tensors="pt", padding=True)["input_ids"]
+    encoded_text = encoded_text.squeeze()
+    model_token_limit = get_token_limit_for_model(context["model"])
+
+    tokens_for_questions = calculate_num_tokens_for_message(encoded_text)
+    estimated_tokens_per_question = AVERAGE_TOKENS_PER_RESULT
+    estimated_total_question_tokens = estimated_tokens_per_question * context["total_questions"]
+    # Ensure there's a reasonable minimum chunk size
+    max_tokens_for_text = max(model_token_limit - tokens_for_questions - estimated_total_question_tokens, model_token_limit // 10)
+    
+    chunks, current_chunk = [], []
+    print(f"Splitting text into chunks of {max_tokens_for_text} tokens, encoded_text {len(encoded_text)}", flush=True)
+    for token in encoded_text:
+        if len(current_chunk) >= max_tokens_for_text:
+            chunks.append(tokenizer.decode(current_chunk).strip())
+            current_chunk = []
+        else:
+            current_chunk.append(token)
+
+    if current_chunk:
+        chunks.append(tokenizer.decode(current_chunk).strip())
+
+    print(f"Number of chunks in the processed text: {len(chunks)}", flush=True)
+   
+    return chunks

+ 103 - 0
recipes/use_cases/end2end-recipes/chatbot/data_pipelines/generate_question_answers.py

@@ -0,0 +1,103 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import argparse
+import asyncio
+import json
+from config import load_config
+from generator_utils import generate_question_batches, parse_qa_to_json
+from itertools import chain
+import logging
+import aiofiles  # Ensure aiofiles is installed for async file operations
+from abc import ABC, abstractmethod
+from octoai.client import Client
+from functools import partial
+
+# Configure logging to include the timestamp, log level, and message
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
+
+# Manage rate limits with throttling
+rate_limit_threshold = 2000
+allowed_concurrent_requests = int(rate_limit_threshold * 0.75)
+request_limiter = asyncio.Semaphore(allowed_concurrent_requests)
+
+class ChatService(ABC):
+    @abstractmethod
+    async def execute_chat_request_async(self, api_context: dict, chat_request):
+        pass
+
+# Please implement your own chat service class here.
+# The class should inherit from the ChatService class and implement the execute_chat_request_async method.
+class OctoAIChatService(ChatService):
+    async def execute_chat_request_async(self, api_context: dict, chat_request):
+        async with request_limiter:
+            try:
+                event_loop = asyncio.get_running_loop()
+                client = Client(api_context['api_key'])
+                api_chat_call = partial(
+                    client.chat.completions.create,
+                    model=api_context['model'],
+                    messages=chat_request,
+                    temperature=0.0
+                )
+                response = await event_loop.run_in_executor(None, api_chat_call)
+                assistant_response = next((choice.message.content for choice in response.choices if choice.message.role == 'assistant'), "")
+                assistant_response_json = parse_qa_to_json(assistant_response)
+                      
+                return assistant_response_json
+            except Exception as error:
+                print(f"Error during chat request execution: {error}")
+                return ""
+            
+async def main(context):
+    chat_service = OctoAIChatService()
+    try:
+        logging.info("Starting to generate question/answer pairs.")
+        data = await generate_question_batches(chat_service, context)
+        if not data:
+            logging.warning("No data generated. Please check the input context or model configuration.")
+            return
+        flattened_list = list(chain.from_iterable(data))
+        logging.info(f"Successfully generated {len(flattened_list)} question/answer pairs.")
+        # Use asynchronous file operation for writing to the file
+        async with aiofiles.open("data.json", "w") as output_file:
+            await output_file.write(json.dumps(flattened_list, indent=4))
+        logging.info("Data successfully written to 'data.json'. Process completed.")
+
+    except Exception as e:
+        logging.error(f"An unexpected error occurred during the process: {e}")
+
+def parse_arguments():
+    # Define command line arguments for the script
+    parser = argparse.ArgumentParser(
+        description="Generate question/answer pairs from documentation."
+    )
+    parser.add_argument(
+        "-t", "--total_questions",
+        type=int,
+        default=10,
+        help="Specify the number of question/answer pairs to generate."
+    )
+    parser.add_argument(
+        "-m", "--model",
+        choices=["llama-2-70b-chat-fp16", "llama-2-13b-chat-fp16"],
+        default="llama-2-70b-chat-fp16",
+        help="Select the model to use for generation."
+    )
+    parser.add_argument(
+        "-c", "--config_path",
+        default="config.yaml",
+        help="Set the configuration file path that has system prompt along with language, dataset path and number of questions."
+    )
+    return parser.parse_args()
+
+if __name__ == "__main__":
+    logging.info("Initializing the process and loading configuration...")
+    args = parse_arguments()
+
+    context = load_config(args.config_path)
+    context["total_questions"] = args.total_questions
+    context["model"] = args.model
+
+    logging.info(f"Configuration loaded. Generating {args.total_questions} question/answer pairs using model '{args.model}'.")
+    asyncio.run(main(context))

+ 121 - 0
recipes/use_cases/end2end-recipes/chatbot/data_pipelines/generator_utils.py

@@ -0,0 +1,121 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import os
+import re
+from transformers import  AutoTokenizer
+from octoai.client import Client
+import asyncio
+import magic
+from PyPDF2 import PdfReader
+import json
+from doc_processor import split_text_into_chunks
+import logging
+# Initialize logging
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
+
+
+def read_text_file(file_path):
+    try:
+        with open(file_path, 'r') as f:
+            return f.read().strip() + ' '
+    except Exception as e:
+        logging.error(f"Error reading text file {file_path}: {e}")
+    return ''
+
+def read_pdf_file(file_path):
+    try:
+        with open(file_path, 'rb') as f:
+            pdf_reader = PdfReader(f)
+            num_pages = len(pdf_reader.pages)
+            file_text = [pdf_reader.pages[page_num].extract_text().strip() + ' ' for page_num in range(num_pages)]
+            return ''.join(file_text)
+    except Exception as e:
+        logging.error(f"Error reading PDF file {file_path}: {e}")
+    return ''
+
+def read_json_file(file_path):
+    try:
+        with open(file_path, 'r') as f:
+            data = json.load(f)
+            # Assuming each item in the list has a 'question' and 'answer' key
+            # Concatenating question and answer pairs with a space in between and accumulating them into a single string
+            file_text = ' '.join([item['question'].strip() + ' ' + item['answer'].strip() + ' ' for item in data])
+            return file_text
+    except Exception as e:
+        logging.error(f"Error reading JSON file {file_path}: {e}")
+    return ''
+
+
+def process_file(file_path):
+    file_type = magic.from_file(file_path, mime=True)
+    if file_type in ['text/plain', 'text/markdown', 'JSON']:
+        return read_text_file(file_path)
+    elif file_type == 'application/pdf':
+        return read_pdf_file(file_path)
+    else:
+        logging.warning(f"Unsupported file type {file_type} for file {file_path}")
+        return ''
+
+def read_file_content(context):
+    file_strings = []
+
+    for root, _, files in os.walk(context['data_dir']):
+        for file in files:
+            file_path = os.path.join(root, file)
+            file_text = process_file(file_path)
+            if file_text:
+                file_strings.append(file_text)
+
+    return ' '.join(file_strings)
+
+
+
+def parse_qa_to_json(response_string):
+    # Adjusted regex to capture question-answer pairs more flexibly
+    # This pattern accounts for optional numbering and different question/answer lead-ins
+    pattern = re.compile(
+        r"\d*\.\s*Question:\s*(.*?)\nAnswer:\s*(.*?)(?=\n\d*\.\s*Question:|\Z)", 
+        re.DOTALL
+    )
+
+    # Find all matches in the response string
+    matches = pattern.findall(response_string)
+
+    # Convert matches to a structured format
+    qa_list = [{"question": match[0].strip(), "answer": match[1].strip()} for match in matches]
+
+    # Convert the list to a JSON string
+    return json.dumps(qa_list, indent=4)
+
+
+async def prepare_and_send_request(chat_service, api_context: dict, document_content: str, total_questions: int) -> dict:
+    prompt_for_system = api_context['question_prompt_template'].format(total_questions=total_questions, language=api_context["language"])
+    chat_request_payload = [{'role': 'system', 'content': prompt_for_system}, {'role': 'user', 'content': document_content}]
+    return json.loads(await chat_service.execute_chat_request_async(api_context, chat_request_payload))
+
+async def generate_question_batches(chat_service, api_context: dict):
+    document_text = read_file_content(api_context)
+    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
+    document_batches = split_text_into_chunks(api_context, document_text, tokenizer)
+    
+    total_questions = api_context["total_questions"]
+    batches_count = len(document_batches)
+    base_questions_per_batch = total_questions // batches_count
+    extra_questions = total_questions % batches_count
+
+    print(f"Questions per batch: {base_questions_per_batch} (+1 for the first {extra_questions} batches), Total questions: {total_questions}, Batches: {batches_count}")
+    generation_tasks = []
+    for batch_index, batch_content in enumerate(document_batches):
+        print(f"len of batch_content: {len(batch_content)}, batch_index: {batch_index}")
+        #Distribute extra questions across the first few batches
+        questions_in_current_batch = base_questions_per_batch + (1 if batch_index < extra_questions else 0)
+        print(f"Batch {batch_index + 1} - {questions_in_current_batch} questions ********")
+        generation_tasks.append(prepare_and_send_request(chat_service, api_context, batch_content, questions_in_current_batch))
+
+    question_generation_results = await asyncio.gather(*generation_tasks)
+
+    return question_generation_results
+
+
+