Parcourir la source

Merge branch 'main' into readme-fix

Suraj Subramanian il y a 8 mois
Parent
commit
f7bdf4da51

+ 13 - 0
.github/scripts/spellcheck_conf/wordlist.txt

@@ -1400,6 +1400,19 @@ sqlite
 customerservice
 fn
 ExecuTorch
+LLMScore
+RecursiveCharacterTextSplitter
+TPD
+TPM
+Tianjun
+Zhang
+distractor
+distractors
+frac
+numRefusal
+totalQA
+DirectoryLoader
+SitemapLoader
 nf
 quant
 DLAI

+ 97 - 0
recipes/quickstart/finetuning/datasets/raft_dataset.py

@@ -0,0 +1,97 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
+
+
+import copy
+from datasets import load_dataset
+import itertools
+
+# check system prompt token seq or user prompt token seq is in the current token list
+def check_header(targets,seq):
+    for i in range(len(seq)-3):
+        if seq[i:i+3] in targets:
+            return True
+    return False
+def replace_target(target,seq):
+    for i in range(len(seq)-3):
+        if seq[i:i+3] == target:
+            seq[i],seq[i+1],seq[i+2] = -100,-100,-100
+    return seq
+def tokenize_dialog(dialog, tokenizer):
+    # If vocab size is above 128000, use the chat template to generate the tokens as it is from Llama 3 family models
+    if tokenizer.vocab_size >= 128000:
+        dialog_tokens = tokenizer.apply_chat_template(dialog)
+        eot_indices = [i for i,n in enumerate(dialog_tokens) if n == 128009]
+        labels = copy.copy(dialog_tokens)
+        last_idx = 0
+        # system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
+        # user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
+        prompt_header_seqs = [[128006, 9125, 128007],[128006, 882, 128007]]
+        for n, idx in enumerate(eot_indices):
+            current_seq = labels[last_idx:idx+1]
+            if check_header(prompt_header_seqs,current_seq):
+                # found prompt header, indicating that this seq should be masked
+                labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
+            else:
+                last_idx = idx
+        # Lastly mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
+        assistant_header_seq = [128006, 78191, 128007]
+        labels = replace_target(assistant_header_seq,labels)
+        dialog_tokens = [dialog_tokens]
+        labels_tokens = [labels]
+    else:
+        raise Exception("This raft_dataset only supports Llama 3 family models, please make sure the tokenizer is from Llama 3 family models.")
+
+    combined_tokens = {
+        "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
+        "labels": list(itertools.chain(*(t for t in labels_tokens))),
+    }
+
+    return dict(combined_tokens, attention_mask=[1]*len(combined_tokens["input_ids"]))
+def raft_tokenize(q_a_pair, tokenizer):
+    end_tag = "</DOCUMENT>"
+    # find the last end_tag in the instruction, the rest is the question
+    try:
+        index =q_a_pair["instruction"].rindex(end_tag)+len(end_tag)
+    except ValueError:
+        print(q_a_pair["instruction"])
+        raise Exception("The instruction does not contain the end tag <\/DOCUMENT>")
+    # all the lines after end_tag are the question
+    question = q_a_pair["instruction"][index:].strip()
+    # all the lines before end_tag are the context
+    documents = q_a_pair["instruction"][:index].strip() 
+    # output is the label
+    answer = q_a_pair["output"]
+    system_prompt = "You are a helpful chatbot who can provide an answer to every questions from the user given a relevant context."
+    user_prompt = """
+        Question: {question}\nContext: {context}\n
+        Answer this question using the information given by multiple documents in the context above. Here are the things to pay attention to:
+        - The context contains many documents, each document starts with <DOCUMENT> and ends </DOCUMENT>.
+        - First provide step-by-step reasoning on how to answer the question.
+        - In the reasoning, if you need to copy paste some sentences from the context, include them in ##begin_quote## and ##end_quote##. This would mean that things outside of ##begin_quote## and ##end_quote## are not directly copy paste from the context.
+        - End your response with final answer in the form <ANSWER>: $answer, the answer should less than 60 words.
+        You MUST begin your final answer with the tag "<ANSWER>:".
+    """.format(question=question, context=documents)
+
+    chat = [
+    {"role": "system", "content": system_prompt},
+    {"role": "user", "content": user_prompt},
+    {"role": "assistant", "content": answer}
+    ]
+    return tokenize_dialog(chat, tokenizer)
+
+
+def get_custom_dataset(dataset_config, tokenizer, split, split_ratio=0.9):
+    # load_dataset will return DatasetDict that contains all the data in the train set
+    dataset_dict = load_dataset('json', data_files=dataset_config.data_path)
+    dataset = dataset_dict['train']
+    dataset = dataset.train_test_split(test_size=1-split_ratio, shuffle=True, seed=42)
+
+    dataset = dataset[split].map(lambda sample: {
+        "instruction": sample["instruction"],
+        "output": sample["cot_answer"],
+        },
+        batched=True,
+    )
+    dataset = dataset.map(lambda x: raft_tokenize(x, tokenizer))
+    return dataset

Fichier diff supprimé car celui-ci est trop grand
+ 1 - 1
recipes/responsible_ai/prompt_guard/prompt_guard_tutorial.ipynb


Fichier diff supprimé car celui-ci est trop grand
+ 243 - 0
recipes/use_cases/end2end-recipes/RAFT-Chatbot/README.md


+ 10 - 0
recipes/use_cases/end2end-recipes/RAFT-Chatbot/config.py

@@ -0,0 +1,10 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import yaml
+
+def load_config(config_path: str = "./config.yaml"):
+    # Read the YAML configuration file
+    with open(config_path, "r") as file:
+        config = yaml.safe_load(file)
+    return config

Fichier diff supprimé car celui-ci est trop grand
+ 287 - 0
recipes/use_cases/end2end-recipes/RAFT-Chatbot/eval_llama.json


+ 174 - 0
recipes/use_cases/end2end-recipes/RAFT-Chatbot/format.py

