|
@@ -2,7 +2,9 @@
|
|
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
|
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
|
|
|
|
|
import os
|
|
import os
|
|
-import openai
|
|
|
|
|
|
+import re
|
|
|
|
+from transformers import AutoTokenizer
|
|
|
|
+from octoai.client import Client
|
|
import asyncio
|
|
import asyncio
|
|
import magic
|
|
import magic
|
|
from PyPDF2 import PdfReader
|
|
from PyPDF2 import PdfReader
|
|
@@ -61,21 +63,44 @@ def read_file_content(context):
|
|
return ' '.join(file_strings)
|
|
return ' '.join(file_strings)
|
|
|
|
|
|
|
|
|
|
|
|
+
|
|
|
|
+def parse_qa_to_json(response_string):
|
|
|
|
+ # Adjusted regex to capture question-answer pairs more flexibly
|
|
|
|
+ # This pattern accounts for optional numbering and different question/answer lead-ins
|
|
|
|
+ pattern = re.compile(
|
|
|
|
+ r"\d*\.\s*Question:\s*(.*?)\nAnswer:\s*(.*?)(?=\n\d*\.\s*Question:|\Z)",
|
|
|
|
+ re.DOTALL
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ # Find all matches in the response string
|
|
|
|
+ matches = pattern.findall(response_string)
|
|
|
|
+
|
|
|
|
+ # Convert matches to a structured format
|
|
|
|
+ qa_list = [{"question": match[0].strip(), "answer": match[1].strip()} for match in matches]
|
|
|
|
+
|
|
|
|
+ # Convert the list to a JSON string
|
|
|
|
+ return json.dumps(qa_list, indent=4)
|
|
|
|
+
|
|
async def execute_chat_request_async(api_context: dict, chat_request):
|
|
async def execute_chat_request_async(api_context: dict, chat_request):
|
|
async with request_limiter:
|
|
async with request_limiter:
|
|
try:
|
|
try:
|
|
event_loop = asyncio.get_running_loop()
|
|
event_loop = asyncio.get_running_loop()
|
|
- # Prepare the OpenAI API call
|
|
|
|
- openai_chat_call = partial(
|
|
|
|
- openai.ChatCompletion.create,
|
|
|
|
|
|
+ # Prepare the API call
|
|
|
|
+ client = Client(api_context['api_key'])
|
|
|
|
+ api_chat_call = partial(
|
|
|
|
+ client.chat.completions.create,
|
|
model=api_context['model'],
|
|
model=api_context['model'],
|
|
messages=chat_request,
|
|
messages=chat_request,
|
|
temperature=0.0
|
|
temperature=0.0
|
|
)
|
|
)
|
|
# Execute the API call in a separate thread
|
|
# Execute the API call in a separate thread
|
|
- response = await event_loop.run_in_executor(None, openai_chat_call)
|
|
|
|
|
|
+ response = await event_loop.run_in_executor(None, api_chat_call)
|
|
# Extract and return the assistant's response
|
|
# Extract and return the assistant's response
|
|
- return next((message['message']['content'] for message in response.choices if message['message']['role'] == 'assistant'), "")
|
|
|
|
|
|
+ # return next((message['message']['content'] for message in response.choices if message['message']['role'] == 'assistant'), "")
|
|
|
|
+ assistant_response = next((choice.message.content for choice in response.choices if choice.message.role == 'assistant'), "")
|
|
|
|
+ assistant_response_json = parse_qa_to_json(assistant_response)
|
|
|
|
+
|
|
|
|
+ return assistant_response_json
|
|
except Exception as error:
|
|
except Exception as error:
|
|
print(f"Error during chat request execution: {error}")
|
|
print(f"Error during chat request execution: {error}")
|
|
return ""
|
|
return ""
|
|
@@ -87,7 +112,8 @@ async def prepare_and_send_request(api_context: dict, document_content: str, tot
|
|
|
|
|
|
async def generate_question_batches(api_context: dict):
|
|
async def generate_question_batches(api_context: dict):
|
|
document_text = read_file_content(api_context)
|
|
document_text = read_file_content(api_context)
|
|
- document_batches = split_text_into_chunks(api_context, document_text)
|
|
|
|
|
|
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
|
|
|
|
+ document_batches = split_text_into_chunks(api_context, document_text, tokenizer)
|
|
|
|
|
|
total_questions = api_context["total_questions"]
|
|
total_questions = api_context["total_questions"]
|
|
batches_count = len(document_batches)
|
|
batches_count = len(document_batches)
|
|
@@ -95,13 +121,14 @@ async def generate_question_batches(api_context: dict):
|
|
extra_questions = total_questions % batches_count
|
|
extra_questions = total_questions % batches_count
|
|
|
|
|
|
print(f"Questions per batch: {base_questions_per_batch} (+1 for the first {extra_questions} batches), Total questions: {total_questions}, Batches: {batches_count}")
|
|
print(f"Questions per batch: {base_questions_per_batch} (+1 for the first {extra_questions} batches), Total questions: {total_questions}, Batches: {batches_count}")
|
|
-
|
|
|
|
generation_tasks = []
|
|
generation_tasks = []
|
|
for batch_index, batch_content in enumerate(document_batches):
|
|
for batch_index, batch_content in enumerate(document_batches):
|
|
- # Distribute extra questions across the first few batches
|
|
|
|
|
|
+ print(f"len of batch_content: {len(batch_content)}, batch_index: {batch_index}")
|
|
|
|
+ #Distribute extra questions across the first few batches
|
|
questions_in_current_batch = base_questions_per_batch + (1 if batch_index < extra_questions else 0)
|
|
questions_in_current_batch = base_questions_per_batch + (1 if batch_index < extra_questions else 0)
|
|
print(f"Batch {batch_index + 1} - {questions_in_current_batch} questions ********")
|
|
print(f"Batch {batch_index + 1} - {questions_in_current_batch} questions ********")
|
|
generation_tasks.append(prepare_and_send_request(api_context, batch_content, questions_in_current_batch))
|
|
generation_tasks.append(prepare_and_send_request(api_context, batch_content, questions_in_current_batch))
|
|
|
|
+ # generation_tasks.append(prepare_and_send_request(api_context, document_batches_2[0], total_questions))
|
|
|
|
|
|
question_generation_results = await asyncio.gather(*generation_tasks)
|
|
question_generation_results = await asyncio.gather(*generation_tasks)
|
|
|
|
|