Explorar o código

added experiment results to README

Kai Wu hai 10 meses
pai
achega
a65e56c67c

+ 0 - 38
recipes/finetuning/datasets/chatbot_dataset.py

@@ -1,38 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
-
-
-import copy
-import datasets
-from datasets import Dataset, load_dataset, DatasetDict
-import itertools
-
-
-B_INST, E_INST = "[INST]", "[/INST]"
-
-def tokenize_dialog(q_a_pair, tokenizer):
-    question, answer = q_a_pair["Question"], q_a_pair["Answer"]
-    prompt_tokens = tokenizer.encode(f"{tokenizer.bos_token}{B_INST} {(question).strip()} {E_INST}", add_special_tokens=False)
-    answer_tokens = tokenizer.encode(f"{answer.strip()} {tokenizer.eos_token}", add_special_tokens=False)
-    sample = {
-            "input_ids": prompt_tokens + answer_tokens,
-            "attention_mask" : [1] * (len(prompt_tokens) + len(answer_tokens)),
-            "labels": [-100] * len(prompt_tokens) + answer_tokens,
-            }
-
-    return sample
-
-
-def get_custom_dataset(dataset_config, tokenizer, split, split_ratio=0.8):
-    dataset_dict = load_dataset('json', data_files=dataset_config.data_path)
-    dataset = dataset_dict['train']
-    dataset = dataset.train_test_split(test_size=1-split_ratio, shuffle=True, seed=42)
-
-    dataset = dataset[split].map(lambda sample: {
-        "Question": sample["Question"],
-        "Answer": sample["Answer"],
-        },
-        batched=True,
-    )
-    dataset = dataset.map(lambda x: tokenize_dialog(x, tokenizer))
-    return dataset

+ 9 - 4
recipes/finetuning/datasets/raft_dataset.py

@@ -50,12 +50,17 @@ def tokenize_dialog(dialog, tokenizer):
 
     return dict(combined_tokens, attention_mask=[1]*len(combined_tokens["input_ids"]))
 def raft_tokenize(q_a_pair, tokenizer):
-    end_tag = "<\/DOCUMENT>\n"
+    end_tag = "</DOCUMENT>"
     # find the last end_tag in the instruction, the rest is the question
-    index =q_a_pair["instruction"].rindex("<\/DOCUMENT>\n")+len(end_tag)
-    question = q_a_pair["instruction"][index:]
+    try:
+        index =q_a_pair["instruction"].rindex(end_tag)+len(end_tag)
+    except ValueError:
+        print(q_a_pair["instruction"])
+        raise Exception("The instruction does not contain the end tag <\/DOCUMENT>")
+    # all the lines after end_tag are the question
+    question = q_a_pair["instruction"][index:].strip()
     # all the lines before end_tag are the context
-    documents = q_a_pair["instruction"][:index]
+    documents = q_a_pair["instruction"][:index].strip() 
     # output is the label
     answer = q_a_pair["output"]
     system_prompt = "You are a helpful chatbot who can provide an answer to every questions from the user given a relevant context."

A diferenza do arquivo foi suprimida porque é demasiado grande
+ 136 - 65
recipes/use_cases/end2end-recipes/raft/README.md


A diferenza do arquivo foi suprimida porque é demasiado grande
+ 207 - 0
recipes/use_cases/end2end-recipes/raft/chatbot.md


A diferenza do arquivo foi suprimida porque é demasiado grande
+ 0 - 103
recipes/use_cases/end2end-recipes/raft/data/llama_website0613


+ 0 - 164
recipes/use_cases/end2end-recipes/raft/data_urls.xml

