瀏覽代碼

updated vllm formatter for images

khare19yash 1 月之前
父節點
當前提交
38ca3797bc

+ 11 - 9
src/finetune_pipeline/config.yaml

@@ -82,7 +82,7 @@ formatter:
   column_mapping:
     input: "query"             # Field containing the input text
     output: "response"              # Field containing the output text
-    image: null           # Field containing the image path (optional)
+    image: "images"           # Field containing the image path (optional)
 
   # Additional arguments to pass to the load_dataset function
   dataset_kwargs:
@@ -90,14 +90,16 @@ formatter:
 
 # Training configuration
 finetuning:
-  model_path: "/home/yashkhare/workspace/Llama-3.1-8B-Instruct" # Path to the model checkpoint
-  tokenizer_path: "/home/yashkhare/workspace/Llama-3.1-8B-Instruct/original/tokenizer.model" # Path to the tokenizer
+  #formatter_type: "torchtune"            # Type of formatter to use ('torchtune', 'vllm', or 'openai')
+  model_path: "/home/yashkhare/workspace/Llama-3.2-11B-Vision-Instruct" # Path to the model checkpoint
+  tokenizer_path: "/home/yashkhare/workspace/Llama-3.2-11B-Vision-Instruct/original/tokenizer.model" # Path to the tokenizer
   output_dir: /home/yashkhare/workspace/finetuning-pipeline/model_outputs  # Directory to store checkpoints
   log_dir: /home/yashkhare/workspace/finetuning-pipeline/logs  # Directory to store logs
   strategy: "lora"               # Training strategy ('fft' or 'lora')
   num_epochs: 1                 # Number of training epochs
-  batch_size: 4                 # Batch size per device for training
-  torchtune_config: "llama3_1/8B_lora"             # TorchTune-specific configuration
+  max_steps_per_epoch: null
+  batch_size: 8                 # Batch size per device for training
+  torchtune_config: "llama3_2_vision/11B_lora"             # TorchTune-specific configuration
   num_processes_per_node: 8             # TorchTune-specific configuration
   distributed: true             # Whether to use distributed training
 
@@ -105,7 +107,7 @@ finetuning:
 # vLLM Inference configuration
 inference:
   # Model configuration
-  model_path: "/home/yashkhare/workspace/finetuning-pipeline/model_outputs/epoch_0" # Path to the model checkpoint
+  model_path: "/home/yashkhare/workspace/Llama-3.2-11B-Vision-Instruct" # Path to the model checkpoint
   quantization: null            # Quantization method (awq, gptq, squeezellm)
   dtype: "auto"                 # Data type for model weights (half, float, bfloat16, auto)
   trust_remote_code: false      # Trust remote code when loading the model
@@ -115,8 +117,8 @@ inference:
   host: "0.0.0.0"               # Host to run the server on
 
   # Performance configuration
-  tensor_parallel_size: 1       # Number of GPUs to use for tensor parallelism
-  max_model_len: 512           # Maximum sequence length
+  tensor_parallel_size: 8       # Number of GPUs to use for tensor parallelism
+  max_model_len: 8192           # Maximum sequence length
   max_num_seqs: 1              # Maximum number of sequences
   gpu_memory_utilization: 0.95   # Fraction of GPU memory to use
   enforce_eager: false          # Enforce eager execution
@@ -126,7 +128,7 @@ inference:
     split: "validation"               # Dataset split to load
     formatter_type: "vllm"            # Type of formatter to use ('torchtune', 'vllm', or 'openai')
     format_data: true                # Whether to format the inference dataset
-    max_samples: null                 # Maximum number of samples to load (null for all)
+    max_samples: 10                 # Maximum number of samples to load (null for all)
     is_local: false                   # Whether the data is stored locally
 
   # Additional vLLM parameters (optional)

+ 24 - 2
src/finetune_pipeline/data/data_loader.py

@@ -132,6 +132,23 @@ def get_formatter(formatter_type: str) -> Formatter:
     return formatter_map[formatter_type.lower()]()
 
 