@@ -0,0 +1,174 @@
+# file copied from https://github.com/ShishirPatil/gorilla/blob/main/raft/format.py
+from abc import ABC, abstractmethod
+import argparse
+from datasets import Dataset, load_dataset
+from typing import Dict, Literal, Any, get_args
+
+"""
+This file allows to convert raw HuggingFace Datasets into files suitable to fine tune completion and chat models.
+"""
+
+OutputDatasetType = Literal["parquet", "jsonl"]
+outputDatasetTypes = list(get_args(OutputDatasetType))
+
+InputDatasetType = Literal["arrow", "jsonl"]
+inputDatasetTypes = list(get_args(InputDatasetType))
+
+DatasetFormat = Literal["hf", "completion", "chat"]
+datasetFormats = list(get_args(DatasetFormat))
+
+def get_args() -> argparse.Namespace:
+    """
+    Parses and returns the arguments specified by the user's command
+    """
+    parser = argparse.ArgumentParser()
+
+    parser.add_argument("--input", type=str, required=True, help="Input HuggingFace dataset file")
+    parser.add_argument("--input-type", type=str, default="arrow", help="Format of the input dataset. Defaults to arrow.", choices=inputDatasetTypes)
+    parser.add_argument("--output", type=str, required=True, help="Output file")
+    parser.add_argument("--output-format", type=str, required=True, help="Format to convert the dataset to", choices=datasetFormats)
+    parser.add_argument("--output-type", type=str, default="jsonl", help="Type to export the dataset to. Defaults to jsonl.", choices=outputDatasetTypes)
+    parser.add_argument("--output-chat-system-prompt", type=str, help="The system prompt to use when the output format is chat")
+
+    args = parser.parse_args()
+    return args
+
+class DatasetFormatter(ABC):
+    """
+    Base class for dataset formatters. Formatters rename columns, remove and add 
+    columns to match the expected target format structure. HF, Chat or Completion models file formats.
+    https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset
+    """
+    @abstractmethod
+    def format(self, ds: Dataset, params: Dict[str, str]) -> Dataset:
+        pass
+
+class DatasetExporter(ABC):
+    """
+    Base class for dataset exporters. Exporters export dataset to different file types, JSONL, Parquet, ...
+    """
+    @abstractmethod
+    def export(self, ds: Dataset, output_path: str):
+        pass
+
+class DatasetConverter():
+    """
+    Entry point class. It resolves which DatasetFormatter and which DatasetExporter to use and runs them.
+    """
+    formats: Dict[DatasetFormat, DatasetFormatter]
+    exporters: Dict[OutputDatasetType, Any]
+
+    def __init__(self) -> None:
+        self.formats = {
+            "hf": HuggingFaceDatasetFormatter(),
+            "completion": OpenAiCompletionDatasetFormatter(),
+            "chat": OpenAiChatDatasetFormatter()
+        }
+        self.exporters = {
+            "parquet": ParquetDatasetExporter(),
+            "jsonl": JsonlDatasetExporter()
+        }
+
+    def convert(self, ds: Dataset, format: DatasetFormat, output_path: str, output_type: OutputDatasetType, params: Dict[str, str]):
+        if not format in self.formats:
+            raise Exception(f"Output Format {format} is not supported, pleased select one of {self.formats.keys()}")
+        
+        if not output_type in self.exporters:
+            raise Exception(f"Output Type {output_type} is not supported, pleased select one of {self.exporters.keys()}")
+
+        formatter = self.formats[format]
+        newds = formatter.format(ds, params)
+        exporter = self.exporters[output_type]
+        exporter.export(newds, output_path)
+
+class HuggingFaceDatasetFormatter(DatasetFormatter):
+    """
+    Returns the HuggingFace Dataset as is
+    """
+    def format(self, ds: Dataset, params: Dict[str, str]) -> Dataset:
+        return ds
+
+def _remove_all_columns_but(ds: Dataset, keep_columns) -> Dataset:
+    """
+    HF Dataset doesn't have a way to copy only specific columns of a Dataset so this help
+    removes all columns but the ones specified.
+    """
+    remove_columns = list(ds.column_names)
+    for keep in keep_columns:
+        remove_columns.remove(keep)
+    ds = ds.remove_columns(remove_columns)
+    return ds
+
+class OpenAiCompletionDatasetFormatter(DatasetFormatter):
+    """
+    Returns the Dataset in the OpenAI Completion Fine-tuning file format with two fields "prompt" and "completion".
+    https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset
+    """
+    def format(self, ds: Dataset, params: Dict[str, str]) -> Dataset:
+        newds = ds.rename_columns({'question': 'prompt', 'cot_answer': 'completion'})
+        return _remove_all_columns_but(newds, ['prompt', 'completion'])
+
+class OpenAiChatDatasetFormatter(OpenAiCompletionDatasetFormatter):
+    """
+    Returns the Dataset in the OpenAI Chat Fine-tuning file format with one field "messages".
+    https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset
+    """
+    def format(self, ds: Dataset, params: Dict[str, str]) -> Dataset:
+        newds = super().format(ds, params)
+
+        def format_messages(row):
+            messages = []
+            if 'system_prompt' in params:
+                system_prompt = params['system_prompt']
+                messages.append({ "role": "system", "content": system_prompt})
+            messages.extend([{ "role": "user", "content": row['prompt']}, { "role": "assistant", "content": row['completion']}])
+            chat_row = {"messages": messages}
+            return chat_row
+
+        newds = newds.map(format_messages)
+        return _remove_all_columns_but(newds, ['messages'])
+
+def append_extension(path: str, extension: str) -> str:
+    suffix = "." + extension
+    if not path.endswith(suffix):
+        path = path + suffix
+    return path
+
+
+class JsonlDatasetExporter(DatasetExporter):
+    """
+    Exports the Dataset to a JSONL file
+    """
+
+    def export(self, ds: Dataset, output_path: str):
+        ds.to_json(append_extension(output_path, "jsonl"))
+
+
+class ParquetDatasetExporter(DatasetExporter):
+    """
+    Exports the Dataset to a Parquet file
+    """
+
+    def export(self, ds: Dataset, output_path: str):
+        ds.to_parquet(append_extension(output_path, "parquet"))
+
+
+def main():
+    """
+    When raft.py is executed from the command line.
+    """
+    args = get_args()
+    ds = load_dataset(args.input_type, data_files={"train": args.input})['train']
+    formatter = DatasetConverter()
+
+    if args.output_chat_system_prompt and args.output_format != "chat":
+        raise Exception("Parameter --output-chat-system-prompt can only be used with --output-format chat")
+
+    format_params = {}
+    if args.output_chat_system_prompt:
+        format_params['system_prompt'] = args.output_chat_system_prompt
+
+    formatter.convert(ds=ds, format=args.output_format, output_path=args.output, output_type=args.output_type, params=format_params)
+
+if __name__ == "__main__":
+    main()

