Просмотр исходного кода

changed yaml to get langchain working

Kai Wu 1 год назад
Родитель
Сommit
cd5ae9ec63

Разница между файлами не показана из-за своего большого размера
+ 13 - 15
recipes/use_cases/end2end-recipes/raft/README.md


+ 0 - 80
recipes/use_cases/end2end-recipes/raft/chat_utils.py

@@ -1,80 +0,0 @@
-import asyncio
-import logging
-from abc import ABC, abstractmethod
-from octoai.client import OctoAI
-from functools import partial
-from openai import OpenAI
-import json
-# Configure logging to include the timestamp, log level, and message
-logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
-# Since OctoAI has different naming for llama models, create this mapping to get huggingface offical model name given OctoAI names.
-MODEL_NAME_MAPPING={"meta-llama-3-70b-instruct":"meta-llama/Meta-Llama-3-70B-Instruct",
-"meta-llama-3-8b-instruct":"meta-llama/Meta-Llama-3-8B-Instruct","llama-2-7b-chat":"meta-llama/Llama-2-7b-chat-hf"
-,"llama-2-70b-chat":"meta-llama/Llama-2-70b-chat-hf"}
-# Manage rate limits with throttling
-rate_limit_threshold = 2000
-allowed_concurrent_requests = int(rate_limit_threshold * 0.75)
-request_limiter = asyncio.Semaphore(allowed_concurrent_requests)
-class ChatService(ABC):
-    @abstractmethod
-    async def execute_chat_request_async(self, api_context: dict, chat_request):
-        pass
-def strip_str(s: str) -> str:
-    """
-    Helper function for helping format strings returned by GPT-4.
-    """
-    l, r = 0, len(s)-1
-    beg_found = False
-    for i in range(len(s)):
-        if s[i].isalpha():
-            if not beg_found:
-                l = i
-                beg_found = True
-            else:
-                r = i
-    r += 2
-    return s[l:min(r, len(s))]
-# Please implement your own chat service class here.
-# The class should inherit from the ChatService class and implement the execute_chat_request_async method.
-# The following are two example chat service classes that you can use as a reference.
-class OctoAIChatService(ChatService):
-    async def execute_chat_request_async(self, api_context: dict, chat_request):
-        async with request_limiter:
-            try:
-                event_loop = asyncio.get_running_loop()
-                client = OctoAI(api_context['api_key'])
-                api_chat_call = partial(
-                    client.chat.completions.create,
-                    model=api_context['model'],
-                    messages=chat_request,
-                    temperature=0.0
-                )
-                response = await event_loop.run_in_executor(None, api_chat_call)
-                assistant_response = next((choice.message.content for choice in response.choices if choice.message.role == 'assistant'), "")
-                return assistant_response
-            except Exception as error:
-                logging.error(f"Error during chat request execution: {error}",exc_info=True)
-                return ""
-# Use the local vllm openai compatible server for generating question/answer pairs to make API call syntax consistent
-# please read for more detail:https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html.
-class VllmChatService(ChatService):
-    async def execute_chat_request_async(self, api_context: dict, chat_request):
-        try:
-            event_loop = asyncio.get_running_loop()
-            if api_context["model"] in MODEL_NAME_MAPPING:
-                model_name = MODEL_NAME_MAPPING[api_context['model']]
-            else:
-                model_name = api_context['model']
-            client = OpenAI(api_key=api_context['api_key'], base_url="http://localhost:"+ str(api_context['endpoint'])+"/v1")
-            api_chat_call = partial(
-                client.chat.completions.create,
-                model=model_name,
-                messages=chat_request,
-                temperature=0.0
-            )
-            response = await event_loop.run_in_executor(None, api_chat_call)
-            assistant_response = next((choice.message.content for choice in response.choices if choice.message.role == 'assistant'), "")
-            return assistant_response
-        except Exception as error:
-            logging.error(f"Error during chat request execution: {error}",exc_info=True)
-            return ""

+ 0 - 8
recipes/use_cases/end2end-recipes/raft/config.py

@@ -8,12 +8,4 @@ def load_config(config_path: str = "./config.yaml"):
     # Read the YAML configuration file
     with open(config_path, "r") as file:
         config = yaml.safe_load(file)
