Prechádzať zdrojové kódy

add code for measuring skills transferability across LoRA and FFT finetunes

Suraj Subramanian 4 dní pred
rodič
commit
6ab852ab4e
25 zmenil súbory, kde vykonal 112015 pridanie a 0 odobranie
  1. 130 0
      end-to-end-use-cases/transferability/README.md
  2. 60 0
      end-to-end-use-cases/transferability/config.yaml
  3. 22033 0
      end-to-end-use-cases/transferability/experiments/w2_ocr/grader_logs/base_model_evaluation_log.json
  4. 22033 0
      end-to-end-use-cases/transferability/experiments/w2_ocr/grader_logs/finetuned_full_fusion+encoder+decoder_evaluation_log.json
  5. 22033 0
      end-to-end-use-cases/transferability/experiments/w2_ocr/grader_logs/finetuned_full_fusion_evaluation_log.json
  6. 22033 0
      end-to-end-use-cases/transferability/experiments/w2_ocr/grader_logs/finetuned_lora_64_evaluation_log.json
  7. 22033 0
      end-to-end-use-cases/transferability/experiments/w2_ocr/grader_logs/finetuned_lora_8_evaluation_log.json
  8. 3 0
      end-to-end-use-cases/transferability/transferability/data/__init__.py
  9. 14 0
      end-to-end-use-cases/transferability/transferability/data/__main__.py
  10. 228 0
      end-to-end-use-cases/transferability/transferability/data/dataset_builder.py
  11. 3 0
      end-to-end-use-cases/transferability/transferability/datasets/__init__.py
  12. 34 0
      end-to-end-use-cases/transferability/transferability/datasets/torchtune_format.py
  13. 5 0
      end-to-end-use-cases/transferability/transferability/evals/__init__.py
  14. 14 0
      end-to-end-use-cases/transferability/transferability/evals/__main__.py
  15. 251 0
      end-to-end-use-cases/transferability/transferability/evals/eval_grid.py
  16. 116 0
      end-to-end-use-cases/transferability/transferability/evals/grader.py
  17. 187 0
      end-to-end-use-cases/transferability/transferability/evals/inference.py
  18. 240 0
      end-to-end-use-cases/transferability/transferability/evals/json_grading_utils.py
  19. 44 0
      end-to-end-use-cases/transferability/transferability/evals/shift_analysis.py
  20. 94 0
      end-to-end-use-cases/transferability/transferability/finetune/8b_full.yaml
  21. 103 0
      end-to-end-use-cases/transferability/transferability/finetune/8b_lora.yaml
  22. 3 0
      end-to-end-use-cases/transferability/transferability/finetune/__init__.py
  23. 14 0
      end-to-end-use-cases/transferability/transferability/finetune/__main__.py
  24. 232 0
      end-to-end-use-cases/transferability/transferability/finetune/finetune_grid.py
  25. 75 0
      end-to-end-use-cases/transferability/transferability/utils.py

+ 130 - 0
end-to-end-use-cases/transferability/README.md

@@ -0,0 +1,130 @@
+# Transferability Research Tool
+
+A Python package for evaluating model transferability across vision-language tasks through systematic fine-tuning and evaluation.
+
+## Directory Structure
+
+```
+./
+├── config.yaml                    # Main configuration file
+├── experiments/                   # Output directory for all experiments
+│   └── <experiment_name>/
+│       ├── formatted_datasets/    # Processed datasets ready for training
+│       ├── finetuned_checkpoints/ # Fine-tuned model checkpoints
+│       ├── finetune_logs/         # Training logs
+│       ├── grader_logs/           # Evaluation logs per model
+│       └── eval_grid_results.json # Final evaluation results
+└── transferability/               # Source code package
+    ├── __init__.py               # Package entry points
+    ├── __main__.py               # Main CLI entry point
+    ├── data/                     # Dataset processing
+    │   ├── __init__.py
+    │   ├── __main__.py           # Module CLI entry point
+    │   └── dataset_builder.py
+    ├── datasets/                 # Dataset format utilities
+    │   ├── __init__.py
+    │   └── torchtune_format.py   # TorchTune dataset format
+    ├── evals/                    # Evaluation utilities
+    │   ├── __init__.py
+    │   ├── __main__.py           # Module CLI entry point
+    │   ├── eval_grid.py          # Main evaluation grid runner
+    │   ├── grader.py             # Task-specific graders
+    │   ├── inference.py          # Model inference utilities
+    │   ├── json_grading_utils.py # JSON grading utilities
+    │   └── shift_analysis.py     # Distribution shift analysis
+    ├── finetune/                 # Fine-tuning utilities
+    │   ├── __init__.py
+    │   ├── __main__.py           # Module CLI entry point
+    │   ├── finetune_grid.py      # Main fine-tuning grid runner
+    │   ├── 8b_full.yaml          # TorchTune config for full fine-tuning
+    │   └── 8b_lora.yaml          # TorchTune config for LoRA fine-tuning
+    └── utils.py                  # Shared utilities
+```
+
+## Usage
+
+Run individual components as Python modules:
+
+```bash
+# Prepare datasets
+python -m transferability.data ./experiments/my_experiment
+
+# Run fine-tuning grid
+python -m transferability.finetune ./experiments/my_experiment
+
+# Run evaluation grid
+python -m transferability.evals ./experiments/my_experiment
+```
+
+
+## Configuration
+
+Edit `config.yaml` to configure your tasks, datasets, and training parameters:
+
+```yaml
+task1:
+  dataset: your/huggingface/dataset
+  system_prompt: "Your system prompt"
+  user_prompt: "Your user prompt"
+  image_column: image
+  assistant_text_column: ground_truth
+  grader: JSONGrader
+  sample_percent: 0.01
+
+task2:
+  # Similar structure for second task
+
+finetuning:
+  model_path: /path/to/your/base/model
+  tokenizer_path: /path/to/tokenizer
+  epochs: 1
+  batch_size: 8
+  # Fine-tuning strategy flags
+  fusion: false
+  fusion+encoder: false
+  fusion+decoder: false
+  fusion+encoder+decoder: true
+  lora_ranks: [8, 16, 32]
+
+evals:
+  nb_eval_samples: null  # null = use all samples
+  checkpoint_to_eval: -1  # -1 = use latest checkpoint
+  model_server_args:
+    tensor_parallel_size: 2
+    max_model_len: 4096
+```
+
+## Workflow
+
+1. **Configure**: Edit `config.yaml` with your tasks and model paths
+2. **Prepare Data**: Download and format datasets from HuggingFace
+3. **Fine-tune**: Train models using different strategies (LoRA, full fine-tuning)
+4. **Evaluate**: Test all models on all tasks and generate results
+
+## Key Features
+
+- **Modular Design**: Each component can be run independently
+- **Multiple Execution Methods**: Module-level, package-level, or direct imports
+- **Configurable Tasks**: Define tasks via YAML configuration
+- **Grid Search**: Automatically train multiple model variants
+- **Comprehensive Evaluation**: Test transferability across tasks
+- **Rich Logging**: Detailed logs and metrics for analysis
+
+## Output Structure
+
+Each experiment creates:
+- `formatted_datasets/`: HuggingFace datasets converted to training format
+- `finetuned_checkpoints/`: Model checkpoints for each training configuration
+- `finetune_logs/`: Training metrics and logs
+- `grader_logs/`: Per-model evaluation details
+- `eval_grid_results.json`: Summary of all evaluation results
+
+## Next Steps
+
+The package is now properly structured for module execution. You can:
+
+1. Update hardcoded paths in `__main__` sections (as planned)
+2. Add more sophisticated CLI argument parsing
+3. Add configuration validation
+4. Add progress tracking and resumption capabilities
+5. Add visualization utilities for results analysis

Rozdielové dáta súboru neboli zobrazené, pretože súbor je príliš veľký
+ 60 - 0
end-to-end-use-cases/transferability/config.yaml


Rozdielové dáta súboru neboli zobrazené, pretože súbor je príliš veľký
+ 22033 - 0
end-to-end-use-cases/transferability/experiments/w2_ocr/grader_logs/base_model_evaluation_log.json


