瀏覽代碼

update for llama

Hamid Shojanazeri 1 年之前
父節點
當前提交
14d935c95a
共有 1 個文件被更改,包括 36 次插入9 次删除
  1. 36 9
      tutorials/chatbot/data_pipelines/generator_utils.py

+ 36 - 9
tutorials/chatbot/data_pipelines/generator_utils.py

@@ -2,7 +2,9 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 
 import os
 import os
-import openai
+import re
+from transformers import  AutoTokenizer
+from octoai.client import Client
 import asyncio
 import asyncio
 import magic
 import magic
 from PyPDF2 import PdfReader
 from PyPDF2 import PdfReader
@@ -61,21 +63,44 @@ def read_file_content(context):
     return ' '.join(file_strings)
     return ' '.join(file_strings)
 
 
 
 
+
+def parse_qa_to_json(response_string):
+    # Adjusted regex to capture question-answer pairs more flexibly
+    # This pattern accounts for optional numbering and different question/answer lead-ins
+    pattern = re.compile(
+        r"\d*\.\s*Question:\s*(.*?)\nAnswer:\s*(.*?)(?=\n\d*\.\s*Question:|\Z)", 
+        re.DOTALL
+    )
+
+    # Find all matches in the response string
+    matches = pattern.findall(response_string)
+
+    # Convert matches to a structured format
+    qa_list = [{"question": match[0].strip(), "answer": match[1].strip()} for match in matches]
+
+    # Convert the list to a JSON string
+    return json.dumps(qa_list, indent=4)
+
 async def execute_chat_request_async(api_context: dict, chat_request):
 async def execute_chat_request_async(api_context: dict, chat_request):
     async with request_limiter:
     async with request_limiter:
         try:
         try:
             event_loop = asyncio.get_running_loop()
             event_loop = asyncio.get_running_loop()
-            # Prepare the OpenAI API call
-            openai_chat_call = partial(
-                openai.ChatCompletion.create,
+            # Prepare the API call
+            client = Client(api_context['api_key'])
+            api_chat_call = partial(
+                client.chat.completions.create,
                 model=api_context['model'],
                 model=api_context['model'],
                 messages=chat_request,
                 messages=chat_request,
                 temperature=0.0
                 temperature=0.0
             )
             )
             # Execute the API call in a separate thread
             # Execute the API call in a separate thread
-            response = await event_loop.run_in_executor(None, openai_chat_call)
+            response = await event_loop.run_in_executor(None, api_chat_call)
             # Extract and return the assistant's response
             # Extract and return the assistant's response
-            return next((message['message']['content'] for message in response.choices if message['message']['role'] == 'assistant'), "")
+            # return next((message['message']['content'] for message in response.choices if message['message']['role'] == 'assistant'), "")
+            assistant_response = next((choice.message.content for choice in response.choices if choice.message.role == 'assistant'), "")
+            assistant_response_json = parse_qa_to_json(assistant_response)
+                  
+            return assistant_response_json
         except Exception as error:
         except Exception as error:
             print(f"Error during chat request execution: {error}")
             print(f"Error during chat request execution: {error}")
             return ""
             return ""
@@ -87,7 +112,8 @@ async def prepare_and_send_request(api_context: dict, document_content: str, tot
 
 
 async def generate_question_batches(api_context: dict):
 async def generate_question_batches(api_context: dict):
     document_text = read_file_content(api_context)
     document_text = read_file_content(api_context)
-    document_batches = split_text_into_chunks(api_context, document_text)
+    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
+    document_batches = split_text_into_chunks(api_context, document_text, tokenizer)
     
     
     total_questions = api_context["total_questions"]
     total_questions = api_context["total_questions"]
     batches_count = len(document_batches)
     batches_count = len(document_batches)
@@ -95,13 +121,14 @@ async def generate_question_batches(api_context: dict):
     extra_questions = total_questions % batches_count
     extra_questions = total_questions % batches_count
 
 
     print(f"Questions per batch: {base_questions_per_batch} (+1 for the first {extra_questions} batches), Total questions: {total_questions}, Batches: {batches_count}")
     print(f"Questions per batch: {base_questions_per_batch} (+1 for the first {extra_questions} batches), Total questions: {total_questions}, Batches: {batches_count}")
-    
     generation_tasks = []
     generation_tasks = []
     for batch_index, batch_content in enumerate(document_batches):
     for batch_index, batch_content in enumerate(document_batches):
-        # Distribute extra questions across the first few batches
+        print(f"len of batch_content: {len(batch_content)}, batch_index: {batch_index}")
+        #Distribute extra questions across the first few batches
         questions_in_current_batch = base_questions_per_batch + (1 if batch_index < extra_questions else 0)
         questions_in_current_batch = base_questions_per_batch + (1 if batch_index < extra_questions else 0)
         print(f"Batch {batch_index + 1} - {questions_in_current_batch} questions ********")
         print(f"Batch {batch_index + 1} - {questions_in_current_batch} questions ********")
         generation_tasks.append(prepare_and_send_request(api_context, batch_content, questions_in_current_batch))
         generation_tasks.append(prepare_and_send_request(api_context, batch_content, questions_in_current_batch))
+    # generation_tasks.append(prepare_and_send_request(api_context, document_batches_2[0], total_questions))
 
 
     question_generation_results = await asyncio.gather(*generation_tasks)
     question_generation_results = await asyncio.gather(*generation_tasks)