Browse Source

W2 finetuning initial commit

Beto de Paola 1 week ago
parent
commit
ae3f0269e9

+ 100 - 0
getting-started/finetuning/vision/11B_full_w2.yaml

@@ -0,0 +1,100 @@
+# Top-level output directory
+output_dir: ./outputs/Llama-3.2-11B-Instruct-w2-full
+
+# Model
+model:
+  _component_: torchtune.models.llama3_2_vision.llama3_2_vision_11b
+  decoder_trainable: True
+  encoder_trainable: True
+  fusion_trainable: True
+  image_size: 560 # Make sure this matches the image_size in tokenizer
+
+# Tokenizer / vision transform
+tokenizer:
+  _component_: torchtune.models.llama3_2_vision.llama3_2_vision_transform
+  path: ./Llama-3.2-11B-Vision-Instruct/original/tokenizer.model
+  image_size: 560
+  max_seq_len: 8192
+
+# Checkpointing
+checkpointer:
+  _component_: torchtune.training.FullModelHFCheckpointer
+  checkpoint_dir: ./Llama-3.2-11B-Vision-Instruct
+  checkpoint_files:
+    filename_format: model-{}-of-{}.safetensors
+    max_filename: "00005"
+  recipe_checkpoint: null
+  output_dir: ${output_dir}
+  model_type: LLAMA3_VISION
+
+resume_from_checkpoint: false
+save_adapter_weights_only: False # PeFT formatting not available yet. This will save it in torchtune format only.
+
+# Dataset
+dataset:
+  _component_: torchtune.datasets.multimodal.vqa_dataset
+  source: arrow
+  data_files:
+    train: "fake_w2_us_tax_form_dataset_train30_test70/train/data-00000-of-00001.arrow"
+  split: train
+  column_map:
+    input: input
+    output: ground_truth
+    image: image
+
+# General data handling
+seed: null
+shuffle: true
+collate_fn: torchtune.data.padded_collate_tiled_images_and_mask
+
+# Training loop & hyperparams
+
+epochs: 5
+max_steps_per_epoch: null
+batch_size: 4
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+# explicit optimizer / scheduler / loss
+optimizer:
+  _component_: bitsandbytes.optim.PagedAdamW8bit
+  lr: 2e-5
+optimizer_in_bwd: False  # True saves memory. Requires gradient_accumulation_steps=1
+
+loss:
+  _component_: torchtune.modules.loss.LinearCrossEntropyLoss
+
+clip_grad_norm: 1.0
+compile: false
+
+# Device & memory
+device: cuda
+enable_activation_checkpointing: true
+dtype: bf16
+
+# Logging
+
+metric_logger:
+  _component_: torchtune.training.metric_logging.WandBLogger
+  project: llama3_2_w2_extraction
+  entity: <your_wandb_entity>
+  job_type: full_finetune_single_device
+  group: llama-cookbook
+log_every_n_steps: 5
+save_steps: 100
+log_peak_memory_stats: true
+log_level: INFO
+
+# Profiler (off by default)
+profiler:
+  _component_: torchtune.training.setup_torch_profiler
+  enabled: false
+  output_dir: ${output_dir}/profiling_outputs
+  cpu: true
+  cuda: true
+  profile_memory: false
+  with_stack: false
+  record_shapes: true
+  with_flops: false
+  wait_steps: 5
+  warmup_steps: 3
+  active_steps: 2
+  num_cycles: 1

+ 118 - 0
getting-started/finetuning/vision/11B_lora_w2.yaml