BIN
recipes/use_cases/end2end-recipes/RAFT-Chatbot/images/Answers_Precision.png


BIN
recipes/use_cases/end2end-recipes/RAFT-Chatbot/images/LLM_score_comparison.png


BIN
recipes/use_cases/end2end-recipes/RAFT-Chatbot/images/Num_of_refusal_comparison.png


BIN
recipes/use_cases/end2end-recipes/RAFT-Chatbot/images/RAFT.png


+ 89 - 0
recipes/use_cases/end2end-recipes/RAFT-Chatbot/raft.py

@@ -0,0 +1,89 @@
+import logging
+import os
+import argparse
+from raft_utils import generate_questions, add_chunk_to_dataset
+from format import DatasetConverter, datasetFormats, outputDatasetTypes
+from config import load_config
+
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
+
+def main(api_config):
+    ds = None
+    try:
+        logging.info("Starting to generate question pair.")
+        # Generate questions as list for each chunk
+        chunk_questions_zip = generate_questions(api_config)
+        if not chunk_questions_zip:
+            logging.warning("No questions generated from text. Please check the api_config or model configuration.")
+            return
+        logging.info(f"Successfully generated {sum([len(q) for c,q in chunk_questions_zip])} question/answer pairs.")
+        ds = add_chunk_to_dataset(chunk_questions_zip,api_config)
+        ds.save_to_disk(args.output)
+        logging.info(f"Data successfully written to {api_config['output']}. Process completed.")
+        formatter = DatasetConverter()
+
+        # Extract format specific params
+        format_params = {}
+        formatter.convert(ds=ds, format=args.output_format, output_path=args.output+"raft", output_type=args.output_type, params=format_params)
+    except Exception as e:
+        logging.error(f"An unexpected error occurred during the process: {e}",exc_info=True)
+
+def parse_arguments():
+    # Define command line arguments for the script
+    parser = argparse.ArgumentParser(
+        description="Generate RAFT question/answer/context pairs from documentation."
+    )
+    parser.add_argument(
+        "-t", "--questions_per_chunk",
+        type=int,
+        default=4,
+        help="Specify the number of question pairs to generate per chunk."
+    )
+    parser.add_argument(
+        "-m", "--model",
+        default="meta-llama/Meta-Llama-3-70B-Instruct",
+        help="Select the model to use for generation."
+    )
+    parser.add_argument(
+        "-c", "--config_path",
+        default="./raft.yaml",
+        help="Set the configuration file path that has system prompt along with language, dataset path and number of questions."
+    )
+    parser.add_argument(
+        "-u", "--endpoint_url",
+        default="http://localhost:8001/v1",
+        type=str,
+        help="LLM API url for generating question/answer pairs."
+    )
+    parser.add_argument(
+        "-k", "--api_key",
+        default="EMPTY",
+        type=str,
+        help="LLM API key for generating question/answer pairs."
+    )
+    parser.add_argument("--chunk_size", type=int, default=1000, help="The size of each chunk in number of tokens")
+    parser.add_argument("-o","--output", type=str, default="./output/", help="The path at which to save the dataset")
+    parser.add_argument("--output-format", type=str, default="hf", help="Format to convert the dataset to. Defaults to hf.", choices=datasetFormats)
+    parser.add_argument("--output-type", type=str, default="jsonl", help="Type to export the dataset to. Defaults to jsonl.", choices=outputDatasetTypes)
+    return parser.parse_args()
+
+if __name__ == "__main__":
+    logging.info("Initializing the process and loading configuration...")
+    args = parse_arguments()
+
+    api_config = load_config(args.config_path)
+    api_config["questions_per_chunk"] = args.questions_per_chunk
+    api_config["model"] = args.model
+    api_config["chunk_size"] = args.chunk_size
+    api_config["endpoint_url"] = args.endpoint_url
+    api_config["output"] = args.output
+    api_config["api_key"] = args.api_key
+    # if OPENAI_API_KEY is defined in the system environment, use it as the API key
+    if os.environ.get('API_KEY') is not None:
+        api_config["api_key"] = os.environ["API_KEY"]
+    logging.info(f"Configuration loaded. Generating {args.questions_per_chunk} question per chunk using model '{args.model}'.")
+    logging.info(f"Chunk size: {args.chunk_size}.")
+    logging.info(f"num_distract_docs: {api_config['num_distract_docs']}, refusal_probability: {api_config['refusal_probability']}")
+    logging.info(f"Will use endpoint_url: {args.endpoint_url}.")
+    logging.info(f"Output will be written to {args.output}.")
+    main(api_config)

+ 51 - 0
recipes/use_cases/end2end-recipes/RAFT-Chatbot/raft.yaml

