Ver código fonte

first commit

Ubuntu 3 meses atrás
pai
commit
81ceb2b62b

+ 0 - 0
src/finetune_pipeline/augmentation.py


+ 83 - 0
src/finetune_pipeline/formatter.py

@@ -0,0 +1,83 @@
+class Message:
+    """
+    Data class representing a message with a role and content.
+    """
+
+    def __init__(self, role, content,content_type):
+        self.role = role
+        self.content = content
+        self.type = content_type
+
+
+class Formatter:
+    """
+    Base class for formatters that convert messages to different formats.
+    """
+
+    def __init__(self, message):
+        self.message = message
+
+    def format_data(self, data):
+        """
+        Format the message. This method should be overridden by subclasses.
+        """
+        raise NotImplementedError("Subclasses must implement format()")
+
+    def format_sample(self, sample):
+        """
+        Format a sample. This method should be overridden by subclasses.
+        """
+        raise NotImplementedError("Subclasses must implement format_sample()")
+
+    def read_data(self, data):
+        """
+        Format a sample. This method should be overridden by subclasses.
+        """
+        raise NotImplementedError("Subclasses must implement format_sample()")
+
+
+class TorchtuneFormatter(Formatter):
+    """
+    Formatter for Torchtune format.
+    """
+
+    data = None
+
+    def read_data(self, data):
+        """
+        Format a sample. This method should be overridden by subclasses.
+        """
+        raise NotImplementedError("Subclasses must implement format_sample()")
+
+    def format(self):
+        """
+        Format the message in Torchtune format.
+        """
+        # Implementation for Torchtune format
+        return {"role": self.message.role, "content": self.message.content}
+
+
+class vLLMFormatter(Formatter):
+    """
+    Formatter for vLLM format.
+    """
+
+    def format(self):
+        """
+        Format the message in vLLM format.
+        """
+        # Implementation for vLLM format
+        return f"{self.message.role}: {self.message.content}"
+
+
+class OpenAIFormatter(Formatter):
+    """
+    Formatter for Hugging Face format.
+    """
+
+    def format(self):
+        """
+        Format the message in Hugging Face format.
+        """
+        # Implementation for OpenAI format
+        raise NotImplementedError("Subclasses must implement format_sample()")