generator_utils.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import re
  5. from transformers import AutoTokenizer
  6. from octoai.client import Client
  7. import asyncio
  8. import magic
  9. from PyPDF2 import PdfReader
  10. import json
  11. from doc_processor import split_text_into_chunks
  12. import logging
  13. # Initialize logging
  14. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
  15. def read_text_file(file_path):
  16. try:
  17. with open(file_path, 'r') as f:
  18. return f.read().strip() + ' '
  19. except Exception as e:
  20. logging.error(f"Error reading text file {file_path}: {e}")
  21. return ''
  22. def read_pdf_file(file_path):
  23. try:
  24. with open(file_path, 'rb') as f:
  25. pdf_reader = PdfReader(f)
  26. num_pages = len(pdf_reader.pages)
  27. file_text = [pdf_reader.pages[page_num].extract_text().strip() + ' ' for page_num in range(num_pages)]
  28. return ''.join(file_text)
  29. except Exception as e:
  30. logging.error(f"Error reading PDF file {file_path}: {e}")
  31. return ''
  32. def read_json_file(file_path):
  33. try:
  34. with open(file_path, 'r') as f:
  35. data = json.load(f)
  36. # Assuming each item in the list has a 'question' and 'answer' key
  37. # Concatenating question and answer pairs with a space in between and accumulating them into a single string
  38. file_text = ' '.join([item['question'].strip() + ' ' + item['answer'].strip() + ' ' for item in data])
  39. return file_text
  40. except Exception as e:
  41. logging.error(f"Error reading JSON file {file_path}: {e}")
  42. return ''
  43. def process_file(file_path):
  44. file_type = magic.from_file(file_path, mime=True)
  45. if file_type in ['text/plain', 'text/markdown', 'JSON']:
  46. return read_text_file(file_path)
  47. elif file_type == 'application/pdf':
  48. return read_pdf_file(file_path)
  49. else:
  50. logging.warning(f"Unsupported file type {file_type} for file {file_path}")
  51. return ''
  52. def read_file_content(context):
  53. file_strings = []
  54. for root, _, files in os.walk(context['data_dir']):
  55. for file in files:
  56. file_path = os.path.join(root, file)
  57. file_text = process_file(file_path)
  58. if file_text:
  59. file_strings.append(file_text)
  60. return ' '.join(file_strings)
  61. def parse_qa_to_json(response_string):
  62. # Adjusted regex to capture question-answer pairs more flexibly
  63. # This pattern accounts for optional numbering and different question/answer lead-ins
  64. pattern = re.compile(
  65. r"\d*\.\s*Question:\s*(.*?)\nAnswer:\s*(.*?)(?=\n\d*\.\s*Question:|\Z)",
  66. re.DOTALL
  67. )
  68. # Find all matches in the response string
  69. matches = pattern.findall(response_string)
  70. # Convert matches to a structured format
  71. qa_list = [{"question": match[0].strip(), "answer": match[1].strip()} for match in matches]
  72. # Convert the list to a JSON string
  73. return json.dumps(qa_list, indent=4)
  74. async def prepare_and_send_request(chat_service, api_context: dict, document_content: str, total_questions: int) -> dict:
  75. prompt_for_system = api_context['question_prompt_template'].format(total_questions=total_questions, language=api_context["language"])
  76. chat_request_payload = [{'role': 'system', 'content': prompt_for_system}, {'role': 'user', 'content': document_content}]
  77. return json.loads(await chat_service.execute_chat_request_async(api_context, chat_request_payload))
  78. async def generate_question_batches(chat_service, api_context: dict):
  79. document_text = read_file_content(api_context)
  80. tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
  81. document_batches = split_text_into_chunks(api_context, document_text, tokenizer)
  82. total_questions = api_context["total_questions"]
  83. batches_count = len(document_batches)
  84. base_questions_per_batch = total_questions // batches_count
  85. extra_questions = total_questions % batches_count
  86. print(f"Questions per batch: {base_questions_per_batch} (+1 for the first {extra_questions} batches), Total questions: {total_questions}, Batches: {batches_count}")
  87. generation_tasks = []
  88. for batch_index, batch_content in enumerate(document_batches):
  89. print(f"len of batch_content: {len(batch_content)}, batch_index: {batch_index}")
  90. #Distribute extra questions across the first few batches
  91. questions_in_current_batch = base_questions_per_batch + (1 if batch_index < extra_questions else 0)
  92. print(f"Batch {batch_index + 1} - {questions_in_current_batch} questions ********")
  93. generation_tasks.append(prepare_and_send_request(chat_service, api_context, batch_content, questions_in_current_batch))
  94. question_generation_results = await asyncio.gather(*generation_tasks)
  95. return question_generation_results