@@ -0,0 +1,51 @@
+COT_prompt_template: >
+  <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a helpful chatbot who can provide an answer to every questions from the user given a relevant context.<|eot_id|>
+  <|start_header_id|>user<|end_header_id|>
+  Question: {question}\nContext: {context}\n
+  Answer this question using the information given by multiple documents in the context above. Here are the things to pay attention to:
+  - The context contains many documents, each document starts with <DOCUMENT> and ends </DOCUMENT>.
+  - First provide step-by-step reasoning on how to answer the question.
+  - In the reasoning, if you need to copy paste some sentences from the context, include them in ##begin_quote## and ##end_quote##. This would mean that things outside of ##begin_quote## and ##end_quote## are not directly copy paste from the context.
+  - End your response with final answer in the form <ANSWER>: $answer, the answer should less than 60 words.
+  You MUST begin your final answer with the tag "<ANSWER>:". <|eot_id|><|start_header_id|>assistant<|end_header_id|>
+
+question_prompt_template: >
+  <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a synthetic question-answer pair generator. Given a chunk of context about
+  some topic(s), generate {num_questions} example questions a user could ask and would be answered
+  using information from the chunk. For example, if the given context was a Wikipedia
+  paragraph about the United States, an example question could be 'How many states are
+  in the United States?
+  Your questions should be formulated in the same style as questions that users could ask in a search engine.
+  This means that your questions MUST NOT mention something like "according to the passage" or "context".
+  The questions should be able to be answered in 60 words or less. Include only the questions in your response.<|eot_id|>
+  <|start_header_id|>user<|end_header_id|>
+  Context: {context}\n <|eot_id|><|start_header_id|>assistant<|end_header_id|>
+
+# question_prompt_template: >
+#   <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a language model skilled in creating quiz questions.
+#   You will be provided with a document,
+#   read it and please generate factoid question and answer pairs that are most likely be asked by a user of Llama language models
+#   which includes LLama, Llama2, Meta Llama3, Code Llama, Meta Llama Guard 1,	Meta Llama Guard 2
+#   Your factoid questions should be answerable with a specific, concise piece of factual information from the context.
+#   Your factoid questions should be formulated in the same style as questions users could ask in a search engine.
+#   This means that your factoid questions MUST NOT mention something like "according to the passage" or "context".
+#   please make sure you follow those rules:
+#   1. Generate {num_questions} question answer pairs, you can generate less answer if there is nothing related to
+#   model, training, fine-tuning and evaluation details of Llama language models,
+#   2. The questions can be answered based *solely* on the given passage.
+#   3. Avoid asking questions with similar meaning.
+#   4. Never use any abbreviation.
+#   5. The questions should be able to be answered in 60 words or less. Include only the questions in your response. <|eot_id|>
+#   <|start_header_id|>user<|end_header_id|>
+#   Context: {context}\n <|eot_id|><|start_header_id|>assistant<|end_header_id|>
+data_dir: "./data"
+
+xml_path: ""
+
+chunk_size: 1000
+
+questions_per_chunk: 5
+
+num_distract_docs: 4 # number of distracting documents to add to each chunk
+
+refusal_probability: 0.05 # probability of related documents to be added to each chunk

+ 336 - 0
recipes/use_cases/end2end-recipes/RAFT-Chatbot/raft_eval.py