@@ -1,164 +0,0 @@
-<urlset xmlns="http://www.sitemaps.org/schemas/sitemap/0.9">
-<url>
-<loc>http://llama.meta.com/</loc>
-</url>
-<url>
-<loc>http://llama.meta.com/use-policy/</loc>
-</url>
-<url>
-<loc>http://llama.meta.com/responsible-use-guide/</loc>
-</url>
-<url>
-<loc>http://llama.meta.com/llama2/</loc>
-</url>
-<url>
-<loc>http://llama.meta.com/llama2/license/</loc>
-</url>
-<url>
-<loc>http://llama.meta.com/llama2/use-policy/</loc>
-</url>
-<url>
-<loc>http://llama.meta.com/license/</loc>
-</url>
-<url>
-<loc>http://llama.meta.com/code-llama/</loc>
-</url>
-<url>
-<loc>http://llama.meta.com/llama3/</loc>
-</url>
-<url>
-<loc>http://llama.meta.com/llama3/license/</loc>
-</url>
-<url>
-<loc>http://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3</loc>
-</url>
-<url>
-<loc>http://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-guard-2</loc>
-</url>
-<url>
-<loc>http://llama.meta.com/docs/model-cards-and-prompt-formats/meta-code-llama-70b</loc>
-</url>
-<url>
-<loc>http://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-guard-1</loc>
-</url>
-<url>
-<loc>http://llama.meta.com/docs/model-cards-and-prompt-formats/meta-code-llama</loc>
-</url>
-<url>
-<loc>http://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-2</loc>
-</url>
-<url>
-<loc>http://llama.meta.com/docs/getting_the_models</loc>
-</url>
-<url>
-<loc>http://llama.meta.com/docs/getting-the-models/hugging-face</loc>
-</url>
-<url>
-<loc>http://llama.meta.com/docs/getting-the-models/kaggle</loc>
-</url>
-<url>
-<loc>http://llama.meta.com/docs/llama-everywhere</loc>
-</url>
-<url>
-<loc>http://llama.meta.com/docs/llama-everywhere/running-meta-llama-on-linux/</loc>
-</url>
-<url>
-<loc>http://llama.meta.com/docs/llama-everywhere/running-meta-llama-on-windows/</loc>
-</url>
-<url>
-<loc>http://llama.meta.com/docs/llama-everywhere/running-meta-llama-on-mac/</loc>
-</url>
-<url>
-<loc>http://llama.meta.com/docs/llama-everywhere/running-meta-llama-in-the-cloud/</loc>
-</url>
-<url>
-<loc>http://llama.meta.com/docs/how-to-guides/fine-tuning</loc>
-</url>
-<url>
-<loc>http://llama.meta.com/docs/how-to-guides/quantization</loc>
-</url>
-<url>
-<loc>http://llama.meta.com/docs/how-to-guides/prompting</loc>
-</url>
-<url>
-<loc>http://llama.meta.com/docs/how-to-guides/validation</loc>
-</url>
-<url>
-<loc>http://llama.meta.com/docs/integration-guides/meta-code-llama</loc>
-</url>
-<url>
-<loc>http://llama.meta.com/docs/integration-guides/langchain</loc>
-</url>
-<url>
-<loc>http://llama.meta.com/docs/integration-guides/llamaindex</loc>
-</url>
-<url>
-<loc>http://raw.githubusercontent.com/meta-llama/llama-recipes/main/README.md</loc>
-</url>
-<url>
-<loc>http://raw.githubusercontent.com/meta-llama/llama/main/MODEL_CARD.md</loc>
-</url>
-<url>
-<loc>http://raw.githubusercontent.com/meta-llama/llama/main/README.md</loc>
-</url>
-<url>
-<loc>http://raw.githubusercontent.com/meta-llama/llama3/main/MODEL_CARD.md</loc>
-</url>
-<url>
-<loc>http://raw.githubusercontent.com/meta-llama/llama3/main/README.md</loc>
-</url>
-<url>
-<loc>http://raw.githubusercontent.com/meta-llama/codellama/main/MODEL_CARD.md</loc>
-</url>
-<url>
-<loc>http://raw.githubusercontent.com/meta-llama/codellama/main/README.md</loc>
-</url>
-<url>
-<loc>http://raw.githubusercontent.com/meta-llama/PurpleLlama/main/README.md</loc>
-</url>
-<url>
-<loc>http://raw.githubusercontent.com/meta-llama/PurpleLlama/main/Llama-Guard2/MODEL_CARD.md</loc>
-</url>
-<url>
-<loc>http://raw.githubusercontent.com/meta-llama/PurpleLlama/main/Llama-Guard2/README.md</loc>
-</url>
-<url>
-<loc>http://raw.githubusercontent.com/meta-llama/PurpleLlama/main/Llama-Guard/MODEL_CARD.md</loc>
-</url>
-<url>
-<loc>https://hamel.dev/notes/llm/inference/03_inference.html</loc>
-</url>
-<url>
-<loc>https://www.anyscale.com/blog/continuous-batching-llm-inference</loc>
-</url>
-<url>
-<loc>https://github.com/huggingface/peft</loc>
-</url><url>
-<loc>https://github.com/facebookresearch/llama-recipes/blob/main/docs/LLM_finetuning.md</loc>
-</url>
-<url>
-<loc>https://github.com/meta-llama/llama-recipes/blob/main/recipes/finetuning/datasets/README.md</loc>
-</url><url>
-<loc>https://www.databricks.com/blog/efficient-fine-tuning-lora-guide-llms</loc>
-</url>
-<url>
-<loc>https://www.wandb.courses/courses/training-fine-tuning-LLMs</loc>
-</url>
-<url>
-<loc>https://www.snowflake.com/blog/meta-code-llama-testing/</loc>
-</url><url>
-<loc>https://www.phind.com/blog/code-llama-beats-gpt4</loc>
-</url>
-<loc>https://www.anyscale.com/blog/llama-2-is-about-as-factually-accurate-as-gpt-4-for-summaries-and-is-30x-cheaper</loc>
-</url>
-<url>
-<loc>https://ragntune.com/blog/gpt3.5-vs-llama2-finetuning</loc>
-</url><url>
-<loc>https://deci.ai/blog/fine-tune-llama-2-with-lora-for-question-answering/</loc>
-</url>
-<url>
-<loc>https://replicate.com/blog/fine-tune-translation-model-axolotl</loc>
-</url><url>
-<loc>https://huyenchip.com/2023/04/11/llm-engineering.html</loc>
-</url>
-</urlset>

