generator_utils.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import os
  2. import openai
  3. import asyncio
  4. import magic
  5. from PyPDF2 import PdfReader
  6. from functools import partial
  7. import json
  8. from token_processor import split_text_into_tokenized_chunks
  9. # from file_handler import read_file_content
  10. import logging
  11. # Initialize logging
  12. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
  13. # Manage rate limits with throttling
  14. rate_limit_threshold = 100
  15. allowed_concurrent_requests = int(rate_limit_threshold * 0.75)
  16. request_limiter = asyncio.Semaphore(allowed_concurrent_requests)
  17. def read_text_file(file_path):
  18. try:
  19. with open(file_path, 'r') as f:
  20. return f.read().strip() + ' '
  21. except Exception as e:
  22. logging.error(f"Error reading text file {file_path}: {e}")
  23. return ''
  24. def read_pdf_file(file_path):
  25. try:
  26. with open(file_path, 'rb') as f:
  27. pdf_reader = PdfReader(f)
  28. num_pages = len(pdf_reader.pages)
  29. file_text = [pdf_reader.pages[page_num].extract_text().strip() + ' ' for page_num in range(num_pages)]
  30. return ''.join(file_text)
  31. except Exception as e:
  32. logging.error(f"Error reading PDF file {file_path}: {e}")
  33. return ''
  34. def process_file(file_path):
  35. file_type = magic.from_file(file_path, mime=True)
  36. if file_type in ['text/plain', 'text/markdown']:
  37. return read_text_file(file_path)
  38. elif file_type == 'application/pdf':
  39. return read_pdf_file(file_path)
  40. else:
  41. logging.warning(f"Unsupported file type {file_type} for file {file_path}")
  42. return ''
  43. def read_file_content(context):
  44. file_strings = []
  45. for root, _, files in os.walk(context['data_dir']):
  46. for file in files:
  47. file_path = os.path.join(root, file)
  48. file_text = process_file(file_path)
  49. if file_text:
  50. file_strings.append(file_text)
  51. return ' '.join(file_strings)
  52. async def execute_chat_request_async(api_context: dict, chat_request):
  53. async with request_limiter:
  54. try:
  55. event_loop = asyncio.get_running_loop()
  56. # Prepare the OpenAI API call
  57. openai_chat_call = partial(
  58. openai.ChatCompletion.create,
  59. model=api_context['model'],
  60. messages=chat_request,
  61. temperature=0.0
  62. )
  63. # Execute the API call in a separate thread
  64. response = await event_loop.run_in_executor(None, openai_chat_call)
  65. # Extract and return the assistant's response
  66. return next((message['message']['content'] for message in response.choices if message['message']['role'] == 'assistant'), "")
  67. except Exception as error:
  68. print(f"Error during chat request execution: {error}")
  69. return ""
  70. async def prepare_and_send_request(api_context: dict, document_content: str, total_questions: int) -> dict:
  71. prompt_for_system = api_context['question_prompt_template'].format(total_questions=total_questions, language=api_context["language"])
  72. chat_request_payload = [{'role': 'system', 'content': prompt_for_system}, {'role': 'user', 'content': document_content}]
  73. return json.loads(await execute_chat_request_async(api_context, chat_request_payload))
  74. async def generate_question_batches(api_context: dict):
  75. document_text = read_file_content(api_context)
  76. print("completed step 1")
  77. document_batches = split_text_into_tokenized_chunks(api_context, document_text)
  78. print("completed step 2")
  79. questions_per_batch = api_context["total_questions"] // len(document_batches)
  80. print("completed step 3")
  81. generation_tasks = []
  82. for batch_index, batch_content in enumerate(document_batches):
  83. questions_in_current_batch = questions_per_batch + 1 if batch_index == len(document_batches) - 1 and len(document_batches) * questions_per_batch < api_context["total_questions"] else questions_per_batch
  84. generation_tasks.append(prepare_and_send_request(api_context, batch_content, questions_in_current_batch))
  85. print("completed step 4")
  86. question_generation_results = await asyncio.gather(*generation_tasks)
  87. print("completed step 5")
  88. return question_generation_results