raft_utils.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import logging
  5. from langchain.text_splitter import RecursiveCharacterTextSplitter
  6. from datasets import Dataset
  7. import random
  8. from langchain_community.document_loaders import SitemapLoader,DirectoryLoader
  9. from bs4 import BeautifulSoup
  10. from langchain_openai import ChatOpenAI
  11. import copy
  12. # Initialize logging
  13. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
  14. def strip_str(s: str) -> str:
  15. """
  16. Helper function for helping format strings returned by GPT-4.
  17. """
  18. l, r = 0, len(s)-1
  19. beg_found = False
  20. for i in range(len(s)):
  21. if s[i].isalpha():
  22. if not beg_found:
  23. l = i
  24. beg_found = True
  25. else:
  26. r = i
  27. r += 2
  28. return s[l:min(r, len(s))]
  29. def clean_documents(raw_text):
  30. all_lines = []
  31. for line in raw_text.split("\n"):
  32. line = line.strip()
  33. if len(line.split()) == 0:
  34. continue
  35. else:
  36. all_lines.append(line)
  37. result = " ".join(all_lines)
  38. return result
  39. def clean_text(content: BeautifulSoup) -> str:
  40. # Find all 'nav' and 'header' elements in the BeautifulSoup object
  41. nav_elements = content.find_all("nav")
  42. header_elements = content.find_all("header")
  43. mydivs = content.find_all("div", {"role": "list"})
  44. # Remove each 'nav' and 'header' element from the BeautifulSoup object
  45. for element in nav_elements + header_elements+mydivs:
  46. element.decompose()
  47. raw_text = content.get_text("\n")
  48. return clean_documents(raw_text)
  49. # Read
  50. def read_file_content(xml_path: str, data_folder: str) -> str:
  51. if xml_path and data_folder:
  52. logging.info(f"Error: both xml_path and data_folder are provided, will only read from xml for now")
  53. if not xml_path and not data_folder:
  54. logging.info(f"Error: both xml_path and data_folder are not provided")
  55. return ""
  56. if xml_path:
  57. if not os.path.exists(xml_path):
  58. logging.info(f"Error: {xml_path} does not exist")
  59. return ""
  60. # Use langchain to load the documents from webpage links in the xml file
  61. sitemap_loader = SitemapLoader(web_path=xml_path,is_local=True,parsing_function=clean_text)
  62. sitemap_loader.requests_kwargs = {"verify": False}
  63. docs = sitemap_loader.load()
  64. return docs
  65. elif len(data_folder) != 0:
  66. if not os.path.exists(data_folder):
  67. logging.info(f"Error: {data_folder} does not exist")
  68. return ""
  69. # Use langchain to load the documents from data folder
  70. loader = DirectoryLoader(data_folder)
  71. docs = loader.load()
  72. return docs
  73. def get_chunks(
  74. docs: list,
  75. chunk_size: int = 1000,
  76. api_config: dict = None,
  77. ) -> list[str]:
  78. """
  79. Takes in a list of documents, breaks them down into chunks of size
  80. `chunk_size`, and returns the chunks.
  81. """
  82. chunks = []
  83. if len(docs) == 0:
  84. raise TypeError("Can not get chunks from empty text")
  85. else:
  86. text_splitter = RecursiveCharacterTextSplitter(chunk_size=api_config["chunk_size"],chunk_overlap=int(api_config["chunk_size"] / 10),separators= ["----------","\n\n", "\n", " "],strip_whitespace=True)
  87. docs_processed = text_splitter.split_documents(docs)
  88. logging.info(f"Total number of docs_processed: {len(docs_processed)}")
  89. # Remove duplicates
  90. unique_texts = {}
  91. docs_processed_unique = []
  92. for doc in docs_processed:
  93. if doc.page_content not in unique_texts and len(doc.page_content) > 100 :
  94. unique_texts[doc.page_content] = True
  95. docs_processed_unique.append(doc)
  96. chunks = [chunk.page_content for chunk in docs_processed_unique]
  97. logging.info(f"Total number of docs_processed_unique: {len(docs_processed_unique)}")
  98. return chunks
  99. # read all the files in the data folder, then split them into chunks
  100. # generate questions for each chunk and return zip of chunk and related questions list
  101. def generate_questions(api_config):
  102. # get documents from the data folder or xml file
  103. api_url = api_config["endpoint_url"]
  104. key = api_config["api_key"]
  105. documents = read_file_content(api_config["xml_path"],api_config["data_dir"])
  106. if len(documents) == 0:
  107. logging.info(f"Error reading files, document_text is {len(documents)}")
  108. document_batches = get_chunks(documents,api_config["chunk_size"],api_config)
  109. # use OpenAI API protocol to handle the chat request, including local VLLM openai compatible server
  110. llm = ChatOpenAI(
  111. openai_api_key=key,
  112. openai_api_base=api_url,
  113. model_name=api_config["model"],
  114. temperature=0.0,
  115. max_tokens=500
  116. )
  117. all_tasks = [api_config['question_prompt_template'].format(num_questions=str(api_config['questions_per_chunk']),context=document) for document in document_batches]
  118. generated_answers = llm.batch(all_tasks)
  119. generated_answers = [ item.content for item in generated_answers]
  120. if len(generated_answers) == 0:
  121. logging.error("No model answers generated. Please check the input context or model configuration in ",api_config["model"])
  122. return []
  123. final_result = []
  124. for result in generated_answers:
  125. queries = result.split('\n')
  126. queries = [strip_str(q) for q in queries]
  127. queries = [q for q in queries if any(c.isalpha() for c in q)]
  128. if len(queries) > int(api_config['questions_per_chunk']):
  129. # As the model may have unrelated question at the beginning of the result
  130. # if queries is more than questions_per_chunk, then we need to truncate it and only keep last questions_per_chunk lines
  131. queries = queries[-int(api_config['questions_per_chunk']):]
  132. final_result.append(queries)
  133. return list(zip(document_batches,final_result))
  134. # Generate COT answer for each question given the chunk context
  135. def generate_COT(chunk_questions_zip,api_config) -> dict:
  136. all_tasks = []
  137. chunk_questions = []
  138. question_asked = set()
  139. for document_content,questions in chunk_questions_zip:
  140. for question in questions:
  141. question = question.strip()
  142. # avoid asking the same question twice
  143. if question not in question_asked:
  144. question_asked.add(question)
  145. prompt = api_config['COT_prompt_template'].format(question=question,context=str(document_content))
  146. all_tasks.append(prompt)
  147. chunk_questions.append((document_content,question))
  148. # use OpenAI API protocol to handle the chat request, including local VLLM openai compatible server
  149. llm = ChatOpenAI(
  150. openai_api_key=api_config["api_key"],
  151. openai_api_base=api_config["endpoint_url"],
  152. model_name=api_config["model"],
  153. temperature=0.0,
  154. max_tokens=500
  155. )
  156. generated_answers = llm.batch(all_tasks)
  157. generated_answers = [ item.content for item in generated_answers]
  158. COT_results = []
  159. # return a list of (chunk, question, generated_answer)
  160. for (chunk, question),generated_answer in zip(chunk_questions,generated_answers):
  161. COT_results.append((chunk,question,generated_answer))
  162. return COT_results
  163. def add_chunk_to_dataset(
  164. chunk_questions_zip: list,
  165. api_config: dict,
  166. ) -> None:
  167. """
  168. Given a chunk and related questions lists, create {Q, A, D} triplets and add them to the dataset.
  169. """
  170. num_distract = api_config["num_distract_docs"]
  171. p = api_config["refusal_probability"]
  172. chunks = [chunk for chunk, _ in chunk_questions_zip]
  173. COT_results = generate_COT(chunk_questions_zip,api_config)
  174. logging.info(f"COT generation completed, total num of COT results: {len(COT_results)}")
  175. completed,refusal= 0,0
  176. data_list = []
  177. for chunk, q , cot in COT_results:
  178. # The COT answer will be used as the label in the fine-tuning stage
  179. datapt = {
  180. "id": None,
  181. "type": "general",
  182. "question": q,
  183. "context": None,
  184. "oracle_context": None,
  185. "cot_answer": cot
  186. }
  187. i = chunks.index(chunk)
  188. datapt["id"] = f"seed_task_{len(data_list)}"
  189. # add num_distract distractor docs
  190. docs = [chunk]
  191. indices = list(range(0, len(chunks)))
  192. indices.remove(i)
  193. for j in random.sample(indices, num_distract):
  194. docs.append(chunks[j])
  195. doc_copy = docs.copy()
  196. random.shuffle(docs)
  197. d = {
  198. "title": [],
  199. "sentences": []
  200. }
  201. d["title"].append(["placeholder_title"]*(num_distract+1))
  202. d["sentences"].append(docs)
  203. datapt["context"] = d
  204. datapt["oracle_context"] = chunk
  205. # construct model instruction
  206. context = ""
  207. for doc in docs:
  208. context += "<DOCUMENT>" + str(doc) + "</DOCUMENT>\n"
  209. context += q
  210. # This instruction will be used in the fine-tuning stage
  211. datapt["instruction"] = context
  212. datapt_copy = copy.deepcopy(datapt)
  213. # add to dataset
  214. data_list.append(datapt)
  215. # decides whether to add refusal example where the related documents are not provided
  216. refusal = random.uniform(0, 1) <= p
  217. if refusal:
  218. doc_copy[0] = chunks[random.sample(indices, 1)[0]]
  219. random.shuffle(doc_copy)
  220. refusl_context = ""
  221. for doc in doc_copy:
  222. refusl_context += "<DOCUMENT>" + str(doc) + "</DOCUMENT>\n"
  223. refusl_context += q
  224. # This instruction will be used in the fine-tuning stage
  225. datapt_copy["id"] = f"refusal_task_{len(data_list)}"
  226. datapt_copy["instruction"] = refusl_context
  227. datapt_copy["cot_answer"] = "Sorry, I don't know the answer to this question because related documents are not found. Please try again."
  228. data_list.append(datapt_copy)
  229. refusal += 1
  230. completed += 1
  231. if completed % 100 == 0:
  232. logging.info(f"refusal example added: {refusal}, total examples added: {completed}, total examples to be added: {len(COT_results)- completed}")
  233. ds = Dataset.from_list(data_list)
  234. return ds