BIN=BIN
recipes/use_cases/end2end-recipes/raft/images/Answers_Precision.png


BIN=BIN
recipes/use_cases/end2end-recipes/raft/images/LLM_score_comparison.png


BIN=BIN
recipes/use_cases/end2end-recipes/raft/images/Num_of_refusal_comparison.png


BIN=BIN
recipes/use_cases/end2end-recipes/raft/images/RAFT.png


+ 3 - 7
recipes/use_cases/end2end-recipes/raft/raft.py

@@ -16,12 +16,8 @@ def main(api_config):
         if not chunk_questions_zip:
             logging.warning("No questions generated from text. Please check the api_config or model configuration.")
             return
-        for chunk, questions in chunk_questions_zip:
-            logging.info(f"Chunk: {chunk}, question length: {len(questions)}")
-            for question in questions:
-                logging.info(f"Question: {question}")
         logging.info(f"Successfully generated {sum([len(q) for c,q in chunk_questions_zip])} question/answer pairs.")
-        ds = add_chunk_to_dataset(chunk_questions_zip,api_config,ds)
+        ds = add_chunk_to_dataset(chunk_questions_zip,api_config)
         ds.save_to_disk(args.output)
         logging.info(f"Data successfully written to {api_config['output']}. Process completed.")
         formatter = DatasetConverter()
