raft_utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import re
  5. import string
  6. from transformers import AutoTokenizer
  7. import asyncio
  8. import magic
  9. from PyPDF2 import PdfReader
  10. import json
  11. from doc_processor import split_text_into_chunks
  12. import logging
  13. import json
  14. from langchain_community.embeddings import HuggingFaceEmbeddings
  15. from langchain_experimental.text_splitter import SemanticChunker
  16. from math import ceil
  17. import datasets
  18. from datasets import Dataset, load_dataset
  19. import random
  20. # Initialize logging
  21. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
  22. def strip_str(s: str) -> str:
  23. """
  24. Helper function for helping format strings returned by GPT-4.
  25. """
  26. l, r = 0, len(s)-1
  27. beg_found = False
  28. for i in range(len(s)):
  29. if s[i].isalpha():
  30. if not beg_found:
  31. l = i
  32. beg_found = True
  33. else:
  34. r = i
  35. r += 2
  36. return s[l:min(r, len(s))]
  37. def read_text_file(file_path):
  38. try:
  39. with open(file_path, 'r') as f:
  40. text = f.read().strip() + ' '
  41. if len(text) == 0:
  42. print("File is empty ",file_path)
  43. return text
  44. except Exception as e:
  45. logging.error(f"Error reading text file {file_path}: {e}")
  46. return ''
  47. def read_pdf_file(file_path):
  48. try:
  49. with open(file_path, 'rb') as f:
  50. pdf_reader = PdfReader(f)
  51. num_pages = len(pdf_reader.pages)
  52. file_text = [pdf_reader.pages[page_num].extract_text().strip() + ' ' for page_num in range(num_pages)]
  53. text = ''.join(file_text)
  54. if len(text) == 0:
  55. print("File is empty ",file_path)
  56. return ''.join(file_text)
  57. except Exception as e:
  58. logging.error(f"Error reading PDF file {file_path}: {e}")
  59. return ''
  60. def read_json_file(file_path):
  61. try:
  62. with open(file_path, 'r') as f:
  63. data = json.load(f)
  64. # Assuming each item in the list has a 'question' and 'answer' key
  65. # Concatenating question and answer pairs with a space in between and accumulating them into a single string
  66. file_text = ' '.join([item['question'].strip() + ' ' + item['answer'].strip() + ' ' for item in data])
  67. if len(file_text) == 0:
  68. print("File is empty ",file_path)
  69. return file_text
  70. except Exception as e:
  71. logging.error(f"Error reading JSON file {file_path}: {e}")
  72. return ''
  73. def process_file(file_path):
  74. print("starting to process file: ", file_path)
  75. file_type = magic.from_file(file_path, mime=True)
  76. if file_type in ['text/plain', 'text/markdown', 'JSON']:
  77. return read_text_file(file_path)
  78. elif file_type == 'application/pdf':
  79. return read_pdf_file(file_path)
  80. else:
  81. logging.warning(f"Unsupported file type {file_type} for file {file_path}")
  82. return ''
  83. def read_file_content(context):
  84. file_strings = []
  85. for root, _, files in os.walk(context['data_dir']):
  86. for file in files:
  87. file_path = os.path.join(root, file)
  88. file_text = process_file(file_path)
  89. if file_text:
  90. file_strings.append(file_text)
  91. text = '\n'.join(file_strings)
  92. text = remove_non_printable(text)
  93. return remove_non_printable(text)
  94. def remove_non_printable(s):
  95. printable = set(string.printable)
  96. return ''.join(filter(lambda x: x in printable, s))
  97. async def generate_question_request(chat_service, api_context: dict, document_content: str, num_questions: int) -> dict:
  98. if num_questions == 0:
  99. logging.info(f"Error: num_questions is 0")
  100. return {}
  101. prompt_for_system = api_context['question_prompt_template'].format(num_questions=num_questions)
  102. chat_request_payload = [{'role': 'system', 'content': prompt_for_system}, {'role': 'user', 'content': str(document_content)}]
  103. # parse the result string to a list of dict that has Question, Answer, Context
  104. return await chat_service.execute_chat_request_async(api_context, chat_request_payload)
  105. def get_chunks(
  106. text: str,
  107. chunk_size: int = 512,
  108. embedding_model: str = None
  109. ) -> list[str]:
  110. """
  111. Takes in a `file_path` and `doctype`, retrieves the document, breaks it down into chunks of size
  112. `chunk_size`, and returns the chunks.
  113. """
  114. chunks = []
  115. if len(text) == 0:
  116. raise TypeError("Can not get chunks from empty text")
  117. else:
  118. num_chunks = ceil(len(text) / chunk_size)
  119. logging.info(f"Splitting text into {num_chunks} chunks")
  120. text_splitter = SemanticChunker(embedding_model, number_of_chunks=num_chunks)
  121. chunks = text_splitter.create_documents([text])
  122. chunks = [chunk.page_content for chunk in chunks]
  123. return chunks
  124. # read all the files in the data folder, then split them into chunks
  125. # generate questions for each chunk and return zip of chunk and related questions list
  126. async def generate_questions(chat_service, api_context: dict):
  127. document_text = read_file_content(api_context)
  128. if len(document_text) == 0:
  129. logging.info(f"Error reading files, document_text is {len(document_text)}")
  130. model_name = "sentence-transformers/all-mpnet-base-v2"
  131. embedding_model = HuggingFaceEmbeddings(model_name=model_name)
  132. document_batches = get_chunks(document_text,api_context["chunk_size"],embedding_model)
  133. batches_count = len(document_batches)
  134. total_questions = api_context["questions_per_chunk"] * batches_count
  135. print(f"Questions per batch: {api_context['questions_per_chunk']}, Total questions: {total_questions}, Batches: {batches_count}")
  136. generation_tasks = []
  137. for batch_index, batch_content in enumerate(document_batches):
  138. print(f"len of batch_content: {len(batch_content)}, batch_index: {batch_index}")
  139. #Distribute extra questions across the first few batches
  140. if len(batch_content) < 10:
  141. logging.info("Context is not enough, ignore this batch")
  142. else:
  143. print(f"Batch {batch_index + 1} - {api_context['questions_per_chunk']} questions ********")
  144. try:
  145. task = generate_question_request(chat_service, api_context, batch_content, api_context["questions_per_chunk"])
  146. generation_tasks.append(task)
  147. except Exception as e:
  148. print(f"Error during chat request execution: {e}")
  149. question_generation_results = await asyncio.gather(*generation_tasks)
  150. final_result = []
  151. for result in question_generation_results:
  152. queries = result.split('\n')
  153. queries = [strip_str(q) for q in queries]
  154. queries = [q for q in queries if any(c.isalpha() for c in q)]
  155. if len(queries) > int(api_context['questions_per_chunk']):
  156. # As the model may have unrelated question at the begining of the result
  157. # if queries is more than questions_per_chunk, then we need to truncate it and only keep last questions_per_chunk lines
  158. queries = queries[-int(api_context['questions_per_chunk']):]
  159. final_result.append(queries)
  160. return list(zip(document_batches,final_result))
  161. async def generate_COT(chat_service, api_context: dict, document_content: str, question: str) -> dict:
  162. prompt = api_context['COT_prompt_template'].format(question=question,context=str(document_content))
  163. chat_request_payload = [{"role": "system", "content": "You are a helpful question answerer who can provide an answer given a question and relevant context."}]
  164. chat_request_payload.append({"role": "user", "content": prompt})
  165. response = await chat_service.execute_chat_request_async(api_context, chat_request_payload)
  166. return (document_content,question,response)
  167. async def add_chunk_to_dataset(
  168. chunk_questions_zip: list,
  169. context: dict,
  170. chat_service,
  171. ds,
  172. num_distract: int = 3,
  173. p: float = 0.8,
  174. ) -> None:
  175. """
  176. Given a chunk and related questions lists, create {Q, A, D} triplets and add them to the dataset.
  177. """
  178. COT_tasks = []
  179. chunks = [chunk for chunk, _ in chunk_questions_zip]
  180. for i, chunk_questions in enumerate(chunk_questions_zip):
  181. chunk, questions = chunk_questions
  182. # generate COT answer for each question given the chunk context
  183. for question in questions:
  184. COT_tasks.append(generate_COT(chat_service, context, chunk, question))
  185. COT_results = await asyncio.gather(*COT_tasks)
  186. for chunk, q , cot in COT_results:
  187. datapt = {
  188. "id": None,
  189. "type": "general",
  190. "question": q,
  191. "context": None,
  192. "oracle_context": None,
  193. "cot_answer": cot
  194. }
  195. i = chunks.index(chunk)
  196. datapt["id"] = f"seed_task_{0 if not ds else ds.num_rows}"
  197. # add num_distract distractor docs
  198. docs = [chunk]
  199. indices = list(range(0, len(chunks)))
  200. indices.remove(i)
  201. for j in random.sample(indices, num_distract):
  202. docs.append(chunks[j])
  203. # decides whether to add oracle document
  204. oracle = random.uniform(0, 1) < p
  205. if not oracle:
  206. docs[0] = chunks[random.sample(indices, 1)[0]]
  207. random.shuffle(docs)
  208. d = {
  209. "title": [],
  210. "sentences": []
  211. }
  212. d["title"].append(["placeholder_title"]*(num_distract+1))
  213. d["sentences"].append(docs)
  214. datapt["context"] = d
  215. datapt["oracle_context"] = chunk
  216. # construct model instruction
  217. context = ""
  218. for doc in docs:
  219. context += "<DOCUMENT>" + str(doc) + "</DOCUMENT>\n"
  220. context += q
  221. datapt["instruction"] = context
  222. # add to dataset
  223. if not ds:
  224. # init ds
  225. datapt["id"] = [datapt["id"]]
  226. datapt["type"] = [datapt["type"]]
  227. datapt["question"] = [datapt["question"]]
  228. datapt["context"] = [datapt["context"]]
  229. datapt["oracle_context"] = [datapt["oracle_context"]]
  230. datapt["cot_answer"] = [datapt["cot_answer"]]
  231. datapt["instruction"] = [datapt["instruction"]]
  232. ds = Dataset.from_dict(datapt)
  233. else:
  234. ds = ds.add_item(datapt)
  235. return ds
  236. # This function is used to evaluate the quality of generated QA pairs. Return the original QA pair if the model eval result is YES. Otherwise, return an empty dict.
  237. async def LLM_judge_request(chat_service, api_context: dict, document_content: dict) -> dict:
  238. prompt_for_system = api_context['judge_prompt_template'].format(language=api_context["language"])
  239. chat_request_payload = [{'role': 'system', 'content': prompt_for_system}, {'role': 'user', 'content': f"Question: {document_content['Question']} \n Teacher's Answer: {document_content['Ground_truth']}\n Student's Answer: {document_content['Generated_answer']} "}]
  240. result = await chat_service.execute_chat_request_async(api_context, chat_request_payload)
  241. if not result:
  242. return {}
  243. # no parsing needed, just return the loads the result as a dict
  244. result = json.loads(result)
  245. if "Result" not in result:
  246. print("Error: eval response does not contain answer")
  247. print(document_content,result)
  248. return {}
  249. return result
  250. async def generate_LLM_eval(chat_service, api_context: dict, judge_list: list):
  251. eval_tasks = []
  252. for batch_index, batch_content in enumerate(judge_list):
  253. try:
  254. result = LLM_judge_request(chat_service, api_context, batch_content)
  255. eval_tasks.append(result)
  256. except Exception as e:
  257. print(f"Error during data eval request execution: {e}")
  258. judge_results = await asyncio.gather(*eval_tasks)
  259. return judge_results