format.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. # file copied from https://github.com/ShishirPatil/gorilla/blob/main/raft/format.py
  2. from abc import ABC, abstractmethod
  3. import argparse
  4. from datasets import Dataset, load_dataset
  5. from typing import Dict, Literal, Any, get_args
  6. """
  7. This file allows to convert raw HuggingFace Datasets into files suitable to fine tune completion and chat models.
  8. """
  9. OutputDatasetType = Literal["parquet", "jsonl"]
  10. outputDatasetTypes = list(get_args(OutputDatasetType))
  11. InputDatasetType = Literal["arrow", "jsonl"]
  12. inputDatasetTypes = list(get_args(InputDatasetType))
  13. DatasetFormat = Literal["hf", "completion", "chat"]
  14. datasetFormats = list(get_args(DatasetFormat))
  15. def get_args() -> argparse.Namespace:
  16. """
  17. Parses and returns the arguments specified by the user's command
  18. """
  19. parser = argparse.ArgumentParser()
  20. parser.add_argument("--input", type=str, required=True, help="Input HuggingFace dataset file")
  21. parser.add_argument("--input-type", type=str, default="arrow", help="Format of the input dataset. Defaults to arrow.", choices=inputDatasetTypes)
  22. parser.add_argument("--output", type=str, required=True, help="Output file")
  23. parser.add_argument("--output-format", type=str, required=True, help="Format to convert the dataset to", choices=datasetFormats)
  24. parser.add_argument("--output-type", type=str, default="jsonl", help="Type to export the dataset to. Defaults to jsonl.", choices=outputDatasetTypes)
  25. parser.add_argument("--output-chat-system-prompt", type=str, help="The system prompt to use when the output format is chat")
  26. args = parser.parse_args()
  27. return args
  28. class DatasetFormatter(ABC):
  29. """
  30. Base class for dataset formatters. Formatters rename columns, remove and add
  31. columns to match the expected target format structure. HF, Chat or Completion models file formats.
  32. https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset
  33. """
  34. @abstractmethod
  35. def format(self, ds: Dataset, params: Dict[str, str]) -> Dataset:
  36. pass
  37. class DatasetExporter(ABC):
  38. """
  39. Base class for dataset exporters. Exporters export dataset to different file types, JSONL, Parquet, ...
  40. """
  41. @abstractmethod
  42. def export(self, ds: Dataset, output_path: str):
  43. pass
  44. class DatasetConverter():
  45. """
  46. Entry point class. It resolves which DatasetFormatter and which DatasetExporter to use and runs them.
  47. """
  48. formats: Dict[DatasetFormat, DatasetFormatter]
  49. exporters: Dict[OutputDatasetType, Any]
  50. def __init__(self) -> None:
  51. self.formats = {
  52. "hf": HuggingFaceDatasetFormatter(),
  53. "completion": OpenAiCompletionDatasetFormatter(),
  54. "chat": OpenAiChatDatasetFormatter()
  55. }
  56. self.exporters = {
  57. "parquet": ParquetDatasetExporter(),
  58. "jsonl": JsonlDatasetExporter()
  59. }
  60. def convert(self, ds: Dataset, format: DatasetFormat, output_path: str, output_type: OutputDatasetType, params: Dict[str, str]):
  61. if not format in self.formats:
  62. raise Exception(f"Output Format {format} is not supported, pleased select one of {self.formats.keys()}")
  63. if not output_type in self.exporters:
  64. raise Exception(f"Output Type {output_type} is not supported, pleased select one of {self.exporters.keys()}")
  65. formatter = self.formats[format]
  66. newds = formatter.format(ds, params)
  67. exporter = self.exporters[output_type]
  68. exporter.export(newds, output_path)
  69. class HuggingFaceDatasetFormatter(DatasetFormatter):
  70. """
  71. Returns the HuggingFace Dataset as is
  72. """
  73. def format(self, ds: Dataset, params: Dict[str, str]) -> Dataset:
  74. return ds
  75. def _remove_all_columns_but(ds: Dataset, keep_columns) -> Dataset:
  76. """
  77. HF Dataset doesn't have a way to copy only specific columns of a Dataset so this help
  78. removes all columns but the ones specified.
  79. """
  80. remove_columns = list(ds.column_names)
  81. for keep in keep_columns:
  82. remove_columns.remove(keep)
  83. ds = ds.remove_columns(remove_columns)
  84. return ds
  85. class OpenAiCompletionDatasetFormatter(DatasetFormatter):
  86. """
  87. Returns the Dataset in the OpenAI Completion Fine-tuning file format with two fields "prompt" and "completion".
  88. https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset
  89. """
  90. def format(self, ds: Dataset, params: Dict[str, str]) -> Dataset:
  91. newds = ds.rename_columns({'question': 'prompt', 'cot_answer': 'completion'})
  92. return _remove_all_columns_but(newds, ['prompt', 'completion'])
  93. class OpenAiChatDatasetFormatter(OpenAiCompletionDatasetFormatter):
  94. """
  95. Returns the Dataset in the OpenAI Chat Fine-tuning file format with one field "messages".
  96. https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset
  97. """
  98. def format(self, ds: Dataset, params: Dict[str, str]) -> Dataset:
  99. newds = super().format(ds, params)
  100. def format_messages(row):
  101. messages = []
  102. if 'system_prompt' in params:
  103. system_prompt = params['system_prompt']
  104. messages.append({ "role": "system", "content": system_prompt})
  105. messages.extend([{ "role": "user", "content": row['prompt']}, { "role": "assistant", "content": row['completion']}])
  106. chat_row = {"messages": messages}
  107. return chat_row
  108. newds = newds.map(format_messages)
  109. return _remove_all_columns_but(newds, ['messages'])
  110. def append_extension(path: str, extension: str) -> str:
  111. suffix = "." + extension
  112. if not path.endswith(suffix):
  113. path = path + suffix
  114. return path
  115. class JsonlDatasetExporter(DatasetExporter):
  116. """
  117. Exports the Dataset to a JSONL file
  118. """
  119. def export(self, ds: Dataset, output_path: str):
  120. ds.to_json(append_extension(output_path, "jsonl"))
  121. class ParquetDatasetExporter(DatasetExporter):
  122. """
  123. Exports the Dataset to a Parquet file
  124. """
  125. def export(self, ds: Dataset, output_path: str):
  126. ds.to_parquet(append_extension(output_path, "parquet"))
  127. def main():
  128. """
  129. When raft.py is executed from the command line.
  130. """
  131. args = get_args()
  132. ds = load_dataset(args.input_type, data_files={"train": args.input})['train']
  133. formatter = DatasetConverter()
  134. if args.output_chat_system_prompt and args.output_format != "chat":
  135. raise Exception("Parameter --output-chat-system-prompt can only be used with --output-format chat")
  136. format_params = {}
  137. if args.output_chat_system_prompt:
  138. format_params['system_prompt'] = args.output_chat_system_prompt
  139. formatter.convert(ds=ds, format=args.output_format, output_path=args.output, output_type=args.output_type, params=format_params)
  140. if __name__ == "__main__":
  141. main()