Rozdielové dáta súboru neboli zobrazené, pretože súbor je príliš veľký
+ 22033 - 0
end-to-end-use-cases/transferability/experiments/w2_ocr/grader_logs/finetuned_full_fusion+encoder+decoder_evaluation_log.json


Rozdielové dáta súboru neboli zobrazené, pretože súbor je príliš veľký
+ 22033 - 0
end-to-end-use-cases/transferability/experiments/w2_ocr/grader_logs/finetuned_full_fusion_evaluation_log.json


Rozdielové dáta súboru neboli zobrazené, pretože súbor je príliš veľký
+ 22033 - 0
end-to-end-use-cases/transferability/experiments/w2_ocr/grader_logs/finetuned_lora_64_evaluation_log.json


Rozdielové dáta súboru neboli zobrazené, pretože súbor je príliš veľký
+ 22033 - 0
end-to-end-use-cases/transferability/experiments/w2_ocr/grader_logs/finetuned_lora_8_evaluation_log.json


+ 3 - 0
end-to-end-use-cases/transferability/transferability/data/__init__.py

@@ -0,0 +1,3 @@
+"""Dataset building and processing utilities."""
+
+from .dataset_builder import run_dataset_builder

+ 14 - 0
end-to-end-use-cases/transferability/transferability/data/__main__.py

@@ -0,0 +1,14 @@
+"""Entry point for running dataset builder as a module."""
+
+if __name__ == "__main__":
+    import sys
+
+    from .dataset_builder import run_dataset_builder
+
+    if len(sys.argv) < 2:
+        print("Usage: python -m transferability.data <experiment_dir>")
+        print("Example: python -m transferability.data ./experiments/my_experiment")
+        sys.exit(1)
+
+    experiment_dir = sys.argv[1]
+    run_dataset_builder(experiment_dir)

+ 228 - 0
end-to-end-use-cases/transferability/transferability/data/dataset_builder.py

@@ -0,0 +1,228 @@
+"""
+Data loader module for loading and formatting data from Hugging Face.
+"""
+
+import json
+import os
+from pathlib import Path
+from typing import Dict, Optional
+
+from datasets import concatenate_datasets, load_dataset, load_from_disk
+
+from ..utils import image_to_base64_url, load_config
+
+
+def load_hf_dataset(data_path: str, is_local: bool = False, **kwargs):
+    """
+    Load data from Hugging Face Hub or local disk.
+
+    Args:
+        data_path: Path to the dataset (either a Hugging Face dataset ID or a local path)
+        is_local: Whether the data is stored locally
+        **kwargs: Additional arguments to pass to the load_dataset function
+
+    Returns:
+        Dataset object from the datasets library with all splits
+
+    Raises:
+        ImportError: If the datasets package is not installed
+        ValueError: If data_path is None or empty
+    """
+    if not data_path:
+        raise ValueError("data_path must be provided")
+
+    dataset = None
+    if is_local:
+        # Load from local disk
+        file_extension = Path(data_path).suffix.lower()
+        if file_extension in [".csv"]:
+            dataset = load_dataset("csv", data_files=data_path, **kwargs)
+        else:
+            dataset = load_from_disk(data_path, **kwargs)
+    else:
+        # Load from Hugging Face Hub
+        dataset = load_dataset(data_path, **kwargs)
+
+    return dataset
+
+
+def resplit_dataset(dataset, train_percent: float, sample_percent: float = 1.0):
+    if isinstance(dataset, dict):
+        dataset = concatenate_datasets(list(dataset.values()))
+
+    # sample
+    dataset = dataset.take(int(len(dataset) * sample_percent))
+
+    # resplit into "train" and "test" splits
+    if train_percent == 0.0:
+        # hf datasets doesn't allow empty splits; this will create a singleton split
+        splits = dataset.train_test_split(train_size=1)
+    elif train_percent == 1.0:
+        splits = dataset.train_test_split(test_size=1)
+    else:
+        splits = dataset.train_test_split(train_size=train_percent)
+    return splits
+
+
+def convert_to_encoded_messages(
+    example: Dict,
+    image_column: str = None,
+    user_text_column: str = None,
+    assistant_text_column: str = None,
+    system_prompt: Optional[str] = None,
+    user_prompt: Optional[str] = None,
+) -> Dict:
+    image = example.get(image_column, None)  # if image_field in example else None
+    user_text = example.get(user_text_column, "")
+    assistant_text = example.get(assistant_text_column, "")
+
+    messages = []
+
+    # Create system message if system_prompt is provided
+    if system_prompt:
+        messages.append(
+            {
+                "role": "system",
+                "content": [{"type": "text", "text": system_prompt}],
+            }
+        )
+
+    # Create user content and user message
+    user_content = []
+    if any([user_prompt, user_text]):
+        user_text = user_prompt + "\n" + user_text
+        user_content.append(
+            {"type": "text", "text": user_text},
+        )
+
+    # Add image(s) to user content
+    if image is not None:
+        if not isinstance(image, list):
+            image = [image]
+        for img in image:
+            b64_img_url = image_to_base64_url(img)
+            user_content.append(
+                {
+                    "type": "image_url",
+                    "image_url": {"url": b64_img_url},
+                }
+            )
+
+    messages.append({"role": "user", "content": user_content})
+
+    # Create assistant message with text content
+    if assistant_text:
+        messages.append(
+            {
+                "role": "assistant",
+                "content": [{"type": "text", "text": assistant_text}],
+            }
+        )
+    # Serialize to string and return. This is required because datasets.map adds extra keys to each dict in messages
+    example["messages"] = json.dumps(messages)
+    return example
+
+
+def save_encoded_dataset(encoded_dataset, output_dir: str, split: str):
+    # Create the output directory if it doesn't exist
+    os.makedirs(output_dir, exist_ok=True)
+
+    # Define the output file path
+    conversation_data_path = os.path.join(output_dir, f"{split}_conversation_data.json")
+
+    if not "messages" in encoded_dataset.column_names:
+        raise RuntimeError
+    messages = [json.loads(x) for x in encoded_dataset["messages"]]
+    with open(conversation_data_path, "w") as f:
+        json.dump(messages, f, indent=2)
+
+
+def convert_hf_dataset(
+    dataset: str,
+    output_dir: str,
+    is_local: bool = False,
+    system_prompt: Optional[str] = None,
+    user_prompt: Optional[str] = None,
+    image_column: Optional[str] = None,
+    user_text_column: Optional[str] = None,
+    assistant_text_column: Optional[str] = None,
+    resplit_train_percent: float = 0.7,
+    sample_percent: float = 1.0,
+):
+    """
+    Load and format data based on either a configuration file or individual parameters.
+
+    Args:
+        output_dir: Directory to save the formatted data
+        dataset: Path/ID to the dataset to load
+        is_local: Whether the data is stored locally
+        system_prompt: System prompt to use for the dataset
+        user_prompt: User prompt to prepend to user text
+        image_column: Name of the column containing images
+        user_text_column: Name of the column containing user text
+        assistant_text_column: Name of the column containing assistant responses
+        resplit_train_percent: Percentage of data to use for training split
+        sample_percent: Percentage of total data to sample
+
+    Returns:
+        str: Path to the output directory containing the formatted data
+    """
+
+    # Load the dataset
+    dataset = load_hf_dataset(data_path=dataset, is_local=is_local)
+
+    # Concatenate and resplit the dataset into 'train' and 'test' splits
+    dataset_splits = resplit_dataset(
+        dataset, resplit_train_percent, sample_percent=sample_percent
+    )
+
+    # Process each split
+    for split_name, split_dataset in dataset_splits.items():
+        # Apply the conversion function with all parameters
+        encoded_dataset = split_dataset.map(
+            lambda example: convert_to_encoded_messages(
+                example,
+                image_column=image_column,
+                user_text_column=user_text_column,
+                assistant_text_column=assistant_text_column,
+                system_prompt=system_prompt,
+                user_prompt=user_prompt,
+            )
+        )
+
+        # Save the encoded dataset
+        save_encoded_dataset(encoded_dataset, output_dir, split_name)
+
+    return output_dir
+
+
+def run_dataset_builder(experiment_dir: str):
+    script_dir = Path(__file__).parent.parent.parent
+    config_path = script_dir / "config.yaml"
+
+    # Load configuration
+    config = load_config(config_path)
+
+    # Set output directory
+    output_dir = Path(experiment_dir) / "formatted_datasets"
+
+    # get task1 dataset
+    for task in ["task1", "task2"]:
+        convert_hf_dataset(
+            dataset=config[task].get("dataset"),
+            output_dir=output_dir / task,
+            is_local=config[task].get("is_local"),
+            system_prompt=config[task].get("system_prompt"),
+            user_prompt=config[task].get("user_prompt"),
+            image_column=config[task].get("image_column"),
+            user_text_column=config[task].get("user_text_column"),
+            assistant_text_column=config[task].get("assistant_text_column"),
+            resplit_train_percent=config[task].get("resplit_train_percent", None),
+            sample_percent=config[task].get("sample_percent", 1.0),
+        )
+
+
+if __name__ == "__main__":
+    run_dataset_builder(
+        "/data/users/subramen/fbsource/fbcode/users/subramen/internal-llama-cookbook/end-to-end-use-cases/transferability/experiments/test01"
+    )

