123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175 |
- # file copied from https://github.com/ShishirPatil/gorilla/blob/main/raft/format.py
- from abc import ABC, abstractmethod
- import argparse
- from datasets import Dataset, load_dataset
- from typing import Dict, Literal, Any, get_args
- """
- This file allows to convert raw HuggingFace Datasets into files suitable to fine tune completion and chat models.
- """
- OutputDatasetType = Literal["parquet", "jsonl"]
- outputDatasetTypes = list(get_args(OutputDatasetType))
- InputDatasetType = Literal["arrow", "jsonl"]
- inputDatasetTypes = list(get_args(InputDatasetType))
- DatasetFormat = Literal["hf", "completion", "chat"]
- datasetFormats = list(get_args(DatasetFormat))
- def get_args() -> argparse.Namespace:
- """
- Parses and returns the arguments specified by the user's command
- """
- parser = argparse.ArgumentParser()
- parser.add_argument("--input", type=str, required=True, help="Input HuggingFace dataset file")
- parser.add_argument("--input-type", type=str, default="arrow", help="Format of the input dataset. Defaults to arrow.", choices=inputDatasetTypes)
- parser.add_argument("--output", type=str, required=True, help="Output file")
- parser.add_argument("--output-format", type=str, required=True, help="Format to convert the dataset to", choices=datasetFormats)
- parser.add_argument("--output-type", type=str, default="jsonl", help="Type to export the dataset to. Defaults to jsonl.", choices=outputDatasetTypes)
- parser.add_argument("--output-chat-system-prompt", type=str, help="The system prompt to use when the output format is chat")
- args = parser.parse_args()
- return args
- class DatasetFormatter(ABC):
- """
- Base class for dataset formatters. Formatters rename columns, remove and add
- columns to match the expected target format structure. HF, Chat or Completion models file formats.
- https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset
- """
- @abstractmethod
- def format(self, ds: Dataset, params: Dict[str, str]) -> Dataset:
- pass
- class DatasetExporter(ABC):
- """
- Base class for dataset exporters. Exporters export dataset to different file types, JSONL, Parquet, ...
- """
- @abstractmethod
- def export(self, ds: Dataset, output_path: str):
- pass
- class DatasetConverter():
- """
- Entry point class. It resolves which DatasetFormatter and which DatasetExporter to use and runs them.
- """
- formats: Dict[DatasetFormat, DatasetFormatter]
- exporters: Dict[OutputDatasetType, Any]
- def __init__(self) -> None:
- self.formats = {
- "hf": HuggingFaceDatasetFormatter(),
- "completion": OpenAiCompletionDatasetFormatter(),
- "chat": OpenAiChatDatasetFormatter()
- }
- self.exporters = {
- "parquet": ParquetDatasetExporter(),
- "jsonl": JsonlDatasetExporter()
- }
- def convert(self, ds: Dataset, format: DatasetFormat, output_path: str, output_type: OutputDatasetType, params: Dict[str, str]):
- if not format in self.formats:
- raise Exception(f"Output Format {format} is not supported, pleased select one of {self.formats.keys()}")
-
- if not output_type in self.exporters:
- raise Exception(f"Output Type {output_type} is not supported, pleased select one of {self.exporters.keys()}")
- formatter = self.formats[format]
- newds = formatter.format(ds, params)
- exporter = self.exporters[output_type]
- exporter.export(newds, output_path)
- class HuggingFaceDatasetFormatter(DatasetFormatter):
- """
- Returns the HuggingFace Dataset as is
- """
- def format(self, ds: Dataset, params: Dict[str, str]) -> Dataset:
- return ds
- def _remove_all_columns_but(ds: Dataset, keep_columns) -> Dataset:
- """
- HF Dataset doesn't have a way to copy only specific columns of a Dataset so this help
- removes all columns but the ones specified.
- """
- remove_columns = list(ds.column_names)
- for keep in keep_columns:
- remove_columns.remove(keep)
- ds = ds.remove_columns(remove_columns)
- return ds
- class OpenAiCompletionDatasetFormatter(DatasetFormatter):
- """
- Returns the Dataset in the OpenAI Completion Fine-tuning file format with two fields "prompt" and "completion".
- https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset
- """
- def format(self, ds: Dataset, params: Dict[str, str]) -> Dataset:
- newds = ds.rename_columns({'question': 'prompt', 'cot_answer': 'completion'})
- return _remove_all_columns_but(newds, ['prompt', 'completion'])
- class OpenAiChatDatasetFormatter(OpenAiCompletionDatasetFormatter):
- """
- Returns the Dataset in the OpenAI Chat Fine-tuning file format with one field "messages".
- https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset
- """
- def format(self, ds: Dataset, params: Dict[str, str]) -> Dataset:
- newds = super().format(ds, params)
- def format_messages(row):
- messages = []
- if 'system_prompt' in params:
- system_prompt = params['system_prompt']
- messages.append({ "role": "system", "content": system_prompt})
- messages.extend([{ "role": "user", "content": row['prompt']}, { "role": "assistant", "content": row['completion']}])
- chat_row = {"messages": messages}
- return chat_row
- newds = newds.map(format_messages)
- return _remove_all_columns_but(newds, ['messages'])
- def append_extension(path: str, extension: str) -> str:
- suffix = "." + extension
- if not path.endswith(suffix):
- path = path + suffix
- return path
- class JsonlDatasetExporter(DatasetExporter):
- """
- Exports the Dataset to a JSONL file
- """
- def export(self, ds: Dataset, output_path: str):
- ds.to_json(append_extension(output_path, "jsonl"))
- class ParquetDatasetExporter(DatasetExporter):
- """
- Exports the Dataset to a Parquet file
- """
- def export(self, ds: Dataset, output_path: str):
- ds.to_parquet(append_extension(output_path, "parquet"))
- def main():
- """
- When raft.py is executed from the command line.
- """
- args = get_args()
- ds = load_dataset(args.input_type, data_files={"train": args.input})['train']
- formatter = DatasetConverter()
- if args.output_chat_system_prompt and args.output_format != "chat":
- raise Exception("Parameter --output-chat-system-prompt can only be used with --output-format chat")
- format_params = {}
- if args.output_chat_system_prompt:
- format_params['system_prompt'] = args.output_chat_system_prompt
- formatter.convert(ds=ds, format=args.output_format, output_path=args.output, output_type=args.output_type, params=format_params)
- if __name__ == "__main__":
- main()
|