@@ -0,0 +1,336 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
+import logging
+import evaluate
+import argparse
+from config import load_config
+import json
+from langchain_openai import ChatOpenAI
+from langchain_community.embeddings import HuggingFaceEmbeddings
+from langchain_community.vectorstores import FAISS
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+from langchain_community.vectorstores.utils import DistanceStrategy
+from datetime import datetime
+from langchain_community.document_loaders import DirectoryLoader
+import re
+import string
+import pandas as pd 
+
+
+def generate_answers_model_only(model_name,question_list,api_url="http://localhost:8000/v1",key="EMPTY"):
+        # Use langchain to load the documents from data directory
+    # Load the RAFT model
+
+    llm = ChatOpenAI(
+        openai_api_key=key,
+        openai_api_base=api_url,
+        model_name=model_name,
+        temperature=0.0,
+        max_tokens=1000
+        )
+
+    all_tasks = [api_config['eval_prompt_template'].format(question=question) for question in question_list]
+    generated_answers = llm.batch(all_tasks)
+    generated_answers = [ item.content for item in generated_answers]
+    if len(generated_answers) == 0:
+        logging.error("No model answers generated. Please check the input context or model configuration in ",model_name)
+        return []
+    return clean_text_list(generated_answers)
+def format_docs_raft(docs):
+    context = ""
+    for doc in docs:
+        context += "\n<DOCUMENT>" + str(doc.page_content) + "</DOCUMENT>\n"
+    return context
+def build_retriever(api_config,embedding_model_name,retrieved_docs_num=5):
+    # Use langchain to load the documents from data directory
+    loader = DirectoryLoader(api_config['data_dir'])
+    docs = loader.load()
+    # Split the document into chunks with a specified chunk size
+    text_splitter = RecursiveCharacterTextSplitter(chunk_size=api_config["chunk_size"],chunk_overlap=int(api_config["chunk_size"] / 10),separators= ["----------","\n\n", "\n", " ", ""],strip_whitespace=True)
+    docs_processed = text_splitter.split_documents(docs)
+    # Remove duplicates
+    unique_texts = {}
+    docs_processed_unique = []
+    for doc in docs_processed:
+        if doc.page_content not in unique_texts:
+            unique_texts[doc.page_content] = True
+            docs_processed_unique.append(doc)
+    logging.info(f"Total number of docs_processed used by vectorstore: {len(docs_processed_unique)}")
+    # Store the document into a vector store with a specific embedding model
+    embedding_model = HuggingFaceEmbeddings(
+        model_name=embedding_model_name,
+        model_kwargs={"device": "cuda"},
+        encode_kwargs={"normalize_embeddings": True},  # Set `True` for cosine similarity
+    )
+    vectorstore = FAISS.from_documents(docs_processed_unique, embedding_model, distance_strategy=DistanceStrategy.COSINE)
+    retriever = vectorstore.as_retriever(
+        search_kwargs={"k": retrieved_docs_num},
+    )
+    return retriever
+def generate_answers_with_RAG(model_name, question_list,api_config,retriever,api_url_overwrite=None):
+    api_url = api_config['model_endpoint_url']
+    if api_url_overwrite:
+        api_url = api_url_overwrite
+    key = api_config['api_key']
+    # Load the RAFT model
+    llm = ChatOpenAI(
+        openai_api_key=key,
+        openai_api_base=api_url,
+        model_name=model_name,
+        temperature=0.0,
+        max_tokens=1000
+        )
+    all_tasks = []
+    for q in question_list:
+        # retrive the top K documents
+        retrieved_docs = retriever.invoke(q)        
+        # format the documents into a string
+        documents = format_docs_raft(retrieved_docs)
+        # create a prompt
+        text = api_config["RAG_prompt_template"].format(context=documents,question=q)
+        all_tasks.append(text)
+    generated_answers = llm.batch(all_tasks)
+    generated_answers = [ item.content for item in generated_answers]
+    if len(generated_answers) == 0:
+        logging.error("No RAG answers generated. Please check the input context or model configuration in ",model_name)
+        return []
+    return clean_text_list(generated_answers)
+def compute_rouge_score(generated : list, reference: list):
+    rouge_score = evaluate.load('rouge')
+    return rouge_score.compute(
+        predictions=generated,
+        references=reference,
+        use_stemmer=True,
+        use_aggregator=True
+    )
+def clean_text_list(text_list):
+    result = []
+    for text in text_list:
+        # for raft model, the answer will started with <ANSWER>
+        index = text.rfind("<ANSWER>")
+        if index!= -1:
+            text = text[index:]
+            text = text.replace("</ANSWER>:","")
+        text = text.replace("begin_quote","")
+        text = text.replace("end_quote","")
+        text = text.replace("##","")
+        text = text.strip()
+        result.append(text)
+    return result
+
+def normalize_answer(s):
+
+    def remove_articles(text):
+        return re.sub(r'\b(a|an|the)\b', ' ', text)
+
+    def white_space_fix(text):
+        return ' '.join(text.split())
+
+    def remove_punc(text):
+        exclude = set(string.punctuation)
+        return ''.join(ch for ch in text if ch not in exclude)
+
+    def lower(text):
+        return text.lower()
+
+    return white_space_fix(remove_articles(remove_punc(lower(s))))
+def exact_match_score(prediction, ground_truth):
+    """Computes EM score for a single prediction and ground truth answer."""
+    num_match = 0
+    assert len(prediction) == len(ground_truth), "Answer length does not match prediction length."
+    assert(len(ground_truth) > 0)
+    for idx, (pred,gold) in enumerate(zip(prediction, ground_truth)):
+        if (normalize_answer(pred) == normalize_answer(gold)):
+            num_match += 1
+    return num_match/len(ground_truth)
+def compute_judge_score(questions: list, generated : list, reference: list, api_config,api_url="http://localhost:8001/v1",key="EMPTY"):
+    correct_num = 0
+    model_name = "meta-llama/Meta-Llama-3-70B-Instruct"
+    llm = ChatOpenAI(
+        openai_api_key=key,
+        openai_api_base=api_url,
+        model_name=model_name,
+        max_tokens=1000,
+        temperature=0.0)
+    all_tasks = []
+    for question,prediction,gold in zip(questions, generated,reference):
+        message = api_config['judge_prompt_template'].format(question=question,prediction=prediction,gold=gold)
+        all_tasks.append(message)
+    judge_responses = llm.batch(all_tasks)
+    judge_responses = ["YES" in item.content for item in judge_responses]
+    correct_num = sum(judge_responses)
+    return correct_num/len(questions),judge_responses
+def score_single(api_config,generated,reference,questions, run_exact_match=True,run_rouge=True, run_llm_as_judge=True):
+    # set metric to default -1, means no metric is computed
+    metric = {
+        "Rouge_score": -1,
+        "LLM_judge_score": -1,
+        "Exact_match": -1
+    }
+    if run_rouge:
+        rouge_score = compute_rouge_score(generated,reference)
+        metric["Rouge_score"] = rouge_score
+        print("Rouge_score:",rouge_score)
+    if api_config["judge_endpoint_url"] and run_llm_as_judge:
+        api_url = api_config["judge_endpoint_url"]
+        LLM_judge_score,judge_responses = compute_judge_score(questions, generated, reference, api_config,api_url=api_url)
+        metric["LLM_judge_score"] = LLM_judge_score
+        metric["LLM_judge_responses"] = judge_responses
+        print(f"LLM_judge_score: {LLM_judge_score}")
+    if run_exact_match:
+        exact_match = exact_match_score(generated,reference)
+        print(f"Exact_match_percentage: {exact_match:.4f}")
+        metric["Exact_match"] = exact_match
+    return metric
+def main(api_config):
+    # Since the eval set is small, we can run the eval without async functions
+    try:
+        api_url = api_config["model_endpoint_url"]
+        logging.info("Starting to generate answer given the eval set.")
+        questions,groud_truth = [],[]
+        if api_config["eval_file"].endswith(".parquet"):
+            eval_file = pd.read_parquet(api_config["eval_file"],filters=[('source', '=', 'pt_discuss_forum')])
+            for index, item in eval_file.iterrows():
+                questions.append(item["question"]+"\nDetails:\n"+item["context"])
+                groud_truth.append(item["answer"])
+        else:
+            with open(api_config["eval_file"]) as fp:
+                eval_file = json.load(fp)
+                for index, item in enumerate(eval_file):
+                    questions.append(item["question"])
+                    groud_truth.append(item["answer"])
+        generated_answers = {}            
+        # build retriver
+        retriever = build_retriever(api_config,"sentence-transformers/multi-qa-mpnet-base-cos-v1",api_config["rag_topk"])
+        # Generate answers for 8B models
+        model_name = api_config["model_name"]
+        generated_answers[model_name] = generate_answers_model_only(model_name,questions,api_url)
+        generated_answers[model_name+"_RAG"] = generate_answers_with_RAG(model_name, questions,api_config,retriever)
+        print("Finished generating answers for ", model_name)
+        large_model_name = "meta-llama/Meta-Llama-3-70B-Instruct"
+        large_api_url = api_config["judge_endpoint_url"]
+        generated_answers["70B_Base"] = generate_answers_model_only(large_model_name,questions,large_api_url)
+        generated_answers["70B_RAG"] = generate_answers_with_RAG(large_model_name, questions,api_config,retriever,large_api_url)
+        print("Finished generating answers for ", large_model_name)
+        logging.info(f"Successfully generated {len(generated_answers[model_name+'_RAG'])} answers for all models.")
+        # for generate answer from each model, compute the score metric
+        all_metrics = []
+        output_file = api_config["output_log"]+str(datetime.now().strftime("%Y%m%d_%H%M%S"))
+
+        for model_name,model_answer in generated_answers.items():
+            if len(model_answer) != len(groud_truth):
+                print(f"The length of {model_name} answer is not equal to the length of ground truth.")
+                continue
+            metric = score_single(api_config,model_answer,groud_truth,questions)
+            print(f"The eval result for {model_name} is: {metric}")
+            with open(output_file,"a") as fp:
+                fp.write(f"Eval_result for {model_name} \n")
+                fp.write(f"Rouge_score: {metric['Rouge_score']} \n")
+                fp.write(f"Exact_match_percentage: {metric['Exact_match']} \n")
+                judge_responses = ["None"] * len(questions)
+                if api_config["judge_endpoint_url"]:
+                    fp.write(f"LLM_judge_score: {metric['LLM_judge_score']} \n")
+                    judge_responses = metric["LLM_judge_responses"]
+                    all_metrics.append((model_name,metric['LLM_judge_score'],metric["LLM_judge_responses"]))
+                fp.write(f"QA details: \n")
+                for item in zip(questions,model_answer,groud_truth,judge_responses):
+                    fp.write(f"question: {item[0]} \n")
+                    fp.write(f"generated_answers: {item[1]} \n")
+                    fp.write(f"groud_truth: {item[2]} \n")
+                    fp.write(f"LLM_judge_response: {item[3]} \n")
+                    fp.write("\n")
+                fp.write("\n------------------------------------\n")
+        # Now we want to take a closer look at the questions that are not answered the same by all the models.
+        judge_zip = list(zip(*[item[-1] for item in all_metrics]))
+        model_names = [item[0] for item in all_metrics]
+        with open(output_file,"a") as fp:
+            for item in all_metrics:
+                fp.write(f"Model_Name: {item[0]}, LLM_SCORE: {item[1]} \n")
+            for idx,item in enumerate(judge_zip):
+                # if all the responses are "YES", then we skip this question
+                if sum(item) == len(item):
+                    continue 
+                else:
+                    fp.write(f"Comparing interested question: {questions[idx]} \n")
+                    fp.write(f"groud_truth: {groud_truth[idx]} \n")
+                    for i in range(len(model_names)):
+                        fp.write(f"{item[i]} {model_names[i]}_answers: {generated_answers[model_names[i]][idx]} \n")
+                    fp.write("------------------------\n")
+            fp.write(json.dumps(all_metrics))
+        print("Finished evaluating the model.")
+
+
+        logging.info(f"Eval successfully, the eval result is saved to {api_config['output_log']}.")
+        # Saving the eval result to a log file
+    except Exception as e:
+        logging.error(f"An unexpected error occurred during the process: {e}",exc_info=True)
+
+def parse_arguments():
+    # Define command line arguments for the script
+    parser = argparse.ArgumentParser(
+        description="Generate question/answer pairs from documentation."
+    )
+    parser.add_argument(
+        "-m", "--model_name",
+        default=None,
+        help="Provide the model_name to use for evaluation. If not specified, the model_path in eval_config.yaml will be used."
+    )
+    parser.add_argument(
+        "-c", "--config_path",
+        default="raft_eval_config.yaml",
+        help="Set the configuration file path that has system prompt along with language, evalset path."
+    )
+    parser.add_argument(
+        "-d", "--data_dir",
+        default=None,
+        help="Provide the data folder path to build RAG for evaluation. If not specified, the data_dir in eval_config.yaml will be used."
+    )
+    parser.add_argument(
+        "-u", "--model_endpoint_url",
+        default="http://localhost:8000/v1",
+        type=str,
+        help="The raft model endpoint url for eval."
+    )
+    parser.add_argument(
+        "-j", "--judge_endpoint_url",
+        default=None,
+        type=str,
+        help="The large model endpoint url for judge as LLM."
+    )
+    parser.add_argument(
+        "-o", "--output_log",
+        default="./eval_result",
+        help="save the eval result to a log file. Default is eval_result[timestamp].log"
+    )
+    parser.add_argument(
+        "-k", "--api_key",
+        default="EMPTY",
+        type=str,
+        help="LLM API key for generating question/answer pairs."
+    )
+    parser.add_argument(
+        "-r", "--rag_topk",
+        default=5,
+        type=int,
+        help="set the number of top k documents the RAG needs to retrive."
+    )
+    parser.add_argument("--chunk_size", type=int, default=1000, help="The character size of each chunk used in RAG")
+    return parser.parse_args()
+
+if __name__ == "__main__":
+    logging.info("Initializing the process and loading configuration...")
+    args = parse_arguments()
+    api_config = load_config(args.config_path)
+    api_config["model_endpoint_url"] = args.model_endpoint_url
+    if args.data_dir:
+        api_config["data_dir"] = args.data_dir
+    if args.model_name:
+        api_config["model_name"] = args.model_name
+    api_config["judge_endpoint_url"] = args.judge_endpoint_url
+    api_config["output_log"] = args.output_log
+    api_config["api_key"] = args.api_key
+    api_config["chunk_size"] = args.chunk_size
+    api_config["rag_topk"] = args.rag_topk
+    if api_config["judge_endpoint_url"]:
+        logging.info(f"The judge model url is: '{args.judge_endpoint_url}'.")
+    main(api_config)