+ 3 - 0
end-to-end-use-cases/transferability/transferability/datasets/__init__.py

@@ -0,0 +1,3 @@
+"""Dataset utilities and formats."""
+
+from .torchtune_format import sft_dataset

+ 34 - 0
end-to-end-use-cases/transferability/transferability/datasets/torchtune_format.py

@@ -0,0 +1,34 @@
+"""
+Custom SFT dataset for fine-tuning.
+"""
+
+from torchtune.data import OpenAIToMessages
+from torchtune.datasets import SFTDataset
+from torchtune.modules.transforms import Transform
+
+
+def sft_dataset(
+    model_transform: Transform,
+    *,
+    dataset_path: str,
+) -> SFTDataset:
+    """
+    Creates a custom SFT dataset for fine-tuning.
+
+    Args:
+        dataset_path: Path to the formatted data JSON file
+        train_on_input: Whether to train on input tokens
+        split: Dataset split to use
+
+    Returns:
+        SFTDataset: A dataset ready for fine-tuning with TorchTune
+    """
+    message_transform = OpenAIToMessages()
+
+    ds = SFTDataset(
+        source="json",
+        data_files=dataset_path,
+        message_transform=message_transform,
+        model_transform=model_transform,
+    )
+    return ds

+ 5 - 0
end-to-end-use-cases/transferability/transferability/evals/__init__.py

@@ -0,0 +1,5 @@
+"""Evaluation utilities and grid search."""
+
+from .eval_grid import run_eval_grid
+from .grader import get_grader
+from .inference import create_inference_request, LocalModelRunner

+ 14 - 0
end-to-end-use-cases/transferability/transferability/evals/__main__.py

@@ -0,0 +1,14 @@
+"""Entry point for running evaluation grid as a module."""
+
+if __name__ == "__main__":
+    import sys
+
+    from .eval_grid import run_eval_grid
+
+    if len(sys.argv) < 2:
+        print("Usage: python -m transferability.evals <experiment_dir>")
+        print("Example: python -m transferability.evals ./experiments/my_experiment")
+        sys.exit(1)
+
+    experiment_dir = sys.argv[1]
+    run_eval_grid(experiment_dir)

+ 251 - 0
end-to-end-use-cases/transferability/transferability/evals/eval_grid.py

@@ -0,0 +1,251 @@
+import json
+from pathlib import Path
+from typing import Optional
+
+from ..utils import load_config
+
+from .grader import get_grader
+
+from .inference import create_inference_request, LocalModelRunner
+
+from .shift_analysis import calculate_transferability_index
+
+
+def load_dataset(dataset_path: Path, nb_samples: Optional[int] = None):
+    """Load conversation dataset from JSON file."""
+    with open(dataset_path, "r") as f:
+        samples = json.load(f)
+    if nb_samples is not None:
+        samples = samples[:nb_samples]
+    return samples
+
+
+def grade_dataset(llm_runner, dataset, grader, inference_params):
+    """Grade a dataset using the LLM runner."""
+    requests = [create_inference_request(m, **inference_params) for m in dataset]
+    llm_outputs = llm_runner.run_batch(requests)
+    rows = [
+        {"expected_output": m[-1]["content"][0]["text"], "raw_response": l}
+        for m, l in zip(dataset, llm_outputs)
+    ]
+    return grader.grade(rows)
+
+
+def get_finetuned_checkpoint_path(base_path: Path, checkpoint_to_eval: int):
+    """Get the path to a specific finetuned checkpoint."""
+    if checkpoint_to_eval == -1:
+        # Find the last checkpoint
+        checkpoint_dirs = []
+        if base_path.exists():
+            for item in base_path.iterdir():
+                if item.is_dir() and item.name.startswith("epoch_"):
+                    try:
+                        epoch_num = int(item.name.split("_")[1])
+                        checkpoint_dirs.append((epoch_num, item))
+                    except (ValueError, IndexError):
+                        continue
+
+        if not checkpoint_dirs:
+            raise FileNotFoundError(f"No checkpoints found in {base_path}")
+
+        # Return the highest epoch
+        return max(checkpoint_dirs, key=lambda x: x[0])[1]
+    else:
+        # Return specific epoch
+        epoch_path = base_path / f"epoch_{checkpoint_to_eval}"
+        if not epoch_path.exists():
+            raise FileNotFoundError(f"Checkpoint not found: {epoch_path}")
+        return epoch_path
+
+
+def run_eval_grid(experiment_dir: str):
+    print("🚀 Starting evaluation grid execution...")
+    print(f"📁 Experiment directory: {experiment_dir}")
+
+    # Get script directory and config path
+    script_dir = Path(__file__).parent.parent.parent
+    config_path = script_dir / "config.yaml"
+    print(f"📝 Loading configuration from: {config_path}")
+
+    logs_dir = Path(experiment_dir) / "grader_logs"
+
+    # Load configuration
+    config = load_config(config_path)
+    print("✅ Configuration loaded successfully")
+
+    # Load task names
+    tasks = ["task1", "task2"]
+
+    # Populate checkpoints dictionary with base and finetuned checkpoints
+    print("🔍 Building checkpoint list...")
+    checkpoints = {}
+
+    # Add base model from config
+    base_model_path = config["finetuning"]["model_path"]
+    checkpoints["base_model"] = base_model_path
+    print(f"   📋 Base model: {base_model_path}")
+
+    # Add finetuned checkpoints
+    finetuned_ckpts_dir = Path(experiment_dir) / "finetuned_checkpoints"
+    checkpoint_to_eval = config["evals"]["checkpoint_to_eval"]
+
+    if finetuned_ckpts_dir.exists():
+        for ckpt_dir in finetuned_ckpts_dir.iterdir():
+            if ckpt_dir.is_dir():
+                try:
+                    ckpt_path = get_finetuned_checkpoint_path(
+                        ckpt_dir, checkpoint_to_eval
+                    )
+                    model_name = f"finetuned_{ckpt_dir.name}"
+                    checkpoints[model_name] = str(ckpt_path)
+                    print(f"   📋 Finetuned: {model_name} -> {ckpt_path}")
+                except FileNotFoundError as e:
+                    print(f"   ⚠️  Skipping {ckpt_dir.name}: {e}")
+    else:
+        print("   ⚠️  No finetuned checkpoints directory found")
+
+    print(f"📊 Total checkpoints to evaluate: {len(checkpoints)}")
+
+    # Load model server args from config
+    model_server_args = config["evals"]["model_server_args"]
+    print(f"🔧 Model server args: {model_server_args}")
+
+    # Load inference params from config
+    inference_params = config["evals"]["inference_params"]
+    print(f"⚡ Inference params: {inference_params}")
+
+    eval_grid_results = []
+
+    print(f"\n🎯 Starting evaluation grid...")
+    print("=" * 60)
+
+    total_evaluations = len(checkpoints) * len(tasks)
+    eval_count = 0
+
+    for model_name, ckpt in checkpoints.items():
+        print(f"\n🤖 Evaluating model: {model_name}")
+        print(f"📁 Checkpoint: {ckpt}")
+
+        # Initialize model runner for this checkpoint
+        llm_runner = LocalModelRunner(ckpt, **model_server_args)
+
+        # Create log file for this model in `logs_dir`
+        logs_dir.mkdir(parents=True, exist_ok=True)
+        log_file_path = logs_dir / f"{model_name}_evaluation_log.json"
+        model_log_data = {
+            "model_name": model_name,
+            "checkpoint_path": str(ckpt),
+            "model_server_args": model_server_args,
+            "inference_params": inference_params,
+            "tasks": {},
+        }
+        for task_name in tasks:
+            eval_count += 1
+            print(f"\n📈 Evaluation {eval_count}/{total_evaluations}")
+            print(f"🎯 Model: {model_name}, Task: {task_name}")
+
+            # Get task-specific grader from config
+            grader_name = config[task_name].get("grader", "JSONGrader")
+            grader = get_grader(grader_name)
+            print(f"   🔧 Using grader: {grader_name}")
+
+            dataset_path = (
+                Path(experiment_dir)
+                / "formatted_datasets"
+                / task_name
+                / "test_conversation_data.json"
+            )
+
+            if not dataset_path.exists():
+                print(f"   ❌ Dataset not found: {dataset_path}")
+                continue
+
+            print(f"   📊 Loading dataset: {dataset_path}")
+            dataset = load_dataset(dataset_path)
+            print(f"   📋 Dataset size: {len(dataset)} samples")
+
+            try:
+                print("   ⏳ Running evaluation...")
+                eval_result = grade_dataset(
+                    llm_runner, dataset, grader, inference_params
+                )
+
+                # Log eval_result for each task in the log file
+                model_log_data["tasks"][task_name] = {
+                    "metrics": eval_result.metrics,
+                    "topline_metric_name": eval_result.topline_metric_name,
+                    "num_samples": len(eval_result.result_data),
+                    "result_data": eval_result.result_data,
+                    "rows": eval_result.rows,
+                }
+
+                topline_metric = eval_result.topline_metric_name
+                score = eval_result.metrics.get(topline_metric)
+
+                print(f"   ✅ {topline_metric}: {score:.4f}")
+
+                eval_grid_results.append(
+                    {
+                        "model": model_name,
+                        "task": task_name,
+                        "topline_metric": topline_metric,
+                        "score": score,
+                        "metrics": eval_result.metrics,
+                    }
+                )
+
+            except Exception as e:
+                print(f"   ❌ Evaluation failed: {e}")
+                eval_grid_results.append(
+                    {
+                        "model": model_name,
+                        "task": task_name,
+                        "topline_metric": "error",
+                        "score": -1,
+                        "error": str(e),
+                    }
+                )
+
+        # Write the log file for this model
+        with open(log_file_path, "w") as f:
+            json.dump(model_log_data, f, indent=2)
+        print(f"   📄 Evaluation log saved to: {log_file_path}")
+
+        llm_runner.shutdown()
+
+    # Save results
+    results_path = Path(experiment_dir) / "eval_grid_results.json"
+    with open(results_path, "w") as f:
+        json.dump(eval_grid_results, f, indent=2)
+
+    print("\n" + "=" * 60)
+    print("🎉 Evaluation grid completed!")
+    print(f"📁 Results saved to: {results_path}")
+
+    # Print summary table
+    # print("\n📊 Results Summary:")
+    # print("-" * 80)
+    # print(f"{'Model':<25} {'Task':<10} {'Metric':<15} {'Score':<10}")
+    # print("-" * 80)
+    # for result in eval_grid_results:
+    #     print(
+    #         f"{result['model']:<25} {result['task']:<10} {result['topline_metric']:<15} {result['score']:<10.4f}"
+    #     )
+    # print("-" * 80)
+
+    transferability_results = calculate_transferability_index(eval_grid_results)
+
+    # Print summary table
+    print("\n📊 Results Summary:")
+    print("-" * 80)
+    print(transferability_results)
+    transferability_results_path = Path(experiment_dir) / "transferability_results.csv"
+    transferability_results.to_csv(transferability_results_path, index=False)
+
+    return eval_grid_results
+
+
+if __name__ == "__main__":
+    run_eval_grid(
+        "/data/users/subramen/fbsource/fbcode/users/subramen/internal-llama-cookbook/end-to-end-use-cases/transferability/experiments/test01"
+    )

