123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
- import os
- import re
- from transformers import AutoTokenizer
- from octoai.client import Client
- import asyncio
- import magic
- from PyPDF2 import PdfReader
- from functools import partial
- import json
- from doc_processor import split_text_into_chunks
- import logging
- # Initialize logging
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
- # Manage rate limits with throttling
- rate_limit_threshold = 2000
- allowed_concurrent_requests = int(rate_limit_threshold * 0.75)
- request_limiter = asyncio.Semaphore(allowed_concurrent_requests)
- def read_text_file(file_path):
- try:
- with open(file_path, 'r') as f:
- return f.read().strip() + ' '
- except Exception as e:
- logging.error(f"Error reading text file {file_path}: {e}")
- return ''
- def read_pdf_file(file_path):
- try:
- with open(file_path, 'rb') as f:
- pdf_reader = PdfReader(f)
- num_pages = len(pdf_reader.pages)
- file_text = [pdf_reader.pages[page_num].extract_text().strip() + ' ' for page_num in range(num_pages)]
- return ''.join(file_text)
- except Exception as e:
- logging.error(f"Error reading PDF file {file_path}: {e}")
- return ''
- def process_file(file_path):
- file_type = magic.from_file(file_path, mime=True)
- if file_type in ['text/plain', 'text/markdown']:
- return read_text_file(file_path)
- elif file_type == 'application/pdf':
- return read_pdf_file(file_path)
- else:
- logging.warning(f"Unsupported file type {file_type} for file {file_path}")
- return ''
- def read_file_content(context):
- file_strings = []
- for root, _, files in os.walk(context['data_dir']):
- for file in files:
- file_path = os.path.join(root, file)
- file_text = process_file(file_path)
- if file_text:
- file_strings.append(file_text)
- return ' '.join(file_strings)
- def parse_qa_to_json(response_string):
- # Adjusted regex to capture question-answer pairs more flexibly
- # This pattern accounts for optional numbering and different question/answer lead-ins
- pattern = re.compile(
- r"\d*\.\s*Question:\s*(.*?)\nAnswer:\s*(.*?)(?=\n\d*\.\s*Question:|\Z)",
- re.DOTALL
- )
- # Find all matches in the response string
- matches = pattern.findall(response_string)
- # Convert matches to a structured format
- qa_list = [{"question": match[0].strip(), "answer": match[1].strip()} for match in matches]
- # Convert the list to a JSON string
- return json.dumps(qa_list, indent=4)
- async def execute_chat_request_async(api_context: dict, chat_request):
- async with request_limiter:
- try:
- event_loop = asyncio.get_running_loop()
- # Prepare the API call
- client = Client(api_context['api_key'])
- api_chat_call = partial(
- client.chat.completions.create,
- model=api_context['model'],
- messages=chat_request,
- temperature=0.0
- )
- # Execute the API call in a separate thread
- response = await event_loop.run_in_executor(None, api_chat_call)
- # Extract and return the assistant's response
- # return next((message['message']['content'] for message in response.choices if message['message']['role'] == 'assistant'), "")
- assistant_response = next((choice.message.content for choice in response.choices if choice.message.role == 'assistant'), "")
- assistant_response_json = parse_qa_to_json(assistant_response)
-
- return assistant_response_json
- except Exception as error:
- print(f"Error during chat request execution: {error}")
- return ""
- async def prepare_and_send_request(api_context: dict, document_content: str, total_questions: int) -> dict:
- prompt_for_system = api_context['question_prompt_template'].format(total_questions=total_questions, language=api_context["language"])
- chat_request_payload = [{'role': 'system', 'content': prompt_for_system}, {'role': 'user', 'content': document_content}]
- return json.loads(await execute_chat_request_async(api_context, chat_request_payload))
- async def generate_question_batches(api_context: dict):
- document_text = read_file_content(api_context)
- tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
- document_batches = split_text_into_chunks(api_context, document_text, tokenizer)
-
- total_questions = api_context["total_questions"]
- batches_count = len(document_batches)
- base_questions_per_batch = total_questions // batches_count
- extra_questions = total_questions % batches_count
- print(f"Questions per batch: {base_questions_per_batch} (+1 for the first {extra_questions} batches), Total questions: {total_questions}, Batches: {batches_count}")
- generation_tasks = []
- for batch_index, batch_content in enumerate(document_batches):
- print(f"len of batch_content: {len(batch_content)}, batch_index: {batch_index}")
- #Distribute extra questions across the first few batches
- questions_in_current_batch = base_questions_per_batch + (1 if batch_index < extra_questions else 0)
- print(f"Batch {batch_index + 1} - {questions_in_current_batch} questions ********")
- generation_tasks.append(prepare_and_send_request(api_context, batch_content, questions_in_current_batch))
- # generation_tasks.append(prepare_and_send_request(api_context, document_batches_2[0], total_questions))
- question_generation_results = await asyncio.gather(*generation_tasks)
- return question_generation_results
|