generator_utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import re
  5. import string
  6. from transformers import AutoTokenizer
  7. import asyncio
  8. import magic
  9. from PyPDF2 import PdfReader
  10. import json
  11. from doc_processor import split_text_into_chunks
  12. import logging
  13. import json
  14. # Initialize logging
  15. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
  16. def read_text_file(file_path):
  17. try:
  18. with open(file_path, 'r') as f:
  19. text = f.read().strip() + ' '
  20. if len(text) == 0:
  21. print("File is empty ",file_path)
  22. return text
  23. except Exception as e:
  24. logging.error(f"Error reading text file {file_path}: {e}")
  25. return ''
  26. def read_pdf_file(file_path):
  27. try:
  28. with open(file_path, 'rb') as f:
  29. pdf_reader = PdfReader(f)
  30. num_pages = len(pdf_reader.pages)
  31. file_text = [pdf_reader.pages[page_num].extract_text().strip() + ' ' for page_num in range(num_pages)]
  32. text = ''.join(file_text)
  33. if len(text) == 0:
  34. print("File is empty ",file_path)
  35. return ''.join(file_text)
  36. except Exception as e:
  37. logging.error(f"Error reading PDF file {file_path}: {e}")
  38. return ''
  39. def read_json_file(file_path):
  40. try:
  41. with open(file_path, 'r') as f:
  42. data = json.load(f)
  43. # Assuming each item in the list has a 'question' and 'answer' key
  44. # Concatenating question and answer pairs with a space in between and accumulating them into a single string
  45. file_text = ' '.join([item['question'].strip() + ' ' + item['answer'].strip() + ' ' for item in data])
  46. if len(file_text) == 0:
  47. print("File is empty ",file_path)
  48. return file_text
  49. except Exception as e:
  50. logging.error(f"Error reading JSON file {file_path}: {e}")
  51. return ''
  52. def process_file(file_path):
  53. print("starting to process file: ", file_path)
  54. file_type = magic.from_file(file_path, mime=True)
  55. if file_type in ['text/plain', 'text/markdown', 'JSON']:
  56. return read_text_file(file_path)
  57. elif file_type == 'application/pdf':
  58. return read_pdf_file(file_path)
  59. else:
  60. logging.warning(f"Unsupported file type {file_type} for file {file_path}")
  61. return ''
  62. def remove_non_printable(s):
  63. printable = set(string.printable)
  64. return ''.join(filter(lambda x: x in printable, s))
  65. def read_file_content(context):
  66. file_strings = []
  67. for root, _, files in os.walk(context['data_dir']):
  68. for file in files:
  69. file_path = os.path.join(root, file)
  70. file_text = process_file(file_path)
  71. if file_text:
  72. file_strings.append(file_text)
  73. text = '\n'.join(file_strings)
  74. text = remove_non_printable(text)
  75. with open(context['data_dir'] + '/' + 'all_text.txt', 'w') as f:
  76. f.write(text)
  77. return remove_non_printable(text)
  78. # clean the text by removing all parts that did not contain any alphanumeric characters
  79. def clean(s):
  80. result = []
  81. for item in s.split('"'):
  82. if any(c.isalnum() for c in item):
  83. result.append(item)
  84. return " ".join(result)
  85. # given a response string, return a string that can be saved as json.
  86. def parse_qac_to_json(response_string):
  87. split_lines = response_string.split("\n")
  88. start,mid,end = None,None,None
  89. # must use set to avoid duplicate question/answer pairs due to async function calls
  90. qa_set = set()
  91. for i in range(len(split_lines)):
  92. line = split_lines[i]
  93. # starting to find "Question"
  94. if not start:
  95. # Once found, set start to this line number
  96. if '"Question":' in line:
  97. start = i
  98. else:
  99. # "Question" has been found, find "Answer", once found, set end to this line number
  100. if '"Answer":' in line:
  101. mid = i
  102. elif '"Context":' in line:
  103. end = i
  104. # found Question means we have reached the end of the question, so add it to qa_list
  105. elif '"Question":' in line:
  106. question = " ".join(split_lines[start:mid]).split('"Question":')[1]
  107. answer = " ".join(split_lines[mid:end]).split('"Answer":')[1]
  108. context = " ".join(split_lines[end:i]).split('"Context":')[1]
  109. start,mid,end = i,None,None
  110. qa_set.add((clean(question), clean(answer),clean(context)))
  111. # adding last question back to qa_list
  112. if start and mid and end:
  113. question = " ".join(split_lines[start:mid]).split('"Question":')[1]
  114. answer = " ".join(split_lines[mid:end]).split('"Answer":')[1]
  115. context = " ".join(split_lines[end:]).split('"Context":')[1]
  116. start,mid,end = i,None,None
  117. qa_set.add((clean(question), clean(answer),clean(context)))
  118. qa_list = [{"Question": q, "Answer":a, "Context":c} for q,a,c in qa_set]
  119. return qa_list
  120. def parse_qa_to_json(response_string):
  121. split_lines = response_string.split("\n")
  122. start,end = None,None
  123. # must use set to avoid duplicate question/answer pairs due to async function calls
  124. qa_set = set()
  125. for i in range(len(split_lines)):
  126. line = split_lines[i]
  127. # starting to find "Question"
  128. if not start:
  129. # Once found, set start to this line number
  130. if '"Question":' in line:
  131. start = i
  132. else:
  133. # "Question" has been found, find "Answer", once found, set end to this line number
  134. if '"Answer":' in line:
  135. end = i
  136. # found Question means we have reached the end of the question, so add it to qa_list
  137. elif '"Question":' in line:
  138. question = " ".join(split_lines[start:end]).split('"Question":')[1]
  139. answer = " ".join(split_lines[end:i]).split('"Answer":')[1]
  140. start,end = i,None
  141. qa_set.add((clean(question), clean(answer)))
  142. # adding last question back to qa_list
  143. if start and end:
  144. question = " ".join(split_lines[start:end]).split('"Question":')[1]
  145. answer = " ".join(split_lines[end:]).split('"Answer":')[1]
  146. qa_set.add((clean(question), clean(answer)))
  147. qa_list = [{"Question": q, "Answer":a} for q,a in qa_set]
  148. return qa_list
  149. async def prepare_and_send_request(chat_service, api_context: dict, document_content: str, num_questions: int) -> dict:
  150. if num_questions == 0:
  151. logging.info(f"Error: num_questions is 0")
  152. return {}
  153. prompt_for_system = api_context['question_prompt_template'].format(num_questions=num_questions, language=api_context["language"])
  154. chat_request_payload = [{'role': 'system', 'content': prompt_for_system}, {'role': 'user', 'content': document_content}]
  155. # parse the result string to a list of dict that has Question, Answer, Context
  156. return await chat_service.execute_chat_request_async(api_context, chat_request_payload)
  157. # This function is used to evaluate the quality of generated QA pairs. Return the original QA pair if the model eval result is YES. Otherwise, return an empty dict.
  158. async def data_curation_request(chat_service, api_context: dict, document_content: dict) -> dict:
  159. prompt_for_system = api_context['curation_prompt_template'].format(language=api_context["language"])
  160. chat_request_payload = [{'role': 'system', 'content': prompt_for_system}, {'role': 'user', 'content': f"Question: {document_content['Question']} \n Answer: {document_content['Answer']}\n Context: {document_content['Context']} "}]
  161. result = await chat_service.execute_chat_request_async(api_context, chat_request_payload)
  162. if not result:
  163. return {}
  164. # no parsing needed, just return the loads the result as a dict
  165. result = json.loads(result)
  166. if "Result" not in result:
  167. print("Error: eval response does not contain answer")
  168. print(document_content,result)
  169. return {}
  170. # Send back the original QA pair is the model eval result is YES
  171. if result["Result"] == "YES":
  172. return document_content
  173. else:
  174. print(document_content,result)
  175. return {}
  176. async def generate_question_batches(chat_service, api_context: dict):
  177. document_text = read_file_content(api_context)
  178. if len(document_text)== 0:
  179. logging.error(f"Error reading files, document_text is empty")
  180. if api_context["model"] in ["meta-llama-3-70b-instruct","meta-llama-3-8b-instruct"]:
  181. tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", pad_token="</s>", padding_side="right")
  182. else:
  183. tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
  184. document_batches = split_text_into_chunks(api_context, document_text, tokenizer)
  185. total_questions = api_context["total_questions"]
  186. batches_count = len(document_batches)
  187. # each batch should have at least 1 question
  188. base_questions_per_batch = max(total_questions // batches_count,1)
  189. extra_questions = total_questions % batches_count
  190. print(f"Questions per batch: {base_questions_per_batch} (+1 for the first {extra_questions} batches), Total questions: {total_questions}, Batches: {batches_count}")
  191. generation_tasks = []
  192. for batch_index, batch_content in enumerate(document_batches):
  193. print(f"len of batch_content: {len(batch_content)}, batch_index: {batch_index}")
  194. #Distribute extra questions across the first few batches
  195. questions_in_current_batch = base_questions_per_batch + (1 if batch_index < extra_questions else 0)
  196. print(f"Batch {batch_index + 1} - {questions_in_current_batch} questions ********")
  197. try:
  198. task = prepare_and_send_request(chat_service, api_context, batch_content, questions_in_current_batch)
  199. generation_tasks.append(task)
  200. except Exception as e:
  201. print(f"Error during chat request execution: {e}")
  202. question_generation_results = await asyncio.gather(*generation_tasks)
  203. final_result = []
  204. for result in question_generation_results:
  205. parsed_json = parse_qac_to_json(result)
  206. final_result.extend(parsed_json)
  207. return final_result
  208. async def generate_data_curation(chat_service, api_context: dict, evaluation_list: list):
  209. eval_tasks = []
  210. for batch_index, batch_content in enumerate(evaluation_list):
  211. try:
  212. result = data_curation_request(chat_service, api_context, batch_content)
  213. eval_tasks.append(result)
  214. except Exception as e:
  215. print(f"Error during data eval request execution: {e}")
  216. eval_results = await asyncio.gather(*eval_tasks)
  217. curated_data = []
  218. for item in eval_results:
  219. # if the item is not empty, add it to the curated data list
  220. if item:
  221. curated_data.append(item)
  222. return curated_data
  223. # This function is used to evaluate the quality of generated QA pairs. Return the original QA pair if the model eval result is YES. Otherwise, return an empty dict.
  224. async def LLM_judge_request(chat_service, api_context: dict, document_content: dict) -> dict:
  225. prompt_for_system = api_context['judge_prompt_template'].format(language=api_context["language"])
  226. chat_request_payload = [{'role': 'system', 'content': prompt_for_system}, {'role': 'user', 'content': f"Question: {document_content['Question']} \n Teacher's Answer: {document_content['Ground_truth']}\n Student's Answer: {document_content['Generated_answer']} "}]
  227. result = await chat_service.execute_chat_request_async(api_context, chat_request_payload)
  228. if not result:
  229. return {}
  230. # no parsing needed, just return the loads the result as a dict
  231. result = json.loads(result)
  232. if "Result" not in result:
  233. print("Error: eval response does not contain answer")
  234. print(document_content,result)
  235. return {}
  236. return result
  237. async def generate_LLM_eval(chat_service, api_context: dict, judge_list: list):
  238. eval_tasks = []
  239. for batch_index, batch_content in enumerate(judge_list):
  240. try:
  241. result = LLM_judge_request(chat_service, api_context, batch_content)
  242. eval_tasks.append(result)
  243. except Exception as e:
  244. print(f"Error during data eval request execution: {e}")
  245. judge_results = await asyncio.gather(*eval_tasks)
  246. return judge_results