+ 116 - 0
end-to-end-use-cases/transferability/transferability/evals/grader.py

@@ -0,0 +1,116 @@
+import json
+from abc import ABC, abstractmethod
+from typing import Any
+
+import pandas as pd
+from pydantic import BaseModel
+
+from ..utils import map_with_progress
+
+Row = dict[str, Any]
+
+
+class EvalResult(BaseModel):
+    rows: list[Row]  # raw rows
+    # overall metrics
+    metrics: dict[str, float] | None = None
+    # result for each row
+    result_data: list[dict[str, Any]]
+    topline_metric_name: str
+
+
+class IGrader(ABC):
+    @abstractmethod
+    def grade_row(self, row: Row) -> dict[str, Any]:
+        """Calculates metrics for a single row."""
+        pass
+
+    @abstractmethod
+    def calculate_aggregate_metrics(
+        self, results: list[dict[str, Any]]
+    ) -> dict[str, Any]:
+        """Calculates aggregate metrics for a list of row results.
+        This is used for both overall and per-subset calculations.
+
+        :param results: List of row results, returned by grade_row
+        :param rows: List of input rows
+        """
+        pass
+
+    @abstractmethod
+    def topline_metric(self) -> str:
+        """Key of the grade value in the overall metrics dict."""
+        pass
+
+    def grade(self, rows: list[Row]) -> EvalResult:
+        """Grades rows, calculating overall and per-subset metrics using helper methods."""
+
+        result_data = map_with_progress(self.grade_row, rows)
+        overall_metrics = self.calculate_aggregate_metrics(result_data)
+
+        return EvalResult(
+            rows=rows,
+            metrics=overall_metrics,
+            result_data=result_data,
+            topline_metric_name=self.topline_metric(),
+        )
+
+    def __str__(self) -> str:
+        return f"{self.__class__.__name__}"
+
+
+class JSONGrader(IGrader):
+    """Generic grader for JSON-based structured data extraction tasks.
+
+    This grader can be used for any task where the model outputs JSON that needs
+    to be compared against ground truth JSON, such as W2 form extraction, OCR tasks, etc.
+    """
+
+    def grade_row(self, row: Row) -> dict[str, Any]:
+        from .json_grading_utils import calculate_json_accuracy, JSONUtils
+
+        ground_truth = JSONUtils.load_json_from_str(row["expected_output"])
+        if "gt_parse" in ground_truth:
+            ground_truth = ground_truth["gt_parse"]
+        json_response = JSONUtils.extract_json_from_response(row["raw_response"])
+        return calculate_json_accuracy(ground_truth, json_response)
+
+    def calculate_aggregate_metrics(
+        self, results: list[dict[str, Any]]
+    ) -> dict[str, Any]:
+        """Calculate aggregate metrics for JSON-based evaluation."""
+        if not results:
+            return {"accuracy": 0.0, "error_rate": 1.0}
+
+        results_df = pd.DataFrame(results)
+
+        # Calculate accuracy (mean of scores)
+        accuracy = results_df["score"].mean() if "score" in results_df.columns else 0.0
+
+        # Calculate error rate (percentage of failed parsing attempts)
+        error_rate = (
+            (results_df["score"] == -1).mean() if "score" in results_df.columns else 0.0
+        )
+
+        return {
+            "accuracy": accuracy,
+            "error_rate": error_rate,
+            "total_samples": len(results),
+        }
+
+    def topline_metric(self) -> str:
+        return "accuracy"
+
+
+# Grader Registry - Clean factory pattern
+GRADER_REGISTRY = {
+    "JSONGrader": JSONGrader,
+}
+
+
+def get_grader(grader_name: str):
+    """Factory function to get grader by name from config."""
+    if grader_name not in GRADER_REGISTRY:
+        available = ", ".join(GRADER_REGISTRY.keys())
+        raise ValueError(f"Unknown grader '{grader_name}'. Available: {available}")
+    return GRADER_REGISTRY[grader_name]()

+ 187 - 0
end-to-end-use-cases/transferability/transferability/evals/inference.py

