raft_utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import re
  5. import string
  6. from transformers import AutoTokenizer
  7. import asyncio
  8. import magic
  9. from PyPDF2 import PdfReader
  10. import json
  11. from doc_processor import split_text_into_chunks
  12. import logging
  13. import json
  14. from langchain.embeddings import HuggingFaceEmbeddings
  15. from langchain_experimental.text_splitter import SemanticChunker
  16. from math import ceil
  17. import random
  18. # Initialize logging
  19. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
  20. def strip_str(s: str) -> str:
  21. """
  22. Helper function for helping format strings returned by GPT-4.
  23. """
  24. l, r = 0, len(s)-1
  25. beg_found = False
  26. for i in range(len(s)):
  27. if s[i].isalpha():
  28. if not beg_found:
  29. l = i
  30. beg_found = True
  31. else:
  32. r = i
  33. r += 2
  34. return s[l:min(r, len(s))]
  35. def read_text_file(file_path):
  36. try:
  37. with open(file_path, 'r') as f:
  38. text = f.read().strip() + ' '
  39. if len(text) == 0:
  40. print("File is empty ",file_path)
  41. return text
  42. except Exception as e:
  43. logging.error(f"Error reading text file {file_path}: {e}")
  44. return ''
  45. def read_pdf_file(file_path):
  46. try:
  47. with open(file_path, 'rb') as f:
  48. pdf_reader = PdfReader(f)
  49. num_pages = len(pdf_reader.pages)
  50. file_text = [pdf_reader.pages[page_num].extract_text().strip() + ' ' for page_num in range(num_pages)]
  51. text = ''.join(file_text)
  52. if len(text) == 0:
  53. print("File is empty ",file_path)
  54. return ''.join(file_text)
  55. except Exception as e:
  56. logging.error(f"Error reading PDF file {file_path}: {e}")
  57. return ''
  58. def read_json_file(file_path):
  59. try:
  60. with open(file_path, 'r') as f:
  61. data = json.load(f)
  62. # Assuming each item in the list has a 'question' and 'answer' key
  63. # Concatenating question and answer pairs with a space in between and accumulating them into a single string
  64. file_text = ' '.join([item['question'].strip() + ' ' + item['answer'].strip() + ' ' for item in data])
  65. if len(file_text) == 0:
  66. print("File is empty ",file_path)
  67. return file_text
  68. except Exception as e:
  69. logging.error(f"Error reading JSON file {file_path}: {e}")
  70. return ''
  71. def process_file(file_path):
  72. print("starting to process file: ", file_path)
  73. file_type = magic.from_file(file_path, mime=True)
  74. if file_type in ['text/plain', 'text/markdown', 'JSON']:
  75. return read_text_file(file_path)
  76. elif file_type == 'application/pdf':
  77. return read_pdf_file(file_path)
  78. else:
  79. logging.warning(f"Unsupported file type {file_type} for file {file_path}")
  80. return ''
  81. def read_file_content(context):
  82. file_strings = []
  83. for root, _, files in os.walk(context['data_dir']):
  84. for file in files:
  85. file_path = os.path.join(root, file)
  86. file_text = process_file(file_path)
  87. if file_text:
  88. file_strings.append(file_text)
  89. text = '\n'.join(file_strings)
  90. text = remove_non_printable(text)
  91. return remove_non_printable(text)
  92. def remove_non_printable(s):
  93. printable = set(string.printable)
  94. return ''.join(filter(lambda x: x in printable, s))
  95. async def generate_question_request(chat_service, api_context: dict, document_content: str, num_questions: int) -> dict:
  96. if num_questions == 0:
  97. logging.info(f"Error: num_questions is 0")
  98. return {}
  99. prompt_for_system = api_context['question_prompt_template'].format(num_questions=num_questions)
  100. chat_request_payload = [{'role': 'system', 'content': prompt_for_system}, {'role': 'user', 'content': str(document_content)}]
  101. # parse the result string to a list of dict that has Question, Answer, Context
  102. return await chat_service.execute_chat_request_async(api_context, chat_request_payload)
  103. def get_chunks(
  104. text: str,
  105. chunk_size: int = 512,
  106. embedding_model: str = None
  107. ) -> list[str]:
  108. """
  109. Takes in a `file_path` and `doctype`, retrieves the document, breaks it down into chunks of size
  110. `chunk_size`, and returns the chunks.
  111. """
  112. chunks = []
  113. if len(text) == 0:
  114. raise TypeError("Can not get chunks from empty text")
  115. else:
  116. num_chunks = ceil(len(text) / chunk_size)
  117. logging.info(f"Splitting text into {num_chunks} chunks")
  118. text_splitter = SemanticChunker(embedding_model, number_of_chunks=num_chunks)
  119. chunks = text_splitter.create_documents([text])
  120. chunks = [chunk.page_content for chunk in chunks]
  121. return chunks
  122. # read all the files in the data folder, then split them into chunks
  123. # generate questions for each chunk and return a list of questions list
  124. async def generate_questions(chat_service, api_context: dict):
  125. document_text = read_file_content(api_context)
  126. if len(document_text)== 0:
  127. logging.error(f"Error reading files, document_text is empty")
  128. model_name = "sentence-transformers/all-mpnet-base-v2"
  129. embedding_model = HuggingFaceEmbeddings(model_name=model_name)
  130. document_batches = get_chunks(document_text,api_context["chunk_size"],embedding_model)
  131. batches_count = len(document_batches)
  132. total_questions = api_context["questions_per_chunk"] * batches_count
  133. print(f"Questions per batch: {api_context['questions_per_chunk']}, Total questions: {total_questions}, Batches: {batches_count}")
  134. generation_tasks = []
  135. for batch_index, batch_content in enumerate(document_batches):
  136. print(f"len of batch_content: {len(batch_content)}, batch_index: {batch_index}")
  137. #Distribute extra questions across the first few batches
  138. print(f"Batch {batch_index + 1} - {api_context['questions_per_chunk']} questions ********")
  139. try:
  140. task = generate_question_request(chat_service, api_context, batch_content, api_context["questions_per_chunk"])
  141. generation_tasks.append(task)
  142. except Exception as e:
  143. print(f"Error during chat request execution: {e}")
  144. question_generation_results = await asyncio.gather(*generation_tasks)
  145. final_result = []
  146. for result in question_generation_results:
  147. queries = result.split('\n')
  148. queries = [strip_str(q) for q in queries]
  149. queries = [q for q in queries if any(c.isalpha() for c in q)]
  150. if len(queries) > int(api_context['questions_per_chunk']):
  151. # As the model may have unrelated question at the begining of the result
  152. # if queries is more than questions_per_chunk, then we need to truncate it and only keep last questions_per_chunk lines
  153. queries = queries[-int(api_context['questions_per_chunk']):]
  154. final_result.append(queries)
  155. return final_result
  156. def add_chunk_to_dataset(
  157. chunks: list[str],
  158. chunk: str,
  159. x: int = 5,
  160. num_distract: int = 3,
  161. p: float = 0.8,
  162. model: str = None
  163. ) -> None:
  164. """
  165. Given a chunk, create {Q, A, D} triplets and add them to the dataset.
  166. """
  167. global ds
  168. i = chunks.index(chunk)
  169. qs = generate_instructions(client, chunk, x, model) if doctype == "api" else generate_instructions_gen(client, chunk, x, model)
  170. for q in qs:
  171. datapt = {
  172. "id": None,
  173. "type": None,
  174. "question": None,
  175. "context": None,
  176. "oracle_context": None,
  177. "cot_answer": None
  178. }
  179. datapt["id"] = f"seed_task_{0 if not ds else ds.num_rows}"
  180. datapt["type"] = "api call" if doctype == "api" else "general"
  181. datapt["question"] = q
  182. # add num_distract distractor docs
  183. docs = [chunk]
  184. indices = list(range(0, len(chunks)))
  185. indices.remove(i)
  186. for j in random.sample(indices, num_distract):
  187. docs.append(chunks[j])
  188. # decides whether to add oracle document
  189. oracle = random.uniform(0, 1) < p
  190. if not oracle:
  191. docs[0] = chunks[random.sample(indices, 1)[0]]
  192. random.shuffle(docs)
  193. d = {
  194. "title": [],
  195. "sentences": []
  196. }
  197. d["title"].append(["placeholder_title"]*(num_distract+1))
  198. d["sentences"].append(docs)
  199. datapt["context"] = d
  200. datapt["oracle_context"] = chunk
  201. # add answer to q
  202. datapt["cot_answer"] = generate_label(client, q, chunk, doctype, model=model)
  203. # construct model instruction
  204. context = ""
  205. for doc in docs:
  206. context += "<DOCUMENT>" + str(doc) + "</DOCUMENT>\n"
  207. context += q
  208. datapt["instruction"] = context
  209. # add to dataset
  210. if not ds:
  211. # init ds
  212. datapt["id"] = [datapt["id"]]
  213. datapt["type"] = [datapt["type"]]
  214. datapt["question"] = [datapt["question"]]
  215. datapt["context"] = [datapt["context"]]
  216. datapt["oracle_context"] = [datapt["oracle_context"]]
  217. datapt["cot_answer"] = [datapt["cot_answer"]]
  218. datapt["instruction"] = [datapt["instruction"]]
  219. ds = Dataset.from_dict(datapt)
  220. else:
  221. ds = ds.add_item(datapt)
  222. # This function is used to evaluate the quality of generated QA pairs. Return the original QA pair if the model eval result is YES. Otherwise, return an empty dict.
  223. async def LLM_judge_request(chat_service, api_context: dict, document_content: dict) -> dict:
  224. prompt_for_system = api_context['judge_prompt_template'].format(language=api_context["language"])
  225. chat_request_payload = [{'role': 'system', 'content': prompt_for_system}, {'role': 'user', 'content': f"Question: {document_content['Question']} \n Teacher's Answer: {document_content['Ground_truth']}\n Student's Answer: {document_content['Generated_answer']} "}]
  226. result = await chat_service.execute_chat_request_async(api_context, chat_request_payload)
  227. if not result:
  228. return {}
  229. # no parsing needed, just return the loads the result as a dict
  230. result = json.loads(result)
  231. if "Result" not in result:
  232. print("Error: eval response does not contain answer")
  233. print(document_content,result)
  234. return {}
  235. return result
  236. async def generate_LLM_eval(chat_service, api_context: dict, judge_list: list):
  237. eval_tasks = []
  238. for batch_index, batch_content in enumerate(judge_list):
  239. try:
  240. result = LLM_judge_request(chat_service, api_context, batch_content)
  241. eval_tasks.append(result)
  242. except Exception as e:
  243. print(f"Error during data eval request execution: {e}")
  244. judge_results = await asyncio.gather(*eval_tasks)
  245. return judge_results