-    # Set the API key from the environment variable
-    try:
-        config["api_key"] = os.environ["OCTOAI_API_TOKEN"]
-    except KeyError:
-        print("API token did not found, please set the OCTOAI_API_TOKEN environment variable if using OctoAI, otherwise set api_key to default EMPTY")
-        # local Vllm endpoint did not need API key, so set the API key to "EMPTY" if OCTOAI_API_TOKEN not found
-        config["api_key"] = "EMPTY"
     return config
-

+ 137 - 0
recipes/use_cases/end2end-recipes/raft/data_urls.xml

@@ -0,0 +1,137 @@
+<urlset xmlns="http://www.sitemaps.org/schemas/sitemap/0.9">
+<url>
+<loc>http://llama.meta.com/</loc>
+</url>
+<url>
+<loc>http://llama.meta.com/use-policy/</loc>
+</url>
+<url>
+<loc>http://llama.meta.com/responsible-use-guide/</loc>
+</url>
+<url>
+<loc>http://llama.meta.com/llama2/</loc>
+</url>
+<url>
+<loc>http://llama.meta.com/llama2/license/</loc>
+</url>
+<url>
+<loc>http://llama.meta.com/llama2/use-policy/</loc>
+</url>
+<url>
+<loc>http://llama.meta.com/license/</loc>
+</url>
+<url>
+<loc>http://llama.meta.com/code-llama/</loc>
+</url>
+<url>
+<loc>http://llama.meta.com/llama3/</loc>
+</url>
+<url>
+<loc>http://llama.meta.com/llama3/license/</loc>
+</url>
+<url>
+<loc>http://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3</loc>
+</url>
+<url>
+<loc>http://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-guard-2</loc>
+</url>
+<url>
+<loc>http://llama.meta.com/docs/model-cards-and-prompt-formats/meta-code-llama-70b</loc>
+</url>
+<url>
+<loc>http://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-guard-1</loc>
+</url>
+<url>
+<loc>http://llama.meta.com/docs/model-cards-and-prompt-formats/meta-code-llama</loc>
+</url>
+<url>
+<loc>http://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-2</loc>
+</url>
+<url>
+<loc>http://llama.meta.com/docs/getting_the_models</loc>
+</url>
+<url>
+<loc>http://llama.meta.com/docs/getting-the-models/hugging-face</loc>
+</url>
+<url>
+<loc>http://llama.meta.com/docs/getting-the-models/kaggle</loc>
+</url>
+<url>
+<loc>http://llama.meta.com/docs/llama-everywhere</loc>
+</url>
+<url>
+<loc>http://llama.meta.com/docs/llama-everywhere/running-meta-llama-on-linux/</loc>
+</url>
+<url>
+<loc>http://llama.meta.com/docs/llama-everywhere/running-meta-llama-on-windows/</loc>
+</url>
+<url>
+<loc>http://llama.meta.com/docs/llama-everywhere/running-meta-llama-on-mac/</loc>
+</url>
+<url>
+<loc>http://llama.meta.com/docs/llama-everywhere/running-meta-llama-in-the-cloud/</loc>
+</url>
+<url>
+<loc>http://llama.meta.com/docs/how-to-guides/fine-tuning</loc>
+</url>
+<url>
+<loc>http://llama.meta.com/docs/how-to-guides/quantization</loc>
+</url>
+<url>
+<loc>http://llama.meta.com/docs/how-to-guides/prompting</loc>
+</url>
+<url>
+<loc>http://llama.meta.com/docs/how-to-guides/validation</loc>
+</url>
+<url>
+<loc>http://llama.meta.com/docs/integration-guides/meta-code-llama</loc>
+</url>
+<url>
+<loc>http://llama.meta.com/docs/integration-guides/langchain</loc>
+</url>
+<url>
+<loc>http://llama.meta.com/docs/integration-guides/llamaindex</loc>
+</url>
+<url>
+<loc>http://raw.githubusercontent.com/meta-llama/llama-recipes/main/README.md</loc>
+</url>
+<url>
+<loc>http://raw.githubusercontent.com/meta-llama/llama/main/MODEL_CARD.md</loc>
+</url>
+<url>
+<loc>http://raw.githubusercontent.com/meta-llama/llama/main/README.md</loc>
+</url>
+<url>
+<loc>http://raw.githubusercontent.com/meta-llama/llama/main/LICENSE.md</loc>
+</url>
+<url>
+<loc>http://raw.githubusercontent.com/meta-llama/llama3/main/MODEL_CARD.md</loc>
+</url>
+<url>
+<loc>http://raw.githubusercontent.com/meta-llama/llama3/main/README.md</loc>
+</url>
+<url>
+<loc>http://raw.githubusercontent.com/meta-llama/llama3/main/LICENSE.md</loc>
+</url>
+<url>
+<loc>http://raw.githubusercontent.com/meta-llama/codellama/main/MODEL_CARD.md</loc>
+</url>
+<url>
+<loc>http://raw.githubusercontent.com/meta-llama/codellama/main/README.md</loc>
+</url>
+<url>
+<loc>http://raw.githubusercontent.com/meta-llama/PurpleLlama/main/README.md</loc>
+</url>
+<url>
+<loc>http://raw.githubusercontent.com/meta-llama/PurpleLlama/main/Llama-Guard2/MODEL_CARD.md</loc>
+</url>
+<url>
+<loc>http://raw.githubusercontent.com/meta-llama/PurpleLlama/main/Llama-Guard2/README.md</loc>
+</url>
+<url>
+<loc>http://raw.githubusercontent.com/meta-llama/PurpleLlama/main/Llama-Guard/MODEL_CARD.md</loc>
+</url>
+<url>
+<loc>http://raw.githubusercontent.com/meta-llama/PurpleLlama/main/Llama-Guard/README.md</loc>
+</url>
+</urlset>