@@ -0,0 +1,187 @@
+import gc
+import logging
+from typing import Any, Dict, List
+
+import torch
+from openai import OpenAI
+from tqdm import tqdm
+from vllm import LLM, SamplingParams
+from vllm.sampling_params import GuidedDecodingParams
+
+logger = logging.getLogger(__name__)
+
+
+def create_inference_request(
+    messages: List[Dict[str, Any]],
+    temperature: float = 0.0,
+    top_p: float = 1.0,
+    max_completion_tokens: int = 4096,
+    seed: int = 42,
+) -> Dict[str, Any]:
+    """
+    Create an inference request for the model.
+
+    Args:
+        messages: List of message dictionaries for the conversation
+        model: Model name to use for inference
+        temperature: Sampling temperature
+        top_p: Top-p sampling parameter
+        max_completion_tokens: Maximum tokens to generate
+        seed: Random seed for reproducibility
+        use_json_decode: Whether to use JSON-guided decoding
+        json_schema: JSON schema for guided decoding (required if use_json_decode=True)
+
+    Returns:
+        Dict containing the formatted request parameters
+    """
+    # strip assistant outputs
+    messages = [m for m in messages if m["role"] != "assistant"]
+
+    try:
+        request = {
+            "temperature": temperature,
+            "top_p": top_p,
+            "max_tokens": max_completion_tokens,
+            "seed": seed,
+            "messages": messages,
+        }
+
+        # Add JSON-guided decoding if requested
+        # if use_json_decode:
+        #     if json_schema is None:
+        #         raise ValueError("json_schema is required when use_json_decode=True")
+        #     request["response_format"] = {
+        #         "type": "json_schema",
+        #         "json_schema": {"name": "ExtractionSchema", "schema": json_schema},
+        #     }
+
+        return request
+
+    except Exception as e:
+        logger.error(f"Failed to create inference request: {e}")
+        raise
+
+
+class ModelRunner:
+    """Base class for model inference runners."""
+
+    def run_batch(self, requests: List[Dict[str, Any]]) -> List[str]:
+        """
+        Run inference on a batch of requests.
+
+        Args:
+            requests: List of request parameters
+
+        Returns:
+            List of raw text responses
+        """
+        # Abstract method, to be implemented by subclasses
+        raise NotImplementedError("Subclasses must implement run_batch")
+
+
+# Not supported
+class APIModelRunner(ModelRunner):
+    """Runner for API-based model inference."""
+
+    def __init__(self, api_key: str, base_url: str):
+        """Initialize the API client."""
+        self.client = OpenAI(api_key=api_key, base_url=base_url)
+
+    def run_batch(self, requests: List[Dict[str, Any]]) -> List[str]:
+        """
+        Run inference on a batch of requests using the API.
+
+        Args:
+            requests: List of request parameters
+
+        Returns:
+            List of raw text responses
+        """
+        raise NotImplementedError(
+            "Not supported at the moment. Use local model instead."
+        )
+        responses = []
+        for request in tqdm(requests, desc="API Inference"):
+            try:
+                response = self.client.chat.completions.create(**request)
+                responses.append(response.choices[0].message.content)
+            except Exception as e:
+                responses.append(f"Error: {str(e)}")
+        return responses
+
+
+class LocalModelRunner(ModelRunner):
+    """Runner for local model inference."""
+
+    def __init__(
+        self,
+        ckpt_path: str,
+        tensor_parallel_size: int = 1,
+        max_model_len: int = 8192,
+        max_num_seqs: int = 128,
+        enforce_eager: bool = True,
+    ):
+        """Initialize the local model."""
+        try:
+            self.model = LLM(
+                ckpt_path,
+                tensor_parallel_size=tensor_parallel_size,
+                max_model_len=max_model_len,
+                max_num_seqs=max_num_seqs,
+                enforce_eager=enforce_eager,
+            )
+            logger.info(f"Initialized local model: {ckpt_path}")
+
+        except Exception as e:
+            logger.error(f"Failed to initialize local model: {e}")
+            raise
+
+    def run_batch(self, request_batch: List[Dict[str, Any]]) -> List[str]:
+        """
+        Run inference on a batch of requests using the local model.
+
+        Args:
+            requests: List of request parameters
+
+        Returns:
+            List of raw text responses
+        """
+        try:
+            # Extract messages
+            messages = [req["messages"] for req in request_batch]
+
+            # Prepare sampling parameters
+            common_sampling_params = {
+                "top_p": request_batch[0]["top_p"],
+                "temperature": request_batch[0]["temperature"],
+                "max_tokens": request_batch[0]["max_tokens"],
+                "seed": request_batch[0]["seed"],
+            }
+
+            # Handle JSON-guided decoding if present
+            if "response_format" in request_batch[0]:
+                sampling_params = []
+                for req in request_batch:
+                    gd_params = GuidedDecodingParams(
+                        json=req["response_format"]["json_schema"]["schema"]
+                    )
+                    sampling_params.append(
+                        SamplingParams(
+                            guided_decoding=gd_params, **common_sampling_params
+                        )
+                    )
+            else:
+                sampling_params = SamplingParams(**common_sampling_params)
+
+            # Run inference
+            outputs = self.model.chat(messages, sampling_params, use_tqdm=True)
+            return [output.outputs[0].text for output in outputs]
+
+        except Exception as e:
+            logger.error(f"Local model inference failed: {e}")
+            return [f"Error: {str(e)}" for _ in request_batch]
+
+    def shutdown(self):
+        del self.model
+        torch.cuda.empty_cache()
+        gc.collect()

+ 240 - 0
end-to-end-use-cases/transferability/transferability/evals/json_grading_utils.py

