generator_utils.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import re
  5. from transformers import AutoTokenizer
  6. import asyncio
  7. import magic
  8. from PyPDF2 import PdfReader
  9. import json
  10. from doc_processor import split_text_into_chunks
  11. import logging
  12. import json
  13. # Initialize logging
  14. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
  15. def read_text_file(file_path):
  16. try:
  17. with open(file_path, 'r') as f:
  18. text = f.read().strip() + ' '
  19. if len(text) == 0:
  20. print("File is empty ",file_path)
  21. return text
  22. except Exception as e:
  23. logging.error(f"Error reading text file {file_path}: {e}")
  24. return ''
  25. def read_pdf_file(file_path):
  26. try:
  27. with open(file_path, 'rb') as f:
  28. pdf_reader = PdfReader(f)
  29. num_pages = len(pdf_reader.pages)
  30. file_text = [pdf_reader.pages[page_num].extract_text().strip() + ' ' for page_num in range(num_pages)]
  31. text = ''.join(file_text)
  32. if len(text) == 0:
  33. print("File is empty ",file_path)
  34. return ''.join(file_text)
  35. except Exception as e:
  36. logging.error(f"Error reading PDF file {file_path}: {e}")
  37. return ''
  38. def read_json_file(file_path):
  39. try:
  40. with open(file_path, 'r') as f:
  41. data = json.load(f)
  42. # Assuming each item in the list has a 'question' and 'answer' key
  43. # Concatenating question and answer pairs with a space in between and accumulating them into a single string
  44. file_text = ' '.join([item['question'].strip() + ' ' + item['answer'].strip() + ' ' for item in data])
  45. if len(file_text) == 0:
  46. print("File is empty ",file_path)
  47. return file_text
  48. except Exception as e:
  49. logging.error(f"Error reading JSON file {file_path}: {e}")
  50. return ''
  51. def process_file(file_path):
  52. print("starting to process file: ", file_path)
  53. file_type = magic.from_file(file_path, mime=True)
  54. if file_type in ['text/plain', 'text/markdown', 'JSON']:
  55. return read_text_file(file_path)
  56. elif file_type == 'application/pdf':
  57. return read_pdf_file(file_path)
  58. else:
  59. logging.warning(f"Unsupported file type {file_type} for file {file_path}")
  60. return ''
  61. def read_file_content(context):
  62. file_strings = []
  63. for root, _, files in os.walk(context['data_dir']):
  64. for file in files:
  65. file_path = os.path.join(root, file)
  66. file_text = process_file(file_path)
  67. if file_text:
  68. file_strings.append(file_text)
  69. text = ' '.join(file_strings)
  70. if len(text) == 0:
  71. logging.error(f"Error reading files, text is empty")
  72. return ' '.join(file_strings)
  73. # clean the text by removing all parts that did not contain any alphanumeric characters
  74. def clean(s):
  75. result = []
  76. for item in s.split('"'):
  77. if any(c.isalnum() for c in item):
  78. result.append(item)
  79. return " ".join(result)
  80. def parse_qa_to_json(response_string):
  81. split_lines = response_string.split("\n")
  82. start,end = None,None
  83. # must use set to avoid duplicate question/answer pairs due to async function calls
  84. qa_set = set()
  85. for i in range(len(split_lines)):
  86. line = split_lines[i]
  87. # starting to find "Question"
  88. if not start:
  89. # Once found, set start to this line number
  90. if '"Question":' in line:
  91. start = i
  92. else:
  93. # "Question" has been found, find "Answer", once found, set end to this line number
  94. if '"Answer":' in line:
  95. end = i
  96. # found Question means we have reached the end of the question, so add it to qa_list
  97. elif '"Question":' in line:
  98. question = " ".join(split_lines[start:end]).split('"Question":')[1]
  99. answer = " ".join(split_lines[end:i]).split('"Answer":')[1]
  100. start,end = i,None
  101. qa_set.add((clean(question), clean(answer)))
  102. # adding last question back to qa_list
  103. if start and end:
  104. question = " ".join(split_lines[start:end]).split('"Question":')[1]
  105. answer = " ".join(split_lines[end:]).split('"Answer":')[1]
  106. qa_set.add((clean(question), clean(answer)))
  107. qa_list = [{"question": q, "answer":a} for q,a in qa_set]
  108. return json.dumps(qa_list, indent=4)
  109. async def prepare_and_send_request(chat_service, api_context: dict, document_content: str, num_questions: int) -> dict:
  110. prompt_for_system = api_context['question_prompt_template'].format(num_questions=num_questions, language=api_context["language"])
  111. chat_request_payload = [{'role': 'system', 'content': prompt_for_system}, {'role': 'user', 'content': document_content}]
  112. result = await chat_service.execute_chat_request_async(api_context, chat_request_payload)
  113. if not result:
  114. return {}
  115. return json.loads(await chat_service.execute_chat_request_async(api_context, chat_request_payload))
  116. async def generate_question_batches(chat_service, api_context: dict):
  117. document_text = read_file_content(api_context)
  118. if len(document_text)== 0:
  119. logging.error(f"Error reading files, document_text is empty")
  120. if api_context["model"] in ["meta-llama-3-70b-instruct","meta-llama-3-8b-instruct"]:
  121. tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", pad_token="</s>", padding_side="right")
  122. else:
  123. tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
  124. document_batches = split_text_into_chunks(api_context, document_text, tokenizer)
  125. total_questions = api_context["total_questions"]
  126. batches_count = len(document_batches)
  127. # each batch should have at least 1 question
  128. base_questions_per_batch = max(total_questions // batches_count,1)
  129. extra_questions = total_questions % batches_count
  130. print(f"Questions per batch: {base_questions_per_batch} (+1 for the first {extra_questions} batches), Total questions: {total_questions}, Batches: {batches_count}")
  131. generation_tasks = []
  132. for batch_index, batch_content in enumerate(document_batches):
  133. print(f"len of batch_content: {len(batch_content)}, batch_index: {batch_index}")
  134. #Distribute extra questions across the first few batches
  135. questions_in_current_batch = base_questions_per_batch + (1 if batch_index < extra_questions else 0)
  136. print(f"Batch {batch_index + 1} - {questions_in_current_batch} questions ********")
  137. try:
  138. result = prepare_and_send_request(chat_service, api_context, batch_content, questions_in_current_batch)
  139. generation_tasks.append(result)
  140. except Exception as e:
  141. print(f"Error during chat request execution: {e}")
  142. question_generation_results = await asyncio.gather(*generation_tasks)
  143. return question_generation_results