|
@@ -79,7 +79,7 @@ def load_data(data_path: str, is_local: bool = False, **kwargs):
|
|
|
**kwargs: Additional arguments to pass to the load_dataset function
|
|
**kwargs: Additional arguments to pass to the load_dataset function
|
|
|
|
|
|
|
|
Returns:
|
|
Returns:
|
|
|
- Dataset object from the datasets library
|
|
|
|
|
|
|
+ Dataset object from the datasets library with all splits
|
|
|
|
|
|
|
|
Raises:
|
|
Raises:
|
|
|
ImportError: If the datasets package is not installed
|
|
ImportError: If the datasets package is not installed
|
|
@@ -174,9 +174,18 @@ def convert_to_conversations(data, column_mapping: Optional[Dict] = None):
|
|
|
user_content = [
|
|
user_content = [
|
|
|
{"type": "text", "text": input_text},
|
|
{"type": "text", "text": input_text},
|
|
|
]
|
|
]
|
|
|
- # Add image to user content
|
|
|
|
|
|
|
+ # Add image(s) to user content
|
|
|
if image is not None:
|
|
if image is not None:
|
|
|
- user_content.append({"type": "image", "image_url": {"url": image}})
|
|
|
|
|
|
|
+ if isinstance(image, list):
|
|
|
|
|
+ # Handle list of images
|
|
|
|
|
+ for img in image:
|
|
|
|
|
+ if img: # Check if image path is not empty
|
|
|
|
|
+ user_content.append(
|
|
|
|
|
+ {"type": "image", "image_url": {"url": img}}
|
|
|
|
|
+ )
|
|
|
|
|
+ else:
|
|
|
|
|
+ # Handle single image
|
|
|
|
|
+ user_content.append({"type": "image", "image_url": {"url": image}})
|
|
|
|
|
|
|
|
user_message = {"role": "user", "content": user_content}
|
|
user_message = {"role": "user", "content": user_content}
|
|
|
|
|
|
|
@@ -197,7 +206,7 @@ def convert_to_conversations(data, column_mapping: Optional[Dict] = None):
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_formatted_data(
|
|
def save_formatted_data(
|
|
|
- formatted_data: List[Any], output_dir: str, formatter_type: str
|
|
|
|
|
|
|
+ formatted_data: List[Any], output_dir: str, formatter_type: str, split: str
|
|
|
) -> str:
|
|
) -> str:
|
|
|
"""
|
|
"""
|
|
|
Save formatted data to a JSON file.
|
|
Save formatted data to a JSON file.
|
|
@@ -215,7 +224,7 @@ def save_formatted_data(
|
|
|
|
|
|
|
|
# Define the output file path
|
|
# Define the output file path
|
|
|
formatted_data_path = os.path.join(
|
|
formatted_data_path = os.path.join(
|
|
|
- output_dir, f"{formatter_type}_formatted_data.json"
|
|
|
|
|
|
|
+ output_dir, f"{split}_{formatter_type}_formatted_data.json"
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
# Save the formatted data
|
|
# Save the formatted data
|
|
@@ -237,7 +246,7 @@ def save_formatted_data(
|
|
|
return formatted_data_path
|
|
return formatted_data_path
|
|
|
|
|
|
|
|
|
|
|
|
|
-def save_conversation_data(conversation_data: List, output_dir: str) -> str:
|
|
|
|
|
|
|
+def save_conversation_data(conversation_data: List, output_dir: str, split: str) -> str:
|
|
|
"""
|
|
"""
|
|
|
Save conversation data to a JSON file.
|
|
Save conversation data to a JSON file.
|
|
|
|
|
|
|
@@ -252,7 +261,7 @@ def save_conversation_data(conversation_data: List, output_dir: str) -> str:
|
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
|
|
|
|
# Define the output file path
|
|
# Define the output file path
|
|
|
- conversation_data_path = os.path.join(output_dir, "conversation_data.json")
|
|
|
|
|
|
|
+ conversation_data_path = os.path.join(output_dir, f"{split}_conversation_data.json")
|
|
|
|
|
|
|
|
# Convert Conversation objects to a serializable format
|
|
# Convert Conversation objects to a serializable format
|
|
|
serializable_conversations = []
|
|
serializable_conversations = []
|
|
@@ -267,37 +276,103 @@ def save_conversation_data(conversation_data: List, output_dir: str) -> str:
|
|
|
return conversation_data_path
|
|
return conversation_data_path
|
|
|
|
|
|
|
|
|
|
|
|
|
-def format_data(data, formatter_type: str, column_mapping: Optional[Dict] = None):
|
|
|
|
|
|
|
+def format_data(
|
|
|
|
|
+ data,
|
|
|
|
|
+ formatter_type: str,
|
|
|
|
|
+ output_dir: str,
|
|
|
|
|
+ column_mapping: Optional[Dict] = None,
|
|
|
|
|
+ dataset_kwargs: Optional[Dict] = None,
|
|
|
|
|
+):
|
|
|
"""
|
|
"""
|
|
|
- Format the data using the specified formatter.
|
|
|
|
|
|
|
+ Format the data using the specified formatter for all splits.
|
|
|
|
|
|
|
|
Args:
|
|
Args:
|
|
|
- data: Data to format
|
|
|
|
|
|
|
+ data: Dataset with multiple splits to format or a single dataset
|
|
|
formatter_type: Type of formatter to use ('torchtune', 'vllm', or 'openai')
|
|
formatter_type: Type of formatter to use ('torchtune', 'vllm', or 'openai')
|
|
|
|
|
+ output_dir: Directory to save the formatted data
|
|
|
column_mapping: Optional mapping of column names
|
|
column_mapping: Optional mapping of column names
|
|
|
|
|
+ dataset_kwargs: Optional dataset kwargs that may contain split information
|
|
|
|
|
|
|
|
Returns:
|
|
Returns:
|
|
|
- Tuple containing formatted data and conversation data
|
|
|
|
|
|
|
+ Tuple containing (formatted_data_paths, conversation_data_paths) where each is a list of paths to saved files
|
|
|
"""
|
|
"""
|
|
|
- # First convert the data to conversations
|
|
|
|
|
- conversations = convert_to_conversations(data, column_mapping)
|
|
|
|
|
|
|
+ formatted_data_paths = []
|
|
|
|
|
+ conversation_data_paths = []
|
|
|
|
|
+
|
|
|
|
|
+ # Check if the dataset has explicit splits
|
|
|
|
|
+ if (
|
|
|
|
|
+ hasattr(data, "keys")
|
|
|
|
|
+ and callable(data.keys)
|
|
|
|
|
+ and len(data.keys()) > 0
|
|
|
|
|
+ and isinstance(data, dict)
|
|
|
|
|
+ ):
|
|
|
|
|
+ # Dataset has splits (train, validation, test, etc.)
|
|
|
|
|
+ splits = data.keys()
|
|
|
|
|
+
|
|
|
|
|
+ for split in splits:
|
|
|
|
|
+ # First convert the data to conversations
|
|
|
|
|
+ conversations = convert_to_conversations(data[split], column_mapping)
|
|
|
|
|
+
|
|
|
|
|
+ # Then get the formatter and format the conversations
|
|
|
|
|
+ formatter = get_formatter(formatter_type)
|
|
|
|
|
+ formatted_data = formatter.format_data(conversations)
|
|
|
|
|
+ print(
|
|
|
|
|
+ f"Loaded and formatted data for split '{split}': {len(formatted_data)} samples"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # Save the formatted data
|
|
|
|
|
+ formatted_data_path = save_formatted_data(
|
|
|
|
|
+ formatted_data, output_dir, formatter_type, split
|
|
|
|
|
+ )
|
|
|
|
|
+ formatted_data_paths.append(formatted_data_path)
|
|
|
|
|
+
|
|
|
|
|
+ # Save the conversation data
|
|
|
|
|
+ conversation_data_path = save_conversation_data(
|
|
|
|
|
+ conversations, output_dir, split
|
|
|
|
|
+ )
|
|
|
|
|
+ conversation_data_paths.append(conversation_data_path)
|
|
|
|
|
+ else:
|
|
|
|
|
+ # Dataset doesn't have explicit splits, treat it as a single dataset
|
|
|
|
|
+ # Check if a split is specified in dataset_kwargs
|
|
|
|
|
+ split = "default"
|
|
|
|
|
+ if dataset_kwargs and "split" in dataset_kwargs:
|
|
|
|
|
+ split = dataset_kwargs["split"]
|
|
|
|
|
+
|
|
|
|
|
+ # First convert the data to conversations
|
|
|
|
|
+ conversations = convert_to_conversations(data, column_mapping)
|
|
|
|
|
+
|
|
|
|
|
+ # Then get the formatter and format the conversations
|
|
|
|
|
+ formatter = get_formatter(formatter_type)
|
|
|
|
|
+ formatted_data = formatter.format_data(conversations)
|
|
|
|
|
+ print(
|
|
|
|
|
+ f"Loaded and formatted data for split '{split}': {len(formatted_data)} samples"
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
- # Then get the formatter and format the conversations
|
|
|
|
|
- formatter = get_formatter(formatter_type)
|
|
|
|
|
- formatted_data = formatter.format_data(conversations)
|
|
|
|
|
|
|
+ # Save the formatted data
|
|
|
|
|
+ formatted_data_path = save_formatted_data(
|
|
|
|
|
+ formatted_data, output_dir, formatter_type, split
|
|
|
|
|
+ )
|
|
|
|
|
+ formatted_data_paths.append(formatted_data_path)
|
|
|
|
|
+
|
|
|
|
|
+ # Save the conversation data
|
|
|
|
|
+ conversation_data_path = save_conversation_data(
|
|
|
|
|
+ conversations, output_dir, split
|
|
|
|
|
+ )
|
|
|
|
|
+ conversation_data_paths.append(conversation_data_path)
|
|
|
|
|
|
|
|
- return formatted_data, conversations
|
|
|
|
|
|
|
+ return formatted_data_paths, conversation_data_paths
|
|
|
|
|
|
|
|
|
|
|
|
|
-def load_and_format_data(formatter_config: Dict):
|
|
|
|
|
|
|
+def load_and_format_data(formatter_config: Dict, output_dir: str):
|
|
|
"""
|
|
"""
|
|
|
Load and format data based on the configuration.
|
|
Load and format data based on the configuration.
|
|
|
|
|
|
|
|
Args:
|
|
Args:
|
|
|
formatter_config: Dictionary containing formatter configuration parameters
|
|
formatter_config: Dictionary containing formatter configuration parameters
|
|
|
|
|
+ output_dir: Directory to save the formatted data
|
|
|
|
|
|
|
|
Returns:
|
|
Returns:
|
|
|
- Formatted data in the specified format
|
|
|
|
|
|
|
+ Tuple containing (formatted_data_paths, conversation_data_paths) where each is a list of paths to saved files
|
|
|
"""
|
|
"""
|
|
|
|
|
|
|
|
# Extract parameters from config
|
|
# Extract parameters from config
|
|
@@ -316,11 +391,11 @@ def load_and_format_data(formatter_config: Dict):
|
|
|
data = load_data(data_path, is_local, **dataset_kwargs)
|
|
data = load_data(data_path, is_local, **dataset_kwargs)
|
|
|
|
|
|
|
|
# Format the data
|
|
# Format the data
|
|
|
- formatted_data, conversation_data = format_data(
|
|
|
|
|
- data, formatter_type, column_mapping
|
|
|
|
|
|
|
+ formatted_data_paths, conversation_data_paths = format_data(
|
|
|
|
|
+ data, formatter_type, output_dir, column_mapping, dataset_kwargs
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- return formatted_data, conversation_data
|
|
|
|
|
|
|
+ return formatted_data_paths, conversation_data_paths
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
@@ -341,14 +416,9 @@ if __name__ == "__main__":
|
|
|
# Read the configuration
|
|
# Read the configuration
|
|
|
config = read_config(args.config)
|
|
config = read_config(args.config)
|
|
|
formatter_config = config.get("formatter", {})
|
|
formatter_config = config.get("formatter", {})
|
|
|
- output_dir = config.get("output_dir")
|
|
|
|
|
|
|
+ output_dir = config.get("output_dir", "/tmp/finetune-pipeline/data/")
|
|
|
|
|
|
|
|
# Load and format the data
|
|
# Load and format the data
|
|
|
- formatted_data, conversation_data = load_and_format_data(formatter_config)
|
|
|
|
|
- print(f"Loaded and formatted data: {len(formatted_data)} samples")
|
|
|
|
|
-
|
|
|
|
|
- # Save the data if output_dir is provided
|
|
|
|
|
- if output_dir:
|
|
|
|
|
- formatter_type = formatter_config.get("type", "torchtune")
|
|
|
|
|
- save_formatted_data(formatted_data, output_dir, formatter_type)
|
|
|
|
|
- save_conversation_data(conversation_data, output_dir)
|
|
|
|
|
|
|
+ formatted_data_paths, conversation_data_paths = load_and_format_data(
|
|
|
|
|
+ formatter_config, output_dir
|
|
|
|
|
+ )
|