generator_utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import re
  5. import string
  6. from transformers import AutoTokenizer
  7. import asyncio
  8. import magic
  9. from PyPDF2 import PdfReader
  10. import json
  11. from doc_processor import split_text_into_chunks
  12. import logging
  13. import json
  14. # Initialize logging
  15. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
  16. def read_text_file(file_path):
  17. try:
  18. with open(file_path, 'r') as f:
  19. text = f.read().strip() + ' '
  20. if len(text) == 0:
  21. print("File is empty ",file_path)
  22. return text
  23. except Exception as e:
  24. logging.error(f"Error reading text file {file_path}: {e}")
  25. return ''
  26. def read_pdf_file(file_path):
  27. try:
  28. with open(file_path, 'rb') as f:
  29. pdf_reader = PdfReader(f)
  30. num_pages = len(pdf_reader.pages)
  31. file_text = [pdf_reader.pages[page_num].extract_text().strip() + ' ' for page_num in range(num_pages)]
  32. text = ''.join(file_text)
  33. if len(text) == 0:
  34. print("File is empty ",file_path)
  35. return ''.join(file_text)
  36. except Exception as e:
  37. logging.error(f"Error reading PDF file {file_path}: {e}")
  38. return ''
  39. def read_json_file(file_path):
  40. try:
  41. with open(file_path, 'r') as f:
  42. data = json.load(f)
  43. # Assuming each item in the list has a 'question' and 'answer' key
  44. # Concatenating question and answer pairs with a space in between and accumulating them into a single string
  45. file_text = ' '.join([item['question'].strip() + ' ' + item['answer'].strip() + ' ' for item in data])
  46. if len(file_text) == 0:
  47. print("File is empty ",file_path)
  48. return file_text
  49. except Exception as e:
  50. logging.error(f"Error reading JSON file {file_path}: {e}")
  51. return ''
  52. def process_file(file_path):
  53. print("starting to process file: ", file_path)
  54. file_type = magic.from_file(file_path, mime=True)
  55. if file_type in ['text/plain', 'text/markdown', 'JSON']:
  56. return read_text_file(file_path)
  57. elif file_type == 'application/pdf':
  58. return read_pdf_file(file_path)
  59. else:
  60. logging.warning(f"Unsupported file type {file_type} for file {file_path}")
  61. return ''
  62. def remove_non_printable(s):
  63. printable = set(string.printable)
  64. return ''.join(filter(lambda x: x in printable, s))
  65. def read_file_content(context):
  66. file_strings = []
  67. for root, _, files in os.walk(context['data_dir']):
  68. for file in files:
  69. file_path = os.path.join(root, file)
  70. file_text = process_file(file_path)
  71. if file_text:
  72. file_strings.append(file_text)
  73. text = '\n'.join(file_strings)
  74. return remove_non_printable(text)
  75. # clean the text by removing all parts that did not contain any alphanumeric characters
  76. def clean(s):
  77. result = []
  78. for item in s.split('"'):
  79. if any(c.isalnum() for c in item):
  80. result.append(item)
  81. return " ".join(result)
  82. # given a response string, return a string that can be saved as json.
  83. def parse_qac_to_json(response_string):
  84. split_lines = response_string.split("\n")
  85. start,mid,end = None,None,None
  86. # must use set to avoid duplicate question/answer pairs due to async function calls
  87. qa_set = set()
  88. for i in range(len(split_lines)):
  89. line = split_lines[i]
  90. # starting to find "Question"
  91. if not start:
  92. # Once found, set start to this line number
  93. if '"Question":' in line:
  94. start = i
  95. else:
  96. # "Question" has been found, find "Answer", once found, set end to this line number
  97. if '"Answer":' in line:
  98. mid = i
  99. elif '"Context":' in line:
  100. end = i
  101. # found Question means we have reached the end of the question, so add it to qa_list
  102. elif '"Question":' in line:
  103. question = " ".join(split_lines[start:mid]).split('"Question":')[1]
  104. answer = " ".join(split_lines[mid:end]).split('"Answer":')[1]
  105. context = " ".join(split_lines[end:i]).split('"Context":')[1]
  106. start,mid,end = i,None,None
  107. qa_set.add((clean(question), clean(answer),clean(context)))
  108. # adding last question back to qa_list
  109. if start and mid and end:
  110. question = " ".join(split_lines[start:mid]).split('"Question":')[1]
  111. answer = " ".join(split_lines[mid:end]).split('"Answer":')[1]
  112. context = " ".join(split_lines[end:]).split('"Context":')[1]
  113. start,mid,end = i,None,None
  114. qa_set.add((clean(question), clean(answer),clean(context)))
  115. qa_list = [{"Question": q, "Answer":a, "Context":c} for q,a,c in qa_set]
  116. return json.dumps(qa_list, indent=4)
  117. def parse_qa_to_json(response_string):
  118. split_lines = response_string.split("\n")
  119. start,end = None,None
  120. # must use set to avoid duplicate question/answer pairs due to async function calls
  121. qa_set = set()
  122. for i in range(len(split_lines)):
  123. line = split_lines[i]
  124. # starting to find "Question"
  125. if not start:
  126. # Once found, set start to this line number
  127. if '"Question":' in line:
  128. start = i
  129. else:
  130. # "Question" has been found, find "Answer", once found, set end to this line number
  131. if '"Answer":' in line:
  132. end = i
  133. # found Question means we have reached the end of the question, so add it to qa_list
  134. elif '"Question":' in line:
  135. question = " ".join(split_lines[start:end]).split('"Question":')[1]
  136. answer = " ".join(split_lines[end:i]).split('"Answer":')[1]
  137. start,end = i,None
  138. qa_set.add((clean(question), clean(answer)))
  139. # adding last question back to qa_list
  140. if start and end:
  141. question = " ".join(split_lines[start:end]).split('"Question":')[1]
  142. answer = " ".join(split_lines[end:]).split('"Answer":')[1]
  143. qa_set.add((clean(question), clean(answer)))
  144. qa_list = [{"Question": q, "Answer":a} for q,a in qa_set]
  145. return qa_list
  146. async def prepare_and_send_request(chat_service, api_context: dict, document_content: str, num_questions: int) -> dict:
  147. prompt_for_system = api_context['question_prompt_template'].format(num_questions=num_questions, language=api_context["language"])
  148. chat_request_payload = [{'role': 'system', 'content': prompt_for_system}, {'role': 'user', 'content': document_content}]
  149. result = await chat_service.execute_chat_request_async(api_context, chat_request_payload)
  150. # parse the result string to a list of dict that has Question, Answer, Context
  151. result = parse_qac_to_json(result)
  152. if not result:
  153. return {}
  154. return json.loads(await chat_service.execute_chat_request_async(api_context, chat_request_payload,eval=False))
  155. # This function is used to evaluate the quality of generated QA pairs. Return the original QA pair if the model eval result is YES. Otherwise, return an empty dict.
  156. async def data_curation_request(chat_service, api_context: dict, document_content: dict) -> dict:
  157. prompt_for_system = api_context['curation_prompt_template'].format(language=api_context["language"])
  158. chat_request_payload = [{'role': 'system', 'content': prompt_for_system}, {'role': 'user', 'content': f"Question: {document_content['Question']} \n Answer: {document_content['Answer']}\n Context: {document_content['Context']} "}]
  159. result = await chat_service.execute_chat_request_async(api_context, chat_request_payload)
  160. if not result:
  161. return {}
  162. # no parsing needed, just return the loads the result as a dict
  163. result = json.loads(result)
  164. if "Result" not in result:
  165. print("Error: eval response does not contain answer")
  166. print(document_content,result)
  167. return {}
  168. # Send back the original QA pair is the model eval result is YES
  169. if result["Result"] == "YES":
  170. return document_content
  171. else:
  172. print(document_content,result)
  173. return {}
  174. async def generate_question_batches(chat_service, api_context: dict):
  175. document_text = read_file_content(api_context)
  176. if len(document_text)== 0:
  177. logging.error(f"Error reading files, document_text is empty")
  178. if api_context["model"] in ["meta-llama-3-70b-instruct","meta-llama-3-8b-instruct"]:
  179. tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", pad_token="</s>", padding_side="right")
  180. else:
  181. tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
  182. document_batches = split_text_into_chunks(api_context, document_text, tokenizer)
  183. total_questions = api_context["total_questions"]
  184. batches_count = len(document_batches)
  185. # each batch should have at least 1 question
  186. base_questions_per_batch = max(total_questions // batches_count,1)
  187. extra_questions = total_questions % batches_count
  188. print(f"Questions per batch: {base_questions_per_batch} (+1 for the first {extra_questions} batches), Total questions: {total_questions}, Batches: {batches_count}")
  189. generation_tasks = []
  190. for batch_index, batch_content in enumerate(document_batches):
  191. print(f"len of batch_content: {len(batch_content)}, batch_index: {batch_index}")
  192. #Distribute extra questions across the first few batches
  193. questions_in_current_batch = base_questions_per_batch + (1 if batch_index < extra_questions else 0)
  194. print(f"Batch {batch_index + 1} - {questions_in_current_batch} questions ********")
  195. try:
  196. result = prepare_and_send_request(chat_service, api_context, batch_content, questions_in_current_batch)
  197. generation_tasks.append(result)
  198. except Exception as e:
  199. print(f"Error during chat request execution: {e}")
  200. question_generation_results = await asyncio.gather(*generation_tasks)
  201. return question_generation_results
  202. async def generate_data_curation(chat_service, api_context: dict, generated_questions: list):
  203. eval_tasks = []
  204. for batch_index, batch_content in enumerate(generated_questions):
  205. try:
  206. result = data_curation_request(chat_service, api_context, batch_content)
  207. eval_tasks.append(result)
  208. except Exception as e:
  209. print(f"Error during data eval request execution: {e}")
  210. eval_results = await asyncio.gather(*eval_tasks)
  211. curated_data = []
  212. for item in eval_results:
  213. # if the item is not empty, add it to the curated data list
  214. if item:
  215. curated_data.append(item)
  216. return curated_data