raft_utils.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import logging
  5. from langchain.text_splitter import RecursiveCharacterTextSplitter
  6. from math import ceil
  7. from datasets import Dataset
  8. import random
  9. from langchain_community.document_loaders import SitemapLoader,DirectoryLoader
  10. from bs4 import BeautifulSoup
  11. import copy
  12. from langchain_openai import ChatOpenAI
  13. # Initialize logging
  14. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
  15. def strip_str(s: str) -> str:
  16. """
  17. Helper function for helping format strings returned by GPT-4.
  18. """
  19. l, r = 0, len(s)-1
  20. beg_found = False
  21. for i in range(len(s)):
  22. if s[i].isalpha():
  23. if not beg_found:
  24. l = i
  25. beg_found = True
  26. else:
  27. r = i
  28. r += 2
  29. return s[l:min(r, len(s))]
  30. def clean_documents(raw_text):
  31. unwanted= ["Technology",
  32. "Getting Started",
  33. "Trust & Safety",
  34. "Community",
  35. "Resources",
  36. "Skip to main content",
  37. "How-to guides"]
  38. all_lines = []
  39. for line in raw_text.split("\n"):
  40. line = line.strip()
  41. if line in unwanted or len(line.split()) == 0:
  42. continue
  43. else:
  44. all_lines.append(line)
  45. result = " ".join(all_lines)
  46. return result
  47. def clean_text(content: BeautifulSoup) -> str:
  48. # Find all 'nav' and 'header' elements in the BeautifulSoup object
  49. nav_elements = content.find_all("nav")
  50. header_elements = content.find_all("header")
  51. mydivs = content.find_all("div", {"role": "list"})
  52. # Remove each 'nav' and 'header' element from the BeautifulSoup object
  53. for element in nav_elements + header_elements+mydivs:
  54. element.decompose()
  55. raw_text = content.get_text("\n")
  56. return clean_documents(raw_text)
  57. # Read
  58. def read_file_content(xml_path: str, data_folder: str) -> str:
  59. if xml_path and data_folder:
  60. logging.info(f"Error: both xml_path and data_folder are provided, will only read from xml for now")
  61. if not xml_path and not data_folder:
  62. logging.info(f"Error: both xml_path and data_folder are not provided")
  63. return ""
  64. if xml_path:
  65. if not os.path.exists(xml_path):
  66. logging.info(f"Error: {xml_path} does not exist")
  67. return ""
  68. # Use langchain to load the documents from webpage links in the xml file
  69. sitemap_loader = SitemapLoader(web_path=xml_path,is_local=True,parsing_function=clean_text)
  70. sitemap_loader.requests_kwargs = {"verify": False}
  71. docs = sitemap_loader.load()
  72. return "\n".join([doc.page_content for doc in docs])
  73. elif len(data_folder) != 0:
  74. if not os.path.exists(data_folder):
  75. logging.info(f"Error: {data_folder} does not exist")
  76. return ""
  77. # Use langchain to load the documents from data folder
  78. loader = DirectoryLoader(data_folder)
  79. docs = loader.load()
  80. text = "\n".join([clean_documents(doc.page_content) for doc in docs])
  81. return text
  82. def get_chunks(
  83. text: str,
  84. chunk_size: int = 512,
  85. api_config: dict = None,
  86. ) -> list[str]:
  87. """
  88. Takes in a `file_path` and `doctype`, retrieves the document, breaks it down into chunks of size
  89. `chunk_size`, and returns the chunks.
  90. """
  91. chunks = []
  92. if len(text) == 0:
  93. raise TypeError("Can not get chunks from empty text")
  94. else:
  95. num_chunks = ceil(len(text) / chunk_size)
  96. logging.info(f"Splitting text into {num_chunks} chunks")
  97. text_splitter = RecursiveCharacterTextSplitter(chunk_size=api_config["chunk_size"], chunk_overlap=int(api_config["chunk_size"]/10))
  98. chunks = text_splitter.create_documents([text])
  99. chunks = [chunk.page_content for chunk in chunks]
  100. return chunks
  101. # read all the files in the data folder, then split them into chunks
  102. # generate questions for each chunk and return zip of chunk and related questions list
  103. def generate_questions(api_config):
  104. # get documents from the data folder or xml file
  105. api_url = api_config["endpoint_url"]
  106. key = api_config["api_key"]
  107. document_text = read_file_content(api_config["xml_path"],api_config["data_dir"])
  108. if len(document_text) == 0:
  109. logging.info(f"Error reading files, document_text is {len(document_text)}")
  110. document_batches = get_chunks(document_text,api_config["chunk_size"],api_config)
  111. # use OpenAI API protocol to hanlde the chat request, including local VLLM openai compatible server
  112. llm = ChatOpenAI(
  113. openai_api_key=key,
  114. openai_api_base=api_url,
  115. model_name=api_config["model"],
  116. temperature=0.0,
  117. max_tokens=500
  118. )
  119. all_tasks = [api_config['question_prompt_template'].format(num_questions=str(api_config['questions_per_chunk']),context=document) for document in document_batches]
  120. generated_answers = llm.batch(all_tasks)
  121. generated_answers = [ item.content for item in generated_answers]
  122. if len(generated_answers) == 0:
  123. logging.error("No model answers generated. Please check the input context or model configuration in ",api_config["model"])
  124. return []
  125. final_result = []
  126. for result in generated_answers:
  127. queries = result.split('\n')
  128. queries = [strip_str(q) for q in queries]
  129. queries = [q for q in queries if any(c.isalpha() for c in q)]
  130. if len(queries) > int(api_config['questions_per_chunk']):
  131. # As the model may have unrelated question at the begining of the result
  132. # if queries is more than questions_per_chunk, then we need to truncate it and only keep last questions_per_chunk lines
  133. queries = queries[-int(api_config['questions_per_chunk']):]
  134. final_result.append(queries)
  135. return list(zip(document_batches,final_result))
  136. # Generate COT answer for each question given the chunk context
  137. def generate_COT(chunk_questions_zip,api_config) -> dict:
  138. all_tasks = []
  139. chunk_questions = []
  140. for document_content,questions in chunk_questions_zip:
  141. for question in questions:
  142. prompt = api_config['COT_prompt_template'].format(question=question,context=str(document_content))
  143. all_tasks.append(prompt)
  144. chunk_questions.append((document_content,question))
  145. # use OpenAI API protocol to hanlde the chat request, including local VLLM openai compatible server
  146. llm = ChatOpenAI(
  147. openai_api_key=api_config["api_key"],
  148. openai_api_base=api_config["endpoint_url"],
  149. model_name=api_config["model"],
  150. temperature=0.0,
  151. max_tokens=500
  152. )
  153. generated_answers = llm.batch(all_tasks)
  154. generated_answers = [ item.content for item in generated_answers]
  155. COT_results = []
  156. # return a list of (chunk, question, generated_answer)
  157. for (chunk, question),generated_answer in zip(chunk_questions,generated_answers):
  158. COT_results.append((chunk,question,generated_answer))
  159. return COT_results
  160. def add_chunk_to_dataset(
  161. chunk_questions_zip: list,
  162. api_config: dict,
  163. ds,
  164. ) -> None:
  165. """
  166. Given a chunk and related questions lists, create {Q, A, D} triplets and add them to the dataset.
  167. """
  168. num_distract = api_config["num_distract_docs"]
  169. p = api_config["oracle_p"]
  170. chunks = [chunk for chunk, _ in chunk_questions_zip]
  171. COT_results = generate_COT(chunk_questions_zip,api_config)
  172. for chunk, q , cot in COT_results:
  173. # The COT answer will be used as the label in the fine-tuning stage
  174. datapt = {
  175. "id": None,
  176. "type": "general",
  177. "question": q,
  178. "context": None,
  179. "oracle_context": None,
  180. "cot_answer": cot
  181. }
  182. i = chunks.index(chunk)
  183. datapt["id"] = f"seed_task_{0 if not ds else ds.num_rows}"
  184. # add num_distract distractor docs
  185. docs = [chunk]
  186. indices = list(range(0, len(chunks)))
  187. indices.remove(i)
  188. for j in random.sample(indices, num_distract):
  189. docs.append(chunks[j])
  190. doc_copy = docs.copy()
  191. random.shuffle(docs)
  192. d = {
  193. "title": [],
  194. "sentences": []
  195. }
  196. d["title"].append(["placeholder_title"]*(num_distract+1))
  197. d["sentences"].append(docs)
  198. datapt["context"] = d
  199. datapt["oracle_context"] = chunk
  200. # construct model instruction
  201. context = ""
  202. for doc in docs:
  203. context += "<DOCUMENT>" + str(doc) + "</DOCUMENT>\n"
  204. context += q
  205. # This instruction will be used in the fine-tuning stage
  206. datapt["instruction"] = context
  207. datapt_copy = copy.deepcopy(datapt)
  208. # add to dataset
  209. if not ds:
  210. # init ds
  211. datapt["id"] = [datapt["id"]]
  212. datapt["type"] = [datapt["type"]]
  213. datapt["question"] = [datapt["question"]]
  214. datapt["context"] = [datapt["context"]]
  215. datapt["oracle_context"] = [datapt["oracle_context"]]
  216. datapt["cot_answer"] = [datapt["cot_answer"]]
  217. datapt["instruction"] = [datapt["instruction"]]
  218. ds = Dataset.from_dict(datapt)
  219. else:
  220. ds = ds.add_item(datapt)
  221. # decides whether to add refusal example where the related documents are not provided
  222. oracle = random.uniform(0, 1) < p
  223. if not oracle:
  224. doc_copy[0] = chunks[random.sample(indices, 1)[0]]
  225. random.shuffle(doc_copy)
  226. context = ""
  227. for doc in doc_copy:
  228. context += "<DOCUMENT>" + str(doc) + "</DOCUMENT>\n"
  229. context += q
  230. # This instruction will be used in the fine-tuning stage
  231. datapt_copy["instruction"] = context
  232. datapt_copy["cot_answer"] = "Sorry, I don't know the answer to this question because related documents are not found. Please try again."
  233. ds.add_item(datapt_copy)
  234. return ds