|
@@ -6,6 +6,7 @@ import json
|
|
|
import os
|
|
|
from pathlib import Path
|
|
|
from typing import Any, Dict, List, Optional, Union
|
|
|
+import pandas as pd
|
|
|
|
|
|
# Try to import yaml, but don't fail if it's not available
|
|
|
try:
|
|
@@ -17,7 +18,7 @@ except ImportError:
|
|
|
|
|
|
# Try to import datasets, but don't fail if it's not available
|
|
|
try:
|
|
|
- from datasets import load_dataset, load_from_disk
|
|
|
+ from datasets import load_dataset, load_from_disk, Dataset
|
|
|
|
|
|
HAS_DATASETS = True
|
|
|
except ImportError:
|
|
@@ -94,9 +95,15 @@ def load_data(data_path: str, is_local: bool = False, **kwargs):
|
|
|
if not data_path:
|
|
|
raise ValueError("data_path must be provided")
|
|
|
|
|
|
+ dataset = None
|
|
|
if is_local:
|
|
|
# Load from local disk
|
|
|
- dataset = load_from_disk(data_path)
|
|
|
+ file_extension = Path(data_path).suffix.lower()
|
|
|
+ if file_extension in [".csv"]:
|
|
|
+ data = pd.read_csv(data_path)
|
|
|
+ dataset = Dataset.from_pandas(data)
|
|
|
+ else:
|
|
|
+ dataset = load_from_disk(data_path)
|
|
|
else:
|
|
|
# Load from Hugging Face Hub
|
|
|
dataset = load_dataset(data_path, **kwargs)
|
|
@@ -132,24 +139,10 @@ def get_formatter(formatter_type: str) -> Formatter:
|
|
|
return formatter_map[formatter_type.lower()]()
|
|
|
|
|
|
|
|
|
-def get_image_path(img: str) -> str:
|
|
|
- """
|
|
|
- Get the image path from the image URL.
|
|
|
-
|
|
|
- Args:
|
|
|
- img: Image URL
|
|
|
-
|
|
|
- Returns:
|
|
|
- str: Image path
|
|
|
- """
|
|
|
-
|
|
|
- img_name = img.split("/")[-2]
|
|
|
- img_id = img.split("/")[-1]
|
|
|
- img_dir = "/home/yashkhare/workspace/IU-Xray/mnt/bn/haiyang-dataset-lq/medical/home/yisiyang/ysy/medical_dataset/iu_xray/iu_xray/images"
|
|
|
- return f"{img_dir}/{img_name}/{img_id}"
|
|
|
|
|
|
-
|
|
|
-def convert_to_conversations(data, column_mapping: Optional[Dict] = None):
|
|
|
+def convert_to_conversations(
|
|
|
+ data, column_mapping: Optional[Dict] = None, system_prompt: Optional[str] = None
|
|
|
+) -> List[Any]:
|
|
|
"""
|
|
|
Convert data to a list of Conversation objects.
|
|
|
|
|
@@ -187,6 +180,13 @@ def convert_to_conversations(data, column_mapping: Optional[Dict] = None):
|
|
|
# Create a new conversation
|
|
|
conversation = Conversation()
|
|
|
|
|
|
+ system_message = None
|
|
|
+ if system_prompt:
|
|
|
+ system_message = {
|
|
|
+ "role": "system",
|
|
|
+ "content": [{"type": "text", "text": system_prompt}],
|
|
|
+ }
|
|
|
+
|
|
|
# Create user content and user message
|
|
|
user_content = [
|
|
|
{"type": "text", "text": input_text},
|
|
@@ -197,16 +197,14 @@ def convert_to_conversations(data, column_mapping: Optional[Dict] = None):
|
|
|
# Handle list of images
|
|
|
for img in image:
|
|
|
if img: # Check if image path is not empty
|
|
|
- img_path = get_image_path(img)
|
|
|
user_content.append(
|
|
|
- {"type": "image_url", "image_url": {"url": img_path}}
|
|
|
+ {"type": "image_url", "image_url": {"url": img}}
|
|
|
)
|
|
|
break
|
|
|
else:
|
|
|
# Handle single image
|
|
|
- img_path = get_image_path(image)
|
|
|
user_content.append(
|
|
|
- {"type": "image_url", "image_url": {"url": img_path}}
|
|
|
+ {"type": "image_url", "image_url": {"url": image}}
|
|
|
)
|
|
|
|
|
|
user_message = {"role": "user", "content": user_content}
|
|
@@ -218,6 +216,8 @@ def convert_to_conversations(data, column_mapping: Optional[Dict] = None):
|
|
|
assistant_message = {"role": "assistant", "content": assistant_content}
|
|
|
|
|
|
# Add messages to the conversation
|
|
|
+ if system_message:
|
|
|
+ conversation.add_message(system_message)
|
|
|
conversation.add_message(user_message)
|
|
|
conversation.add_message(assistant_message)
|
|
|
|
|
@@ -303,6 +303,7 @@ def format_data(
|
|
|
formatter_type: str,
|
|
|
output_dir: str,
|
|
|
column_mapping: Optional[Dict] = None,
|
|
|
+ system_prompt: Optional[str] = None,
|
|
|
dataset_kwargs: Optional[Dict] = None,
|
|
|
):
|
|
|
"""
|
|
@@ -385,7 +386,7 @@ def format_data(
|
|
|
return formatted_data_paths, conversation_data_paths
|
|
|
|
|
|
|
|
|
-def load_and_format_data(formatter_config: Dict, output_dir: str):
|
|
|
+def load_and_format_data(data_config: Dict, output_dir: str):
|
|
|
"""
|
|
|
Load and format data based on the configuration.
|
|
|
|
|
@@ -398,23 +399,24 @@ def load_and_format_data(formatter_config: Dict, output_dir: str):
|
|
|
"""
|
|
|
|
|
|
# Extract parameters from config
|
|
|
- data_path = formatter_config.get("data_path")
|
|
|
+ data_path = data_config.get("data_path")
|
|
|
if not data_path:
|
|
|
raise ValueError(
|
|
|
"data_path must be specified in the formatter section of the config file"
|
|
|
)
|
|
|
|
|
|
- is_local = formatter_config.get("is_local", False)
|
|
|
- formatter_type = formatter_config.get("type", "torchtune")
|
|
|
- column_mapping = formatter_config.get("column_mapping")
|
|
|
- dataset_kwargs = formatter_config.get("dataset_kwargs", {})
|
|
|
+ is_local = data_config.get("is_local", False)
|
|
|
+ formatter_type = data_config.get("formatter_type", "torchtune")
|
|
|
+ column_mapping = data_config.get("column_mapping")
|
|
|
+ dataset_kwargs = data_config.get("dataset_kwargs", {})
|
|
|
+ system_prompt = data_config.get("system_prompt", None)
|
|
|
|
|
|
# Load the data
|
|
|
data = load_data(data_path, is_local, **dataset_kwargs)
|
|
|
|
|
|
# Format the data
|
|
|
formatted_data_paths, conversation_data_paths = format_data(
|
|
|
- data, formatter_type, output_dir, column_mapping, dataset_kwargs
|
|
|
+ data, formatter_type, output_dir, column_mapping, system_prompt, dataset_kwargs
|
|
|
)
|
|
|
|
|
|
return formatted_data_paths, conversation_data_paths
|
|
@@ -437,10 +439,10 @@ if __name__ == "__main__":
|
|
|
|
|
|
# Read the configuration
|
|
|
config = read_config(args.config)
|
|
|
- formatter_config = config.get("formatter", {})
|
|
|
+ data_config = config.get("data", {})
|
|
|
output_dir = config.get("output_dir", "/tmp/finetune-pipeline/data/")
|
|
|
output_data_dir = os.path.join(output_dir, "data")
|
|
|
# Load and format the data
|
|
|
formatted_data_paths, conversation_data_paths = load_and_format_data(
|
|
|
- formatter_config, output_data_dir
|
|
|
+ data_config, output_data_dir
|
|
|
)
|