Kaynağa Gözat

moved read_data to data loader

khare19yash 3 ay önce
ebeveyn
işleme
0b0bfff19b

+ 68 - 5
src/finetune_pipeline/data/data_loader.py

@@ -135,6 +135,70 @@ def get_formatter(formatter_type: str) -> Formatter:
     return formatter_map[formatter_type.lower()]()
 
 
+def convert_to_conversations(data, column_mapping: Optional[Dict] = None):
+    """
+    Convert data to a list of Conversation objects.
+
+    Args:
+        data: Data to convert
+        column_mapping: Optional mapping of column names
+
+    Returns:
+        list: List of Conversation objects
+    """
+    # Import here to avoid circular imports
+    from .formatter import Conversation
+
+    # Default column mapping if none provided
+    if column_mapping is None:
+        column_mapping = {"input": "input", "output": "output", "image": "image"}
+
+    # Validate column mapping
+    required_fields = ["input", "output"]
+    for field in required_fields:
+        if field not in column_mapping:
+            raise ValueError(f"Column mapping must include '{field}' field")
+
+    conversations = []
+    for item in data:
+        # Extract fields from the dataset item using the column mapping
+        image_field = column_mapping.get("image")
+        input_field = column_mapping.get("input")
+        output_field = column_mapping.get("output")
+
+        image = item.get(image_field, None) if image_field else None
+        input_text = item.get(input_field, "")
+        output_label = item.get(output_field, "")
+
+        # Create a new conversation
+        conversation = Conversation()
+
+        # Create user content and user message
+        user_content = [
+            {"type": "text", "text": input_text},
+        ]
+        # Add image to user content
+        if image is not None:
+            user_content.append({"type": "image", "image_url": {"url": image}})
+
+        user_message = {"role": "user", "content": user_content}
+
+        # Create assistant message with text content
+        assistant_content = [
+            {"type": "text", "text": output_label},
+        ]
+        assistant_message = {"role": "assistant", "content": assistant_content}
+
+        # Add messages to the conversation
+        conversation.add_message(user_message)
+        conversation.add_message(assistant_message)
+
+        # Add the conversation to the list
+        conversations.append(conversation)
+
+    return conversations
+
+
 def format_data(data, formatter_type: str, column_mapping: Optional[Dict] = None):
     """
     Format the data using the specified formatter.
@@ -147,12 +211,11 @@ def format_data(data, formatter_type: str, column_mapping: Optional[Dict] = None
     Returns:
         Formatted data in the specified format
     """
-    formatter = get_formatter(formatter_type)
+    # First convert the data to conversations
+    conversations = convert_to_conversations(data, column_mapping)
 
-    # Read the data and convert to conversations
-    conversations = formatter.read_data(data, column_mapping)
-
-    # Format the conversations
+    # Then get the formatter and format the conversations
+    formatter = get_formatter(formatter_type)
     formatted_data = formatter.format_data(conversations)
 
     return formatted_data

+ 19 - 116
src/finetune_pipeline/data/formatter.py

@@ -106,68 +106,7 @@ class Formatter(ABC):
         """
         pass
 
-    def read_data(self, data, column_mapping=None):
-        """
-        Read Hugging Face data and convert it to a list of Conversation objects.
-
-        Args:
-            data: Hugging Face dataset or iterable data
-            column_mapping: Optional dictionary mapping custom column names to standard field names.
-                            Expected keys are 'input', 'output', and 'image'.
-                            Example: {'input': 'question', 'output': 'answer', 'image': 'image_path'}
-                            If None, default field names ('input', 'output', 'image') will be used.
-
-        Returns:
-            list: List of Conversation objects, each containing a list of Message objects
-        """
-        # Default column mapping if none provided
-        if column_mapping is None:
-            column_mapping = {"input": "input", "output": "output", "image": "image"}
-
-        # Validate column mapping
-        required_fields = ["input", "output"]
-        for field in required_fields:
-            if field not in column_mapping:
-                raise ValueError(f"Column mapping must include '{field}' field")
-
-        conversations = []
-        for item in data:
-            # Extract fields from the Hugging Face dataset item using the column mapping
-            image_field = column_mapping.get("image")
-            input_field = column_mapping.get("input")
-            output_field = column_mapping.get("output")
-
-            image = item.get(image_field, None) if image_field else None
-            input_text = item.get(input_field, "")
-            output_label = item.get(output_field, "")
-
-            # Create a new conversation
-            conversation = Conversation()
-
-            # Create user content and user message
-            user_content = [
-                {"type": "text", "text": input_text},
-            ]
-            # Add image to user content
-            if image is not None:
-                user_content.append({"type": "image", "image_url": {"url": image}})
-
-            user_message = {"role": "user", "content": user_content}
-
-            # Create assistant message with text content
-            assistant_content = [
-                {"type": "text", "text": output_label},
-            ]
-            assistant_message = {"role": "assistant", "content": assistant_content}
-
-            # Add messages to the conversation
-            conversation.add_message(user_message)
-            conversation.add_message(assistant_message)
-
-            # Add the conversation to the list
-            conversations.append(conversation)
-
-        return conversations
+    # The read_data function has been moved to convert_to_conversations in data_loader.py
 
 
 class TorchtuneFormatter(Formatter):
@@ -175,39 +114,27 @@ class TorchtuneFormatter(Formatter):
     Formatter for Torchtune format.
     """
 
