浏览代码

add checkpointing

Sanyam Bhutani 2 月之前
父节点
当前提交
a298a52777

+ 972 - 0
end-to-end-use-cases/data-tool/scripts/finetuning/fft.py

@@ -0,0 +1,972 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+import sys
+import time
+
+from functools import partial
+from typing import Any, Dict, List, Optional, Tuple, Union
+from warnings import warn
+
+import torch
+from omegaconf import DictConfig, ListConfig
+
+from torch import nn
+from torch.distributed import (
+    destroy_process_group,
+    init_device_mesh,
+    init_process_group,
+)
+from torch.distributed._tensor import DTensor
+from torch.distributed.tensor.parallel import parallelize_module
+from torch.optim import Optimizer
+from torch.utils.data import DataLoader, DistributedSampler
+from torchtune import config, modules, training, utils
+from torchtune.config._utils import _get_component_from_path
+from torchtune.data import padded_collate_packed
+from torchtune.datasets import ConcatDataset
+from torchtune.recipe_interfaces import FTRecipeInterface
+from torchtune.training import DummyProfiler, PROFILER_KEY
+from torchtune.training.activations import apply_selective_activation_checkpointing
+from torchtune.training.checkpointing._checkpoint_client import (
+    CheckpointClient,
+    TrainingProgress,
+)
+from torchtune.training.lr_schedulers import get_lr
+
+from tqdm import tqdm
+
+log = utils.get_logger("DEBUG")
+
+
+class FullFinetuneRecipeDistributed(FTRecipeInterface):
+    """
+    Full finetuning recipe for dense transformer-based LLMs such as Llama2. This recipe supports
+    distributed training and can be run on a single node (1 to 8 GPUs).
+
+    Features:
+        - FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states
+            is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is
+            done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config
+            ``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy).
+            DDP is currently not supported. Training on CPU is not supported.
+
+        - Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing``
+            flag. Activation checkpointing helps reduce the memory footprint since we no longer keep
+            activations in memory and instead recompute them during the backward pass. This is especially
+            helpful for larger batch sizes when you're memory constrained. But these savings in memory
+            come at the cost of training performance. In most cases training can slow-down quite a bit as
+            a result of this activation recomputation.
+
+        - Activation Offloading. This can be controlled using the ``enable_activation_offloading``
+            flag. Activation offloading is a technique similar to activations checkpointing that helps
+            reduce the memory footprint to prevent OOMs on CUDA and enable bigger batches. Where activations
+            checkpointing drops the activation in the forward to recompute it later in the backward,
+            activations offloading will drop the activation in the forward to the CPU and bring it
+            back during the backward pass. As always, there is a tradeoff--these savings in memory can
+            come at the cost of training performance and CPU resources. To recover some runtime cost,
+            we've added an option to enable offloading on a different stream to permit overlapping with
+            the computation. This option is currently only available on PyTorch 2.5 or later and will
+            be enabled by default if an acceptable torch version is found. Activation offloading can be
+            used in conjunction with activation checkpointing.
+
+        - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype``
+            flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In
+            most cases this should halve the memory footprint of full precision (fp32) training, without
+            loss in model quality (will depend on the model, training data and other settings). For
+            GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16
+            precision are currently not supported.
+
+        - Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is
+            controlled using the ``gradient_accumulation_steps`` flag.
+
+                Total Batch Size = batch_size * number of GPUs * gradient accumulation steps.
+
+            For example: with batch_size=1, nproc_per_node=2 and gradient_accumulation_steps=32 we get a
+            total batch size of 64.
+
+            Gradient accumulation is especially useful when you are memory constrained. In this case,
+            accumulating gradients might give you better training speed than enabling activation
+            checkpointing.
+
+        - Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of
+            training. Optimizer state and recipe state (seed, total_epochs, number of epochs run etc) are
+            only saved at the end of a given epoch and used in case of resuming training.
+
+            Resuming training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is
+            currently not supported.
+
+            For more details on the checkpointer, please take a look at
+            our checkpointer deepdive (https://pytorch.org/torchtune/main/deep_dives/checkpointer.html).
+
+        - Logging. Terminal, Disk, WandB and TensorBoard are all supported.
+
+        - Gradient Clipping. Gradient clipping is supported using the ``clip_grad_norm`` flag. By default,
+            ``clip_grad_norm`` is set to ``None``. If you only want to log the grad norm, you can set
+            ``clip_grad_norm='inf'``.
+
+    For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config
+    has example commands for how to kick-off training.
+
+    Args:
+        cfg (DictConfig): OmegaConf object parsed from yaml file
+
+    Raises:
+        ValueError: If ``dtype`` is set to fp16.
+        RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
+        RuntimeError: If ``left_pad_sequence`` is set as the data collator.
+        RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA.
+        RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False.
+    """
+
+    def __init__(self, cfg: DictConfig) -> None:
+        device_type = cfg.device
+        self._device = utils.get_device(device=device_type)
+        self._dtype = training.get_dtype(cfg.dtype, device=self._device)
+
+        if self._dtype == torch.float16:
+            raise ValueError(
+                "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
+            )
+
+        # Set up the backend for distributed training (NCCL, GLOO, etc.)
+        self._enable_async_checkpointing = cfg.get("enable_async_checkpointing", False)
+        self.fsdp_cpu_offload = cfg.get("fsdp_cpu_offload", False)
+        self.distributed_backend = training.get_distributed_backend(
+            device_type,
+            offload_ops_to_cpu=self.fsdp_cpu_offload
+            or self._enable_async_checkpointing,
+        )
+        init_process_group(self.distributed_backend)
+
+        # Initialize distributed variables
+        self.world_size, self.rank = utils.get_world_size_and_rank()
+        self._is_rank_zero = self.rank == 0
+        self.tensor_parallel_plan = config.instantiate(
+            cfg.get("tensor_parallel_plan", None)
+        )
+        self.tensor_parallel_dim = cfg.get("tensor_parallel_dim", 1)
+        if self.tensor_parallel_dim > 1 and self.tensor_parallel_plan is None:
+            raise ValueError(
+                "Tensor Parallel plan needs to be provided when tensor parallel is enabled."
+            )
+        if self.world_size % self.tensor_parallel_dim != 0:
+            raise ValueError(
+                f"world_size {self.world_size} must be divisible by tensor_parallel_dim {self.tensor_parallel_dim}"
+            )
+        self.data_parallel_dim = self.world_size // self.tensor_parallel_dim
+
+        # Logging attributes
+        self._output_dir = cfg.output_dir
+        self._log_every_n_steps = cfg.get("log_every_n_steps", 1)
+        self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False)
+        if self._log_peak_memory_stats and device_type != "cuda":
+            log.info(
+                "log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False."
+            )
+            self._log_peak_memory_stats = False
+
+        # Training cfg
+        self._resume_from_checkpoint = cfg.resume_from_checkpoint
+        self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
+        self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False)
+        self._clip_grad_norm = cfg.get("clip_grad_norm", None)
+        self._checkpoint_client = CheckpointClient(cfg)
+        self.save_every_epochs = cfg.get("save_every_epochs", 1)
+
+        # Optimizer in backward is not compatible with gradient accumulation or gradient clipping
+        if self._optimizer_in_bwd:
+            if self._clip_grad_norm is not None:
+                raise RuntimeError(
+                    "Gradient clipping is not supported with optimizer in bwd."
+                    "Please set clip_grad_norm=None, or optimizer_in_bwd=False."
+                )
+            if self._gradient_accumulation_steps > 1:
+                raise RuntimeError(
+                    "Gradient accumulation is not supported with optimizer in bwd."
+                    "Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False."
+                )
+
+        # activation checkpointing/offloading
+        self._enable_activation_checkpointing = cfg.get(
+            "enable_activation_checkpointing", False
+        )
+        self._enable_activation_offloading = cfg.get(
+            "enable_activation_offloading", False
+        )
+        if self._enable_activation_offloading:
+            if device_type != "cuda":
+                raise RuntimeError(
+                    "enable_activation_offloading should only be True when training on CUDA"
+                )
+            if not self._enable_activation_checkpointing:
+                raise RuntimeError(
+                    "enable_activation_offloading should only be True when enable_activation_checkpointing is True"
+                )
+        elif (
+            self._enable_activation_checkpointing
+            and cfg.checkpointer.model_type != "LLAMA3_VISION"
+        ):
+            utils.log_rank_zero(
+                log,
+                "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. "
+                "Enabling activation offloading should reduce memory further.",
+            )
+
+        # These are public properties which are updated by the checkpoint loader
+        # when ``resume_from_checkpoint`` is `True` or validated in tests
+        self.seed = training.set_seed(
+            seed=cfg.seed, debug_mode=cfg.get("cudnn_deterministic_mode", None)
+        )
+        self.epochs_run = 0
+        self.total_epochs = cfg.epochs
+        self.max_steps_per_epoch = cfg.max_steps_per_epoch
+        self.global_step = 0
+
+    def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None:
+        """
+        Updates the recipe state from checkpoint.
+        """
+        try:
+            self.epochs_run = ckpt_dict[training.EPOCHS_KEY]
+
+            # on mismatch, warn the user and prevent the override
+            if self.seed != ckpt_dict[training.SEED_KEY]:
+                warn(
+                    message=(
+                        "Config value for seed does not match the checkpoint value, "
+                        f"using the checkpoint value: {ckpt_dict[training.SEED_KEY]}"
+                    )
+                )
+                self.seed = ckpt_dict[training.SEED_KEY]
+            if self.max_steps_per_epoch != ckpt_dict[training.MAX_STEPS_KEY]:
+                warn(
+                    message=(
+                        "Config value for max_steps_per_epoch does not match the checkpoint value, "
+                        f"using the checkpoint value: {ckpt_dict[training.MAX_STEPS_KEY]}"
+                    )
+                )
+                self.max_steps_per_epoch = ckpt_dict[training.MAX_STEPS_KEY]
+
+            # on mismatch, warn the user but allow the override
+            if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]:
+                warn(
+                    message=(
+                        "Config value for total_epochs does not match the checkpoint value, "
+                        f"using the config value: {self.total_epochs}"
+                    )
+                )
+
+        except KeyError as e:
+            raise KeyError(
+                "Checkpoint does not contain the required keys needed for updating recipe state. "
+                "Are you sure you passed in the right recipe checkpoint?"
+            ) from e
+
+    def setup(self, cfg: DictConfig) -> None:
+        """
+        Setup the recipe. This includes training state (if resume_from_checkpoint is True),
+        model, tokenizer, loss, optimizer, lr scheduler, sampler, and dataloader.
+        """
+        if self.fsdp_cpu_offload:
+            # Utilize all available CPU cores for intra-op parallelism. This provides ~2x
+            # speed up when benchmarking fused AdamW on CPU
+            training.set_torch_num_threads()
+
+        if self._is_rank_zero:
+            self._metric_logger = config.instantiate(cfg.metric_logger)
+            # log config with parameter override
+            self._metric_logger.log_config(cfg)
+
+        # Load the base model
+        checkpoint_dict = self._checkpoint_client.load_base_checkpoint()
+
+        self._compile = cfg.get("compile", False)
+        self._model = self._setup_model(
+            cfg_model=cfg.model,
+            enable_activation_checkpointing=self._enable_activation_checkpointing,
+            enable_activation_offloading=self._enable_activation_offloading,
+            custom_sharded_layers=cfg.get("custom_sharded_layers", None),
+            fsdp_cpu_offload=self.fsdp_cpu_offload,
+            reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True),
+            model_state_dict=checkpoint_dict[training.MODEL_KEY],
+            ac_mode=cfg.get("ac_mode", None),
+            ac_option=cfg.get("ac_option", None),
+        )
+        self._tokenizer = config.instantiate(cfg.tokenizer)
+
+        self._optimizer = self._setup_optimizer(
+            cfg_optimizer=cfg.optimizer,
+            optimizer_in_bwd=self._optimizer_in_bwd,
+            opt_state_dict=(
+                checkpoint_dict[training.OPT_KEY]
+                if training.OPT_KEY in checkpoint_dict
+                else None
+            ),
+        )
+
+        if self._resume_from_checkpoint:
+            # If async checkpointing is enabled, intermediate checkpoints are saved asynchronously
+            # using the DistributedCheckpointer.
+            # Therefore the recipe needs to load the distributed checkpoint to restore the training
+            # progress.
+            if self._enable_async_checkpointing:
+                try:
+                    checkpoint_dict = (
+                        self._checkpoint_client.load_distributed_checkpoint(
+                            self._model,
+                            (
+                                self._optim_ckpt_wrapper
+                                if self._optimizer_in_bwd
+                                else self._optimizer
+                            ),
+                        )
+                    )
+                except Exception as e:
+                    log.warning(
+                        f"Failed to load distributed checkpoint: {e}. Training will start from the base checkpoint."
+                    )
+
+            # Update the recipe state from the checkpoint state dict.
+            self._update_recipe_state(checkpoint_dict)
+
+        # initialize loss
+        self._loss_fn = config.instantiate(cfg.loss)
+
+        if self._compile:
+            training.compile_loss(self._loss_fn, verbose=self._is_rank_zero)
+
+        if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss":
+            # set num_output_chunks for model
+            self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
+
+        utils.log_rank_zero(log, "Loss is initialized.")
+
+        # sampler and dataloader depend on the tokenizer and loss_fn and should be
+        # setup after both of these are initialized
+        collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft")
+        self._sampler, self._dataloader = self._setup_data(
+            cfg_dataset=cfg.dataset,
+            shuffle=cfg.shuffle,
+            batch_size=cfg.batch_size,
+            collate_fn=collate_name,
+        )
+
+        # Finally update the recipe state which can only be correctly set after all of the
+        # other components have been initialized and updated.
+        #
+        # Number of training steps in each epoch depends on the number of batches produced
+        # by the dataloader, the max_steps_per_epoch param set by the user and the
+        # gradient_accumulation_steps param. This value is used for logging and tracking
+        # training state. The computation should happen after the dataloader has been setup
+        self._steps_per_epoch = (
+            len(self._dataloader) // self._gradient_accumulation_steps
+        )
+        if (
+            self.max_steps_per_epoch is not None
+            and self.max_steps_per_epoch < self._steps_per_epoch
+        ):
+            self._steps_per_epoch = self.max_steps_per_epoch
+        self.global_step = self.epochs_run * self._steps_per_epoch
+
+        # Setup lr scheduler
+        self._lr_scheduler = self._setup_lr_scheduler(
+            cfg_lr_scheduler=cfg.get("lr_scheduler", None),
+            num_training_steps=self.total_epochs * self._steps_per_epoch,
+            last_epoch=self.global_step - 1,
+        )
+
+        # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method)
+        # if cfg is missing profiler key or if `cfg.profiler.enabled = False`
+        self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None))
+
+        # Used to ignore labels for loss computation
+        self.ignore_labels_cache = torch.full(
+            (cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device
+        )
+
+    def _setup_lr_scheduler(
+        self,
+        cfg_lr_scheduler: Optional[DictConfig],
+        num_training_steps: int,
+        last_epoch: int,
+    ) -> Optional[Optimizer]:
+        """
+        Set up the learning rate scheduler based on the provided configuration.
+        It supports both standard optimization and optimizer-in-backward cases.
+
+        Args:
+            cfg_lr_scheduler (Optional[DictConfig]): The learning rate scheduler configuration.
+            num_training_steps (int): The total number of training steps.
+            last_epoch (int): The index of the last epoch.
+
+        Returns:
+            lr_scheduler (Optional[Optimizer]): The learning rate scheduler.
+        """
+        if cfg_lr_scheduler is None:
+            if self._is_rank_zero:
+                log.info(
+                    "No learning rate scheduler configured. Using constant learning rate."
+                )
+            return None
+
+        if self._optimizer_in_bwd:
+            # Use the first optimizer from the wrapper to represent the learning rate
+            optimizer = next(iter(self._optim_ckpt_wrapper.optim_map.values()))
+        else:
+            # Standard case: use the single optimizer
+            optimizer = self._optimizer
+
+        # Instantiate the learning rate scheduler
+        lr_scheduler = config.instantiate(
+            cfg_lr_scheduler,
+            optimizer,
+            num_training_steps=num_training_steps,
+            last_epoch=last_epoch,
+        )
+
+        if self._optimizer_in_bwd:
+            # Modify the scheduler for optimizer_in_bwd case
+            self._optim_ckpt_wrapper.set_lr_scheduler(lr_scheduler)
+
+        if self._is_rank_zero:
+            log.info("Learning rate scheduler is initialized.")
+
+        return lr_scheduler
+
+    def _setup_profiler(
+        self, cfg_profiler: Optional[DictConfig] = None
+    ) -> Union[torch.profiler.profile, DummyProfiler]:
+        """
+        Parses the `profiler` section of top-level `cfg` and sets up profiler
+
+        Args:
+            cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to
+                `recipe.main`). Default None.
+
+        Returns:
+            profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods
+            for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such
+            that the instrumented training loop does not need to be changed profiling is disabled.
+
+        The profiler config can be provided in configs under the `profiler` key with the following layout:
+
+        .. code-block:: yaml
+            profiler:
+                enabled: bool
+
+                #Output directory of trace artifacts
+                output_dir: str
+
+            #`torch.profiler.ProfilerActivity` types to trace
+            cpu: bool
+            cuda: bool
+
+                #Trace options
+                profile_memory: bool
+                with_stack: bool
+                record_shapes: bool
+                with_flops: bool
+
+            # `torch.profiler.schedule` options:
+            # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
+            wait_steps: int
+            warmup_steps: int
+            active_steps: int
+            num_cycles: int
+        """
+        # Missing profiler section in config, assume disabled
+        if cfg_profiler is None:
+            cfg_profiler = DictConfig({"enabled": False})
+
+        # Check that component is included and set correctly
+        if cfg_profiler.get("_component_", None) is None:
+            cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler"
+        else:
+            assert (
+                cfg_profiler.get("_component_")
+                == "torchtune.training.setup_torch_profiler"
+            ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`"
+
+        profiler, profiler_cfg = config.instantiate(cfg_profiler)
+
+        utils.log_rank_zero(
+            log, f" Profiler config after instantiation: {profiler_cfg}"
+        )
+        if self._is_rank_zero:
+            self.profiler_profile_memory = profiler_cfg.get("profile_memory", False)
+            if profiler_cfg["enabled"]:
+                self.profiler_wait_steps = profiler_cfg["wait_steps"]
+                self.profiler_warmup_steps = profiler_cfg["warmup_steps"]
+                self.profiler_active_steps = profiler_cfg["active_steps"]
+
+        return profiler
+
+    def _setup_model(
+        self,
+        cfg_model: DictConfig,
+        enable_activation_checkpointing: bool,
+        enable_activation_offloading: bool,
+        fsdp_cpu_offload: bool,
+        reshard_after_forward: bool,
+        model_state_dict: Dict[str, Any],
+        custom_sharded_layers: Optional[List[str]] = None,
+        ac_mode: Optional[str] = None,
+        ac_option: Optional[int] = None,
+    ) -> nn.Module:
+        """
+        Model initialization has some important considerations:
+           a. To minimize GPU peak memory, we initialize the model on meta device with
+              the right dtype
+           b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since
+              full state dicts are loaded with ``torch.load(mmap=True)``
+        """
+
+        utils.log_rank_zero(
+            log,
+            "Distributed training is enabled. Instantiating model and loading checkpoint on Rank 0 ...",
+        )
+        init_start = time.perf_counter()
+
+        with training.set_default_dtype(self._dtype), torch.device("meta"):
+            model = config.instantiate(cfg_model)
+
+        if self._compile:
+            training.compile_model(model, verbose=self._is_rank_zero)
+
+        device_mesh = init_device_mesh(
+            self._device.type,
+            mesh_shape=(self.data_parallel_dim, self.tensor_parallel_dim),
+            mesh_dim_names=("dp", "tp"),
+        )
+        self.dp_size = device_mesh["dp"].size()
+        self.dp_rank = device_mesh["dp"].get_local_rank()
+
+        # Apply tensor parallelism to the model
+        if self.tensor_parallel_dim > 1:
+            # Use the local number (num_heads, num_kv_heads, embed_dim) to account for tensor parallel
+            model = training.prepare_mha_for_tp(model, device_mesh["tp"])
+            parallelize_module(
+                model,
+                device_mesh["tp"],
+                parallelize_plan=self.tensor_parallel_plan,
+            )
+
+        # We currently have two versions of activation checkpointing in this recipe
+        # for testing and BC purposes. ``enable_activation_checkpointing`` controls
+        # the older version of AC and this behavior is unchanged
+        # ac_mode and ac_option together control selective AC. This is only enabled
+        # when these are set AND ``enable_activation_checkpointing`` is set to False
+        # We'll clean this up as soon as testing of AC is complete
+        if (not enable_activation_checkpointing) and (ac_mode is not None):
+            apply_selective_activation_checkpointing(
+                model,
+                ac_mode,
+                ac_option,
+            )
+
+        # original activation checkpointing (full) - flip the condition above
+        if enable_activation_checkpointing and ac_mode is None:
+            training.set_activation_checkpointing(
+                model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
+            )
+
+        # Apply Fully Sharded Data Parallelism to the model
+        if self.data_parallel_dim > 1:
+            fsdp_shard_conditions = [
+                partial(
+                    training.get_shard_conditions,
+                    names_to_match=custom_sharded_layers,
+                )
+            ]
+            training.shard_model(
+                model=model,
+                shard_conditions=fsdp_shard_conditions,
+                cpu_offload=fsdp_cpu_offload,
+                reshard_after_forward=reshard_after_forward,
+                dp_mesh=device_mesh["dp"],
+            )
+
+        with training.set_default_dtype(self._dtype), self._device:
+            for m in model.modules():
+                # RoPE is not covered in state dict
+                if hasattr(m, "rope_init"):
+                    m.rope_init()
+
+        # This method will convert the full model state dict into a sharded state
+        # dict and load into the model
+        training.load_from_full_model_state_dict(
+            model,
+            model_state_dict,
+            self._device,
+            strict=True,
+            cpu_offload=fsdp_cpu_offload,
+        )
+
+        # activation offloading
+        self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
+            model, enable_activation_offloading
+        )
+
+        # Ensure no params and buffers are on meta device
+        training.validate_no_params_on_meta_device(model)
+
+        utils.log_rank_zero(
+            log,
+            f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs",
+        )
+
+        if self._is_rank_zero:
+            memory_stats = training.get_memory_stats(device=self._device)
+            training.log_memory_stats(memory_stats)
+
+        # synchronize before training begins
+        torch.distributed.barrier()
+
+        return model
+
+    def _setup_optimizer(
+        self,
+        cfg_optimizer: DictConfig,
+        optimizer_in_bwd: bool = False,
+        opt_state_dict: Optional[Dict[str, Any]] = None,
+    ) -> Optional[Optimizer]:
+        if optimizer_in_bwd:
+            # Maintain a dict of optims for every parameter.
+            optim_dict = {
+                param: config.instantiate(cfg_optimizer, [param])
+                for param in self._model.parameters()
+            }
+
+            # Register optimizer step hooks on the model to run optimizer in backward.
+            training.register_optim_in_bwd_hooks(
+                model=self._model, optim_dict=optim_dict
+            )
+            # Create a wrapper for checkpoint save/load of optimizer states when running in backward.
+            self._optim_ckpt_wrapper = training.create_optim_in_bwd_wrapper(
+                model=self._model, optim_dict=optim_dict
+            )
+            # Load optimizer states for each param. If optimizer states are being restored in an optimizer in
+            # backward run, these need to have been saved with the same setting. Cannot restore from runs that
+            # did not use optimizer in backward.
+            if opt_state_dict is not None:
+                for param in opt_state_dict.keys():
+                    try:
+                        training.load_from_full_optimizer_state_dict(
+                            self._model,
+                            self._optim_ckpt_wrapper.optim_map[param],
+                            opt_state_dict[param],
+                            self._device,
+                        )
+                    except BaseException as e:
+                        raise RuntimeError(
+                            "Failed loading in-backward optimizer checkpoints."
+                            "Please make sure run being restored from was using in-backward optimizer."
+                        ) from e
+            utils.log_rank_zero(log, "In-backward optimizers are set up.")
+            return None
+        else:
+            optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
+            if opt_state_dict:
+                training.load_from_full_optimizer_state_dict(
+                    self._model,
+                    optimizer,
+                    opt_state_dict,
+                    self._device,
+                )
+
+            utils.log_rank_zero(log, "Optimizer is initialized.")
+            return optimizer
+
+    def _setup_data(
+        self,
+        cfg_dataset: DictConfig,
+        shuffle: bool,
+        batch_size: int,
+        collate_fn: str,
+    ) -> Tuple[DistributedSampler, DataLoader]:
+        """
+        All data related setup happens here. Currently this recipe only supports the
+        DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
+        iterable datasets and streaming datasets are not supported.
+        """
+        if isinstance(cfg_dataset, ListConfig):
+            datasets = [
+                config.instantiate(single_cfg_dataset, self._tokenizer)
+                for single_cfg_dataset in cfg_dataset
+            ]
+            ds = ConcatDataset(datasets=datasets)
+            packed = getattr(ds, "packed", False)
+        else:
+            ds = config.instantiate(cfg_dataset, self._tokenizer)
+            packed = cfg_dataset.get("packed", False)
+
+        # Instantiate collate_fn
+        if "left_pad_sequence" in collate_fn:
+            raise RuntimeError("left_pad_sequence collator is only for inference.")
+        collate_fn = _get_component_from_path(collate_fn)
+
+        sampler = DistributedSampler(
+            ds, num_replicas=self.dp_size, rank=self.dp_rank, shuffle=shuffle, seed=0
+        )
+        dataloader = DataLoader(
+            dataset=ds,
+            batch_size=batch_size,
+            sampler=sampler,
+            # dropping last avoids shape issues with compile + flex attention
+            drop_last=True,
+            collate_fn=(
+                partial(
+                    collate_fn,
+                    padding_idx=self._tokenizer.pad_id,
+                    ignore_idx=self._loss_fn.ignore_index,
+                )
+                if not packed
+                else padded_collate_packed
+            ),
+        )
+
+        utils.log_rank_zero(log, "Dataset and Sampler are initialized.")
+
+        return sampler, dataloader
+
+    def train(self) -> None:
+        """
+        The core training loop.
+        """
+        # clean up before training begins
+        training.cleanup_before_training()
+
+        # zero out the gradients before starting training
+        if not self._optimizer_in_bwd:
+            self._optimizer.zero_grad()
+        else:
+            for opt in self._optim_ckpt_wrapper.optim_map.values():
+                opt.zero_grad()
+
+        # Initialize tokens count and running loss (for grad accumulation)
+        t0 = time.perf_counter()
+        running_loss = 0
+        num_tokens = 0
+
+        self._profiler.start()
+        # self.epochs_run should be non-zero when we're resuming from a checkpoint
+        for curr_epoch in range(self.epochs_run, self.total_epochs):
+            # Update the sampler to ensure data is correctly shuffled across epochs
+            # in case shuffle is True
+            self._sampler.set_epoch(curr_epoch)
+
+            pbar = tqdm(total=self._steps_per_epoch, disable=not self._is_rank_zero)
+            for idx, batch in enumerate(self._dataloader):
+                if (
+                    self.max_steps_per_epoch is not None
+                    and (idx // self._gradient_accumulation_steps)
+                    == self.max_steps_per_epoch
+                ):
+                    break
+
+                # Start tracking CUDA memory for active steps for just the first epoch
+                if (
+                    self._is_rank_zero
+                    and curr_epoch == 0
+                    and self.profiler_profile_memory
+                    and idx == self.profiler_wait_steps + self.profiler_warmup_steps
+                    and self._device.type == "cuda"
+                ):
+                    torch.cuda.memory._record_memory_history()
+                utils.batch_to_device(batch, self._device)
+
+                # Calculate the number of unmasked tokens in the current batch
+                # and increment the total number of tokens seen in the step
+                current_num_tokens = (
+                    batch["labels"] != self._loss_fn.ignore_index
+                ).sum()
+                num_tokens += current_num_tokens
+
+                # Shape [b, s], needed for the loss not the model
+                labels = batch.pop("labels")
+
+                with self.activations_handling_ctx:
+                    logits = self._model(**batch)
+
+                # Shift labels to compute loss
+                # equivalent to doing labels[..., 1:] and logits[..., :-1, :]
+                # But this way we dont need to slice the logits. We just add an ignore index to labels.
+                labels = torch.hstack(
+                    (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]])
+                )
+                if not isinstance(logits, list):
+                    labels = labels.reshape(-1)
+                    logits = logits.reshape(-1, logits.size(-1))
+
+                # Compute loss
+                # Loss is normalized by default so we multiply by the number of tokens
+                # This way we can normalize by the total number of tokens if we're accumulating gradients
+                current_loss = self._loss_fn(logits, labels) * current_num_tokens
+
+                # free logits otherwise it peaks backward memory
+                del logits
+
+                running_loss += current_loss
+
+                # For optimizer in backward, we need to normalize before calling backward
+                # This case and gradient accumulation are mutually exclusive
+                if self._optimizer_in_bwd:
+                    torch.distributed.all_reduce(num_tokens)
+                    torch.distributed.all_reduce(running_loss)
+
+                    # We multiply by world_size to undo FSDP2 gradient normalization.
+                    current_loss = current_loss * (self.world_size / num_tokens)
+
+                current_loss.backward()
+
+                # Step with optimizer
+                if (idx + 1) % self._gradient_accumulation_steps == 0:
+                    if not self._optimizer_in_bwd:
+                        # Get total number of tokens across all ranks to normalize gradients
+                        torch.distributed.all_reduce(num_tokens)
+                        # This will ensure that the logged loss matches what we're optimizing
+                        torch.distributed.all_reduce(running_loss)
+                        # Manually scale the gradients from unnormalized loss by total # of tokens
+                        # We multiply by world_size to undo FSDP2 gradient normalization.
+                        training.scale_grads(self._model, self.world_size / num_tokens)
+                        if self._clip_grad_norm is not None:
+                            grad_norm = torch.nn.utils.clip_grad_norm_(
+                                self._model.parameters(),
+                                max_norm=float(self._clip_grad_norm),
+                            )
+                            # If sharded, collect the DTensor here
+                            if isinstance(grad_norm, DTensor):
+                                grad_norm = grad_norm.full_tensor()
+                        self._optimizer.step()
+                        self._optimizer.zero_grad(set_to_none=True)
+
+                    # Update the number of steps when the weights are updated
+                    self.global_step += 1
+
+                    # Step the learning rate scheduler
+                    if self._lr_scheduler is not None:
+                        self._lr_scheduler.step()
+
+                    loss_to_log = running_loss.item() / num_tokens
+                    pbar.update(1)
+                    pbar.set_description(
+                        f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
+                    )
+
+                    # Log per-step metrics
+                    if (
+                        self.global_step % self._log_every_n_steps == 0
+                        and self._is_rank_zero
+                    ):
+                        time_per_step = time.perf_counter() - t0
+                        log_dict = {
+                            "loss": loss_to_log,
+                            "lr": get_lr(
+                                (
+                                    self._optimizer
+                                    if not self._optimizer_in_bwd
+                                    else self._optim_ckpt_wrapper
+                                ),
+                            ),
+                            "tokens_per_second_per_gpu": num_tokens
+                            / (time_per_step * self.world_size),
+                        }
+                        if self._log_peak_memory_stats:
+                            log_dict.update(
+                                training.get_memory_stats(device=self._device)
+                            )
+                        if self._clip_grad_norm is not None:
+                            log_dict.update({"grad_norm": grad_norm})
+                        self._metric_logger.log_dict(
+                            log_dict,
+                            step=self.global_step,
+                        )
+
+                    # Reset running stats for the next step
+                    running_loss = 0
+                    num_tokens = 0
+                    t0 = time.perf_counter()
+
+                    # Stop tracking CUDA memory now that active steps are complete
+                    if (
+                        self._is_rank_zero
+                        and curr_epoch == 0
+                        and self.profiler_profile_memory
+                        and idx
+                        == self.profiler_wait_steps
+                        + self.profiler_warmup_steps
+                        + self.profiler_active_steps
+                        and self._device.type == "cuda"
+                    ):
+                        torch.cuda.memory._record_memory_history(enabled=None)
+
+                    # Step profiler
+                    # Note that this is called within gradient accumulation block, hence
+                    # will include multiple forward / backward passes if gradient accumulation > 1
+                    self._profiler.step()
+
+            self.epochs_run += 1
+            # self._checkpoint_client.save_checkpoint(
+            #     model=self._model,
+            #     optimizer=(
+            #         self._optimizer
+            #         if not self._optimizer_in_bwd
+            #         else self._optim_ckpt_wrapper
+            #     ),
+            #     training_progress=TrainingProgress(
+            #         seed=self.seed,
+            #         epochs_run=self.epochs_run,
+            #         total_epochs=self.total_epochs,
+            #         max_steps_per_epoch=self.max_steps_per_epoch,
+            #     ),
+            #     epoch=curr_epoch,
+            # )
+            self.epochs_run += 1
+            if curr_epoch > 0 and curr_epoch % self.save_every_epochs == 0:
+                utils.log_rank_zero(log, f"Saving checkpoint at epoch {curr_epoch}")
+                self._checkpoint_client.save_checkpoint(
+                    model=self._model,
+                    optimizer=(
+                        self._optimizer
+                        if not self._optimizer_in_bwd
+                        else self._optim_ckpt_wrapper
+                    ),
+                    training_progress=TrainingProgress(
+                        seed=self.seed,
+                        epochs_run=self.epochs_run,
+                        total_epochs=self.total_epochs,
+                        max_steps_per_epoch=self.max_steps_per_epoch,
+                    ),
+                    epoch=curr_epoch,
+                )
+
+        self._profiler.stop()
+
+    def cleanup(self) -> None:
+        if self._is_rank_zero:
+            self._metric_logger.close()
+        destroy_process_group()
+
+
+@config.parse
+def recipe_main(cfg: DictConfig) -> None:
+    """
+    Entry point for the recipe.
+
+    Configurable parameters are read in the following order:
+        - Parameters specified in config (see available configs through ``tune ls``)
+        - Overwritten by arguments from the command-line
+    """
+    config.log_config(recipe_name="FullFinetuneRecipeDistributed", cfg=cfg)
+    recipe = FullFinetuneRecipeDistributed(cfg=cfg)
+    recipe.setup(cfg=cfg)
+    recipe.train()
+    recipe.cleanup()
+
+
+if __name__ == "__main__":
+    sys.exit(recipe_main())

+ 12 - 6
end-to-end-use-cases/data-tool/scripts/finetuning/ft-config.yaml

@@ -52,8 +52,14 @@ checkpointer:
 resume_from_checkpoint: False
 
 # Fine-tuning arguments
-batch_size: 2
-epochs: 1
+batch_size: 4
+epochs: 30
+save_every_epochs: 10
+max_steps_per_epoch: null
+
+lr_scheduler:
+  _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
+  num_warmup_steps: 10
 
 optimizer:
   _component_: torch.optim.AdamW
@@ -64,7 +70,6 @@ optimizer:
 
 loss:
   _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
-max_steps_per_epoch: null
 gradient_accumulation_steps: 1  # Use to increase effective batch size
 
 
@@ -74,8 +79,8 @@ device: cuda
 # Memory management
 enable_activation_checkpointing: True  # True reduces memory
 enable_activation_offloading: False  # True reduces memory
-custom_sharded_layers: ['tok_embeddings', 'output']  # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed.
-fsdp_cpu_offload: True
+#custom_sharded_layers: ['tok_embeddings', 'output']  # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed.
+fsdp_cpu_offload: False
 clip_grad_norm: null
 compile: False  # torch.compile the model + loss, True increases speed + decreases memory
 optimizer_in_bwd: False  # True saves memory. Requires gradient_accumulation_steps=1
@@ -85,7 +90,8 @@ dtype: bf16
 
 # Logging
 metric_logger:
-  _component_: torchtune.training.metric_logging.DiskLogger
+  _component_: torchtune.training.metric_logging.WandBLogger
+  project: ctt
   log_dir: ${output_dir}/logs
 log_every_n_steps: 1
 log_peak_memory_stats: True