@@ -0,0 +1,240 @@
+"""
+Utility functions for evaluation of structured data extraction.
+
+This module provides helper functions for comparing JSON outputs,
+calculating accuracy metrics, and analyzing differences between
+predicted and actual structured data.
+"""
+
+import ast, json
+import logging
+import re
+from typing import Any, Dict, List, Union
+
+from json_repair import repair_json
+from jsondiff import diff
+
+# Setup logging
+logger = logging.getLogger(__name__)
+
+# Compile regex patterns once for better performance
+JSON_BLOCK_OPEN = re.compile(r"```json")
+JSON_BLOCK_CLOSE = re.compile(r"}\s+```")
+
+
+def calculate_json_accuracy(
+    actual: Union[str, Dict[str, Any]],
+    predicted: Union[str, Dict[str, Any]],
+) -> Dict[str, Any]:
+    """
+    Calculate accuracy metrics between predicted and actual JSON.
+
+    Args:
+        actual: The ground truth JSON as string or dict
+        predicted: The predicted JSON as string or dict
+
+    Returns:
+        Dict containing accuracy metrics including score, diffs, and field counts
+    """
+    try:
+        # Use JSONUtils from src for consistency
+        actual = JSONUtils.load_json_from_str(actual)
+        predicted = JSONUtils.load_json_from_str(predicted)
+    except Exception as e:
+        logger.error(f"Failed to parse JSON: {e}")
+        return {
+            "score": -1,
+            "full_json_diff": {},
+            "json_diff": {},
+            "nb_different_fields": -1,
+            "total_fields": -1,
+        }
+
+    full_diff_result = diff(actual, predicted, syntax="symmetric")
+    diff_result = diff(predicted, actual)
+    total_fields = count_total_fields(actual)
+
+    if not diff_result:
+        return {
+            "score": 1,
+            "full_json_diff": {},
+            "json_diff": {},
+            "nb_different_fields": 0,
+            "total_fields": total_fields,
+        }
+
+    changes = count_number_of_differences(diff_result)
+    score = max(0, (total_fields - changes) / total_fields)
+    return {
+        "score": round(score, 4),
+        "full_json_diff": str(full_diff_result),
+        "json_diff": str(diff_result),
+        "nb_different_fields": changes,
+        "total_fields": total_fields,
+    }
+
+
+def count_number_of_differences(differences) -> int:
+    """
+    Count the number of differences in a JSON diff object.
+
+    Args:
+        differences: The diff object or string representation
+
+    Returns:
+        int: Total number of differences found
+    """
+    differences = JSONUtils.load_json_from_str(differences)
+
+    def count_differences(differences: Any) -> int:
+        count = 0
+        if isinstance(differences, list) or isinstance(differences, tuple):
+            count += sum([count_differences(item) for item in differences])
+        if isinstance(differences, dict):
+            for _, value in differences.items():
+                if isinstance(value, dict):
+                    # Recursively count differences in nested objects
+                    count += count_differences(value)
+                elif isinstance(value, list):
+                    count += sum([count_differences(v) for v in value])
+                else:
+                    # Additions or deletions
+                    count += 1
+        return count
+
+    return count_differences(differences)
+
+
+def count_total_fields(obj: Any) -> int:
+    """
+    Count the total number of fields in a JSON object.
+
+    Args:
+        obj: The JSON object to count fields in
+
+    Returns:
+        int: Total number of fields
+    """
+    count = 0
+
+    def traverse(current: Any) -> None:
+        """Recursively traverse the object and count fields."""
+        nonlocal count
+        if not current or not isinstance(current, (dict, list)):
+            return
+
+        if isinstance(current, list):
+            for item in current:
+                if isinstance(item, (dict, list)):
+                    traverse(item)
+                else:
+                    count += 1
+        else:
+            for key, value in current.items():
+                if "__" in key:
+                    continue
+                if isinstance(value, (str, int, float, bool)) or value is None:
+                    count += 1
+                elif isinstance(value, (dict, list)):
+                    traverse(value)
+
+    traverse(obj)
+    return count
+
+
+class JSONUtils:
+    """Utility functions for working with JSON data."""
+
+    @staticmethod
+    def extract_json_blocks(content: str) -> List[str]:
+        """
+        Extract JSON code blocks from markdown-formatted text.
+
+        Parses a string containing markdown-formatted text and extracts all JSON blocks
+        that are enclosed in ```json ... ``` code blocks. This is useful for extracting
+        structured data from LLM responses.
+
+        Args:
+            content: The markdown-formatted text containing JSON code blocks
+
+        Returns:
+            List[str]: A list of extracted JSON strings (without the markdown delimiters)
+        """
+        blocs_ix = []
+        str_ptr = 0
+
+        while str_ptr < len(content):
+            start_ix = content.find("```json", str_ptr)
+            if start_ix == -1:
+                break
+            start_ix += len("```json")
+            end_match = JSON_BLOCK_CLOSE.search(content[start_ix:])
+            if end_match:
+                end_ix = start_ix + end_match.start() + 1
+            else:
+                end_ix = len(content)  # no closing tag, take the rest of the string
+            blocs_ix.append((start_ix, end_ix))
+            str_ptr = end_ix + 1
+
+        return [content[ix[0] : ix[1]].strip() for ix in blocs_ix]
+
+    @staticmethod
+    def load_json_from_str(json_str: str) -> Dict[str, Any]:
+        """
+        Parse a JSON string into a Python dictionary.
+
+        Attempts to parse a string as JSON using multiple methods. First tries standard
+        json.loads(), then falls back to ast.literal_eval() if that fails. This provides
+        more robust JSON parsing for LLM outputs that might not be perfectly formatted.
+
+        Args:
+            json_str: The JSON string to parse
+
+        Returns:
+            Dict[str, Any]: The parsed JSON as a dictionary
+
+        Raises:
+            ValueError: If parsing fails
+        """
+        if not isinstance(json_str, str):
+            return json_str
+
+        json_str = repair_json(json_str)
+        try:
+            return json.loads(json_str)
+        except json.decoder.JSONDecodeError:
+            # Try with None replacement
+            json_str = json_str.replace("null", "None")
+            try:
+                return ast.literal_eval(json_str)
+            except:
+                raise ValueError(f"Failed to load valid JSON from string: {json_str}")
+
+    @staticmethod
+    def extract_json_from_response(content: str) -> Dict[str, Any]:
+        """
+        Extract and parse JSON from an LLM response.
+
+        Processes a response from an LLM that may contain JSON in a markdown code block.
+        First checks if the response contains markdown-formatted JSON blocks and extracts them,
+        then parses the JSON string into a Python dictionary.
+
+        Args:
+            content: The LLM response text that may contain JSON
+
+        Returns:
+            Dict[str, Any]: The parsed JSON as a dictionary
+
+        Raises:
+            ValueError: If extraction or parsing fails
+        """
+        try:
+            if "```json" in content:
+                json_blocks = JSONUtils.extract_json_blocks(content)
+                if not json_blocks:
+                    raise ValueError("No JSON blocks found in response")
+                content = json_blocks[-1]
+
+            return JSONUtils.load_json_from_str(content)
+        except Exception as e:
+            raise ValueError(f"Failed to extract JSON from response: {str(e)}")

+ 44 - 0
end-to-end-use-cases/transferability/transferability/evals/shift_analysis.py

@@ -0,0 +1,44 @@
+from typing import Any, Dict, List
+
+import pandas as pd
+
+
+def calculate_relative_gain(tuned_score: float, baseline_score: float):
+    baseline_score = max(baseline_score, 1e-2)
+    return tuned_score / baseline_score - 1  # --> [-1, inf]
+
+
+def calculate_transferability_index(eval_grid_results: List[Dict[str, Any]]):
+    df = pd.DataFrame(eval_grid_results)
+
+    # Flatten to list of models
+    df_pivoted = (
+        df.pivot(index="model", columns="task", values="score")
+        .add_suffix("_score")
+        .reset_index()
+    )
+
+    # Get all score columns dynamically
+    score_columns = [col for col in df_pivoted.columns if col.endswith("_score")]
+
+    # Get baseline scores
+    baseline_scores = df_pivoted[df_pivoted["model"] == "base_model"][
+        score_columns
+    ].iloc[0]
+
+    # Add relative gain columns
+    for score_col in score_columns:
+        task_name = score_col.replace("_score", "")
+        gain_col = f"{task_name}_relative_gain"
+        df_pivoted[gain_col] = df_pivoted.apply(
+            lambda row: calculate_relative_gain(
+                row[score_col], baseline_scores[score_col]
+            ),
+            axis=1,
+        )
+
+    df_pivoted["transferability"] = (
+        df_pivoted["task2_relative_gain"] / df_pivoted["task1_relative_gain"]
+    )
+
+    return df_pivoted

+ 94 - 0
end-to-end-use-cases/transferability/transferability/finetune/8b_full.yaml

@@ -0,0 +1,94 @@
+# Config for multi-device full finetuning in full_finetune_distributed.py
+# using a Llama3.1 8B Instruct model
+#
+# This config assumes that you've run the following command before launching
+# this run:
+#   tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth"
+#
+# To launch on 4 devices, run the following command from root:
+#   tune run --nproc_per_node 4 full_finetune_distributed --config llama3_1/8B_full
+#
+# You can add specific overrides through the command line. For example
+# to override the checkpointer directory while launching training
+# you can run:
+#   tune run --nproc_per_node 4 full_finetune_distributed --config llama3_1/8B_full checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
+#
+# This config works best when the model is being fine-tuned on 2+ GPUs.
+# Single device full finetuning requires more memory optimizations. It's
+# best to use 8B_full_single_device.yaml for those cases
+
+
+# Model + LoRA settings
+model:
+  _component_: torchtune.models.llama3_1.llama3_1_8b
+
+# Tokenizer / vision transform
+tokenizer:
+  _component_: torchtune.models.llama3.llama3_tokenizer
+  max_num_tiles: 16
+  max_seq_len: 8192
+
+# Checkpointing
+checkpointer:
+  _component_: torchtune.training.FullModelHFCheckpointer
+  checkpoint_files: [
+    model-00001-of-00004.safetensors,
+    model-00002-of-00004.safetensors,
+    model-00003-of-00004.safetensors,
+    model-00004-of-00004.safetensors
+  ]
+  recipe_checkpoint: null
+  output_dir: ${output_dir}
+  model_type: LLAMA3
+resume_from_checkpoint: False
+dataset:
+  _component_: transferability.datasets.torchtune_format.sft_dataset
+
+# General data handling
+seed: 42
+shuffle: False
+
+# Training loop & hyperparams
+epochs: 1
+max_steps_per_epoch: null
+batch_size: 8
+gradient_accumulation_steps: 1 # Use to increase effective batch size
+
+optimizer:
+  _component_: torch.optim.AdamW
+  lr: 2e-5
+  fused: True
+optimizer_in_bwd: False
+
+lr_scheduler:
+  _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
+  num_warmup_steps: 20
+
+loss:
+  _component_: torchtune.modules.loss.LinearCrossEntropyLoss
+
+clip_grad_norm: null
+
+# cuda, cpu, rocm, xpu...
+device: cuda
+
+# Memory management / performance
+enable_activation_checkpointing: True
+enable_activation_offloading: False
+fsdp_cpu_offload: True
+compile: False # torch.compile, set to true for perf/memory improvement
+
+# Reduced precision
+dtype: bf16
+
+# Log metrics during training
+metric_logger:
+  _component_: torchtune.training.metric_logging.DiskLogger
+  log_dir: ${output_dir}/logs
+log_every_n_steps: 1
+log_peak_memory_stats: True
+
+# Useful for understanding how to optimize memory and performance
+profiler:
+  _component_: torchtune.training.setup_torch_profiler
+  enabled: False

