generator_utils.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import openai
  5. import asyncio
  6. import magic
  7. from PyPDF2 import PdfReader
  8. from functools import partial
  9. import json
  10. from token_processor import split_text_into_tokenized_chunks
  11. # from file_handler import read_file_content
  12. import logging
  13. # Initialize logging
  14. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
  15. # Manage rate limits with throttling
  16. rate_limit_threshold = 100
  17. allowed_concurrent_requests = int(rate_limit_threshold * 0.75)
  18. request_limiter = asyncio.Semaphore(allowed_concurrent_requests)
  19. def read_text_file(file_path):
  20. try:
  21. with open(file_path, 'r') as f:
  22. return f.read().strip() + ' '
  23. except Exception as e:
  24. logging.error(f"Error reading text file {file_path}: {e}")
  25. return ''
  26. def read_pdf_file(file_path):
  27. try:
  28. with open(file_path, 'rb') as f:
  29. pdf_reader = PdfReader(f)
  30. num_pages = len(pdf_reader.pages)
  31. file_text = [pdf_reader.pages[page_num].extract_text().strip() + ' ' for page_num in range(num_pages)]
  32. return ''.join(file_text)
  33. except Exception as e:
  34. logging.error(f"Error reading PDF file {file_path}: {e}")
  35. return ''
  36. def process_file(file_path):
  37. file_type = magic.from_file(file_path, mime=True)
  38. if file_type in ['text/plain', 'text/markdown']:
  39. return read_text_file(file_path)
  40. elif file_type == 'application/pdf':
  41. return read_pdf_file(file_path)
  42. else:
  43. logging.warning(f"Unsupported file type {file_type} for file {file_path}")
  44. return ''
  45. def read_file_content(context):
  46. file_strings = []
  47. for root, _, files in os.walk(context['data_dir']):
  48. for file in files:
  49. file_path = os.path.join(root, file)
  50. file_text = process_file(file_path)
  51. if file_text:
  52. file_strings.append(file_text)
  53. return ' '.join(file_strings)
  54. async def execute_chat_request_async(api_context: dict, chat_request):
  55. async with request_limiter:
  56. try:
  57. event_loop = asyncio.get_running_loop()
  58. # Prepare the OpenAI API call
  59. openai_chat_call = partial(
  60. openai.ChatCompletion.create,
  61. model=api_context['model'],
  62. messages=chat_request,
  63. temperature=0.0
  64. )
  65. # Execute the API call in a separate thread
  66. response = await event_loop.run_in_executor(None, openai_chat_call)
  67. # Extract and return the assistant's response
  68. return next((message['message']['content'] for message in response.choices if message['message']['role'] == 'assistant'), "")
  69. except Exception as error:
  70. print(f"Error during chat request execution: {error}")
  71. return ""
  72. async def prepare_and_send_request(api_context: dict, document_content: str, total_questions: int) -> dict:
  73. prompt_for_system = api_context['question_prompt_template'].format(total_questions=total_questions, language=api_context["language"])
  74. chat_request_payload = [{'role': 'system', 'content': prompt_for_system}, {'role': 'user', 'content': document_content}]
  75. return json.loads(await execute_chat_request_async(api_context, chat_request_payload))
  76. async def generate_question_batches(api_context: dict):
  77. document_text = read_file_content(api_context)
  78. print("completed step 1")
  79. document_batches = split_text_into_tokenized_chunks(api_context, document_text)
  80. print("completed step 2")
  81. questions_per_batch = api_context["total_questions"] // len(document_batches)
  82. print("completed step 3")
  83. generation_tasks = []
  84. for batch_index, batch_content in enumerate(document_batches):
  85. questions_in_current_batch = questions_per_batch + 1 if batch_index == len(document_batches) - 1 and len(document_batches) * questions_per_batch < api_context["total_questions"] else questions_per_batch
  86. generation_tasks.append(prepare_and_send_request(api_context, batch_content, questions_in_current_batch))
  87. print("completed step 4")
  88. question_generation_results = await asyncio.gather(*generation_tasks)
  89. print("completed step 5")
  90. return question_generation_results