+def get_image_path(img: str) -> str:
+    """
+    Get the image path from the image URL.
+
+    Args:
+        img: Image URL
+
+    Returns:
+        str: Image path
+    """
+
+    img_name = img.split("/")[-2]
+    img_id = img.split("/")[-1]
+    img_dir = "/home/yashkhare/workspace/IU-Xray/mnt/bn/haiyang-dataset-lq/medical/home/yisiyang/ysy/medical_dataset/iu_xray/iu_xray/images"
+    return f"{img_dir}/{img_name}/{img_id}"
+
+
 def convert_to_conversations(data, column_mapping: Optional[Dict] = None):
     """
     Convert data to a list of Conversation objects.
@@ -180,12 +197,17 @@ def convert_to_conversations(data, column_mapping: Optional[Dict] = None):
                 # Handle list of images
                 for img in image:
                     if img:  # Check if image path is not empty
+                        img_path = get_image_path(img)
                         user_content.append(
-                            {"type": "image", "image_url": {"url": img}}
+                            {"type": "image_url", "image_url": {"url": img_path}}
                         )
+                        break
             else:
                 # Handle single image
-                user_content.append({"type": "image", "image_url": {"url": image}})
+                img_path = get_image_path(image)
+                user_content.append(
+                    {"type": "image_url", "image_url": {"url": img_path}}
+                )
 
         user_message = {"role": "user", "content": user_content}
 

+ 24 - 1
src/finetune_pipeline/data/formatter.py

@@ -1,7 +1,13 @@
+import base64
 from abc import ABC, abstractmethod
 from typing import Dict, List, Optional, TypedDict, Union
 
 
+def image_to_base64(image_path):
+    with open(image_path, "rb") as img:
+        return base64.b64encode(img.read()).decode("utf-8")
+
+
 class MessageContent(TypedDict, total=False):
     """Type definition for message content in LLM requests."""
 
@@ -224,7 +230,24 @@ class vLLMFormatter(Formatter):
         Returns:
             str: Formatted message in vLLM format
         """
-        return message
+        contents = []
+        vllm_message = {}
+
+        for content in message["content"]:
+            if content["type"] == "text":
+                contents.append(content["text"])
+            elif content["type"] == "image_url":
+                base64_image = image_to_base64(content["image_url"]["url"])
+                img_content = {
+                    "type": "image_url",
+                    "image_url": {"url": f"data:image/jpg;base64,{base64_image}"},
+                }
+                contents.append(img_content)
+            else:
+                raise ValueError(f"Unknown content type: {content['type']}")
+        vllm_message["role"] = message["role"]
+        vllm_message["content"] = contents
+        return vllm_message
 
 
 class OpenAIFormatter(Formatter):

+ 3 - 1
src/finetune_pipeline/finetuning/custom_sft_dataset.py

@@ -1,17 +1,19 @@
 """
 Custom SFT dataset for fine-tuning.
 """
+
 from torchtune.data import OpenAIToMessages
 from torchtune.datasets import SFTDataset
 from torchtune.modules.transforms import Transform
 
+
 def custom_sft_dataset(
     model_transform: Transform,
     *,
     dataset_path: str = "/tmp/train.json",
     train_on_input: bool = False,
     split: str = "train",
-
+    source: str = "json",
 ) -> SFTDataset:
     """
     Creates a custom SFT dataset for fine-tuning.

+ 8 - 1
src/finetune_pipeline/inference/run_inference.py

@@ -281,7 +281,9 @@ def run_vllm_batch_inference_on_dataset(
     seed: Optional[int] = None,
     structured: bool = False,
     gpu_memory_utilization: float = 0.95,
-    max_model_len: int = 4096,
+    max_model_len: int = 512,
+    max_num_seqs: int = 1,
+    tensor_parallel_size: int = 1,
 ) -> Dict[str, Any]:
     """
     Run inference on evaluation data using a vLLM server.
@@ -309,6 +311,8 @@ def run_vllm_batch_inference_on_dataset(
             gpu_memory_utilization=gpu_memory_utilization,
             max_model_len=max_model_len,
             seed=seed,
+            max_num_seqs=max_num_seqs,
+            tensor_parallel_size=tensor_parallel_size,
         )
     except Exception as e:
         logger.error(f"Failed to initialize vLLM model: {e}")
@@ -413,6 +417,7 @@ def main():
     # Performance parameters
     gpu_memory_utilization = inference_config.get("gpu_memory_utilization", 0.95)
     max_model_len = inference_config.get("max_model_len", 512)
+    max_num_seqs = inference_config.get("max_num_seqs", 1)
     tensor_parallel_size = inference_config.get("tensor_parallel_size", 1)
     dtype = inference_config.get("dtype", "auto")
     trust_remote_code = inference_config.get("trust_remote_code", False)
@@ -439,6 +444,8 @@ def main():
         structured,
         gpu_memory_utilization,
         max_model_len,
+        max_num_seqs,
+        tensor_parallel_size,
     )
 
     save_inference_results(results, results_path)

+ 1 - 1
src/finetune_pipeline/run_pipeline.py

@@ -122,7 +122,7 @@ def run_finetuning(config_path: str, formatted_data_paths: List[str]) -> str:
 
         # Get the path to the latest chekpoint of the fine-tuned model
         model_output_dir = finetuning_config.get("output_dir", config.get("output_dir"))
-        epochs = finetuning_config.get("epochs", 1)
+        epochs = finetuning_config.get("num_epochs", 1)
         checkpoint_path = os.path.join(model_output_dir, f"epoch_{epochs-1}")
         logger.info(
             f"Fine-tuning complete. Latest checkpoint saved to {checkpoint_path}"