+ 37 - 0
recipes/use_cases/end2end-recipes/RAFT-Chatbot/raft_eval_config.yaml

@@ -0,0 +1,37 @@
+eval_prompt_template: >
+  <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a AI assistant that skilled in answering questions related to Llama language models,
+  which includes LLama, Llama2, Meta Llama3, Code Llama, Meta Llama Guard 1,	Meta Llama Guard 2,
+  Below is a question from a llama user, please the answer it with best of your knowledge,
+  The returned answer should be no more than 60 words. Please return the answers in text directly without any special tokens.<|eot_id|>
+  <|start_header_id|>user<|end_header_id|>
+  Question:{question} \n <|eot_id|><|start_header_id|>assistant<|end_header_id|>
+judge_prompt_template: >
+    <|begin_of_text|><|start_header_id|>system<|end_header_id|>You have been provided with a question, a teacher's answer and a student's answer below.
+    Given that question, you need to score the how good the student answer is compare to
+    the teacher's answer. If the student's answer is correct based on the teacher's answer, then return YES, else return NO.
+    Here are the grade criterias to follow:
+    1. Review it carefully to make sure that the keywords and numerical vaules are exactly the same.
+    2. Ensure that the student answer does not contain any conflicting statements.
+    3. It is OK if the student answer contains more information than the ground truth answer, as long as it is factually accurate relative to the  ground truth answer.
+    YES means that the student's answer meets all of the criteria.
+    NO means that the student's answer does not meet all of the criteria. This is the lowest possible score you can give.
+    Only respond with "YES" or "NO", do not respond with anything else.<|eot_id|>
+    <|start_header_id|>user<|end_header_id|>
+    Question: {question} \n Teacher's Answer: {gold} \n Student's Answer: {prediction} <|eot_id|><|start_header_id|>assistant<|end_header_id|>
+RAG_prompt_template: >
+  <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a helpful chatbot who can provide an answer to every questions from the user given a relevant context.<|eot_id|>
+  <|start_header_id|>user<|end_header_id|>
+  Question: {question}\nContext: {context}\n
+  Answer this question using the information given by multiple documents in the context above. Here are the things to pay attention to:
+  - The context contains many documents, each document starts with <DOCUMENT> and ends </DOCUMENT>.
+  - First provide step-by-step reasoning on how to answer the question.
+  - In the reasoning, if you need to copy paste some sentences from the context, include them in ##begin_quote## and ##end_quote##. This would mean that things outside of ##begin_quote## and ##end_quote## are not directly copy paste from the context.
+  - End your response with final answer in the form <ANSWER>: $answer, the answer should less than 60 words.
+  You MUST begin your final answer with the tag "<ANSWER>:". <|eot_id|><|start_header_id|>assistant<|end_header_id|>
+eval_file: "./eval_llama.json"
+
+model_name: "raft-8b"
+
+data_dir: "./data"
+
+rag_topk: 5

