Procházet zdrojové kódy

add chat service abstraction

Hamid Shojanazeri před 1 rokem
rodič
revize
65af017cc4

+ 39 - 2
tutorials/chatbot/data_pipelines/generate_question_answers.py

@@ -5,18 +5,55 @@ import argparse
 import asyncio
 import json
 from config import load_config
-from generator_utils import generate_question_batches
+from generator_utils import generate_question_batches, parse_qa_to_json
 from itertools import chain
 import logging
 import aiofiles  # Ensure aiofiles is installed for async file operations
+from abc import ABC, abstractmethod
+from octoai.client import Client
+from functools import partial
 
 # Configure logging to include the timestamp, log level, and message
 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
+# Manage rate limits with throttling
+rate_limit_threshold = 2000
+allowed_concurrent_requests = int(rate_limit_threshold * 0.75)
+request_limiter = asyncio.Semaphore(allowed_concurrent_requests)
+
+class ChatService(ABC):
+    @abstractmethod
+    async def execute_chat_request_async(self, api_context: dict, chat_request):
+        pass
+
+# Please implement your own chat service class here.
+# The class should inherit from the ChatService class and implement the execute_chat_request_async method.
+class OctoAIChatService(ChatService):
+    async def execute_chat_request_async(self, api_context: dict, chat_request):
+        async with request_limiter:
+            try:
+                event_loop = asyncio.get_running_loop()
+                client = Client(api_context['api_key'])
+                api_chat_call = partial(
+                    client.chat.completions.create,
+                    model=api_context['model'],
+                    messages=chat_request,
+                    temperature=0.0
+                )
+                response = await event_loop.run_in_executor(None, api_chat_call)
+                assistant_response = next((choice.message.content for choice in response.choices if choice.message.role == 'assistant'), "")
+                assistant_response_json = parse_qa_to_json(assistant_response)
+                      
+                return assistant_response_json
+            except Exception as error:
+                print(f"Error during chat request execution: {error}")
+                return ""
+            
 async def main(context):
+    chat_service = OctoAIChatService()
     try:
         logging.info("Starting to generate question/answer pairs.")
-        data = await generate_question_batches(context)
+        data = await generate_question_batches(chat_service, context)
         if not data:
             logging.warning("No data generated. Please check the input context or model configuration.")
             return

+ 5 - 35
tutorials/chatbot/data_pipelines/generator_utils.py

@@ -8,18 +8,12 @@ from octoai.client import Client
 import asyncio
 import magic
 from PyPDF2 import PdfReader
-from functools import partial
 import json
 from doc_processor import split_text_into_chunks
 import logging
-
 # Initialize logging
 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
-# Manage rate limits with throttling
-rate_limit_threshold = 2000
-allowed_concurrent_requests = int(rate_limit_threshold * 0.75)
-request_limiter = asyncio.Semaphore(allowed_concurrent_requests)
 
 def read_text_file(file_path):
     try:
@@ -81,36 +75,13 @@ def parse_qa_to_json(response_string):
     # Convert the list to a JSON string
     return json.dumps(qa_list, indent=4)
 
-async def execute_chat_request_async(api_context: dict, chat_request):
-    async with request_limiter:
-        try:
-            event_loop = asyncio.get_running_loop()
-            # Prepare the API call
-            client = Client(api_context['api_key'])
-            api_chat_call = partial(
-                client.chat.completions.create,
-                model=api_context['model'],
-                messages=chat_request,
-                temperature=0.0
-            )
-            # Execute the API call in a separate thread
-            response = await event_loop.run_in_executor(None, api_chat_call)
-            # Extract and return the assistant's response
-            # return next((message['message']['content'] for message in response.choices if message['message']['role'] == 'assistant'), "")
-            assistant_response = next((choice.message.content for choice in response.choices if choice.message.role == 'assistant'), "")
-            assistant_response_json = parse_qa_to_json(assistant_response)
-                  
-            return assistant_response_json
-        except Exception as error:
-            print(f"Error during chat request execution: {error}")
-            return ""
-
-async def prepare_and_send_request(api_context: dict, document_content: str, total_questions: int) -> dict:
+
+async def prepare_and_send_request(chat_service, api_context: dict, document_content: str, total_questions: int) -> dict:
     prompt_for_system = api_context['question_prompt_template'].format(total_questions=total_questions, language=api_context["language"])
     chat_request_payload = [{'role': 'system', 'content': prompt_for_system}, {'role': 'user', 'content': document_content}]
-    return json.loads(await execute_chat_request_async(api_context, chat_request_payload))
+    return json.loads(await chat_service.execute_chat_request_async(api_context, chat_request_payload))
 
-async def generate_question_batches(api_context: dict):
+async def generate_question_batches(chat_service, api_context: dict):
     document_text = read_file_content(api_context)
     tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
     document_batches = split_text_into_chunks(api_context, document_text, tokenizer)
@@ -127,8 +98,7 @@ async def generate_question_batches(api_context: dict):
         #Distribute extra questions across the first few batches
         questions_in_current_batch = base_questions_per_batch + (1 if batch_index < extra_questions else 0)
         print(f"Batch {batch_index + 1} - {questions_in_current_batch} questions ********")
-        generation_tasks.append(prepare_and_send_request(api_context, batch_content, questions_in_current_batch))
-    # generation_tasks.append(prepare_and_send_request(api_context, document_batches_2[0], total_questions))
+        generation_tasks.append(prepare_and_send_request(chat_service, api_context, batch_content, questions_in_current_batch))
 
     question_generation_results = await asyncio.gather(*generation_tasks)