@@ -40,7 +36,7 @@ def parse_arguments():
     parser.add_argument(
         "-t", "--questions_per_chunk",
         type=int,
-        default=3,
+        default=4,
         help="Specify the number of question pairs to generate per chunk."
     )
     parser.add_argument(
@@ -87,7 +83,7 @@ if __name__ == "__main__":
         api_config["api_key"] = os.environ["API_KEY"]
     logging.info(f"Configuration loaded. Generating {args.questions_per_chunk} question per chunk using model '{args.model}'.")
     logging.info(f"Chunk size: {args.chunk_size}.")
-    logging.info(f"num_distract_docs: {api_config['num_distract_docs']}, oracle_p: {api_config['oracle_p']}")
+    logging.info(f"num_distract_docs: {api_config['num_distract_docs']}, refusal_probability: {api_config['refusal_probability']}")
     logging.info(f"Will use endpoint_url: {args.endpoint_url}.")
     logging.info(f"Output will be written to {args.output}.")
     main(api_config)

+ 2 - 2
recipes/use_cases/end2end-recipes/raft/raft.yaml

@@ -46,6 +46,6 @@ chunk_size: 1000
 
 questions_per_chunk: 5
 
-num_distract_docs: 5 # number of distracting documents to add to each chunk
+num_distract_docs: 4 # number of distracting documents to add to each chunk
 
-oracle_p: 0.8 # probability of related documents to be added to each chunk
+refusal_probability: 0.05 # probability of related documents to be added to each chunk

+ 24 - 41
recipes/use_cases/end2end-recipes/raft/raft_eval.py

@@ -16,7 +16,6 @@ import re
 import string
 import pandas as pd 
 from langchain.retrievers.document_compressors import FlashrankRerank
-from transformers import AutoTokenizer
 
 
 def generate_answers_model_only(model_name,question_list,api_url="http://localhost:8000/v1",key="EMPTY"):
@@ -48,15 +47,7 @@ def build_retriever(api_config,embedding_model_name,retrieved_docs_num=5):
     loader = DirectoryLoader(api_config['data_dir'])
     docs = loader.load()
     # Split the document into chunks with a specified chunk size
-    text_splitter = RecursiveCharacterTextSplitter(chunk_size=api_config["chunk_size"],chunk_overlap=int(api_config["chunk_size"] / 10),add_start_index=True,strip_whitespace=True)
-    # text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
-    #     AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B"),
-    #     chunk_size=api_config["chunk_size"],
-    #     chunk_overlap=int(api_config["chunk_size"] / 10),
-    #     add_start_index=True,
-    #     strip_whitespace=True,
-    #     separators=["\n\n", "\n", ".", " ", ""],
-    # )
+    text_splitter = RecursiveCharacterTextSplitter(chunk_size=api_config["chunk_size"],chunk_overlap=int(api_config["chunk_size"] / 10),separators= ["----------","\n\n", "\n", " ", ""],strip_whitespace=True)
     docs_processed = text_splitter.split_documents(docs)
     # Remove duplicates
     unique_texts = {}
@@ -65,7 +56,7 @@ def build_retriever(api_config,embedding_model_name,retrieved_docs_num=5):
         if doc.page_content not in unique_texts:
             unique_texts[doc.page_content] = True
             docs_processed_unique.append(doc)
-
+    logging.info(f"Total number of docs_processed used by vectorstore: {len(docs_processed_unique)}")
     # Store the document into a vector store with a specific embedding model
     embedding_model = HuggingFaceEmbeddings(
         model_name=embedding_model_name,
@@ -78,7 +69,7 @@ def build_retriever(api_config,embedding_model_name,retrieved_docs_num=5):
     )
     return retriever
 def generate_answers_with_RAG(model_name, question_list,api_config,retriever,api_url_overwrite=None):
-    api_url = "http://localhost:"+str(api_config['vllm_endpoint'])+"/v1"
+    api_url = api_config['model_endpoint_url']
     if api_url_overwrite:
         api_url = api_url_overwrite
     key = api_config['api_key']
@@ -206,8 +197,8 @@ def score_single(api_config,generated,reference,questions, run_exact_match=True,
         metric["BERTScore_Precision"] = P
         metric["BERTScore_Recall"] = R
         metric["BERTScore_F1"] = F1
-    if api_config["judge_endpoint"] and run_llm_as_judge:
-        api_url = "http://localhost:"+str(api_config["judge_endpoint"])+"/v1"
+    if api_config["judge_endpoint_url"] and run_llm_as_judge:
+        api_url = api_config["judge_endpoint_url"]
         LLM_judge_score,judge_responses = compute_judge_score(questions, generated, reference, api_config,api_url=api_url)
         metric["LLM_judge_score"] = LLM_judge_score
         metric["LLM_judge_responses"] = judge_responses
@@ -220,7 +211,7 @@ def score_single(api_config,generated,reference,questions, run_exact_match=True,
 def main(api_config):
     # Since the eval set is small, we can run the eval without async functions
     try:
-        api_url = "http://localhost:"+str(api_config["vllm_endpoint"])+"/v1"
+        api_url = api_config["model_endpoint_url"]
         logging.info("Starting to generate answer given the eval set.")
         questions,groud_truth = [],[]
         if api_config["eval_file"].endswith(".parquet"):
@@ -234,15 +225,7 @@ def main(api_config):
                 for index, item in enumerate(eval_file):
                     questions.append(item["question"])
                     groud_truth.append(item["answer"])
-        generated_answers = {
-            "RAFT": [],
-            "RAFT_RAG": [],
-            "Baseline": [],
-            "Baseline_RAG": [],
-            "70B_RAG": [],
-            "70B_Base": [],
-            
-        }
+        generated_answers = {}            
         # build retriver
         retriever = build_retriever(api_config,"sentence-transformers/multi-qa-mpnet-base-cos-v1",api_config["rag_topk"])
         # Generate answers for 8B models
@@ -251,11 +234,11 @@ def main(api_config):
         generated_answers[model_name+"_RAG"] = generate_answers_with_RAG(model_name, questions,api_config,retriever)
         print("Finished generating answers for ", model_name)
         large_model_name = "meta-llama/Meta-Llama-3-70B-Instruct"
-        large_api_url = "http://localhost:"+str(api_config["judge_endpoint"])+"/v1"
-        generated_answers["70B_Base"] = generate_answers_model_only(large_model_name,questions,large_api_url)
-        generated_answers["70B_RAG"] = generate_answers_with_RAG(large_model_name, questions,api_config,retriever,large_api_url)
+        large_api_url = api_config["judge_endpoint_url"]
+        #generated_answers["70B_Base"] = generate_answers_model_only(large_model_name,questions,large_api_url)
+        #generated_answers["70B_RAG"] = generate_answers_with_RAG(large_model_name, questions,api_config,retriever,large_api_url)
         print("Finished generating answers for ", large_model_name)
-        logging.info(f"Successfully generated {len(generated_answers[model_name])} answers for all models.")
+        logging.info(f"Successfully generated {len(generated_answers[model_name+'_RAG'])} answers for all models.")
         # for generate answer from each model, compute the score metric
         all_metrics = []
         output_file = api_config["output_log"]+str(datetime.now().strftime("%Y%m%d_%H%M%S"))
@@ -272,7 +255,7 @@ def main(api_config):
                 fp.write(f"BERTScore Precision: {metric['BERTScore_Precision']:.4f}, Recall: {metric['BERTScore_Recall']:.4f}, F1: {metric['BERTScore_F1']:.4f} \n")
                 fp.write(f"Exact_match_percentage: {metric['Exact_match']} \n")
                 judge_responses = ["None"] * len(questions)
-                if api_config["judge_endpoint"]:
+                if api_config["judge_endpoint_url"]:
                     fp.write(f"LLM_judge_score: {metric['LLM_judge_score']} \n")
                     judge_responses = metric["LLM_judge_responses"]
                     all_metrics.append((model_name,metric['LLM_judge_score'],metric["LLM_judge_responses"]))
@@ -330,16 +313,16 @@ def parse_arguments():
         help="Provide the data folder path to build RAG for evaluation. If not specified, the data_dir in eval_config.yaml will be used."
     )
     parser.add_argument(
-        "-v", "--vllm_endpoint",
-        default=8000,
-        type=int,
-        help="If a port is specified, then use local vllm endpoint for eval."
+        "-u", "--model_endpoint_url",
+        default="http://localhost:8000/v1",
+        type=str,
+        help="The raft model endpoint url for eval."
     )
     parser.add_argument(
-        "-j", "--judge_endpoint",
+        "-j", "--judge_endpoint_url",
         default=None,
-        type=int,
-        help="If a port is specified, then use local vllm endpoint as judge LLM."
+        type=str,
+        help="The large model endpoint url for judge as LLM."
     )
     parser.add_argument(
         "-o", "--output_log",
@@ -371,12 +354,12 @@ if __name__ == "__main__":
     logging.info("Initializing the process and loading configuration...")
     args = parse_arguments()
     api_config = load_config(args.config_path)
-    api_config["vllm_endpoint"] = args.vllm_endpoint
+    api_config["model_endpoint_url"] = args.model_endpoint_url
     if args.data_dir:
         api_config["data_dir"] = args.data_dir
-    if args.raft_model_name:
+    if args.model_name:
         api_config["model_name"] = args.model_name
-    api_config["judge_endpoint"] = args.judge_endpoint
+    api_config["judge_endpoint_url"] = args.judge_endpoint_url
     api_config["output_log"] = args.output_log
     api_config["api_key"] = args.api_key
     api_config["chunk_size"] = args.chunk_size
@@ -384,6 +367,6 @@ if __name__ == "__main__":
     api_config["rerank_topk"] = args.rerank_topk
     if api_config["rag_topk"] < api_config["rerank_topk"]:
         logging.error("The rerank_topk should be smaller than rag_topk.")
-    if api_config["judge_endpoint"]:
-        logging.info(f"Use local vllm service for judge at port: '{args.judge_endpoint}'.")
+    if api_config["judge_endpoint_url"]:
+        logging.info(f"The judge model url is: '{args.judge_endpoint_url}'.")
     main(api_config)

+ 52 - 54
recipes/use_cases/end2end-recipes/raft/raft_utils.py

@@ -4,14 +4,12 @@
 import os
 import logging
 from langchain.text_splitter import RecursiveCharacterTextSplitter
-from math import ceil
 from datasets import Dataset
 import random
 from langchain_community.document_loaders import SitemapLoader,DirectoryLoader
 from bs4 import BeautifulSoup
-import copy
 from langchain_openai import ChatOpenAI
-
+import copy
 
 
 # Initialize logging
@@ -32,17 +30,10 @@ def strip_str(s: str) -> str:
     r += 2
     return s[l:min(r, len(s))]
 def clean_documents(raw_text):
-    unwanted= ["Technology",
-    "Getting Started",
-    "Trust & Safety",
-    "Community",
-    "Resources",
-    "Skip to main content",
-    "How-to guides"]
     all_lines = []
     for line in raw_text.split("\n"):
         line = line.strip()
-        if line in unwanted or len(line.split()) == 0:
+        if len(line.split()) == 0:
             continue
         else:
             all_lines.append(line)
@@ -73,7 +64,7 @@ def read_file_content(xml_path: str, data_folder: str) -> str:
         sitemap_loader = SitemapLoader(web_path=xml_path,is_local=True,parsing_function=clean_text)
         sitemap_loader.requests_kwargs = {"verify": False}
         docs = sitemap_loader.load()
-        return "\n".join([doc.page_content for doc in docs])
+        return docs
     elif len(data_folder) != 0:
         if not os.path.exists(data_folder):
             logging.info(f"Error: {data_folder} does not exist")
@@ -81,30 +72,35 @@ def read_file_content(xml_path: str, data_folder: str) -> str:
         # Use langchain to load the documents from data folder
         loader = DirectoryLoader(data_folder)
         docs = loader.load()
-        text = "\n".join([clean_documents(doc.page_content) for doc in docs])
-        return text
+        return docs
 
 
 
 def get_chunks(
-    text: str,
-    chunk_size: int = 512,
+    docs: list,
+    chunk_size: int = 1000,
     api_config: dict = None,
 ) -> list[str]:
     """
-    Takes in a `file_path` and `doctype`, retrieves the document, breaks it down into chunks of size
+    Takes in a list of documents, breaks them down into chunks of size
     `chunk_size`, and returns the chunks.
     """
     chunks = []
-    if  len(text) == 0:
+    if  len(docs) == 0:
         raise TypeError("Can not get chunks from empty text")
     else:
-        num_chunks = ceil(len(text) / chunk_size)
-        logging.info(f"Splitting text into {num_chunks} chunks")
-        text_splitter = RecursiveCharacterTextSplitter(chunk_size=api_config["chunk_size"], chunk_overlap=int(api_config["chunk_size"]/10))
-        chunks = text_splitter.create_documents([text])
-        chunks = [chunk.page_content for chunk in chunks]
-
+        text_splitter = RecursiveCharacterTextSplitter(chunk_size=api_config["chunk_size"],chunk_overlap=int(api_config["chunk_size"] / 10),separators= ["----------","\n\n", "\n", " "],strip_whitespace=True)
+        docs_processed = text_splitter.split_documents(docs)
+        logging.info(f"Total number of docs_processed: {len(docs_processed)}")
+        # Remove duplicates
+        unique_texts = {}
+        docs_processed_unique = []
+        for doc in docs_processed:
+            if doc.page_content not in unique_texts and len(doc.page_content) > 100 :
+                unique_texts[doc.page_content] = True
+                docs_processed_unique.append(doc)        
+        chunks = [chunk.page_content for chunk in docs_processed_unique]
+        logging.info(f"Total number of docs_processed_unique: {len(docs_processed_unique)}")
     return chunks
 # read all the files in the data folder, then split them into chunks
 # generate questions for each chunk and return zip of chunk and related questions list
@@ -112,10 +108,10 @@ def generate_questions(api_config):
     # get documents from the data folder or xml file
     api_url = api_config["endpoint_url"]
     key = api_config["api_key"]
-    document_text = read_file_content(api_config["xml_path"],api_config["data_dir"])
-    if len(document_text) == 0:
-        logging.info(f"Error reading files, document_text is {len(document_text)}")
-    document_batches = get_chunks(document_text,api_config["chunk_size"],api_config)
+    documents = read_file_content(api_config["xml_path"],api_config["data_dir"])
+    if len(documents) == 0:
+        logging.info(f"Error reading files, document_text is {len(documents)}")
+    document_batches = get_chunks(documents,api_config["chunk_size"],api_config)
     # use OpenAI API protocol to hanlde the chat request, including local VLLM openai compatible server
     llm = ChatOpenAI(
         openai_api_key=key,
@@ -146,11 +142,16 @@ def generate_questions(api_config):
 def generate_COT(chunk_questions_zip,api_config) -> dict:
     all_tasks = []
     chunk_questions = []
+    question_asked = set()
     for document_content,questions in chunk_questions_zip:
         for question in questions:
-            prompt = api_config['COT_prompt_template'].format(question=question,context=str(document_content))
-            all_tasks.append(prompt)
-            chunk_questions.append((document_content,question))
+            question = question.strip()
+            # avoid asking the same question twice
+            if question not in question_asked:
+                question_asked.add(question)
+                prompt = api_config['COT_prompt_template'].format(question=question,context=str(document_content))
+                all_tasks.append(prompt)
+                chunk_questions.append((document_content,question))
     # use OpenAI API protocol to hanlde the chat request, including local VLLM openai compatible server
     llm = ChatOpenAI(
         openai_api_key=api_config["api_key"],
@@ -170,17 +171,20 @@ def generate_COT(chunk_questions_zip,api_config) -> dict:
 def add_chunk_to_dataset(
     chunk_questions_zip: list,
     api_config: dict,
-    ds,
 ) -> None:
     """
     Given a chunk and related questions lists, create {Q, A, D} triplets and add them to the dataset.
     """
     num_distract = api_config["num_distract_docs"]
-    p = api_config["oracle_p"]
+    p = api_config["refusal_probability"]
     chunks = [chunk for chunk, _ in chunk_questions_zip]
     COT_results = generate_COT(chunk_questions_zip,api_config)
+    logging.info(f"COT generation completed, total num of COT results: {len(COT_results)}")
+    completed,refusal= 0,0
+    data_list = []
     for chunk, q , cot in COT_results:
         # The COT answer will be used as the label in the fine-tuning stage
+
         datapt = {
             "id": None,
             "type": "general",
@@ -190,8 +194,7 @@ def add_chunk_to_dataset(
             "cot_answer": cot
         }
         i = chunks.index(chunk)
-        datapt["id"] = f"seed_task_{0 if not ds else ds.num_rows}"
-
+        datapt["id"] = f"seed_task_{len(data_list)}"
         # add num_distract distractor docs
         docs = [chunk]
         indices = list(range(0, len(chunks)))
@@ -219,29 +222,24 @@ def add_chunk_to_dataset(
         datapt["instruction"] = context
         datapt_copy = copy.deepcopy(datapt)
         # add to dataset
-        if not ds:
-            # init ds
-            datapt["id"] = [datapt["id"]]
-            datapt["type"] = [datapt["type"]]
-            datapt["question"] = [datapt["question"]]
-            datapt["context"] = [datapt["context"]]
-            datapt["oracle_context"] = [datapt["oracle_context"]]
-            datapt["cot_answer"] = [datapt["cot_answer"]]
-            datapt["instruction"] = [datapt["instruction"]]
-            ds = Dataset.from_dict(datapt)
-        else:
-            ds = ds.add_item(datapt)
+        data_list.append(datapt)
         # decides whether to add refusal example where the related documents are not provided
-        oracle = random.uniform(0, 1) < p
-        if not oracle:
+        refusal = random.uniform(0, 1) <= p
+        if refusal:
             doc_copy[0] = chunks[random.sample(indices, 1)[0]]
             random.shuffle(doc_copy)
-            context = ""
+            refusl_context = ""
             for doc in doc_copy:
-                context += "<DOCUMENT>" + str(doc) + "</DOCUMENT>\n"
-            context += q
+                refusl_context += "<DOCUMENT>" + str(doc) + "</DOCUMENT>\n"
+            refusl_context += q
             # This instruction will be used in the fine-tuning stage
-            datapt_copy["instruction"] = context
+            datapt_copy["id"] = f"refusal_task_{len(data_list)}"
+            datapt_copy["instruction"] = refusl_context
             datapt_copy["cot_answer"] = "Sorry, I don't know the answer to this question because related documents are not found. Please try again."
-            ds.add_item(datapt_copy)
+            data_list.append(datapt_copy)
+            refusal += 1
+        completed += 1
+        if completed % 100 == 0:
+            logging.info(f"refusal example added: {refusal}, total examples added: {completed}, total examples to be added: {len(COT_results)- completed}")
+    ds = Dataset.from_list(data_list)
     return ds