+ 245 - 0
recipes/use_cases/end2end-recipes/RAFT-Chatbot/raft_utils.py

@@ -0,0 +1,245 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import os
+import logging
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+from datasets import Dataset
+import random
+from langchain_community.document_loaders import SitemapLoader,DirectoryLoader
+from bs4 import BeautifulSoup
+from langchain_openai import ChatOpenAI
+import copy
+
+
+# Initialize logging
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
+def strip_str(s: str) -> str:
+    """
+    Helper function for helping format strings returned by GPT-4.
+    """
+    l, r = 0, len(s)-1
+    beg_found = False
+    for i in range(len(s)):
+        if s[i].isalpha():
+            if not beg_found:
+                l = i
+                beg_found = True
+            else:
+                r = i
+    r += 2
+    return s[l:min(r, len(s))]
+def clean_documents(raw_text):
+    all_lines = []
+    for line in raw_text.split("\n"):
+        line = line.strip()
+        if len(line.split()) == 0:
+            continue
+        else:
+            all_lines.append(line)
+    result = " ".join(all_lines)
+    return result
+def clean_text(content: BeautifulSoup) -> str:
+    # Find all 'nav' and 'header' elements in the BeautifulSoup object
+    nav_elements = content.find_all("nav")
+    header_elements = content.find_all("header")
+    mydivs = content.find_all("div", {"role": "list"})
+    # Remove each 'nav' and 'header' element from the BeautifulSoup object
+    for element in nav_elements + header_elements+mydivs:
+        element.decompose()
+    raw_text = content.get_text("\n")
+    return clean_documents(raw_text)
+# Read
+def read_file_content(xml_path: str, data_folder: str) -> str:
+    if xml_path and data_folder:
+        logging.info(f"Error: both xml_path and data_folder are provided, will only read from xml for now")
+    if not xml_path and not data_folder:
+        logging.info(f"Error: both xml_path and data_folder are not provided")
+        return ""
+    if xml_path:
+        if not os.path.exists(xml_path):
+            logging.info(f"Error: {xml_path} does not exist")
+            return ""
+        # Use langchain to load the documents from webpage links in the xml file
+        sitemap_loader = SitemapLoader(web_path=xml_path,is_local=True,parsing_function=clean_text)
+        sitemap_loader.requests_kwargs = {"verify": False}
+        docs = sitemap_loader.load()
+        return docs
+    elif len(data_folder) != 0:
+        if not os.path.exists(data_folder):
+            logging.info(f"Error: {data_folder} does not exist")
+            return ""
+        # Use langchain to load the documents from data folder
+        loader = DirectoryLoader(data_folder)
+        docs = loader.load()
+        return docs
+
+
+
+def get_chunks(
+    docs: list,
+    chunk_size: int = 1000,
+    api_config: dict = None,
+) -> list[str]:
+    """
+    Takes in a list of documents, breaks them down into chunks of size
+    `chunk_size`, and returns the chunks.
+    """
+    chunks = []
+    if  len(docs) == 0:
+        raise TypeError("Can not get chunks from empty text")
+    else:
+        text_splitter = RecursiveCharacterTextSplitter(chunk_size=api_config["chunk_size"],chunk_overlap=int(api_config["chunk_size"] / 10),separators= ["----------","\n\n", "\n", " "],strip_whitespace=True)
+        docs_processed = text_splitter.split_documents(docs)
+        logging.info(f"Total number of docs_processed: {len(docs_processed)}")
+        # Remove duplicates
+        unique_texts = {}
+        docs_processed_unique = []
+        for doc in docs_processed:
+            if doc.page_content not in unique_texts and len(doc.page_content) > 100 :
+                unique_texts[doc.page_content] = True
+                docs_processed_unique.append(doc)        
+        chunks = [chunk.page_content for chunk in docs_processed_unique]
+        logging.info(f"Total number of docs_processed_unique: {len(docs_processed_unique)}")
+    return chunks
+# read all the files in the data folder, then split them into chunks
+# generate questions for each chunk and return zip of chunk and related questions list
+def generate_questions(api_config):
+    # get documents from the data folder or xml file
+    api_url = api_config["endpoint_url"]
+    key = api_config["api_key"]
+    documents = read_file_content(api_config["xml_path"],api_config["data_dir"])
+    if len(documents) == 0:
+        logging.info(f"Error reading files, document_text is {len(documents)}")
+    document_batches = get_chunks(documents,api_config["chunk_size"],api_config)
+    # use OpenAI API protocol to hanlde the chat request, including local VLLM openai compatible server
+    llm = ChatOpenAI(
+        openai_api_key=key,
+        openai_api_base=api_url,
+        model_name=api_config["model"],
+        temperature=0.0,
+        max_tokens=500
+        )
+    all_tasks = [api_config['question_prompt_template'].format(num_questions=str(api_config['questions_per_chunk']),context=document) for document in document_batches]
+    generated_answers = llm.batch(all_tasks)
+    generated_answers = [ item.content for item in generated_answers]
+    if len(generated_answers) == 0:
+        logging.error("No model answers generated. Please check the input context or model configuration in ",api_config["model"])
+        return []
+    final_result = []
+    for result in generated_answers:
+        queries = result.split('\n')
+        queries = [strip_str(q) for q in queries]
+        queries = [q for q in queries if any(c.isalpha() for c in q)]
+        if len(queries) > int(api_config['questions_per_chunk']):
+            # As the model may have unrelated question at the begining of the result
+            # if queries is more than questions_per_chunk, then we need to truncate it and only keep last questions_per_chunk lines
+            queries = queries[-int(api_config['questions_per_chunk']):]
+        final_result.append(queries)
+    return list(zip(document_batches,final_result))
+
+# Generate COT answer for each question given the chunk context
+def generate_COT(chunk_questions_zip,api_config) -> dict:
+    all_tasks = []
+    chunk_questions = []
+    question_asked = set()
+    for document_content,questions in chunk_questions_zip:
+        for question in questions:
+            question = question.strip()
+            # avoid asking the same question twice
+            if question not in question_asked:
+                question_asked.add(question)
+                prompt = api_config['COT_prompt_template'].format(question=question,context=str(document_content))
+                all_tasks.append(prompt)
+                chunk_questions.append((document_content,question))
+    # use OpenAI API protocol to hanlde the chat request, including local VLLM openai compatible server
+    llm = ChatOpenAI(
+        openai_api_key=api_config["api_key"],
+        openai_api_base=api_config["endpoint_url"],
+        model_name=api_config["model"],
+        temperature=0.0,
+        max_tokens=500
+        )
+    generated_answers = llm.batch(all_tasks)
+    generated_answers = [ item.content for item in generated_answers]
+    COT_results = []
+    # return a list of (chunk, question, generated_answer)
+    for (chunk, question),generated_answer in zip(chunk_questions,generated_answers):
+        COT_results.append((chunk,question,generated_answer))
+    return COT_results
+
+def add_chunk_to_dataset(
+    chunk_questions_zip: list,
+    api_config: dict,
+) -> None:
+    """
+    Given a chunk and related questions lists, create {Q, A, D} triplets and add them to the dataset.
+    """
+    num_distract = api_config["num_distract_docs"]
+    p = api_config["refusal_probability"]
+    chunks = [chunk for chunk, _ in chunk_questions_zip]
+    COT_results = generate_COT(chunk_questions_zip,api_config)
+    logging.info(f"COT generation completed, total num of COT results: {len(COT_results)}")
+    completed,refusal= 0,0
+    data_list = []
+    for chunk, q , cot in COT_results:
+        # The COT answer will be used as the label in the fine-tuning stage
+
+        datapt = {
+            "id": None,
+            "type": "general",
+            "question": q,
+            "context": None,
+            "oracle_context": None,
+            "cot_answer": cot
+        }
+        i = chunks.index(chunk)
+        datapt["id"] = f"seed_task_{len(data_list)}"
+        # add num_distract distractor docs
+        docs = [chunk]
+        indices = list(range(0, len(chunks)))
+        indices.remove(i)
+        for j in random.sample(indices, num_distract):
+            docs.append(chunks[j])
+        doc_copy = docs.copy()
+        random.shuffle(docs)
+        d = {
+            "title": [],
+            "sentences": []
+        }
+
+        d["title"].append(["placeholder_title"]*(num_distract+1))
+        d["sentences"].append(docs)
+        datapt["context"] = d
+        datapt["oracle_context"] = chunk
+
+        # construct model instruction
+        context = ""
+        for doc in docs:
+            context += "<DOCUMENT>" + str(doc) + "</DOCUMENT>\n"
+        context += q
+        # This instruction will be used in the fine-tuning stage
+        datapt["instruction"] = context
+        datapt_copy = copy.deepcopy(datapt)
+        # add to dataset
+        data_list.append(datapt)
+        # decides whether to add refusal example where the related documents are not provided
+        refusal = random.uniform(0, 1) <= p
+        if refusal:
+            doc_copy[0] = chunks[random.sample(indices, 1)[0]]
+            random.shuffle(doc_copy)
+            refusl_context = ""
+            for doc in doc_copy:
+                refusl_context += "<DOCUMENT>" + str(doc) + "</DOCUMENT>\n"
+            refusl_context += q
+            # This instruction will be used in the fine-tuning stage
+            datapt_copy["id"] = f"refusal_task_{len(data_list)}"
+            datapt_copy["instruction"] = refusl_context
+            datapt_copy["cot_answer"] = "Sorry, I don't know the answer to this question because related documents are not found. Please try again."
+            data_list.append(datapt_copy)
+            refusal += 1
+        completed += 1
+        if completed % 100 == 0:
+            logging.info(f"refusal example added: {refusal}, total examples added: {completed}, total examples to be added: {len(COT_results)- completed}")
+    ds = Dataset.from_list(data_list)
+    return ds

