format.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. from abc import ABC, abstractmethod
  2. import argparse
  3. from datasets import Dataset, load_dataset
  4. from typing import Dict, Literal, Any, get_args
  5. """
  6. This file allows to convert raw HuggingFace Datasets into files suitable to fine tune completion and chat models.
  7. """
  8. OutputDatasetType = Literal["parquet", "jsonl"]
  9. outputDatasetTypes = list(get_args(OutputDatasetType))
  10. InputDatasetType = Literal["arrow", "jsonl"]
  11. inputDatasetTypes = list(get_args(InputDatasetType))
  12. DatasetFormat = Literal["hf", "completion", "chat"]
  13. datasetFormats = list(get_args(DatasetFormat))
  14. def get_args() -> argparse.Namespace:
  15. """
  16. Parses and returns the arguments specified by the user's command
  17. """
  18. parser = argparse.ArgumentParser()
  19. parser.add_argument("--input", type=str, required=True, help="Input HuggingFace dataset file")
  20. parser.add_argument("--input-type", type=str, default="arrow", help="Format of the input dataset. Defaults to arrow.", choices=inputDatasetTypes)
  21. parser.add_argument("--output", type=str, required=True, help="Output file")
  22. parser.add_argument("--output-format", type=str, required=True, help="Format to convert the dataset to", choices=datasetFormats)
  23. parser.add_argument("--output-type", type=str, default="jsonl", help="Type to export the dataset to. Defaults to jsonl.", choices=outputDatasetTypes)
  24. parser.add_argument("--output-chat-system-prompt", type=str, help="The system prompt to use when the output format is chat")
  25. args = parser.parse_args()
  26. return args
  27. class DatasetFormatter(ABC):
  28. """
  29. Base class for dataset formatters. Formatters rename columns, remove and add
  30. columns to match the expected target format structure. HF, Chat or Completion models file formats.
  31. https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset
  32. """
  33. @abstractmethod
  34. def format(self, ds: Dataset, params: Dict[str, str]) -> Dataset:
  35. pass
  36. class DatasetExporter(ABC):
  37. """
  38. Base class for dataset exporters. Exporters export dataset to different file types, JSONL, Parquet, ...
  39. """
  40. @abstractmethod
  41. def export(self, ds: Dataset, output_path: str):
  42. pass
  43. class DatasetConverter():
  44. """
  45. Entry point class. It resolves which DatasetFormatter and which DatasetExporter to use and runs them.
  46. """
  47. formats: Dict[DatasetFormat, DatasetFormatter]
  48. exporters: Dict[OutputDatasetType, Any]
  49. def __init__(self) -> None:
  50. self.formats = {
  51. "hf": HuggingFaceDatasetFormatter(),
  52. "completion": OpenAiCompletionDatasetFormatter(),
  53. "chat": OpenAiChatDatasetFormatter()
  54. }
  55. self.exporters = {
  56. "parquet": ParquetDatasetExporter(),
  57. "jsonl": JsonlDatasetExporter()
  58. }
  59. def convert(self, ds: Dataset, format: DatasetFormat, output_path: str, output_type: OutputDatasetType, params: Dict[str, str]):
  60. if not format in self.formats:
  61. raise Exception(f"Output Format {format} is not supported, pleased select one of {self.formats.keys()}")
  62. if not output_type in self.exporters:
  63. raise Exception(f"Output Type {output_type} is not supported, pleased select one of {self.exporters.keys()}")
  64. formatter = self.formats[format]
  65. newds = formatter.format(ds, params)
  66. exporter = self.exporters[output_type]
  67. exporter.export(newds, output_path)
  68. class HuggingFaceDatasetFormatter(DatasetFormatter):
  69. """
  70. Returns the HuggingFace Dataset as is
  71. """
  72. def format(self, ds: Dataset, params: Dict[str, str]) -> Dataset:
  73. return ds
  74. def _remove_all_columns_but(ds: Dataset, keep_columns) -> Dataset:
  75. """
  76. HF Dataset doesn't have a way to copy only specific columns of a Dataset so this help
  77. removes all columns but the ones specified.
  78. """
  79. remove_columns = list(ds.column_names)
  80. for keep in keep_columns:
  81. remove_columns.remove(keep)
  82. ds = ds.remove_columns(remove_columns)
  83. return ds
  84. class OpenAiCompletionDatasetFormatter(DatasetFormatter):
  85. """
  86. Returns the Dataset in the OpenAI Completion Fine-tuning file format with two fields "prompt" and "completion".
  87. https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset
  88. """
  89. def format(self, ds: Dataset, params: Dict[str, str]) -> Dataset:
  90. newds = ds.rename_columns({'question': 'prompt', 'cot_answer': 'completion'})
  91. return _remove_all_columns_but(newds, ['prompt', 'completion'])
  92. class OpenAiChatDatasetFormatter(OpenAiCompletionDatasetFormatter):
  93. """
  94. Returns the Dataset in the OpenAI Chat Fine-tuning file format with one field "messages".
  95. https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset
  96. """
  97. def format(self, ds: Dataset, params: Dict[str, str]) -> Dataset:
  98. newds = super().format(ds, params)
  99. def format_messages(row):
  100. messages = []
  101. if 'system_prompt' in params:
  102. system_prompt = params['system_prompt']
  103. messages.append({ "role": "system", "content": system_prompt})
  104. messages.extend([{ "role": "user", "content": row['prompt']}, { "role": "assistant", "content": row['completion']}])
  105. chat_row = {"messages": messages}
  106. return chat_row
  107. newds = newds.map(format_messages)
  108. return _remove_all_columns_but(newds, ['messages'])
  109. def append_extension(path: str, extension: str) -> str:
  110. suffix = "." + extension
  111. if not path.endswith(suffix):
  112. path = path + suffix
  113. return path
  114. class JsonlDatasetExporter(DatasetExporter):
  115. """
  116. Exports the Dataset to a JSONL file
  117. """
  118. def export(self, ds: Dataset, output_path: str):
  119. ds.to_json(append_extension(output_path, "jsonl"))
  120. class ParquetDatasetExporter(DatasetExporter):
  121. """
  122. Exports the Dataset to a Parquet file
  123. """
  124. def export(self, ds: Dataset, output_path: str):
  125. ds.to_parquet(append_extension(output_path, "parquet"))
  126. def main():
  127. """
  128. When raft.py is executed from the command line.
  129. """
  130. args = get_args()
  131. ds = load_dataset(args.input_type, data_files={"train": args.input})['train']
  132. formatter = DatasetConverter()
  133. if args.output_chat_system_prompt and args.output_format != "chat":
  134. raise Exception("Parameter --output-chat-system-prompt can only be used with --output-format chat")
  135. format_params = {}
  136. if args.output_chat_system_prompt:
  137. format_params['system_prompt'] = args.output_chat_system_prompt
  138. formatter.convert(ds=ds, format=args.output_format, output_path=args.output, output_type=args.output_type, params=format_params)
  139. if __name__ == "__main__":
  140. main()