|
@@ -2,14 +2,7 @@
|
|
|
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
|
|
|
|
|
import os
|
|
|
-import re
|
|
|
-import string
|
|
|
from transformers import AutoTokenizer
|
|
|
-import asyncio
|
|
|
-import magic
|
|
|
-from PyPDF2 import PdfReader
|
|
|
-import json
|
|
|
-from doc_processor import split_text_into_chunks
|
|
|
import logging
|
|
|
import json
|
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings
|
|
@@ -18,6 +11,14 @@ from math import ceil
|
|
|
import datasets
|
|
|
from datasets import Dataset, load_dataset
|
|
|
import random
|
|
|
+from langchain_community.document_loaders import SitemapLoader,DirectoryLoader
|
|
|
+from bs4 import BeautifulSoup
|
|
|
+from langchain_openai import ChatOpenAI
|
|
|
+from langchain_core.messages import HumanMessage, SystemMessage
|
|
|
+from langchain_community.llms import VLLMOpenAI
|
|
|
+from langchain_core.prompts import ChatPromptTemplate
|
|
|
+
|
|
|
+
|
|
|
# Initialize logging
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
def strip_str(s: str) -> str:
|
|
@@ -35,82 +36,60 @@ def strip_str(s: str) -> str:
|
|
|
r = i
|
|
|
r += 2
|
|
|
return s[l:min(r, len(s))]
|
|
|
-def read_text_file(file_path):
|
|
|
- try:
|
|
|
- with open(file_path, 'r') as f:
|
|
|
- text = f.read().strip() + ' '
|
|
|
- if len(text) == 0:
|
|
|
- print("File is empty ",file_path)
|
|
|
- return text
|
|
|
- except Exception as e:
|
|
|
- logging.error(f"Error reading text file {file_path}: {e}")
|
|
|
- return ''
|
|
|
-
|
|
|
-def read_pdf_file(file_path):
|
|
|
- try:
|
|
|
- with open(file_path, 'rb') as f:
|
|
|
- pdf_reader = PdfReader(f)
|
|
|
- num_pages = len(pdf_reader.pages)
|
|
|
- file_text = [pdf_reader.pages[page_num].extract_text().strip() + ' ' for page_num in range(num_pages)]
|
|
|
- text = ''.join(file_text)
|
|
|
- if len(text) == 0:
|
|
|
- print("File is empty ",file_path)
|
|
|
- return ''.join(file_text)
|
|
|
- except Exception as e:
|
|
|
- logging.error(f"Error reading PDF file {file_path}: {e}")
|
|
|
- return ''
|
|
|
-
|
|
|
-def read_json_file(file_path):
|
|
|
- try:
|
|
|
- with open(file_path, 'r') as f:
|
|
|
- data = json.load(f)
|
|
|
- # Assuming each item in the list has a 'question' and 'answer' key
|
|
|
- # Concatenating question and answer pairs with a space in between and accumulating them into a single string
|
|
|
- file_text = ' '.join([item['question'].strip() + ' ' + item['answer'].strip() + ' ' for item in data])
|
|
|
- if len(file_text) == 0:
|
|
|
- print("File is empty ",file_path)
|
|
|
- return file_text
|
|
|
- except Exception as e:
|
|
|
- logging.error(f"Error reading JSON file {file_path}: {e}")
|
|
|
- return ''
|
|
|
-
|
|
|
-
|
|
|
-def process_file(file_path):
|
|
|
- print("starting to process file: ", file_path)
|
|
|
- file_type = magic.from_file(file_path, mime=True)
|
|
|
- if file_type in ['text/plain', 'text/markdown', 'JSON']:
|
|
|
- return read_text_file(file_path)
|
|
|
- elif file_type == 'application/pdf':
|
|
|
- return read_pdf_file(file_path)
|
|
|
- else:
|
|
|
- logging.warning(f"Unsupported file type {file_type} for file {file_path}")
|
|
|
- return ''
|
|
|
-def read_file_content(context):
|
|
|
- file_strings = []
|
|
|
-
|
|
|
- for root, _, files in os.walk(context['data_dir']):
|
|
|
- for file in files:
|
|
|
- file_path = os.path.join(root, file)
|
|
|
- file_text = process_file(file_path)
|
|
|
- if file_text:
|
|
|
- file_strings.append(file_text)
|
|
|
- text = '\n'.join(file_strings)
|
|
|
- text = remove_non_printable(text)
|
|
|
- return text
|
|
|
-
|
|
|
-def remove_non_printable(s):
|
|
|
- printable = set(string.printable)
|
|
|
- return ''.join(filter(lambda x: x in printable, s))
|
|
|
+def clean_documents(raw_text):
|
|
|
+ unwanted= ["Technology",
|
|
|
+ "Getting Started",
|
|
|
+ "Trust & Safety",
|
|
|
+ "Community",
|
|
|
+ "Resources",
|
|
|
+ "Skip to main content",
|
|
|
+ "How-to guides"]
|
|
|
+ all_lines = []
|
|
|
+ for line in raw_text.split("\n"):
|
|
|
+ line = line.strip()
|
|
|
+ if line in unwanted or len(line.split()) == 0:
|
|
|
+ continue
|
|
|
+ else:
|
|
|
+ all_lines.append(line)
|
|
|
+ result = " ".join(all_lines)
|
|
|
+ return result
|
|
|
+def clean_text(content: BeautifulSoup) -> str:
|
|
|
+ # Find all 'nav' and 'header' elements in the BeautifulSoup object
|
|
|
+ nav_elements = content.find_all("nav")
|
|
|
+ header_elements = content.find_all("header")
|
|
|
+ mydivs = content.find_all("div", {"role": "list"})
|
|
|
+ # Remove each 'nav' and 'header' element from the BeautifulSoup object
|
|
|
+ for element in nav_elements + header_elements+mydivs:
|
|
|
+ element.decompose()
|
|
|
+ raw_text = content.get_text("\n")
|
|
|
+ return clean_documents(raw_text)
|
|
|
+# Read
|
|
|
+def read_file_content(xml_path: str, data_folder: str) -> str:
|
|
|
+ if xml_path and data_folder:
|
|
|
+ logging.info(f"Error: both xml_path and data_folder are provided, will only read from xml for now")
|
|
|
+ if not xml_path and not data_folder:
|
|
|
+ logging.info(f"Error: both xml_path and data_folder are not provided")
|
|
|
+ return ""
|
|
|
+ if xml_path:
|
|
|
+ if not os.path.exists(xml_path):
|
|
|
+ logging.info(f"Error: {xml_path} does not exist")
|
|
|
+ return ""
|
|
|
+ # Use langchain to load the documents from webpage links in the xml file
|
|
|
+ sitemap_loader = SitemapLoader(web_path=xml_path,is_local=True,parsing_function=clean_text)
|
|
|
+ sitemap_loader.requests_kwargs = {"verify": False}
|
|
|
+ docs = sitemap_loader.load()
|
|
|
+ return "\n".join([doc.page_content for doc in docs])
|
|
|
+ elif len(data_folder) != 0:
|
|
|
+ if not os.path.exists(data_folder):
|
|
|
+ logging.info(f"Error: {data_folder} does not exist")
|
|
|
+ return ""
|
|
|
+ # Use langchain to load the documents from data folder
|
|
|
+ loader = DirectoryLoader(data_folder)
|
|
|
+ docs = loader.load()
|
|
|
+ text = "\n".join([clean_documents(doc.page_content) for doc in docs])
|
|
|
+ return text
|
|
|
|
|
|
|
|
|
-async def generate_question_request(chat_service, api_context: dict, document_content: str, num_questions: int) -> dict:
|
|
|
- if num_questions == 0:
|
|
|
- logging.info(f"Error: num_questions is 0")
|
|
|
- return {}
|
|
|
- prompt_for_system = api_context['question_prompt_template'].format(num_questions=num_questions)
|
|
|
- chat_request_payload = [{'role': 'system', 'content': prompt_for_system}, {'role': 'user', 'content': str(document_content)}]
|
|
|
- # parse the result string to a list of dict that has Question, Answer, Context
|
|
|
- return await chat_service.execute_chat_request_async(api_context, chat_request_payload)
|
|
|
|
|
|
def get_chunks(
|
|
|
text: str,
|
|
@@ -134,55 +113,73 @@ def get_chunks(
|
|
|
return chunks
|
|
|
# read all the files in the data folder, then split them into chunks
|
|
|
# generate questions for each chunk and return zip of chunk and related questions list
|
|
|
-async def generate_questions(chat_service, api_context: dict):
|
|
|
- document_text = read_file_content(api_context)
|
|
|
+def generate_questions(api_config):
|
|
|
+ # get documents from the data folder or xml file
|
|
|
+ api_url = api_config["endpoint_url"]
|
|
|
+ key = api_config["api_key"]
|
|
|
+ document_text = read_file_content(api_config["xml_path"],api_config["data_dir"])
|
|
|
if len(document_text) == 0:
|
|
|
logging.info(f"Error reading files, document_text is {len(document_text)}")
|
|
|
- model_name = "sentence-transformers/all-mpnet-base-v2"
|
|
|
- embedding_model = HuggingFaceEmbeddings(model_name=model_name)
|
|
|
- document_batches = get_chunks(document_text,api_context["chunk_size"],embedding_model)
|
|
|
+ embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2",model_kwargs={'device': 'cuda'})
|
|
|
+ document_batches = get_chunks(document_text,api_config["chunk_size"],embedding_model)
|
|
|
|
|
|
batches_count = len(document_batches)
|
|
|
- total_questions = api_context["questions_per_chunk"] * batches_count
|
|
|
-
|
|
|
- print(f"Questions per batch: {api_context['questions_per_chunk']}, Total questions: {total_questions}, Batches: {batches_count}")
|
|
|
- generation_tasks = []
|
|
|
- for batch_index, batch_content in enumerate(document_batches):
|
|
|
- print(f"len of batch_content: {len(batch_content)}, batch_index: {batch_index}")
|
|
|
- #Distribute extra questions across the first few batches
|
|
|
- if len(batch_content) < 10:
|
|
|
- logging.info("Context is not enough, ignore this batch")
|
|
|
- else:
|
|
|
- print(f"Batch {batch_index + 1} - {api_context['questions_per_chunk']} questions ********")
|
|
|
- try:
|
|
|
- task = generate_question_request(chat_service, api_context, batch_content, api_context["questions_per_chunk"])
|
|
|
- generation_tasks.append(task)
|
|
|
- except Exception as e:
|
|
|
- print(f"Error during chat request execution: {e}")
|
|
|
-
|
|
|
- question_generation_results = await asyncio.gather(*generation_tasks)
|
|
|
+ total_questions = api_config["questions_per_chunk"] * batches_count
|
|
|
+ # use OpenAI API protocol to hanlde the chat request, including local VLLM openai compatible server
|
|
|
+ llm = VLLMOpenAI(
|
|
|
+ openai_api_key=key,
|
|
|
+ openai_api_base=api_url,
|
|
|
+ model_name=api_config["model"],
|
|
|
+ temperature=0.0,
|
|
|
+ max_tokens=250
|
|
|
+ )
|
|
|
+ prompt = api_config['question_prompt_template'].format(num_questions=str(api_config['questions_per_chunk']))
|
|
|
+ system_prompt = SystemMessage(content=prompt)
|
|
|
+ generated_answers = []
|
|
|
+ all_tasks = [[system_prompt, HumanMessage(content=batch)] for batch in document_batches]
|
|
|
+ generated_answers = llm.batch(all_tasks)
|
|
|
+ if len(generated_answers) == 0:
|
|
|
+ logging.error("No model answers generated. Please check the input context or model configuration in ",model_name)
|
|
|
+ return []
|
|
|
final_result = []
|
|
|
- for result in question_generation_results:
|
|
|
+ for result in generated_answers:
|
|
|
queries = result.split('\n')
|
|
|
queries = [strip_str(q) for q in queries]
|
|
|
queries = [q for q in queries if any(c.isalpha() for c in q)]
|
|
|
- if len(queries) > int(api_context['questions_per_chunk']):
|
|
|
+ if len(queries) > int(api_config['questions_per_chunk']):
|
|
|
# As the model may have unrelated question at the begining of the result
|
|
|
# if queries is more than questions_per_chunk, then we need to truncate it and only keep last questions_per_chunk lines
|
|
|
- queries = queries[-int(api_context['questions_per_chunk']):]
|
|
|
+ queries = queries[-int(api_config['questions_per_chunk']):]
|
|
|
final_result.append(queries)
|
|
|
return list(zip(document_batches,final_result))
|
|
|
|
|
|
-async def generate_COT(chat_service, api_context: dict, document_content: str, question: str) -> dict:
|
|
|
- prompt = api_context['COT_prompt_template'].format(question=question,context=str(document_content))
|
|
|
- chat_request_payload = [{"role": "system", "content": "You are a helpful question answerer who can provide an answer given a question and relevant context."}]
|
|
|
- chat_request_payload.append({"role": "user", "content": prompt})
|
|
|
- response = await chat_service.execute_chat_request_async(api_context, chat_request_payload)
|
|
|
- return (document_content,question,response)
|
|
|
-async def add_chunk_to_dataset(
|
|
|
+# Generate COT answer for each question given the chunk context
|
|
|
+def generate_COT(chunk_questions_zip,api_config) -> dict:
|
|
|
+ all_tasks = []
|
|
|
+ chunk_questions = []
|
|
|
+ for document_content,questions in chunk_questions_zip:
|
|
|
+ for question in questions:
|
|
|
+ prompt = api_config['COT_prompt_template'].format(question=question,context=str(document_content))
|
|
|
+ all_tasks.append(prompt)
|
|
|
+ chunk_questions.append((document_content,question))
|
|
|
+ # use OpenAI API protocol to hanlde the chat request, including local VLLM openai compatible server
|
|
|
+ llm = VLLMOpenAI(
|
|
|
+ openai_api_key=api_config["api_key"],
|
|
|
+ openai_api_base=api_config["endpoint_url"],
|
|
|
+ model_name=api_config["model"],
|
|
|
+ temperature=0.0,
|
|
|
+ max_tokens=350
|
|
|
+ )
|
|
|
+ generated_answers = llm.batch(all_tasks)
|
|
|
+ COT_results = []
|
|
|
+ # return a list of (chunk, question, generated_answer)
|
|
|
+ for (chunk, question),generated_answer in zip(chunk_questions,generated_answers):
|
|
|
+ COT_results.append((chunk,question,generated_answer))
|
|
|
+ return COT_results
|
|
|
+
|
|
|
+def add_chunk_to_dataset(
|
|
|
chunk_questions_zip: list,
|
|
|
- context: dict,
|
|
|
- chat_service,
|
|
|
+ api_config: dict,
|
|
|
ds,
|
|
|
num_distract: int = 3,
|
|
|
p: float = 0.8,
|
|
@@ -192,14 +189,9 @@ async def add_chunk_to_dataset(
|
|
|
"""
|
|
|
COT_tasks = []
|
|
|
chunks = [chunk for chunk, _ in chunk_questions_zip]
|
|
|
- for i, chunk_questions in enumerate(chunk_questions_zip):
|
|
|
- chunk, questions = chunk_questions
|
|
|
- # generate COT answer for each question given the chunk context
|
|
|
- for question in questions:
|
|
|
- COT_tasks.append(generate_COT(chat_service, context, chunk, question))
|
|
|
- COT_results = await asyncio.gather(*COT_tasks)
|
|
|
+ COT_results = generate_COT(chunk_questions_zip,api_config)
|
|
|
for chunk, q , cot in COT_results:
|
|
|
- # The COT answer will be used in the fine-tuning stage
|
|
|
+ # The COT answer will be used as the label in the fine-tuning stage
|
|
|
datapt = {
|
|
|
"id": None,
|
|
|
"type": "general",
|