data_loader.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. """
  2. Data loader module for loading and formatting data from Hugging Face.
  3. """
  4. import json
  5. import os
  6. from pathlib import Path
  7. from typing import Dict, Optional
  8. # Try to import yaml, but don't fail if it's not available
  9. try:
  10. import yaml
  11. HAS_YAML = True
  12. except ImportError:
  13. HAS_YAML = False
  14. # Try to import datasets, but don't fail if it's not available
  15. try:
  16. from datasets import load_dataset, load_from_disk
  17. HAS_DATASETS = True
  18. except ImportError:
  19. HAS_DATASETS = False
  20. # Define dummy functions to avoid "possibly unbound" errors
  21. def load_dataset(*args, **kwargs):
  22. raise ImportError("The 'datasets' package is required to load data.")
  23. def load_from_disk(*args, **kwargs):
  24. raise ImportError("The 'datasets' package is required to load data.")
  25. from .formatter import Formatter, OpenAIFormatter, TorchtuneFormatter, vLLMFormatter
  26. def read_config(config_path: str) -> Dict:
  27. """
  28. Read the configuration file (supports both JSON and YAML formats).
  29. Args:
  30. config_path: Path to the configuration file
  31. Returns:
  32. dict: Configuration parameters
  33. Raises:
  34. ValueError: If the file format is not supported
  35. ImportError: If the required package for the file format is not installed
  36. """
  37. file_extension = Path(config_path).suffix.lower()
  38. with open(config_path, "r") as f:
  39. if file_extension in [".json"]:
  40. config = json.load(f)
  41. elif file_extension in [".yaml", ".yml"]:
  42. if not HAS_YAML:
  43. raise ImportError(
  44. "The 'pyyaml' package is required to load YAML files. "
  45. "Please install it with 'pip install pyyaml'."
  46. )
  47. # Only use yaml if it's available (HAS_YAML is True here)
  48. import yaml # This import will succeed because we've already checked HAS_YAML
  49. config = yaml.safe_load(f)
  50. else:
  51. raise ValueError(
  52. f"Unsupported config file format: {file_extension}. "
  53. f"Supported formats are: .json, .yaml, .yml"
  54. )
  55. return config
  56. def load_data(data_path: str, is_local: bool = False, **kwargs):
  57. """
  58. Load data from Hugging Face Hub or local disk.
  59. Args:
  60. data_path: Path to the dataset (either a Hugging Face dataset ID or a local path)
  61. is_local: Whether the data is stored locally
  62. **kwargs: Additional arguments to pass to the load_dataset function
  63. Returns:
  64. Dataset object from the datasets library
  65. Raises:
  66. ImportError: If the datasets package is not installed
  67. ValueError: If data_path is None or empty
  68. """
  69. if not HAS_DATASETS:
  70. raise ImportError(
  71. "The 'datasets' package is required to load data. "
  72. "Please install it with 'pip install datasets'."
  73. )
  74. if not data_path:
  75. raise ValueError("data_path must be provided")
  76. if is_local:
  77. # Load from local disk
  78. dataset = load_from_disk(data_path)
  79. else:
  80. # Load from Hugging Face Hub
  81. dataset = load_dataset(data_path, **kwargs)
  82. return dataset
  83. def get_formatter(formatter_type: str) -> Formatter:
  84. """
  85. Get the appropriate formatter based on the formatter type.
  86. Args:
  87. formatter_type: Type of formatter to use ('torchtune', 'vllm', or 'openai')
  88. Returns:
  89. Formatter: Formatter instance
  90. Raises:
  91. ValueError: If the formatter type is not supported
  92. """
  93. formatter_map = {
  94. "torchtune": TorchtuneFormatter,
  95. "vllm": vLLMFormatter,
  96. "openai": OpenAIFormatter,
  97. }
  98. if formatter_type.lower() not in formatter_map:
  99. raise ValueError(
  100. f"Unsupported formatter type: {formatter_type}. "
  101. f"Supported types are: {', '.join(formatter_map.keys())}"
  102. )
  103. return formatter_map[formatter_type.lower()]()
  104. def convert_to_conversations(data, column_mapping: Optional[Dict] = None):
  105. """
  106. Convert data to a list of Conversation objects.
  107. Args:
  108. data: Data to convert
  109. column_mapping: Optional mapping of column names
  110. Returns:
  111. list: List of Conversation objects
  112. """
  113. # Import here to avoid circular imports
  114. from .formatter import Conversation
  115. # Default column mapping if none provided
  116. if column_mapping is None:
  117. column_mapping = {"input": "input", "output": "output", "image": "image"}
  118. # Validate column mapping
  119. required_fields = ["input", "output"]
  120. for field in required_fields:
  121. if field not in column_mapping:
  122. raise ValueError(f"Column mapping must include '{field}' field")
  123. conversations = []
  124. for item in data:
  125. # Extract fields from the dataset item using the column mapping
  126. image_field = column_mapping.get("image")
  127. input_field = column_mapping.get("input")
  128. output_field = column_mapping.get("output")
  129. image = item.get(image_field, None) if image_field else None
  130. input_text = item.get(input_field, "")
  131. output_label = item.get(output_field, "")
  132. # Create a new conversation
  133. conversation = Conversation()
  134. # Create user content and user message
  135. user_content = [
  136. {"type": "text", "text": input_text},
  137. ]
  138. # Add image to user content
  139. if image is not None:
  140. user_content.append({"type": "image", "image_url": {"url": image}})
  141. user_message = {"role": "user", "content": user_content}
  142. # Create assistant message with text content
  143. assistant_content = [
  144. {"type": "text", "text": output_label},
  145. ]
  146. assistant_message = {"role": "assistant", "content": assistant_content}
  147. # Add messages to the conversation
  148. conversation.add_message(user_message)
  149. conversation.add_message(assistant_message)
  150. # Add the conversation to the list
  151. conversations.append(conversation)
  152. return conversations
  153. def format_data(data, formatter_type: str, column_mapping: Optional[Dict] = None):
  154. """
  155. Format the data using the specified formatter.
  156. Args:
  157. data: Data to format
  158. formatter_type: Type of formatter to use ('torchtune', 'vllm', or 'openai')
  159. column_mapping: Optional mapping of column names
  160. Returns:
  161. Formatted data in the specified format
  162. """
  163. # First convert the data to conversations
  164. conversations = convert_to_conversations(data, column_mapping)
  165. # Then get the formatter and format the conversations
  166. formatter = get_formatter(formatter_type)
  167. formatted_data = formatter.format_data(conversations)
  168. return formatted_data
  169. def load_and_format_data(config_path: str):
  170. """
  171. Load and format data based on the configuration.
  172. Args:
  173. config_path: Path to the configuration file
  174. Returns:
  175. Formatted data in the specified format
  176. """
  177. # Read the configuration
  178. config = read_config(config_path)
  179. # Extract parameters from config
  180. data_path = config.get("data_path")
  181. if not data_path:
  182. raise ValueError("data_path must be specified in the config file")
  183. is_local = config.get("is_local", False)
  184. formatter_type = config.get("formatter_type", "torchtune")
  185. column_mapping = config.get("column_mapping")
  186. dataset_kwargs = config.get("dataset_kwargs", {})
  187. # Load the data
  188. data = load_data(data_path, is_local, **dataset_kwargs)
  189. # Format the data
  190. formatted_data = format_data(data, formatter_type, column_mapping)
  191. return formatted_data
  192. if __name__ == "__main__":
  193. # Example usage
  194. import argparse
  195. parser = argparse.ArgumentParser(
  196. description="Load and format data from Hugging Face"
  197. )
  198. parser.add_argument(
  199. "--config",
  200. type=str,
  201. required=True,
  202. help="Path to the configuration file (JSON or YAML)",
  203. )
  204. args = parser.parse_args()
  205. formatted_data = load_and_format_data(args.config)
  206. print(f"Loaded and formatted data: {len(formatted_data)} samples")