-    def __init__(self, data=None):
+    def __init__(self):
         """
         Initialize the formatter.
-
-        Args:
-            data: Optional data to initialize with
         """
         super().__init__()
-        self.conversation_data = None
-        if data is not None:
-            self.conversation_data = self.read_data(data)
 
-    def format_data(self, data=None):
+    def format_data(self, data):
         """
         Format the data.
 
         Args:
-            data: List of Conversation objects. If None, uses the conversation_data stored during initialization.
+            data: List of Conversation objects.
 
         Returns:
             list: List of formatted data
         """
-        # Use stored conversation_data if no data is provided
-        data_to_format = data if data is not None else self.conversation_data
-
-        # Check if we have data to format
-        if data_to_format is None:
-            raise ValueError(
-                "No data to format. Either provide data to format_data() or initialize with data."
-            )
+        if data is None:
+            raise ValueError("No data provided to format_data()")
 
         formatted_data = []
-        for conversation in data_to_format:
+        for conversation in data:
             formatted_data.append(self.format_conversation(conversation))
         return formatted_data
 
@@ -246,39 +173,27 @@ class vLLMFormatter(Formatter):
     Formatter for vLLM format.
     """
 
-    def __init__(self, data=None):
+    def __init__(self):
         """
         Initialize the formatter.
-
-        Args:
-            data: Optional data to initialize with
         """
         super().__init__()
-        self.conversation_data = None
-        if data is not None:
-            self.conversation_data = self.read_data(data)
 
-    def format_data(self, data=None):
+    def format_data(self, data):
         """
         Format the data.
 
         Args:
-            data: List of Conversation objects. If None, uses the conversation_data stored during initialization.
+            data: List of Conversation objects.
 
         Returns:
             list: List of formatted data in vLLM format
         """
-        # Use stored conversation_data if no data is provided
-        data_to_format = data if data is not None else self.conversation_data
-
-        # Check if we have data to format
-        if data_to_format is None:
-            raise ValueError(
-                "No data to format. Either provide data to format_data() or initialize with data."
-            )
+        if data is None:
+            raise ValueError("No data provided to format_data()")
 
         formatted_data = []
-        for conversation in data_to_format:
+        for conversation in data:
             formatted_data.append(self.format_conversation(conversation))
         return formatted_data
 
@@ -327,39 +242,27 @@ class OpenAIFormatter(Formatter):
     Formatter for OpenAI format.
     """
 
-    def __init__(self, data=None):
+    def __init__(self):
         """
         Initialize the formatter.
-
-        Args:
-            data: Optional data to initialize with
         """
         super().__init__()
-        self.conversation_data = None
-        if data is not None:
-            self.conversation_data = self.read_data(data)
 
-    def format_data(self, data=None):
+    def format_data(self, data):
         """
         Format the data.
 
         Args:
-            data: List of Conversation objects. If None, uses the conversation_data stored during initialization.
+            data: List of Conversation objects.
 
         Returns:
             dict: Formatted data in OpenAI format
         """
-        # Use stored conversation_data if no data is provided
-        data_to_format = data if data is not None else self.conversation_data
-
-        # Check if we have data to format
-        if data_to_format is None:
-            raise ValueError(
-                "No data to format. Either provide data to format_data() or initialize with data."
-            )
+        if data is None:
+            raise ValueError("No data provided to format_data()")
 
         formatted_data = []
-        for conversation in data_to_format:
+        for conversation in data:
             formatted_data.append(self.format_conversation(conversation))
         return formatted_data