@@ -0,0 +1,118 @@
+# Top-level output directory
+output_dir: ./outputs/Llama-3.2-11B-Instruct-w2-lora-80
+
+# Model + LoRA settings
+model:
+  _component_: torchtune.models.llama3_2_vision.lora_llama3_2_vision_11b
+  # preserve your hyperparams
+  lora_rank: 8 # higher increases accuracy and memory
+  lora_alpha: 16 # usually alpha=2*rank
+  lora_dropout: 0.05
+  image_size: 560 # Make sure this matches the image_size in tokenizer
+  # example’s fixed settings
+  decoder_trainable: "frozen"
+  encoder_trainable: "lora"
+  fusion_trainable: "lora"
+  lora_attn_modules:
+    - 'q_proj'
+    - 'v_proj'
+    - 'output_proj'
+  apply_lora_to_mlp: true
+  apply_lora_to_output: false
+
+# Tokenizer / vision transform
+tokenizer:
+  _component_: torchtune.models.llama3_2_vision.llama3_2_vision_transform
+  path: ./Llama-3.2-11B-Vision-Instruct/original/tokenizer.model
+  image_size: 560
+  max_seq_len: 8192
+
+# Checkpointing
+checkpointer:
+  _component_: torchtune.training.FullModelHFCheckpointer
+  checkpoint_dir: ./Llama-3.2-11B-Vision-Instruct
+  checkpoint_files:
+    filename_format: model-{}-of-{}.safetensors
+    max_filename: "00005"
+  recipe_checkpoint: null
+  output_dir: ${output_dir}
+  model_type: LLAMA3_VISION
+
+resume_from_checkpoint: false
+save_adapter_weights_only: false # PeFT formatting not available yet. This will save it in torchtune format only.
+
+# Dataset
+dataset:
+  _component_: torchtune.datasets.multimodal.vqa_dataset
+  source: arrow
+  data_files:
+    # train: "w2_with_input/train/data-00000-of-00001.arrow"
+    train: "fake_w2_us_tax_form_dataset_train80_test20/train/data-00000-of-00001.arrow"
+  split: train
+  column_map:
+    input: input
+    output: ground_truth
+    image: image
+
+# General data handling
+seed: null
+shuffle: true
+collate_fn: torchtune.data.padded_collate_tiled_images_and_mask
+
+# Training loop & hyperparams
+
+# example’s train-control
+epochs: 10
+max_steps_per_epoch: null
+batch_size: 4
+gradient_accumulation_steps: 8 # Use to increase effective batch size
+# explicit optimizer / scheduler / loss
+optimizer:
+  _component_: torch.optim.AdamW
+  fused: true
+  weight_decay: 0.01
+  lr: 1e-4
+
+lr_scheduler:
+  _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
+  num_warmup_steps: 100
+
+loss:
+  _component_: torchtune.modules.loss.LinearCrossEntropyLoss
+
+clip_grad_norm: 1.0
+compile: false
+
+# Device & memory
+device: cuda
+enable_activation_checkpointing: true
+dtype: bf16
+
+# Logging
+
+metric_logger:
+  _component_: torchtune.training.metric_logging.WandBLogger
+  project: llama3_2_w2_extraction
+  entity: <your_wandb_entity>
+  job_type: lora_finetune_single_device
+  group: llama-cookbook
+log_every_n_steps: 5
+save_steps: 100
+log_peak_memory_stats: true
+log_level: INFO
+
+# Profiler (off by default)
+profiler:
+  _component_: torchtune.training.setup_torch_profiler
+  enabled: false
+  output_dir: ${output_dir}/profiling_outputs
+  cpu: true
+  cuda: true
+  profile_memory: false
+  with_stack: false
+  record_shapes: true
+  with_flops: false
+  wait_steps: 5
+  warmup_steps: 3
+  active_steps: 2
+  num_cycles: 1

+ 793 - 0
getting-started/finetuning/vision/evaluate.py

@@ -0,0 +1,793 @@
+#!/usr/bin/env python3
+"""
+Script to evaluate a vision-language model on the W2 tax form dataset using compatible API client.
+Leverages the OpenAI-compatible SDK for various endpoints, like vLLM server, Llama API, or any compatible API.
+Support batch processing.
+Loads images from the provided dataset, sends them to the compatible API server,
+and compares with the expected output.
+"""
+
+import argparse
+import base64
+import json
+import logging
+import os
+import pathlib
+import re
+import time
+import traceback
+from concurrent.futures import as_completed, ThreadPoolExecutor
+from datetime import datetime
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Tuple
+
+from datasets import load_dataset, load_from_disk
+from openai import OpenAI
+from PIL import Image
+from pydantic import BaseModel
+from tqdm import tqdm
+
+# Set up logging
+logging.basicConfig(
+    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
+)
+logger = logging.getLogger(__name__)
+
+
+class W2Form(BaseModel):
+    box_b_employer_identification_number: str
+    box_c_employer_name: str
+    box_c_employer_street_address: str
+    box_c_employer_city_state_zip: str
+    box_a_employee_ssn: str
+    box_e_employee_name: str
+    box_e_employee_street_address: str
+    box_e_employee_city_state_zip: str
+    box_d_control_number: int
+    box_1_wages: float
+    box_2_federal_tax_withheld: float
+    box_3_social_security_wages: float
+    box_4_social_security_tax_withheld: float
+    box_5_medicare_wages: float
+    box_6_medicare_wages_tax_withheld: float
+    box_7_social_security_tips: float
+    box_8_allocated_tips: float
+    box_9_advance_eic_payment: Optional[str]
+    box_10_dependent_care_benefits: float
+    box_11_nonqualified_plans: float
+    box_12a_code: str
+    box_12a_value: float
+    box_12b_code: str
+    box_12b_value: float
+    box_12c_code: str
+    box_12c_value: float
+    box_12d_code: Optional[str]
+    box_12d_value: float
+    box_13_statutary_employee: Optional[str]
+    box_13_retirement_plan: Optional[str]
+    box_13_third_part_sick_pay: Optional[str]
+    box_15_1_state: str
+    box_15_1_employee_state_id: str
+    box_16_1_state_wages: float
+    box_17_1_state_income_tax: float
+    box_18_1_local_wages: float
+    box_19_1_local_income_tax: float
+    box_20_1_locality: str
+    box_15_2_state: str
+    box_15_2_employee_state_id: str
+    box_16_2_state_wages: float
+    box_17_2_state_income_tax: float
+    box_18_2_local_wages: float
+    box_19_2_local_income_tax: float
+    box_20_2_locality: str
+
+
+# ----------- Utilities -----------
+def encode_image_to_base64(image_path: str) -> str:
+    """Encode image to base64 string."""
+    with open(image_path, "rb") as f:
+        return base64.b64encode(f.read()).decode()
+
+
+def create_messages(prompt: str, image_path: str) -> List[Dict]:
+    """Create messages array for API client call."""
+    content = [
+        {"type": "text", "text": prompt},
+        {
+            "type": "image_url",
+            "image_url": {
+                "url": f"data:image/png;base64,{encode_image_to_base64(image_path)}"
+            },
+        },
+    ]
+    return [{"role": "user", "content": content}]
+
+
+def clean_json_string(json_str: str) -> str:
+    """
+    Clean common JSON formatting issues from LLM responses.
+
+    Args:
+        json_str: Raw JSON string that may contain formatting issues
+
+    Returns:
+        Cleaned JSON string
+    """
+    # Remove markdown code block markers
+    json_str = re.sub(r"```(?:json)?\s*", "", json_str)
+    json_str = re.sub(r"\s*```", "", json_str)
+
+    # Fix malformed string patterns like: "field": ",\n" ,
+    # This handles the specific error case where strings are malformed with newlines
+    json_str = re.sub(r':\s*",\s*"\s*,', ': "",', json_str)
+
+    # Fix incomplete string literals with control characters
+    # Pattern: "field": "partial_value\nrest_of_value",
+    json_str = re.sub(r':\s*"([^"]*)\n([^"]*)",', r': "\1\2",', json_str)
+
+    # Fix the specific pattern from the error: "field": "value\n" followed by whitespace and comma
+    json_str = re.sub(r':\s*"([^"]*)\n"\s*,', r': "\1",', json_str)
+
+    # Remove trailing commas in objects and arrays
+    json_str = re.sub(r",(\s*[}\]])", r"\1", json_str)
+
+    # Fix missing quotes around keys (sometimes LLMs output unquoted keys)
+    json_str = re.sub(r"([{,]\s*)([a-zA-Z_][a-zA-Z0-9_]*)\s*:", r'\1"\2":', json_str)
+
+    # Fix single quotes to double quotes (JSON requires double quotes)
+    json_str = re.sub(r"'([^']*)'", r'"\1"', json_str)
+
+    # Remove control characters that are not allowed in JSON strings
+    # Keep only printable ASCII and basic whitespace
+    json_str = "".join(char for char in json_str if ord(char) >= 32 or char in "\t\r ")
+
+    # Fix null-like values that should be proper JSON null
+    json_str = re.sub(r":\s*None\s*,", ": null,", json_str, flags=re.IGNORECASE)
+    json_str = re.sub(r":\s*undefined\s*,", ": null,", json_str, flags=re.IGNORECASE)
+
+    return json_str
+
+
+def extract_json_from_response(response: str) -> Tuple[Dict[str, Any], bool]:
+    """
+    Robust JSON extraction from LLM responses with comprehensive error handling.
+
+    Args:
+        response: Raw response text from LLM
+
+    Returns:
+        Tuple of (extracted_json_dict, has_error)
+    """
+    if not response or not response.strip():
+        logger.warning("Empty response provided")
+        return {}, True
+
+    # Strategy 1: Look for JSON content between triple backticks
+    json_match = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", response, re.DOTALL)
+    if json_match:
+        json_str = json_match.group(1)
+    else:
+        # Strategy 2: Look for JSON object pattern (handle nested braces)
+        json_match = re.search(r"\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}", response, re.DOTALL)
+        if json_match:
+            json_str = json_match.group(0)
+        else:
+            # Strategy 3: Find content between first { and last }
+            start_idx = response.find("{")
+            end_idx = response.rfind("}")
+            if start_idx != -1 and end_idx != -1 and end_idx > start_idx:
+                json_str = response[start_idx : end_idx + 1]
+            else:
+                logger.warning("No JSON pattern found in response")
+                logger.debug(f"Response snippet: {response[:200]}...")
+                return {}, True
+
+    # Clean the extracted JSON string
+    original_json_str = json_str
+    json_str = clean_json_string(json_str)
+
+    # Attempt to parse with multiple strategies
+    parsing_strategies = [
+        ("direct", lambda s: json.loads(s)),
+        ("strip_whitespace", lambda s: json.loads(s.strip())),
+        (
+            "fix_escapes",
+            lambda s: json.loads(s.replace("\\\\", "\\").replace('\\"', '"')),
+        ),
+    ]
+
+    for strategy_name, parse_func in parsing_strategies:
+        try:
+            parsed_json = parse_func(json_str)
+
+            # Validate that it's a dictionary (expected for most use cases)
+            if not isinstance(parsed_json, dict):
+                logger.warning(
+                    f"Extracted JSON is not a dictionary: {type(parsed_json)}"
+                )
+                continue
+
+            logger.debug(f"Successfully parsed JSON using strategy: {strategy_name}")
+            return parsed_json, False
+
+        except json.JSONDecodeError as e:
+            logger.debug(f"Strategy '{strategy_name}' failed: {e}")
+            continue
+        except Exception as e:
+            logger.debug(f"Unexpected error in strategy '{strategy_name}': {e}")
+            continue
+
+    # If all strategies fail, log details for debugging
+    logger.error("All JSON parsing strategies failed")
+    logger.debug(f"Original JSON string (first 500 chars): {original_json_str[:500]}")
+    logger.debug(f"Cleaned JSON string (first 500 chars): {json_str[:500]}")
+
+    return {}, True
+
+
+def generate_prompt(structured=True) -> str:
+    """Generate prompt for the model."""
+    json_schema = W2Form.model_json_schema()
+
+    prompt = (
+        "You are an expert document information extraction system. "
+        "I will show you an image of a W-2 tax form. "
+        "Please extract all the information from this form and return it in a JSON format. "
+        "Include all fields such as employee details, employer details, wages, federal income tax withheld, "
+        "social security wages, social security tax withheld, medicare wages and tips, medicare tax withheld, "
+        "and any other information present on the form. "
+    )
+
+    if not structured:
+        prompt += f"Return ONLY the JSON output without any additional text or explanations following this schema {json_schema}"
+
+    return prompt
+
+
+def call_api_client(
+    client: OpenAI,
+    messages: List[Dict],
+    model: str = "meta-llama/Llama-3.2-11B-Vision-Instruct",
+    temperature: float = 0.0,
+    max_tokens: int = 8192,
+    response_format: Optional[Dict] = None,
+    timeout: int = 300,
+    seed: Optional[int] = 42,
+):
+    """
+    Call compatible API server using OpenAI-compatible client.
+    """
+    try:
+        kwargs = {
+            "model": model,
+            "messages": messages,
+            "temperature": temperature,
+            "max_tokens": max_tokens,
+            "timeout": timeout,
+        }
+
+        # Add seed if provided for reproducible generation
+        if seed is not None:
+            kwargs["seed"] = seed
+
+        # Add response format if structured output is enabled
+        if response_format:
+            kwargs["response_format"] = response_format
+
+        logger.debug(f"Making API client call with model: {model}")
+        response = client.chat.completions.create(**kwargs)
+
+        logger.debug(f"Received response with {len(response.choices)} choices")
+        return response
+
+    except Exception as e:
+        logger.error(f"API client call failed: {e}")
+        raise
+
+
+def process_single_sample(
+    client: OpenAI,
+    sample_data: Tuple[int, Dict],
+    output_dir: str,
+    model: str,
+    structured: bool,
+    timeout: int,
+) -> Dict[str, Any]:
+    """Process a single sample using OpenAI SDK."""
+    idx, sample = sample_data
+
+    try:
+        # Get image
+        image = sample["image"]
+
+        # Save image temporarily
+        image_path = get_image_path(image, output_dir, idx)
+        logger.debug(f"Saved image to {image_path}")
+
+        # Generate prompt and messages
+        prompt = generate_prompt(structured)
+        messages = create_messages(prompt, image_path)
+
+        # Prepare response format for structured output
+        response_format = None
+        if structured:
+            json_schema = W2Form.model_json_schema()
+            response_format = {
+                "type": "json_schema",
+                "json_schema": {
+                    "name": "W2Form",
+                    "schema": json_schema,
+                    "strict": True,
+                },
+            }
+
+        # Call API client
+        start_time = time.time()
+
+        try:
+            response = call_api_client(
+                client=client,
+                messages=messages,
+                model=model,
+                response_format=response_format,
+                timeout=timeout,
+            )
+
+            content = response.choices[0].message.content
+            usage = response.usage.model_dump() if response.usage else {}
+
+        except Exception as e:
+            logger.error(f"Error calling OpenAI SDK for sample {idx}: {e}")
+            content = ""
+            usage = {}
+
+        processing_time = time.time() - start_time
+
+        # Extract JSON from response
+        extracted_json, json_parsing_error = extract_json_from_response(content)
+
+        # Get ground truth
+        ground_truth_raw = json.loads(sample["ground_truth"])
+
+        # Handle the gt_parse wrapper structure if present
+        if "gt_parse" in ground_truth_raw:
+            ground_truth = ground_truth_raw["gt_parse"]
+        else:
+            ground_truth = ground_truth_raw
+
+        # Normalize for comparison
+        normalized_pred = normalize_json(extracted_json)
+        normalized_gt = normalize_json(ground_truth)
+
+        # Save results
+        result = {
+            "sample_id": idx,
+            "prediction": extracted_json,
+            "ground_truth": ground_truth,
+            "normalized_prediction": normalized_pred,
+            "normalized_gt": normalized_gt,
+            "raw_response": content,
+            "processing_time": processing_time,
+            "json_parsing_error": json_parsing_error,
+            "usage": usage,
+        }
+
+        return result
+
+    except Exception as e:
+        traceback_str = traceback.format_exc()
+        logger.error(f"Error processing sample {idx}: {str(e)} at line {traceback_str}")
+        return {
+            "sample_id": idx,
+            "prediction": {},
+            "ground_truth": {},
+            "normalized_prediction": {},
+            "normalized_gt": {},
+            "raw_response": "",
+            "processing_time": 0.0,
+            "json_parsing_error": True,
+            "usage": {},
+            "error": str(e),
+        }
+
+
+def calculate_metrics(results: List[Dict]) -> Dict[str, Any]:
+    """Calculate accuracy metrics for the predictions."""
+    if not results:
+        logger.error("No results provided")
+        return {"accuracy": 0.0, "field_accuracy": {}}
+
+    # Initialize metrics
+    total_fields = 0
+    correct_fields = 0
+    parse_errors = 0
+    total_records = len(results)
+    logger.info(f"Total records: {total_records}")
+    field_counts = {}
+    field_correct = {}
+
+    for result in results:
+        pred, gt = result["prediction"], result["ground_truth"]
+
+        if result["json_parsing_error"]:
+            parse_errors += 1
+            total_fields += len(gt)
+            continue
+
+        for field in gt.keys():
+            # Count total occurrences of this field
+            field_counts[field] = field_counts.get(field, 0) + 1
+            total_fields += 1
+
+            # Check if field is correct
+            if field in pred and pred[field] == gt[field]:
+                correct_fields += 1
+                field_correct[field] = field_correct.get(field, 0) + 1
+
+    # Calculate overall accuracy
+    accuracy = correct_fields / total_fields if total_fields > 0 else 0.0
+    errors = parse_errors / total_records if total_records > 0 else 0.0
+
+    # Calculate per-field accuracy
+    field_accuracy = {}
+    for field in field_counts:
+        field_accuracy[field] = field_correct.get(field, 0) / field_counts[field]
+
+    return {
+        "accuracy": accuracy,
+        "field_accuracy": field_accuracy,
+        "parse_error": errors,
+    }
+
+
+def normalize_field_value(value: Any) -> str:
+    """Normalize field values for comparison."""
+    if value is None:
+        return ""
+
+    # Convert to string and normalize
+    value_str = str(value).strip().lower()
+
+    # Remove common separators in numbers
+    value_str = value_str.replace(",", "").replace(" ", "")
+
+    # Try to convert to float for numeric comparison
+    try:
+        value_float = float(value_str)
+        return str(value_float)
+    except ValueError:
+        return value_str
+
+
+def normalize_json(json_obj: Dict) -> Dict:
+    """Normalize JSON object for comparison."""
+    normalized = {}
+
+    for key, value in json_obj.items():
+        # Normalize key (lowercase, remove spaces)
+        norm_key = key.lower().replace(" ", "_")
+
+        # Normalize value
+        if isinstance(value, dict):
+            normalized[norm_key] = normalize_json(value)
+        elif isinstance(value, list):
+            normalized[norm_key] = [normalize_field_value(v) for v in value]
+        else:
+            normalized[norm_key] = normalize_field_value(value)
+
+    return normalized
+
+
+def get_image_path(image: Image.Image, output_dir: str, idx: int) -> str:
+    """Get the path to save the image."""
+    # Create a temporary file for the image
+    temp_dir = pathlib.Path(output_dir) / "temp"
+    os.makedirs(temp_dir, exist_ok=True)
+    image_path = temp_dir / f"temp_{idx}.png"
+    image_path = str(image_path.resolve())
+    image.save(image_path)
+    return image_path
+
+
+def vllm_openai_sdk_evaluation(
+    test_set,
+    output_dir: str,
+    server_url: str = "http://localhost:8001",
+    api_key: str = "default-blank-localhost",
+    model: str = "meta-llama/Llama-3.2-11B-Vision-Instruct",
+    structured: bool = True,
+    timeout: int = 300,
+    max_workers: int = 10,
+):
+    """
+    Evaluate the W2 extraction task using OpenAI SDK with batch processing.
+    """
+    # Initialize OpenAI client
+    client = OpenAI(
+        api_key=api_key,  # vLLM doesn't require a real API key
+        base_url=f"{server_url}/v1",
+    )
+
+    # Prepare sample data for batch processing
+    sample_data = [(idx, sample) for idx, sample in enumerate(test_set)]
+
+    results = []
+
+    # Use ThreadPoolExecutor for concurrent processing
+    with ThreadPoolExecutor(max_workers=max_workers) as executor:
+        # Submit all tasks
+        future_to_sample = {
+            executor.submit(
+                process_single_sample,
+                client,
+                data,
+                output_dir,
+                model,
+                structured,
+                timeout,
+            ): data[0]
+            for data in sample_data
+        }
+
+        # Collect results with progress bar
+        for future in tqdm(
+            as_completed(future_to_sample),
+            total=len(sample_data),
+            desc="Processing samples with OpenAI SDK (batch)",
+        ):
+            sample_idx = future_to_sample[future]
+            try:
+                result = future.result()
+                results.append(result)
+            except Exception as e:
+                logger.error(f"Exception in sample {sample_idx}: {e}")
+                # Add error result
+                results.append(
+                    {
+                        "sample_id": sample_idx,
+                        "prediction": {},
+                        "ground_truth": {},
+                        "normalized_prediction": {},
+                        "normalized_gt": {},
+                        "raw_response": "",
+                        "processing_time": 0.0,
+                        "json_parsing_error": True,
+                        "usage": {},
+                        "error": str(e),
+                    }
+                )
+
+    # Sort results by sample_id to maintain order
+    results.sort(key=lambda x: x["sample_id"])
+
+    return results
+
+
+def vllm_openai_sdk_sequential_evaluation(
+    test_set,
+    output_dir: str,
+    server_url: str = "http://localhost:8001",
+    api_key: str = "default-blank-localhost",
+    model: str = "meta-llama/Llama-3.2-11B-Vision-Instruct",
+    structured: bool = True,
+    timeout: int = 300,
+):
+    """
+    Evaluate the W2 extraction task using OpenAI SDK sequentially (for debugging).
+    """
+    # Initialize OpenAI client
+    client = OpenAI(
+        api_key=api_key,  # vLLM doesn't require a real API key
+        base_url=f"{server_url}/v1",
+    )
+
+    results = []
+
+    for idx, sample in enumerate(
+        tqdm(test_set, desc="Processing samples with OpenAI SDK (sequential)")
+    ):
+        result = process_single_sample(
+            client, (idx, sample), output_dir, model, structured, timeout
+        )
+        results.append(result)
+
+    return results
+
+
+def main():
+    parser = argparse.ArgumentParser(
+        description="Evaluate vision-language model on W2 tax form dataset using OpenAI SDK"
+    )
+    parser.add_argument(
+        "--server_url",
+        type=str,
+        default="http://localhost:8001",
+        help="URL of the vLLM HTTP server",
+    )
+    parser.add_argument(
+        "--model",
+        type=str,
+        default="meta-llama/Llama-3.2-11B-Vision-Instruct",
+        help="Model name to use for inference",
+    )
+    parser.add_argument(
+        "--dataset_name",
+        type=str,
+        default="singhsays/fake-w2-us-tax-form-dataset",
+        help="Name of the Huggingface dataset",
+    )
+    parser.add_argument(
+        "--output_dir",
+        type=str,
+        default="./w2_evaluation_results",
+        help="Directory to save evaluation results",
+    )
+    parser.add_argument(
+        "--limit",
+        type=int,
+        default=10,
+        help="Number of samples to evaluate (default: 10, use -1 for all)",
+    )
+    parser.add_argument(
+        "--structured",
+        action="store_true",
+        default=False,
+        help="Whether to use structured output (JSON schema)",
+    )
+    parser.add_argument(
+        "--timeout",
+        type=int,
+        default=300,
+        help="Timeout for SDK requests in seconds",
+    )
+    parser.add_argument(
+        "--max_workers",
+        type=int,
+        default=10,
+        help="Maximum number of concurrent workers for batch processing",
+    )
+    parser.add_argument(
+        "--sequential",
+        action="store_true",
+        default=False,
+        help="Process samples sequentially instead of in parallel (for debugging)",
+    )
+
+    args = parser.parse_args()
+
+    # Create output directory
+    os.makedirs(args.output_dir, exist_ok=True)
+
+    # Load dataset
+    logger.info(f"Loading dataset: {args.dataset_name}")
+    test_set = None
+    if Path(args.dataset_name, "state.json").exists():
+        test_set = load_from_disk(args.dataset_name)
+    else:
+        dataset = load_dataset(args.dataset_name)
+        if "test" not in dataset:
+            logger.error("Dataset does not have a test split")
+            return 1
+        test_set = dataset["test"]
+
+    logger.info(f"Loaded test set with {len(test_set)} samples")
+
+    # Limit number of samples if specified
+    if args.limit > 0 and args.limit < len(test_set):
+        test_set = test_set.select(range(args.limit))
+        logger.info(f"Limited to {args.limit} samples")
+
+    # Get API key from environment variable
+    api_key = os.getenv("LLAMA_API_KEY") or os.getenv("OPENAI_API_KEY")
+
+    if not api_key:
+        logger.warning(
+            "No API key found. Please set the LLAMA_API_KEY or OPENAI_API_KEY environment variable for public APIs."
+        )
+        api_key = "default-blank-localhost"
+
+    # Test server connection
+    try:
+        client = OpenAI(
+            api_key=api_key,
+            base_url=f"{args.server_url}/v1",
+        )
+        # Test with a simple call
+        models = client.models.list()
+        logger.info(f"Successfully connected to vLLM server at {args.server_url}")
+        logger.info(f"Available models: {[model.id for model in models.data]}")
+    except Exception as e:
+        logger.error(f"Failed to connect to vLLM server at {args.server_url}: {e}")
+        logger.error("Make sure the vLLM server is running and accessible")
+        return 1
+
+    # Run evaluation
+    if args.sequential:
+        logger.info("Running sequential evaluation...")
+        results = vllm_openai_sdk_sequential_evaluation(
+            test_set=test_set,
+            output_dir=args.output_dir,
+            server_url=args.server_url,
+            api_key=api_key,
+            model=args.model,
+            structured=args.structured,
+            timeout=args.timeout,
+        )
+    else:
+        logger.info(f"Running batch evaluation with {args.max_workers} workers...")
+        results = vllm_openai_sdk_evaluation(
+            test_set=test_set,
+            output_dir=args.output_dir,
+            server_url=args.server_url,
+            api_key=api_key,
+            model=args.model,
+            structured=args.structured,
+            timeout=args.timeout,
+            max_workers=args.max_workers,
+        )
+
+    # Save detailed results
+    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+    try:
+        results_file = os.path.join(args.output_dir, f"results_{timestamp}.json")
+        with open(results_file, "w") as f:
+            json.dump(results, f, indent=2)
+        logger.info(f"Detailed results saved to {results_file}")
+
+    except Exception as e:
+        logger.error(f"Error saving detailed results: {str(e)}")
+        return 1
+
+    # Calculate metrics
+    metrics = calculate_metrics(results)
+
+    # Save evaluation summary
+    output_file = os.path.join(args.output_dir, f"evaluation_results_{timestamp}.json")
+    arguments = {
+        "server_url": args.server_url,
+        "model": args.model,
+        "output_dir": args.output_dir,
+        "dataset_name": args.dataset_name,
+        "limit": args.limit,
+        "structured": args.structured,
+        "timeout": args.timeout,
+        "max_workers": args.max_workers,
+        "sequential": args.sequential,
+        "prompt": generate_prompt(args.structured),
+    }
+
+    summary = {
+        "arguments": arguments,
+        "metrics": metrics,
+        "timestamp": timestamp,
+        "total_samples": len(results),
+    }
+
+    with open(output_file, "w") as f:
+        json.dump(summary, f, indent=2)
+
+    # Print summary
+    logger.info("=" * 50)
+    logger.info("EVALUATION SUMMARY")
+    logger.info("=" * 50)
+    logger.info(f"Overall accuracy: {metrics['accuracy']:.4f}")
+    logger.info(f"Parse error rate: {metrics['parse_error']:.4f}")
+    logger.info("Field-level accuracy:")
+    field_accuracy = metrics["field_accuracy"]
+    for field, acc in sorted(field_accuracy.items(), key=lambda x: x[1], reverse=True):
+        logger.info(f"  {field}: {acc:.4f}")
+
+    logger.info(f"Results saved to {output_file}")
+
+    # Clean up temp directory if it exists
+    temp_dir = os.path.join(args.output_dir, "temp")
+    if os.path.exists(temp_dir):
+        import shutil
+
+        shutil.rmtree(temp_dir)
+
+    return 0
+
+
+if __name__ == "__main__":
+    exit(main())

+ 237 - 0
getting-started/finetuning/vision/prepare_w2_dataset.py

@@ -0,0 +1,237 @@
+#!/usr/bin/env python3
+"""
+Script to modify the dataset by removing the top-level 'gt_parse' attribute from the ground_truth column
+and keeping all the keys under it. Also supports custom train-test splits.
+"""
+
+import argparse
+import json
+import logging
+
+from datasets import load_dataset
+
+
+# Configure logging
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s - %(levelname)s - %(message)s",
+    datefmt="%Y-%m-%d %H:%M:%S",
+)
+logger = logging.getLogger(__name__)
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(
+        description="Prepare W2 dataset with custom train-test splits"
+    )
+    parser.add_argument(
+        "--train-ratio",
+        type=float,
+        default=0.8,
+        help="Ratio of data to use for training (default: 0.8, i.e., 80%% train, 20%% test)",
+    )
+    parser.add_argument(
+        "--output-dir",
+        type=str,
+        default=None,
+        help="Custom output directory name. If not provided, will use 'fake_w2_us_tax_form_dataset_train{train_ratio}_test{1 - train_ratio}'",
+    )
+    parser.add_argument(
+        "--seed",
+        type=int,
+        default=42,
+        help="Random seed for dataset splitting (default: 42)",
+    )
+    parser.add_argument(
+        "--prompt",
+        type=str,
+        default="Parse this W-2 form and extract all fields into a single level json.",
+        help="Custom prompt to use for the input field (default: Parse this W-2 form...)",
+    )
+    parser.add_argument(
+        "--dataset-name",
+        type=str,
+        default="singhsays/fake-w2-us-tax-form-dataset",
+        help="Dataset name from HuggingFace Hub (default: singhsays/fake-w2-us-tax-form-dataset)",
+    )
+    parser.add_argument(
+        "--skip-validation",
+        action="store_true",
+        help="Skip validation split loading (useful if dataset doesn't have validation split)",
+    )
+    return parser.parse_args()
+
+
+# Define a function to modify the ground_truth column
+def remove_gt_parse_wrapper(example):
+    try:
+        # Parse the ground_truth JSON
+        ground_truth = json.loads(example["ground_truth"])
+
+        # Check if gt_parse exists in the ground_truth
+        if "gt_parse" in ground_truth:
+            # Replace the ground_truth with just the contents of gt_parse
+            example["ground_truth"] = json.dumps(ground_truth["gt_parse"])
+        else:
+            logger.warning("No 'gt_parse' key found in ground_truth, keeping original")
+
+        return example
+    except json.JSONDecodeError as e:
+        logger.error(f"Failed to parse ground_truth JSON: {e}")
+        logger.error(f"Problematic data: {example.get('ground_truth', 'N/A')}")
+        # Return the example unchanged if we can't parse it
+        return example
+    except Exception as e:
+        logger.error(f"Unexpected error in remove_gt_parse_wrapper: {e}")
+        return example
+
+
+def validate_dataset(dataset):
+    """Validate the loaded dataset has required columns."""
+    required_columns = ["ground_truth", "image"]
+    missing_columns = [
+        col for col in required_columns if col not in dataset.column_names
+    ]
+
+    if missing_columns:
+        raise ValueError(f"Dataset missing required columns: {missing_columns}")
+
+    logger.info(f"Dataset validation passed. Columns: {dataset.column_names}")
+
+
+def validate_train_ratio(train_ratio):
+    """Validate that train ratio is between 0 and 1 (exclusive)."""
+    if train_ratio <= 0 or train_ratio >= 1:
+        raise ValueError("Train ratio must be between 0 and 1 (exclusive)")
+    return True
+
+
+def create_output_directory_name(train_ratio, test_ratio, output_dir=None):
+    """Create output directory name based on the split ratio if not provided."""
+    if output_dir is None:
+        # Round to 2 decimal places before converting to int to avoid floating point precision issues
+        train_pct = int(round(train_ratio * 100, 2))
+        test_pct = int(round(test_ratio * 100, 2))
+        return f"fake_w2_us_tax_form_dataset_train{train_pct}_test{test_pct}"
+    return output_dir
+
+
+def load_dataset_safely(dataset_name, split="train+test"):
+    """Load dataset with proper error handling."""
+    try:
+        return load_dataset(dataset_name, split=split)
+    except Exception as e:
+        logger.error(f"Failed to load dataset '{dataset_name}': {e}")
+        raise
+
+
+def create_splits(all_data, train_ratio, seed):
+    """Create train-test splits from the dataset."""
+    logger.info(f"Creating new splits with train ratio: {train_ratio}")
+    return all_data.train_test_split(train_size=train_ratio, seed=seed)
+
+
+def load_validation_split(dataset_name, split_ds, skip_validation=False):
+    """Load validation split if not skipped."""
+    if skip_validation:
+        logger.info("Skipping validation split as requested")
+        return split_ds
+
+    try:
+        split_ds["validation"] = load_dataset(dataset_name, split="validation")
+        logger.info(
+            f"Loaded validation split with {len(split_ds['validation'])} examples"
+        )
+    except Exception as e:
+        logger.warning(
+            f"Could not load validation split: {e}. Continuing without validation split."
+        )
+
+    return split_ds
+
+
+def apply_transformations(split_ds, prompt):
+    """Apply data transformations to the dataset."""
+    logger.info("Modifying dataset...")
+    modified_ds = split_ds.map(remove_gt_parse_wrapper)
+
+    logger.info(f"Adding custom prompt: {prompt}")
+    modified_ds = modified_ds.map(lambda _: {"input": prompt})
+
+    return modified_ds
+
+
+def log_dataset_statistics(all_data, modified_ds):
+    """Log comprehensive dataset statistics."""
+    logger.info("\n=== Dataset Statistics ===")
+    logger.info(f"Total examples: {len(all_data)}")
+    logger.info(
+        f"Train split: {len(modified_ds['train'])} examples ({len(modified_ds['train'])/len(all_data)*100:.1f}%)"
+    )
+    logger.info(
+        f"Test split: {len(modified_ds['test'])} examples ({len(modified_ds['test'])/len(all_data)*100:.1f}%)"
+    )
+    if "validation" in modified_ds:
+        logger.info(f"Validation split: {len(modified_ds['validation'])} examples")
+
+
+def save_dataset(modified_ds, output_dir):
+    """Save the modified dataset to disk."""
+    logger.info(f"Saving modified dataset to '{output_dir}'...")
+    modified_ds.save_to_disk(output_dir)
+    logger.info(f"Done! Modified dataset saved to '{output_dir}'")
+
+
+def main():
+    try:
+        args = parse_args()
+
+        # Reconfigure logging with user-specified level
+        global logger
+
+        # Validate train ratio
+        validate_train_ratio(args.train_ratio)
+
+        train_ratio = args.train_ratio
+        test_ratio = 1 - train_ratio
+
+        # Create output directory name
+        output_dir = create_output_directory_name(
+            train_ratio, test_ratio, args.output_dir
+        )
+
+        logger.info(f"Using train-test split: {train_ratio:.2f}-{test_ratio:.2f}")
+        logger.info(f"Output directory will be: {output_dir}")
+        logger.info(f"Dataset: {args.dataset_name}")
+
+        # Load the dataset with error handling
+        logger.info("Loading dataset...")
+        all_data = load_dataset_safely(args.dataset_name, "train+test")
+
+        validate_dataset(all_data)
+        logger.info(f"Loaded {len(all_data)} examples from dataset")
+
+        # Create splits
+        split_ds = create_splits(all_data, train_ratio, args.seed)
+
+        # Load validation split
+        split_ds = load_validation_split(
+            args.dataset_name, split_ds, args.skip_validation
+        )
+
+        # Apply transformations
+        modified_ds = apply_transformations(split_ds, args.prompt)
+
+        # Log statistics
+        log_dataset_statistics(all_data, modified_ds)
+
+        # Save the modified dataset
+        save_dataset(modified_ds, output_dir)
+
+    except Exception as e:
+        logger.error(f"Script failed with error: {e}")
+        raise
+
+
+if __name__ == "__main__":
+    main()