| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175 | # file copied from https://github.com/ShishirPatil/gorilla/blob/main/raft/format.pyfrom abc import ABC, abstractmethodimport argparsefrom datasets import Dataset, load_datasetfrom typing import Dict, Literal, Any, get_args"""This file allows to convert raw HuggingFace Datasets into files suitable to fine tune completion and chat models."""OutputDatasetType = Literal["parquet", "jsonl"]outputDatasetTypes = list(get_args(OutputDatasetType))InputDatasetType = Literal["arrow", "jsonl"]inputDatasetTypes = list(get_args(InputDatasetType))DatasetFormat = Literal["hf", "completion", "chat"]datasetFormats = list(get_args(DatasetFormat))def get_args() -> argparse.Namespace:    """    Parses and returns the arguments specified by the user's command    """    parser = argparse.ArgumentParser()    parser.add_argument("--input", type=str, required=True, help="Input HuggingFace dataset file")    parser.add_argument("--input-type", type=str, default="arrow", help="Format of the input dataset. Defaults to arrow.", choices=inputDatasetTypes)    parser.add_argument("--output", type=str, required=True, help="Output file")    parser.add_argument("--output-format", type=str, required=True, help="Format to convert the dataset to", choices=datasetFormats)    parser.add_argument("--output-type", type=str, default="jsonl", help="Type to export the dataset to. Defaults to jsonl.", choices=outputDatasetTypes)    parser.add_argument("--output-chat-system-prompt", type=str, help="The system prompt to use when the output format is chat")    args = parser.parse_args()    return argsclass DatasetFormatter(ABC):    """    Base class for dataset formatters. Formatters rename columns, remove and add     columns to match the expected target format structure. HF, Chat or Completion models file formats.    https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset    """    @abstractmethod    def format(self, ds: Dataset, params: Dict[str, str]) -> Dataset:        passclass DatasetExporter(ABC):    """    Base class for dataset exporters. Exporters export dataset to different file types, JSONL, Parquet, ...    """    @abstractmethod    def export(self, ds: Dataset, output_path: str):        passclass DatasetConverter():    """    Entry point class. It resolves which DatasetFormatter and which DatasetExporter to use and runs them.    """    formats: Dict[DatasetFormat, DatasetFormatter]    exporters: Dict[OutputDatasetType, Any]    def __init__(self) -> None:        self.formats = {            "hf": HuggingFaceDatasetFormatter(),            "completion": OpenAiCompletionDatasetFormatter(),            "chat": OpenAiChatDatasetFormatter()        }        self.exporters = {            "parquet": ParquetDatasetExporter(),            "jsonl": JsonlDatasetExporter()        }    def convert(self, ds: Dataset, format: DatasetFormat, output_path: str, output_type: OutputDatasetType, params: Dict[str, str]):        if not format in self.formats:            raise Exception(f"Output Format {format} is not supported, pleased select one of {self.formats.keys()}")                if not output_type in self.exporters:            raise Exception(f"Output Type {output_type} is not supported, pleased select one of {self.exporters.keys()}")        formatter = self.formats[format]        newds = formatter.format(ds, params)        exporter = self.exporters[output_type]        exporter.export(newds, output_path)class HuggingFaceDatasetFormatter(DatasetFormatter):    """    Returns the HuggingFace Dataset as is    """    def format(self, ds: Dataset, params: Dict[str, str]) -> Dataset:        return dsdef _remove_all_columns_but(ds: Dataset, keep_columns) -> Dataset:    """    HF Dataset doesn't have a way to copy only specific columns of a Dataset so this help    removes all columns but the ones specified.    """    remove_columns = list(ds.column_names)    for keep in keep_columns:        remove_columns.remove(keep)    ds = ds.remove_columns(remove_columns)    return dsclass OpenAiCompletionDatasetFormatter(DatasetFormatter):    """    Returns the Dataset in the OpenAI Completion Fine-tuning file format with two fields "prompt" and "completion".    https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset    """    def format(self, ds: Dataset, params: Dict[str, str]) -> Dataset:        newds = ds.rename_columns({'question': 'prompt', 'cot_answer': 'completion'})        return _remove_all_columns_but(newds, ['prompt', 'completion'])class OpenAiChatDatasetFormatter(OpenAiCompletionDatasetFormatter):    """    Returns the Dataset in the OpenAI Chat Fine-tuning file format with one field "messages".    https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset    """    def format(self, ds: Dataset, params: Dict[str, str]) -> Dataset:        newds = super().format(ds, params)        def format_messages(row):            messages = []            if 'system_prompt' in params:                system_prompt = params['system_prompt']                messages.append({ "role": "system", "content": system_prompt})            messages.extend([{ "role": "user", "content": row['prompt']}, { "role": "assistant", "content": row['completion']}])            chat_row = {"messages": messages}            return chat_row        newds = newds.map(format_messages)        return _remove_all_columns_but(newds, ['messages'])def append_extension(path: str, extension: str) -> str:    suffix = "." + extension    if not path.endswith(suffix):        path = path + suffix    return pathclass JsonlDatasetExporter(DatasetExporter):    """    Exports the Dataset to a JSONL file    """    def export(self, ds: Dataset, output_path: str):        ds.to_json(append_extension(output_path, "jsonl"))class ParquetDatasetExporter(DatasetExporter):    """    Exports the Dataset to a Parquet file    """    def export(self, ds: Dataset, output_path: str):        ds.to_parquet(append_extension(output_path, "parquet"))def main():    """    When raft.py is executed from the command line.    """    args = get_args()    ds = load_dataset(args.input_type, data_files={"train": args.input})['train']    formatter = DatasetConverter()    if args.output_chat_system_prompt and args.output_format != "chat":        raise Exception("Parameter --output-chat-system-prompt can only be used with --output-format chat")    format_params = {}    if args.output_chat_system_prompt:        format_params['system_prompt'] = args.output_chat_system_prompt    formatter.convert(ds=ds, format=args.output_format, output_path=args.output, output_type=args.output_type, params=format_params)if __name__ == "__main__":    main()
 |