+ 103 - 0
end-to-end-use-cases/transferability/transferability/finetune/8b_lora.yaml

@@ -0,0 +1,103 @@
+# Config for multi-device LoRA finetuning in lora_finetune_distributed.py
+# using a Llama3.1 8B Instruct model
+#
+# This config assumes that you've run the following command before launching
+# this run:
+#   tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth"
+#
+# To launch on 2 devices, run the following command from root:
+#   tune run --nproc_per_node 2 lora_finetune_distributed --config llama3_1/8B_lora
+#
+# You can add specific overrides through the command line. For example
+# to override the checkpointer directory while launching training
+# you can run:
+#   tune run --nproc_per_node 2 lora_finetune_distributed --config llama3_1/8B_lora checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
+#
+# This config works best when the model is being fine-tuned on 2+ GPUs.
+# For single device LoRA finetuning please use 8B_lora_single_device.yaml
+# or 8B_qlora_single_device.yaml
+
+output_dir: ""
+
+# Modeling Arguments
+model:
+  _component_: torchtune.models.llama3_1.llama3_1_8b
+  decoder_trainable: "lora"
+  encoder_trainable: "frozen"
+  fusion_trainable: "lora"
+  lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
+  apply_lora_to_mlp: True
+  apply_lora_to_output: False
+  lora_rank: 64  # higher increases accuracy and memory
+  lora_alpha: 128  # usually alpha=2*rank
+  lora_dropout: 0.05
+
+tokenizer:
+  _component_: torchtune.models.llama3.llama3_tokenizer
+  max_seq_len: null
+  max_num_tiles: 16
+
+checkpointer:
+  _component_: torchtune.training.FullModelHFCheckpointer
+  checkpoint_files: [
+    model-00001-of-00004.safetensors,
+    model-00002-of-00004.safetensors,
+    model-00003-of-00004.safetensors,
+    model-00004-of-00004.safetensors
+  ]
+  recipe_checkpoint: null
+  output_dir: ${output_dir}
+  model_type: LLAMA3
+resume_from_checkpoint: False
+
+dataset:
+  _component_: transferability.datasets.torchtune_format.sft_dataset
+
+# Training arguments
+epochs: 1
+max_steps_per_epoch: null
+batch_size: 24
+gradient_accumulation_steps: 1 # Use to increase effective batch size
+clip_grad_norm: null
+
+# Optimizer
+optimizer:
+  _component_: torch.optim.AdamW
+  lr: 3e-4
+  fused: False
+optimizer_in_bwd: False
+
+# Scheduler
+lr_scheduler:
+  _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
+  num_warmup_steps: 100
+
+loss:
+  _component_: torchtune.modules.loss.LinearCrossEntropyLoss
+
+
+# cuda, cpu, rocm, xpu...
+device: cuda
+seed: 42
+shuffle: False
+
+# Memory management / performance
+enable_activation_checkpointing: True
+enable_activation_offloading: False
+fsdp_cpu_offload: False
+compile: False # torch.compile, set to true for perf/memory improvement
+
+# Reduced precision
+dtype: bf16
+
+# Log metrics during training
+metric_logger:
+  _component_: torchtune.training.metric_logging.DiskLogger
+  log_dir: ${output_dir}/logs
+log_every_n_steps: 1
+log_peak_memory_stats: True
+
+# Useful for understanding how to optimize memory and performance
+profiler:
+  _component_: torchtune.training.setup_torch_profiler
+  enabled: False

+ 3 - 0
end-to-end-use-cases/transferability/transferability/finetune/__init__.py

@@ -0,0 +1,3 @@
+"""Fine-tuning utilities and grid search."""
+
+from .finetune_grid import run_finetune_grid

+ 14 - 0
end-to-end-use-cases/transferability/transferability/finetune/__main__.py

@@ -0,0 +1,14 @@
+"""Entry point for running fine-tuning grid as a module."""
+
+if __name__ == "__main__":
+    import sys
+
+    from .finetune_grid import run_finetune_grid
+
+    if len(sys.argv) < 2:
+        print("Usage: python -m transferability.finetune <experiment_dir>")
+        print("Example: python -m transferability.finetune ./experiments/my_experiment")
+        sys.exit(1)
+
+    experiment_dir = sys.argv[1]
+    run_finetune_grid(experiment_dir)

+ 232 - 0
end-to-end-use-cases/transferability/transferability/finetune/finetune_grid.py