+ 10 - 1
requirements.txt

@@ -19,4 +19,13 @@ chardet
 openai
 typing-extensions==4.8.0
 tabulate
-codeshield
+evaluate
+rouge_score
+pyyaml==6.0.1
+faiss-gpu
+unstructured[pdf]
+langchain_openai
+langchain
+langchain_community
+sentence_transformers
+codeshield

+ 6 - 6
src/llama_recipes/configs/datasets.py

@@ -3,28 +3,27 @@
 
 from dataclasses import dataclass
 
-    
+
 @dataclass
 class samsum_dataset:
     dataset: str =  "samsum_dataset"
     train_split: str = "train"
     test_split: str = "validation"
-    
-    
+
+
 @dataclass
 class grammar_dataset:
     dataset: str = "grammar_dataset"
-    train_split: str = "src/llama_recipes/datasets/grammar_dataset/gtrain_10k.csv" 
+    train_split: str = "src/llama_recipes/datasets/grammar_dataset/gtrain_10k.csv"
     test_split: str = "src/llama_recipes/datasets/grammar_dataset/grammar_validation.csv"
 
-    
+
 @dataclass
 class alpaca_dataset:
     dataset: str = "alpaca_dataset"
     train_split: str = "train"
     test_split: str = "val"
     data_path: str = "src/llama_recipes/datasets/alpaca_data.json"
-    
 
 @dataclass
 class custom_dataset:
@@ -32,6 +31,7 @@ class custom_dataset:
     file: str = "recipes/quickstart/finetuning/datasets/custom_dataset.py"
     train_split: str = "train"
     test_split: str = "validation"
+    data_path: str = ""
     
 @dataclass
 class llamaguard_toxicchat_dataset: