Browse Source

updated functional formatter

khare19yash 2 tháng trước cách đây
mục cha
commit
71c7ece848

+ 4 - 5
src/finetune_pipeline/config.yaml

@@ -2,19 +2,18 @@
 output_dir: "/tmp/finetuning-pipeline/llama3_2_vision/"  # Directory to store output files
 
 data:
-  dataset_id: "data/path"  # Path to the dataset to load
-  is_local: true           # Whether the data is stored locally
+  dataset_id: "itsanmolgupta/mimic-cxr-dataset-cleaned"  # Path to the dataset to load
+  is_local: false           # Whether the data is stored locally
   # formatter_type: "vllm"            # Type of formatter to use ('torchtune', 'vllm', or
   system_prompt: "You are a helpful assisstant"  # System prompt to use for the dataset
   # TODO: Key should be old name, value should be new name. Ref: https://huggingface.co/docs/datasets/v4.0.0/en/package_reference/main_classes#datasets.Dataset.rename_columns
   column_mapping:
-    input: "instruction"             # Field containing the input text
-    output: "output"              # Field containing the output text
+    input: "findings"             # Field containing the input text
+    output: "impression"              # Field containing the output text
     image: "image"           # Field containing the image path (optional)
   # Additional arguments to pass to the load_dataset function
   dataset_kwargs:
     split: "validation"                # Dataset split to load
-    shuffle: false                 # Whether to shuffle the dataset
 
 # Training configuration
 finetuning:

+ 45 - 5
src/finetune_pipeline/data/data_loader.py

@@ -35,12 +35,40 @@ except ImportError:
         raise ImportError("The 'datasets' package is required to load data.")
 
 
-def image_to_base64(image: Union[str, Image.Image]):
+def is_base64_encoded(s: str) -> bool:
+    """Check if a string is already base64 encoded."""
+    try:
+        # Basic character check - base64 only contains these characters
+        if not all(
+            c in "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/="
+            for c in s
+        ):
+            return False
+
+        # Try to decode - if it fails, it's not valid base64
+        decoded = base64.b64decode(s, validate=True)
+
+        # Re-encode and compare - if they match, it was valid base64
+        re_encoded = base64.b64encode(decoded).decode("utf-8")
+        return s == re_encoded or s == re_encoded.rstrip(
+            "="
+        )  # Handle padding differences
+    except Exception:
+        return False
+
+
+def image_to_base64(image: Union[str, list, Image.Image]):
     if isinstance(image, str):
+        # Check if the string is already base64 encoded
+        if is_base64_encoded(image):
+            return image
+        # Otherwise, treat it as a file path
         with open(image, "rb") as img:
             return base64.b64encode(img.read()).decode("utf-8")
     elif isinstance(image, Image.Image):
         return base64.b64encode(image.tobytes()).decode("utf-8")
+    elif isinstance(image, list):
+        return [image_to_base64(img) for img in image]
 
 
 def read_config(config_path: str) -> Dict:
@@ -78,7 +106,7 @@ def read_config(config_path: str) -> Dict:
     return config
 
 