+ 0 - 47
recipes/use_cases/end2end-recipes/raft/doc_processor.py

@@ -1,47 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
-
-# Assuming result_average_token is a constant, use UPPER_CASE for its name to follow Python conventions
-AVERAGE_TOKENS_PER_RESULT = 100
-
-def get_token_limit_for_model(model: str) -> int:
-    """Returns the token limit for a given model."""
-    if model == "llama-2-13b-chat" or model == "llama-2-70b-chat":
-        return 4096
-    else:
-        return 8192
-
-def calculate_num_tokens_for_message(encoded_text) -> int:
-    """Calculates the number of tokens used by a message."""
-    # Added 3 to account for priming with assistant's reply, as per original comment
-    return len(encoded_text) + 3
-
-
-def split_text_into_chunks(context: dict, text: str, tokenizer) -> list[str]:
-    """Splits a long text into substrings based on token length constraints, adjusted for question generation."""
-    # Adjusted approach to calculate max tokens available for text chunks
-    encoded_text = tokenizer(text, return_tensors="pt", padding=True)["input_ids"]
-    encoded_text = encoded_text.squeeze()
-    model_token_limit = get_token_limit_for_model(context["model"])
-
-    tokens_for_questions = calculate_num_tokens_for_message(encoded_text)
-    estimated_tokens_per_question = AVERAGE_TOKENS_PER_RESULT
-    estimated_total_question_tokens = estimated_tokens_per_question * context["total_questions"]
-    # Ensure there's a reasonable minimum chunk size
-    max_tokens_for_text = max(model_token_limit - tokens_for_questions - estimated_total_question_tokens, model_token_limit // 10)
-
-    chunks, current_chunk = [], []
-    print(f"Splitting text into chunks of {max_tokens_for_text} tokens, encoded_text {len(encoded_text)}", flush=True)
-    for token in encoded_text:
-        if len(current_chunk) >= max_tokens_for_text:
-            chunks.append(tokenizer.decode(current_chunk).strip())
-            current_chunk = []
-        else:
-            current_chunk.append(token)
-
-    if current_chunk:
-        chunks.append(tokenizer.decode(current_chunk).strip())
-
-    print(f"Number of chunks in the processed text: {len(chunks)}", flush=True)
-
-    return chunks

+ 1 - 0
recipes/use_cases/end2end-recipes/raft/eval_raft.py

@@ -8,6 +8,7 @@ from config import load_config
 import json
 from itertools import chain
 from langchain_community.llms import VLLMOpenAI
+
 from langchain_community.embeddings import HuggingFaceEmbeddings
 from langchain_community.vectorstores import FAISS
 from langchain.text_splitter import RecursiveCharacterTextSplitter

+ 32 - 31
recipes/use_cases/end2end-recipes/raft/raft.py

@@ -1,15 +1,10 @@
-import mdc
-from mdc import MDC
 import logging
 from typing import Literal, Any
-from openai import OpenAI
 import json
 import random
-import os, shutil
+import os
 import argparse
-import asyncio
 from raft_utils import generate_questions, add_chunk_to_dataset
-from chat_utils import OctoAIChatService, VllmChatService
 from format import DatasetConverter, datasetFormats, outputDatasetTypes
 from config import load_config
 
@@ -17,27 +12,23 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
 
 NUM_DISTRACT_DOCS = 5 # number of distracting documents to add to each chunk
 ORCALE_P = 0.8 # probability of related documents to be added to each chunk
-async def main(context):
+def main(api_config):
     ds = None
-    if context["endpoint"]:
-        chat_service = VllmChatService()
-    else:
-        chat_service = OctoAIChatService()
     try:
         logging.info("Starting to generate question pair.")
         # Generate questions as list for each chunk
-        chunk_questions_zip = await generate_questions(chat_service, context)
+        chunk_questions_zip = generate_questions(api_config)
         if not chunk_questions_zip:
-            logging.warning("No questions generated from text. Please check the input context or model configuration.")
+            logging.warning("No questions generated from text. Please check the api_config or model configuration.")
             return
         for chunk, questions in chunk_questions_zip:
             logging.info(f"Chunk: {chunk}, question length: {len(questions)}")
             for question in questions:
                 logging.info(f"Question: {question}")
         logging.info(f"Successfully generated {sum([len(q) for c,q in chunk_questions_zip])} question/answer pairs.")
-        ds = await add_chunk_to_dataset(chunk_questions_zip,context, chat_service,ds,NUM_DISTRACT_DOCS, ORCALE_P)
+        ds = add_chunk_to_dataset(chunk_questions_zip,api_config,ds,NUM_DISTRACT_DOCS, ORCALE_P)
         ds.save_to_disk(args.output)
-        logging.info(f"Data successfully written to {context['output']}. Process completed.")
+        logging.info(f"Data successfully written to {api_config['output']}. Process completed.")
         formatter = DatasetConverter()
 
         # Extract format specific params
@@ -49,7 +40,7 @@ async def main(context):
 def parse_arguments():
     # Define command line arguments for the script
     parser = argparse.ArgumentParser(
-        description="Generate question/answer pairs from documentation."
+        description="Generate RAFT question/answer/context pairs from documentation."
     )
     parser.add_argument(
         "-t", "--questions_per_chunk",
@@ -59,8 +50,7 @@ def parse_arguments():
     )
     parser.add_argument(
         "-m", "--model",
-        choices=["meta-llama-3-70b-instruct","meta-llama-3-8b-instruct","llama-2-13b-chat", "llama-2-70b-chat"],
-        default="meta-llama-3-70b-instruct",
+        default="meta-llama/Meta-Llama-3-70B-Instruct",
         help="Select the model to use for generation."
     )
     parser.add_argument(
@@ -69,10 +59,16 @@ def parse_arguments():
         help="Set the configuration file path that has system prompt along with language, dataset path and number of questions."
     )
     parser.add_argument(
-        "-v", "--vllm_endpoint",
-        default=None,
-        type=int,
-        help="If a port is specified, then use local vllm endpoint for generating question/answer pairs."
+        "-u", "--endpoint_url",
+        default="http://localhost:8001/v1",
+        type=str,
+        help="LLM API url for generating question/answer pairs."
+    )
+    parser.add_argument(
+        "-k", "--api_key",
+        default="EMPTY",
+        type=str,
+        help="LLM API key for generating question/answer pairs."
     )
     parser.add_argument("--chunk_size", type=int, default=512, help="The size of each chunk in number of tokens")
     parser.add_argument("-o","--output", type=str, default="./", help="The path at which to save the dataset")
@@ -84,13 +80,18 @@ if __name__ == "__main__":
     logging.info("Initializing the process and loading configuration...")
     args = parse_arguments()
 
-    context = load_config(args.config_path)
-    context["questions_per_chunk"] = args.questions_per_chunk
-    context["model"] = args.model
-    context["chunk_size"] = args.chunk_size
-    context["endpoint"] = args.vllm_endpoint
-    context["output"] = args.output
+    api_config = load_config(args.config_path)
+    api_config["questions_per_chunk"] = args.questions_per_chunk
+    api_config["model"] = args.model
+    api_config["chunk_size"] = args.chunk_size
+    api_config["endpoint_url"] = args.endpoint_url
+    api_config["output"] = args.output
+    api_config["api_key"] = args.api_key
+    # if OPENAI_API_KEY is defined in the system environment, use it as the API key
+    if os.environ.get('API_KEY') is not None:
+        api_config["api_key"] = os.environ["API_KEY"]
     logging.info(f"Configuration loaded. Generating {args.questions_per_chunk} question per chunk using model '{args.model}'.")
-    if context["endpoint"]:
-        logging.info(f"Use local vllm service at port: '{args.vllm_endpoint}'.")
-    asyncio.run(main(context))
+    logging.info(f"Chunk size: {args.chunk_size}.")
+    logging.info(f"Will use endpoint_url: {args.endpoint_url}.")
+    logging.info(f"Output will be written to {args.output}.")
+    main(api_config)

+ 36 - 26
recipes/use_cases/end2end-recipes/raft/raft.yaml

@@ -1,32 +1,42 @@
 COT_prompt_template: >
-  Question: {question}\nContext: {context}\n
-        Answer this question using the information given in the context above. Here is things to pay attention to:
-        - First provide step-by-step reasoning on how to answer the question.
-        - In the reasoning, if you need to copy paste some sentences from the context, include them in ##begin_quote## and ##end_quote##. This would mean that things outside of ##begin_quote## and ##end_quote## are not directly copy paste from the context.
-        - End your response with final answer in the form <ANSWER>: $answer, the answer should be succinct.
-        You MUST begin your final answer with the tag "<ANSWER>:
+  <|begin_of_text|><|start_header_id|>system<|end_header_id|> Answer the following question using the information given in the context below. Here is things to pay attention to:
+    - First provide step-by-step reasoning on how to answer the question.
+    - In the reasoning, if you need to copy paste some sentences from the context, include them in ##begin_quote## and ##end_quote##. This would mean that things outside of ##begin_quote## and ##end_quote## are not directly copy paste from the context.
+    - End your response with final answer in the form <ANSWER>: $answer, the answer should be succinct.
+    You MUST begin your final answer with the tag "<ANSWER>:<|eot_id|>
+  <|start_header_id|>user<|end_header_id|>
+  Question: {question}\nContext: {context}\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>
 
-# question_prompt_template: >
-#   You are a synthetic question-answer pair generator. Given a chunk of context about
-#   some topic(s), generate {num_questions} example questions a user could ask and would be answered
-#   \using information from the chunk. For example, if the given context was a Wikipedia
-#   paragraph about the United States, an example question could be 'How many states are
-#   in the United States?
-#   The questions should be able to be answered in a few words or less. Include only the
-#   questions in your response.
 question_prompt_template: >
-  You are a language model skilled in creating quiz questions.
-  You will be provided with a document,
-  read it and please generate question and answer pairs that are most likely be asked by a user of Llama language models,
-  which includes LLama, Llama2, Meta Llama3, Code Llama, Meta Llama Guard 1,	Meta Llama Guard 2,
-  Output only the questions related to Llama:
-  please make sure you follow those rules:
-  1. Generate {num_questions} question answer pairs, you can generate less answer if there is nothing related to model, training, fine-tuning and evaluation details of Llama language models, .
-  2. The questions can be answered based *solely* on the given passage.
-  3. Avoid asking questions with similar meaning.
-  4. Never use any abbreviation.
-  5. Include only the questions in your response.
+  You are a synthetic question-answer pair generator. Given a chunk of context about
+  some topic(s), generate {num_questions} example questions a user could ask and would be answered
+  \using information from the chunk. For example, if the given context was a Wikipedia
+  paragraph about the United States, an example question could be 'How many states are
+  in the United States?
+  The questions should be able to be answered in 100 words or less. Include only the
+  questions in your response.
+
+# question_prompt_template: >
+#   You are a language model skilled in creating quiz questions.
+#   You will be provided with a document,
+#   read it and please generate question and answer pairs that are most likely be asked by a user of Llama language models
+#   which includes LLama, Llama2, Meta Llama3, Code Llama, Meta Llama Guard 1,	Meta Llama Guard 2
+#   Output only the questions related to Llama:
+#   please make sure you follow those rules:
+#   1. Generate {num_questions} question answer pairs, you can generate less answer if there is nothing related to model, training, fine-tuning and evaluation details of Llama language models, .
+#   2. The questions can be answered based *solely* on the given passage.
+#   3. Avoid asking questions with similar meaning.
+#   4. Never use any abbreviation.
+#   5. Include only the questions in your response.
 
 data_dir: "./data"
 
-num_questions: 2
+xml_path: ""
+
+chunk_size: 512
+
+questions_per_chunk: 3
+
+num_distract_docs: 5 # number of distracting documents to add to each chunk
+
+orcale_p: 0.8 # probability of related documents to be added to each chunk

+ 115 - 123
recipes/use_cases/end2end-recipes/raft/raft_utils.py

@@ -2,14 +2,7 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 import os
-import re
-import string
 from transformers import  AutoTokenizer
-import asyncio
-import magic
-from PyPDF2 import PdfReader
-import json
-from doc_processor import split_text_into_chunks
 import logging
 import json
 from langchain_community.embeddings import HuggingFaceEmbeddings
@@ -18,6 +11,14 @@ from math import ceil
 import datasets
 from datasets import Dataset, load_dataset
 import random
+from langchain_community.document_loaders import SitemapLoader,DirectoryLoader
+from bs4 import BeautifulSoup
+from langchain_openai import ChatOpenAI
+from langchain_core.messages import HumanMessage, SystemMessage
+from langchain_community.llms import VLLMOpenAI
+from langchain_core.prompts import ChatPromptTemplate
+
+
 # Initialize logging
 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 def strip_str(s: str) -> str:
@@ -35,82 +36,60 @@ def strip_str(s: str) -> str:
                 r = i
     r += 2
     return s[l:min(r, len(s))]
-def read_text_file(file_path):
-    try:
-        with open(file_path, 'r') as f:
-            text = f.read().strip() + ' '
-            if len(text) == 0:
-                print("File is empty ",file_path)
-            return text
-    except Exception as e:
-        logging.error(f"Error reading text file {file_path}: {e}")
-    return ''
-
-def read_pdf_file(file_path):
-    try:
-        with open(file_path, 'rb') as f:
-            pdf_reader = PdfReader(f)
-            num_pages = len(pdf_reader.pages)
-            file_text = [pdf_reader.pages[page_num].extract_text().strip() + ' ' for page_num in range(num_pages)]
-            text = ''.join(file_text)
-            if len(text) == 0:
-                print("File is empty ",file_path)
-            return ''.join(file_text)
-    except Exception as e:
-        logging.error(f"Error reading PDF file {file_path}: {e}")
-    return ''
-
-def read_json_file(file_path):
-    try:
-        with open(file_path, 'r') as f:
-            data = json.load(f)
-            # Assuming each item in the list has a 'question' and 'answer' key
-            # Concatenating question and answer pairs with a space in between and accumulating them into a single string
-            file_text = ' '.join([item['question'].strip() + ' ' + item['answer'].strip() + ' ' for item in data])
-            if len(file_text) == 0:
-                print("File is empty ",file_path)
-            return file_text
-    except Exception as e:
-        logging.error(f"Error reading JSON file {file_path}: {e}")
-    return ''
-
-
-def process_file(file_path):
-    print("starting to process file: ", file_path)
-    file_type = magic.from_file(file_path, mime=True)
-    if file_type in ['text/plain', 'text/markdown', 'JSON']:
-        return read_text_file(file_path)
-    elif file_type == 'application/pdf':
-        return read_pdf_file(file_path)
-    else:
-        logging.warning(f"Unsupported file type {file_type} for file {file_path}")
-        return ''
-def read_file_content(context):
-    file_strings = []
-
-    for root, _, files in os.walk(context['data_dir']):
-        for file in files:
-            file_path = os.path.join(root, file)
-            file_text = process_file(file_path)
-            if file_text:
-                file_strings.append(file_text)
-    text = '\n'.join(file_strings)
-    text = remove_non_printable(text)
-    return text
-
-def remove_non_printable(s):
-    printable = set(string.printable)
-    return ''.join(filter(lambda x: x in printable, s))
+def clean_documents(raw_text):
+    unwanted= ["Technology",
+    "Getting Started",
+    "Trust & Safety",
+    "Community",
+    "Resources",
+    "Skip to main content",
+    "How-to guides"]
+    all_lines = []
+    for line in raw_text.split("\n"):
+        line = line.strip()
+        if line in unwanted or len(line.split()) == 0:
+            continue
+        else:
+            all_lines.append(line)
+    result = " ".join(all_lines)
+    return result
+def clean_text(content: BeautifulSoup) -> str:
+    # Find all 'nav' and 'header' elements in the BeautifulSoup object
+    nav_elements = content.find_all("nav")
+    header_elements = content.find_all("header")
+    mydivs = content.find_all("div", {"role": "list"})
+    # Remove each 'nav' and 'header' element from the BeautifulSoup object
+    for element in nav_elements + header_elements+mydivs:
+        element.decompose()
+    raw_text = content.get_text("\n")
+    return clean_documents(raw_text)
+# Read
+def read_file_content(xml_path: str, data_folder: str) -> str:
+    if xml_path and data_folder:
+        logging.info(f"Error: both xml_path and data_folder are provided, will only read from xml for now")
+    if not xml_path and not data_folder:
+        logging.info(f"Error: both xml_path and data_folder are not provided")
+        return ""
+    if xml_path:
+        if not os.path.exists(xml_path):
+            logging.info(f"Error: {xml_path} does not exist")
+            return ""
+        # Use langchain to load the documents from webpage links in the xml file
+        sitemap_loader = SitemapLoader(web_path=xml_path,is_local=True,parsing_function=clean_text)
+        sitemap_loader.requests_kwargs = {"verify": False}
+        docs = sitemap_loader.load()
+        return "\n".join([doc.page_content for doc in docs])
+    elif len(data_folder) != 0:
+        if not os.path.exists(data_folder):
+            logging.info(f"Error: {data_folder} does not exist")
+            return ""
+        # Use langchain to load the documents from data folder
+        loader = DirectoryLoader(data_folder)
+        docs = loader.load()
+        text = "\n".join([clean_documents(doc.page_content) for doc in docs])
+        return text
 
 
-async def generate_question_request(chat_service, api_context: dict, document_content: str, num_questions: int) -> dict:
-    if num_questions == 0:
-        logging.info(f"Error: num_questions is 0")
-        return {}
-    prompt_for_system = api_context['question_prompt_template'].format(num_questions=num_questions)
-    chat_request_payload = [{'role': 'system', 'content': prompt_for_system}, {'role': 'user', 'content': str(document_content)}]
-    # parse the result string to a list of dict that has Question, Answer, Context
-    return await chat_service.execute_chat_request_async(api_context, chat_request_payload)
 
 def get_chunks(
     text: str,
@@ -134,55 +113,73 @@ def get_chunks(
     return chunks
 # read all the files in the data folder, then split them into chunks
 # generate questions for each chunk and return zip of chunk and related questions list
-async def generate_questions(chat_service, api_context: dict):
-    document_text = read_file_content(api_context)
+def generate_questions(api_config):
+    # get documents from the data folder or xml file
+    api_url = api_config["endpoint_url"]
+    key = api_config["api_key"]
+    document_text = read_file_content(api_config["xml_path"],api_config["data_dir"])
     if len(document_text) == 0:
         logging.info(f"Error reading files, document_text is {len(document_text)}")
-    model_name = "sentence-transformers/all-mpnet-base-v2"
-    embedding_model = HuggingFaceEmbeddings(model_name=model_name)
-    document_batches = get_chunks(document_text,api_context["chunk_size"],embedding_model)
+    embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2",model_kwargs={'device': 'cuda'})
+    document_batches = get_chunks(document_text,api_config["chunk_size"],embedding_model)
 
     batches_count = len(document_batches)
-    total_questions = api_context["questions_per_chunk"] * batches_count
-
-    print(f"Questions per batch: {api_context['questions_per_chunk']}, Total questions: {total_questions}, Batches: {batches_count}")
-    generation_tasks = []
-    for batch_index, batch_content in enumerate(document_batches):
-        print(f"len of batch_content: {len(batch_content)}, batch_index: {batch_index}")
-        #Distribute extra questions across the first few batches
-        if len(batch_content) < 10:
-            logging.info("Context is not enough, ignore this batch")
-        else:
-            print(f"Batch {batch_index + 1} - {api_context['questions_per_chunk']} questions ********")
-            try:
-                task = generate_question_request(chat_service, api_context, batch_content, api_context["questions_per_chunk"])
-                generation_tasks.append(task)
-            except Exception as e:
-                print(f"Error during chat request execution: {e}")
-
-    question_generation_results = await asyncio.gather(*generation_tasks)
+    total_questions = api_config["questions_per_chunk"] * batches_count
+    # use OpenAI API protocol to hanlde the chat request, including local VLLM openai compatible server
+    llm = VLLMOpenAI(
+        openai_api_key=key,
+        openai_api_base=api_url,
+        model_name=api_config["model"],
+        temperature=0.0,
+        max_tokens=250
+        )
+    prompt = api_config['question_prompt_template'].format(num_questions=str(api_config['questions_per_chunk']))
+    system_prompt = SystemMessage(content=prompt)
+    generated_answers = []
+    all_tasks = [[system_prompt, HumanMessage(content=batch)] for batch in document_batches]
+    generated_answers = llm.batch(all_tasks)
+    if len(generated_answers) == 0:
+        logging.error("No model answers generated. Please check the input context or model configuration in ",model_name)
+        return []
     final_result = []
-    for result in question_generation_results:
+    for result in generated_answers:
         queries = result.split('\n')
         queries = [strip_str(q) for q in queries]
         queries = [q for q in queries if any(c.isalpha() for c in q)]
-        if len(queries) > int(api_context['questions_per_chunk']):
+        if len(queries) > int(api_config['questions_per_chunk']):
             # As the model may have unrelated question at the begining of the result
             # if queries is more than questions_per_chunk, then we need to truncate it and only keep last questions_per_chunk lines
-            queries = queries[-int(api_context['questions_per_chunk']):]
+            queries = queries[-int(api_config['questions_per_chunk']):]
         final_result.append(queries)
     return list(zip(document_batches,final_result))
 
-async def generate_COT(chat_service, api_context: dict, document_content: str, question: str) -> dict:
-    prompt = api_context['COT_prompt_template'].format(question=question,context=str(document_content))
-    chat_request_payload = [{"role": "system", "content": "You are a helpful question answerer who can provide an answer given a question and relevant context."}]
-    chat_request_payload.append({"role": "user", "content": prompt})
-    response = await chat_service.execute_chat_request_async(api_context, chat_request_payload)
-    return (document_content,question,response)
-async def add_chunk_to_dataset(
+# Generate COT answer for each question given the chunk context
+def generate_COT(chunk_questions_zip,api_config) -> dict:
+    all_tasks = []
+    chunk_questions = []
+    for document_content,questions in chunk_questions_zip:
+        for question in questions:
+            prompt = api_config['COT_prompt_template'].format(question=question,context=str(document_content))
+            all_tasks.append(prompt)
+            chunk_questions.append((document_content,question))
+    # use OpenAI API protocol to hanlde the chat request, including local VLLM openai compatible server
+    llm = VLLMOpenAI(
+        openai_api_key=api_config["api_key"],
+        openai_api_base=api_config["endpoint_url"],
+        model_name=api_config["model"],
+        temperature=0.0,
+        max_tokens=350
+        )
+    generated_answers = llm.batch(all_tasks)
+    COT_results = []
+    # return a list of (chunk, question, generated_answer)
+    for (chunk, question),generated_answer in zip(chunk_questions,generated_answers):
+        COT_results.append((chunk,question,generated_answer))
+    return COT_results
+
+def add_chunk_to_dataset(
     chunk_questions_zip: list,
-    context: dict,
-    chat_service,
+    api_config: dict,
     ds,
     num_distract: int = 3,
     p: float = 0.8,
@@ -192,14 +189,9 @@ async def add_chunk_to_dataset(
     """
     COT_tasks = []
     chunks = [chunk for chunk, _ in chunk_questions_zip]
-    for i, chunk_questions in enumerate(chunk_questions_zip):
-        chunk, questions = chunk_questions
-        # generate COT answer for each question given the chunk context
-        for question in questions:
-            COT_tasks.append(generate_COT(chat_service, context, chunk, question))
-    COT_results = await asyncio.gather(*COT_tasks)
+    COT_results = generate_COT(chunk_questions_zip,api_config)
     for chunk, q , cot in COT_results:
-        # The COT answer will be used in the fine-tuning stage
+        # The COT answer will be used as the label in the fine-tuning stage
         datapt = {
             "id": None,
             "type": "general",