@@ -0,0 +1,232 @@
+import os
+import subprocess
+from pathlib import Path
+
+from ..utils import load_config
+
+
+def get_general_finetune_args(finetuning_config, output_dir):
+    experiment_dir = Path(output_dir).parent
+    model_path = finetuning_config["model_path"]
+    if not os.path.exists(model_path):
+        raise RuntimeError(f"Model path {model_path} does not exist")
+    tokenizer_path = finetuning_config["tokenizer_path"]
+
+    # TODO: Change "task1" to task name defined in config
+    dataset_path = (
+        experiment_dir / "formatted_datasets" / "task1" / "train_conversation_data.json"
+    )
+
+    return [
+        f"dataset.dataset_path={dataset_path}",
+        f"checkpointer.checkpoint_dir={model_path}",
+        f"tokenizer.path={tokenizer_path}",
+        f"epochs={finetuning_config['epochs']}",
+        f"batch_size={finetuning_config['batch_size']}",
+        f"metric_logger.log_dir={experiment_dir}/finetune_logs",
+    ]
+
+
+def build_fft_jobs(config, output_dir):
+    """Build FFT (Full Fine-Tuning) jobs based on config"""
+    jobs = []
+    finetuning_config = config["finetuning"]
+
+    recipe = (
+        "full_finetune_distributed"
+        if finetuning_config.get("distributed")
+        else "full_finetune_single_device"
+    )
+
+    torchtune_config = finetuning_config.get("fft_torchtune_config")
+    base_cmd = [
+        "tune",
+        "run",
+        "--nproc_per_node",
+        str(finetuning_config["ngpu"]),
+        recipe,
+        "--config",
+        torchtune_config,
+    ]
+
+    base_cmd += get_general_finetune_args(finetuning_config, output_dir)
+
+    # Build list of modules to train based on config
+    modules_to_train = []
+    if finetuning_config.get("fusion", False):
+        modules_to_train.append("fusion")
+    if finetuning_config.get("fusion+encoder", False):
+        modules_to_train.append("fusion+encoder")
+    if finetuning_config.get("fusion+decoder", False):
+        modules_to_train.append("fusion+decoder")
+    if finetuning_config.get("fusion+encoder+decoder", False):
+        modules_to_train.append("fusion+encoder+decoder")
+
+    for modules in modules_to_train:
+        op_path = f"{output_dir}/full_{modules}"
+        if os.path.exists(op_path):
+            print(f"Skipping {op_path} as it already exists")
+            continue
+        module_opts = [f"model.{mod}_trainable=True" for mod in modules.split("+")]
+        jobs.append(base_cmd + [f"output_dir={op_path}"] + module_opts)
+
+    return jobs
+
+
+def build_lora_jobs(config, output_dir):
+    """Build LoRA jobs based on config"""
+    jobs = []
+    finetuning_config = config["finetuning"]
+
+    if not finetuning_config.get("lora_ranks"):
+        return jobs
+
+    recipe = (
+        "lora_finetune_distributed"
+        if finetuning_config.get("distributed")
+        else "lora_finetune_single_device"
+    )
+
+    torchtune_config = finetuning_config.get("lora_torchtune_config")
+
+    base_cmd = [
+        "tune",
+        "run",
+        "--nproc_per_node",
+        str(finetuning_config["ngpu"]),
+        recipe,
+        "--config",
+        torchtune_config,
+    ]
+
+    base_cmd += get_general_finetune_args(finetuning_config, output_dir)
+
+    for rank in finetuning_config["lora_ranks"]:
+        op_path = f"{output_dir}/lora_{rank}"
+        if os.path.exists(op_path):
+            print(f"Skipping {op_path} as it already exists")
+            continue
+        jobs.append(
+            base_cmd
+            + [
+                f"output_dir={op_path}",
+                f"model.lora_rank={rank}",
+                f"model.lora_alpha={int(rank)*2}",
+            ]
+        )
+
+    return jobs
+
+
+def run_finetune_grid(experiment_dir: str):
+    print("🚀 Starting fine-tuning grid execution...")
+    print(f"📁 Experiment directory: {experiment_dir}")
+
+    # Get script directory and config path
+    script_dir = Path(__file__).parent.parent.parent
+    config_path = script_dir / "config.yaml"
+    print(f"📝 Loading configuration from: {config_path}")
+
+    # Load configuration
+    config = load_config(config_path)
+    print("✅ Configuration loaded successfully")
+
+    # Set output directory
+    output_dir = Path(experiment_dir) / "finetuned_checkpoints"
+    print(f"💾 Output directory: {output_dir}")
+
+    # Create output directory if it doesn't exist
+    output_dir.mkdir(parents=True, exist_ok=True)
+    print("📂 Output directory created/verified")
+
+    # Build all jobs
+    all_jobs = []
+    print("\n🔧 Building fine-tuning jobs...")
+
+    # Check if we should run FFT jobs (if any fusion settings are enabled)
+    finetuning_config = config["finetuning"]
+    if any(
+        [
+            finetuning_config.get("fusion", False),
+            finetuning_config.get("fusion+encoder", False),
+            finetuning_config.get("fusion+decoder", False),
+            finetuning_config.get("fusion+encoder+decoder", False),
+        ]
+    ):
+        print("🔄 Building Full Fine-Tuning (FFT) jobs...")
+        fft_jobs = build_fft_jobs(config, output_dir)
+        all_jobs.extend(fft_jobs)
+        print(f"✅ Built {len(fft_jobs)} FFT jobs")
+
+        # Print details of FFT jobs
+        for i, job in enumerate(fft_jobs, 1):
+            job_type = "FFT"
+            modules = [arg for arg in job if "trainable=True" in str(arg)]
+            if modules:
+                module_info = ", ".join(
+                    [
+                        mod.split(".")[1].replace("_trainable=True", "")
+                        for mod in modules
+                    ]
+                )
+                print(f"   📋 FFT Job {i}: {module_info}")
+
+    # Check if we should run LoRA jobs
+    if finetuning_config.get("lora_ranks"):
+        print("🔄 Building LoRA fine-tuning jobs...")
+        lora_jobs = build_lora_jobs(config, output_dir)
+        all_jobs.extend(lora_jobs)
+        lora_count = len(lora_jobs)
+        print(f"✅ Built {lora_count} LoRA jobs")
+
+        # Print details of LoRA jobs
+        ranks = finetuning_config.get("lora_ranks", [])
+        for i, rank in enumerate(ranks, 1):
+            print(f"   📋 LoRA Job {i}: rank={rank}, alpha={rank*2}")
+
+    total_jobs = len(all_jobs)
+    print(f"\n📊 Total jobs to execute: {total_jobs}")
+
+    # Run all jobs
+    print(f"\n🎯 Executing {total_jobs} fine-tuning jobs...")
+    print("=" * 60)
+
+    for job_idx, job in enumerate(all_jobs, 1):
+        print(f"\n📈 Job {job_idx}/{total_jobs} - Starting...")
+
+        # Extract job type and details for better logging
+        job_type = "LoRA" if "lora_finetune" in " ".join(job) else "FFT"
+        output_path = next(
+            (arg.split("=")[1] for arg in job if arg.startswith("output_dir=")),
+            "unknown",
+        )
+        job_name = Path(output_path).name if output_path != "unknown" else "unknown"
+
+        print(f"🔧 Type: {job_type}")
+        print(f"📁 Output: {job_name}")
+        # print(f"⚡ Command: {' '.join(map(str, job))}")
+        print("-" * 40)
+
+        try:
+            print(f"⏳ Executing job {job_idx}/{total_jobs}...")
+            subprocess.run(job, check=True, capture_output=False)
+            print(f"✅ Job {job_idx}/{total_jobs} completed successfully!")
+
+        except subprocess.CalledProcessError as e:
+            print(
+                f"❌ Job {job_idx}/{total_jobs} failed with return code {e.returncode}"
+            )
+            print(f"💥 Error: {e}")
+            raise
+        except Exception as e:
+            print(f"❌ Job {job_idx}/{total_jobs} failed with unexpected error: {e}")
+            raise
+
+    print("\n" + "=" * 60)
+    print("🎉 All fine-tuning jobs completed successfully!")
+    print(f"📁 Results saved to: {output_dir}")
+    print("🏁 Fine-tuning grid execution finished.")
+
+
+if __name__ == "__main__":
+    run_finetune_grid("experiments/w2_ocr")

+ 75 - 0
end-to-end-use-cases/transferability/transferability/utils.py

@@ -0,0 +1,75 @@
+import base64
+import io
+from multiprocessing.pool import ThreadPool
+from typing import Any, Union
+
+import yaml
+from tqdm import tqdm
+
+
+def map_with_progress(
+    f: callable, xs: list[Any], num_threads: int = 50, show_progress: bool = True
+):
+    """
+    Apply f to each element of xs, using a ThreadPool, and show progress.
+    """
+    if show_progress:
+        with ThreadPool(min(num_threads, len(xs))) as pool:
+            return list(tqdm(pool.imap(f, xs), total=len(xs)))
+    else:
+        with ThreadPool(min(num_threads, len(xs))) as pool:
+            return list(pool.imap(f, xs))
+
+
+def load_config(config_path):
+    """Load configuration from YAML file"""
+    with open(config_path, "r") as f:
+        return yaml.safe_load(f)
+
+
+from PIL import Image
+
+
+def is_base64_encoded(s: str) -> bool:
+    """Check if a string is already base64 encoded."""
+    try:
+        # Basic character check - base64 only contains these characters
+        if not all(
+            c in "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/="
+            for c in s
+        ):
+            return False
+
+        # Try to decode - if it fails, it's not valid base64
+        decoded = base64.b64decode(s, validate=True)
+
+        # Re-encode and compare - if they match, it was valid base64
+        re_encoded = base64.b64encode(decoded).decode("utf-8")
+        return s == re_encoded or s == re_encoded.rstrip(
+            "="
+        )  # Handle padding differences
+    except Exception:
+        return False
+
+
+def image_to_base64_url(image: Union[str, list, Image.Image]):
+    if isinstance(image, str):
+        # Check if the string is already base64 encoded
+        if is_base64_encoded(image):
+            return image
+        # Otherwise, treat it as a file path
+        with open(image, "rb") as img:
+            img_format = image.split(".")[-1]
+            b64_string = base64.b64encode(img.read()).decode("utf-8")
+            return f"data:image/{img_format};base64,{b64_string}"
+    elif isinstance(image, Image.Image):
+        try:
+            img_format = image.format.lower()
+        except AttributeError:
+            img_format = "png"  # Default format
+        buffer = io.BytesIO()
+        image.save(buffer, format=img_format)
+        b64_string = base64.b64encode(buffer.getvalue()).decode("utf-8")
+        return f"data:image/{img_format};base64,{b64_string}"
+    elif isinstance(image, list):
+        return [image_to_base64_url(img) for img in image]