-def load_dataset(
+def load_data(
     data_path: str,
     is_local: bool = False,
     column_mapping: Optional[Dict] = None,
@@ -124,10 +152,15 @@ def load_dataset(
     if column_mapping is None:
         column_mapping = {"input": "input", "output": "output", "image": "image"}
 
+    ## change column mapping
     required_fields = ["input", "output"]
     for field in required_fields:
         if field not in column_mapping:
             raise ValueError(f"Column mapping must include '{field}' field")
+
+    print(f"Column Mapping: {column_mapping}")
+
+    ## switch the key:val of column_mapping for renaming
     dataset = dataset.rename_columns(column_mapping)
 
     return dataset
@@ -165,7 +198,12 @@ def convert_to_encoded_messages(
             image = [image]
         for img in image:
             b64_img = image_to_base64(img)
-            user_content.append({"type": "image_url", "image_url": {"url": b64_img}})
+            user_content.append(
+                {
+                    "type": "image_url",
+                    "image_url": {"url": f"data:image/jpg;base64,{b64_img}"},
+                }
+            )
 
     messages.append({"role": "user", "content": user_content})
 
@@ -242,6 +280,8 @@ def get_hf_dataset(
     """
 
     # If config_path is provided, load from config file
+
+    dataset_kwargs = {}
     if config_path:
         config = read_config(config_path)
         output_dir = config.get("output_dir", "/tmp/finetuning-pipeline/outputs")
@@ -256,7 +296,7 @@ def get_hf_dataset(
     else:
         # Use individual parameters passed to the function
         if dataset_kwargs is None:
-            dataset_kwargs = {}
+            dataset_kwargs = {"split": "train"}
 
     # Validate required parameters
     if not dataset_id:
@@ -265,7 +305,7 @@ def get_hf_dataset(
         )
 
     # Load the dataset
-    dataset = load_dataset(
+    dataset = load_data(
         data_path=dataset_id,
         is_local=is_local,
         column_mapping=column_mapping,

+ 50 - 83
src/finetune_pipeline/data/formatter_functional.py

@@ -1,10 +1,24 @@
-import base64
-from typing import Dict, List
+from typing import Dict, List, Union
 
 
-def image_to_base64(image_path):
-    with open(image_path, "rb") as img:
-        return base64.b64encode(img.read()).decode("utf-8")
+def format_message_torchtune(message: Dict) -> Dict:
+    """Format a message in Torchtune format."""
+    return message
+
+
+def format_message_openai(message: Dict) -> Dict:
+    """Format a message in OpenAI format."""
+    contents = []
+    for content in message["content"]:
+        if content["type"] == "text":
+            contents.append({"type": "input_text", "text": content["text"]})
+        elif content["type"] == "image_url":
+            contents.append(
+                {"type": "input_image", "image_url": content["image_url"]["url"]}
+            )
+        else:
+            raise ValueError(f"Unknown content type: {content['type']}")
+    return {"role": message["role"], "content": contents}
 
 
 def format_message_vllm(message: Dict) -> Dict:
@@ -16,10 +30,11 @@ def format_message_vllm(message: Dict) -> Dict:
         if content["type"] == "text":
             contents.append(content)
         elif content["type"] == "image_url" or content["type"] == "image":
-            base64_image = image_to_base64(content["image_url"]["url"])
             img_content = {
                 "type": "image_url",
-                "image_url": {"url": f"data:image/jpg;base64,{base64_image}"},
+                "image_url": {
+                    "url": f"data:image/jpg;base64,{content["image_url"]["url"]}"
+                },
             }
             contents.append(img_content)
         else:
@@ -29,85 +44,37 @@ def format_message_vllm(message: Dict) -> Dict:
     return vllm_message
 
 
-def format_conversation_vllm(conversation) -> Dict:
-    """Format a conversation in vLLM format."""
-    formatted_messages = []
-    for message in conversation.messages:
-        role = message["role"]
-        if role != "assistant":
-            formatted_messages.append(format_message_vllm(message))
-    return {"messages": formatted_messages}
-
-
-# TODO: Remove
-def format_conversation_openai(conversation) -> Dict:
-    """Format a conversation in OpenAI format."""
-    formatted_messages = []
-    for message in conversation.messages:
-        formatted_messages.append(format_message_openai(message))
-    return {"messages": formatted_messages}
-
-
-# TODO: Remove
-def format_data_torchtune(data: List[Conversation]) -> List[Dict]:
-    """Format data in Torchtune format."""
-    if data is None:
-        raise ValueError("No data provided to format_data()")
-
-    return [format_conversation_torchtune(conversation) for conversation in data]
-
-
-def format_data_vllm(data: List[Conversation]) -> List[Dict]:
-    """Format data in vLLM format."""
-    if data is None:
-        raise ValueError("No data provided to format_data()")
-
-    return [format_conversation_vllm(conversation) for conversation in data]
-
-
-# TODO: Remove
-def format_data_openai(data: List[Conversation]) -> List[Dict]:
-    """Format data in OpenAI format."""
-    if data is None:
-        raise ValueError("No data provided to format_data()")
-
-    return [format_conversation_openai(conversation) for conversation in data]
-
-
-# Dictionary to map format names to functions for easy dispatch
-FORMATTERS = {
-    "torchtune": {
-        "data": format_data_torchtune,
-        "conversation": format_conversation_torchtune,
-        "message": format_message_torchtune,
-    },
-    "vllm": {
-        "data": format_data_vllm,
-        "conversation": format_conversation_vllm,
-        "message": format_message_vllm,
-    },
-    "openai": {
-        "data": format_data_openai,
-        "conversation": format_conversation_openai,
-        "message": format_message_openai,
-    },
-}
-
-
-def format_data(data: List[Conversation], format_type: str) -> List[Dict]:
+def apply_format(data: Union[List[Dict], List[List[Dict]]], format_func) -> List[Dict]:
     """
-    Generic function to format data in the specified format.
+    Apply the format function to the data.
 
     Args:
-        data: List of Conversation objects
-        format_type: One of "torchtune", "vllm", "openai"
+        data: Either a list of message dictionaries or a list of conversations
+              (where each conversation is a list of message dictionaries)
+        format_func: Function that formats a single message dictionary
 
     Returns:
-        List of formatted data
+        List of formatted dictionaries
     """
-    if format_type not in FORMATTERS:
-        raise ValueError(
-            f"Unknown format type: {format_type}. Supported: {list(FORMATTERS.keys())}"
-        )
-
-    return FORMATTERS[format_type]["data"](data)
+    if not data:
+        return []
+
+    # Check if data is a list of conversations (list of lists) or a list of messages
+    if isinstance(data[0], list):
+        # data is a list of conversations, each conversation is a list of messages
+        formatted_conversations = []
+        for conversation in data:
+            formatted_messages = []
+            for message in conversation:
+                formatted_message = format_func(message)
+                formatted_messages.append(formatted_message)
+            # Return the conversation as a dictionary with "messages" key
+            formatted_conversations.append({"messages": formatted_messages})
+        return formatted_conversations
+    else:
+        # data is a list of messages
+        formatted_messages = []
+        for message in data:
+            formatted_message = format_func(message)
+            formatted_messages.append(formatted_message)
+        return formatted_messages