123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- #
- # This source code is licensed under the BSD-style license found in the
- # LICENSE file in the root directory of this source tree.
- import sys
- import time
- from functools import partial
- from typing import Any, Dict, List, Optional, Tuple, Union
- from warnings import warn
- import torch
- from omegaconf import DictConfig, ListConfig
- from torch import nn
- from torch.distributed import (
- destroy_process_group,
- init_device_mesh,
- init_process_group,
- )
- from torch.distributed._tensor import DTensor
- from torch.distributed.tensor.parallel import parallelize_module
- from torch.optim import Optimizer
- from torch.utils.data import DataLoader, DistributedSampler
- from torchtune import config, modules, training, utils
- from torchtune.config._utils import _get_component_from_path
- from torchtune.data import padded_collate_packed
- from torchtune.datasets import ConcatDataset
- from torchtune.recipe_interfaces import FTRecipeInterface
- from torchtune.training import DummyProfiler, PROFILER_KEY
- from torchtune.training.activations import apply_selective_activation_checkpointing
- from torchtune.training.checkpointing._checkpoint_client import (
- CheckpointClient,
- TrainingProgress,
- )
- from torchtune.training.lr_schedulers import get_lr
- from tqdm import tqdm
- log = utils.get_logger("DEBUG")
- class FullFinetuneRecipeDistributed(FTRecipeInterface):
- """
- Full finetuning recipe for dense transformer-based LLMs such as Llama2. This recipe supports
- distributed training and can be run on a single node (1 to 8 GPUs).
- Features:
- - FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states
- is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is
- done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config
- ``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy).
- DDP is currently not supported. Training on CPU is not supported.
- - Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing``
- flag. Activation checkpointing helps reduce the memory footprint since we no longer keep
- activations in memory and instead recompute them during the backward pass. This is especially
- helpful for larger batch sizes when you're memory constrained. But these savings in memory
- come at the cost of training performance. In most cases training can slow-down quite a bit as
- a result of this activation recomputation.
- - Activation Offloading. This can be controlled using the ``enable_activation_offloading``
- flag. Activation offloading is a technique similar to activations checkpointing that helps
- reduce the memory footprint to prevent OOMs on CUDA and enable bigger batches. Where activations
- checkpointing drops the activation in the forward to recompute it later in the backward,
- activations offloading will drop the activation in the forward to the CPU and bring it
- back during the backward pass. As always, there is a tradeoff--these savings in memory can
- come at the cost of training performance and CPU resources. To recover some runtime cost,
- we've added an option to enable offloading on a different stream to permit overlapping with
- the computation. This option is currently only available on PyTorch 2.5 or later and will
- be enabled by default if an acceptable torch version is found. Activation offloading can be
- used in conjunction with activation checkpointing.
- - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype``
- flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In
- most cases this should halve the memory footprint of full precision (fp32) training, without
- loss in model quality (will depend on the model, training data and other settings). For
- GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16
- precision are currently not supported.
- - Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is
- controlled using the ``gradient_accumulation_steps`` flag.
- Total Batch Size = batch_size * number of GPUs * gradient accumulation steps.
- For example: with batch_size=1, nproc_per_node=2 and gradient_accumulation_steps=32 we get a
- total batch size of 64.
- Gradient accumulation is especially useful when you are memory constrained. In this case,
- accumulating gradients might give you better training speed than enabling activation
- checkpointing.
- - Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of
- training. Optimizer state and recipe state (seed, total_epochs, number of epochs run etc) are
- only saved at the end of a given epoch and used in case of resuming training.
- Resuming training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is
- currently not supported.
- For more details on the checkpointer, please take a look at
- our checkpointer deepdive (https://pytorch.org/torchtune/main/deep_dives/checkpointer.html).
- - Logging. Terminal, Disk, WandB and TensorBoard are all supported.
- - Gradient Clipping. Gradient clipping is supported using the ``clip_grad_norm`` flag. By default,
- ``clip_grad_norm`` is set to ``None``. If you only want to log the grad norm, you can set
- ``clip_grad_norm='inf'``.
- For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config
- has example commands for how to kick-off training.
- Args:
- cfg (DictConfig): OmegaConf object parsed from yaml file
- Raises:
- ValueError: If ``dtype`` is set to fp16.
- RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
- RuntimeError: If ``left_pad_sequence`` is set as the data collator.
- RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA.
- RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False.
- """
- def __init__(self, cfg: DictConfig) -> None:
- device_type = cfg.device
- self._device = utils.get_device(device=device_type)
- self._dtype = training.get_dtype(cfg.dtype, device=self._device)
- if self._dtype == torch.float16:
- raise ValueError(
- "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
- )
- # Set up the backend for distributed training (NCCL, GLOO, etc.)
- self._enable_async_checkpointing = cfg.get("enable_async_checkpointing", False)
- self.fsdp_cpu_offload = cfg.get("fsdp_cpu_offload", False)
- self.distributed_backend = training.get_distributed_backend(
- device_type,
- offload_ops_to_cpu=self.fsdp_cpu_offload
- or self._enable_async_checkpointing,
- )
- init_process_group(self.distributed_backend)
- # Initialize distributed variables
- self.world_size, self.rank = utils.get_world_size_and_rank()
- self._is_rank_zero = self.rank == 0
- self.tensor_parallel_plan = config.instantiate(
- cfg.get("tensor_parallel_plan", None)
- )
- self.tensor_parallel_dim = cfg.get("tensor_parallel_dim", 1)
- if self.tensor_parallel_dim > 1 and self.tensor_parallel_plan is None:
- raise ValueError(
- "Tensor Parallel plan needs to be provided when tensor parallel is enabled."
- )
- if self.world_size % self.tensor_parallel_dim != 0:
- raise ValueError(
- f"world_size {self.world_size} must be divisible by tensor_parallel_dim {self.tensor_parallel_dim}"
- )
- self.data_parallel_dim = self.world_size // self.tensor_parallel_dim
- # Logging attributes
- self._output_dir = cfg.output_dir
- self._log_every_n_steps = cfg.get("log_every_n_steps", 1)
- self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False)
- if self._log_peak_memory_stats and device_type != "cuda":
- log.info(
- "log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False."
- )
- self._log_peak_memory_stats = False
- # Training cfg
- self._resume_from_checkpoint = cfg.resume_from_checkpoint
- self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
- self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False)
- self._clip_grad_norm = cfg.get("clip_grad_norm", None)
- self._checkpoint_client = CheckpointClient(cfg)
- self.save_every_epochs = cfg.get("save_every_epochs", 1)
- # Optimizer in backward is not compatible with gradient accumulation or gradient clipping
- if self._optimizer_in_bwd:
- if self._clip_grad_norm is not None:
- raise RuntimeError(
- "Gradient clipping is not supported with optimizer in bwd."
- "Please set clip_grad_norm=None, or optimizer_in_bwd=False."
- )
- if self._gradient_accumulation_steps > 1:
- raise RuntimeError(
- "Gradient accumulation is not supported with optimizer in bwd."
- "Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False."
- )
- # activation checkpointing/offloading
- self._enable_activation_checkpointing = cfg.get(
- "enable_activation_checkpointing", False
- )
- self._enable_activation_offloading = cfg.get(
- "enable_activation_offloading", False
- )
- if self._enable_activation_offloading:
- if device_type != "cuda":
- raise RuntimeError(
- "enable_activation_offloading should only be True when training on CUDA"
- )
- if not self._enable_activation_checkpointing:
- raise RuntimeError(
- "enable_activation_offloading should only be True when enable_activation_checkpointing is True"
- )
- elif (
- self._enable_activation_checkpointing
- and cfg.checkpointer.model_type != "LLAMA3_VISION"
- ):
- utils.log_rank_zero(
- log,
- "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. "
- "Enabling activation offloading should reduce memory further.",
- )
- # These are public properties which are updated by the checkpoint loader
- # when ``resume_from_checkpoint`` is `True` or validated in tests
- self.seed = training.set_seed(
- seed=cfg.seed, debug_mode=cfg.get("cudnn_deterministic_mode", None)
- )
- self.epochs_run = 0
- self.total_epochs = cfg.epochs
- self.max_steps_per_epoch = cfg.max_steps_per_epoch
- self.global_step = 0
- def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None:
- """
- Updates the recipe state from checkpoint.
- """
- try:
- self.epochs_run = ckpt_dict[training.EPOCHS_KEY]
- # on mismatch, warn the user and prevent the override
- if self.seed != ckpt_dict[training.SEED_KEY]:
- warn(
- message=(
- "Config value for seed does not match the checkpoint value, "
- f"using the checkpoint value: {ckpt_dict[training.SEED_KEY]}"
- )
- )
- self.seed = ckpt_dict[training.SEED_KEY]
- if self.max_steps_per_epoch != ckpt_dict[training.MAX_STEPS_KEY]:
- warn(
- message=(
- "Config value for max_steps_per_epoch does not match the checkpoint value, "
- f"using the checkpoint value: {ckpt_dict[training.MAX_STEPS_KEY]}"
- )
- )
- self.max_steps_per_epoch = ckpt_dict[training.MAX_STEPS_KEY]
- # on mismatch, warn the user but allow the override
- if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]:
- warn(
- message=(
- "Config value for total_epochs does not match the checkpoint value, "
- f"using the config value: {self.total_epochs}"
- )
- )
- except KeyError as e:
- raise KeyError(
- "Checkpoint does not contain the required keys needed for updating recipe state. "
- "Are you sure you passed in the right recipe checkpoint?"
- ) from e
- def setup(self, cfg: DictConfig) -> None:
- """
- Setup the recipe. This includes training state (if resume_from_checkpoint is True),
- model, tokenizer, loss, optimizer, lr scheduler, sampler, and dataloader.
- """
- if self.fsdp_cpu_offload:
- # Utilize all available CPU cores for intra-op parallelism. This provides ~2x
- # speed up when benchmarking fused AdamW on CPU
- training.set_torch_num_threads()
- if self._is_rank_zero:
- self._metric_logger = config.instantiate(cfg.metric_logger)
- # log config with parameter override
- self._metric_logger.log_config(cfg)
- # Load the base model
- checkpoint_dict = self._checkpoint_client.load_base_checkpoint()
- self._compile = cfg.get("compile", False)
- self._model = self._setup_model(
- cfg_model=cfg.model,
- enable_activation_checkpointing=self._enable_activation_checkpointing,
- enable_activation_offloading=self._enable_activation_offloading,
- custom_sharded_layers=cfg.get("custom_sharded_layers", None),
- fsdp_cpu_offload=self.fsdp_cpu_offload,
- reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True),
- model_state_dict=checkpoint_dict[training.MODEL_KEY],
- ac_mode=cfg.get("ac_mode", None),
- ac_option=cfg.get("ac_option", None),
- )
- self._tokenizer = config.instantiate(cfg.tokenizer)
- self._optimizer = self._setup_optimizer(
- cfg_optimizer=cfg.optimizer,
- optimizer_in_bwd=self._optimizer_in_bwd,
- opt_state_dict=(
- checkpoint_dict[training.OPT_KEY]
- if training.OPT_KEY in checkpoint_dict
- else None
- ),
- )
- if self._resume_from_checkpoint:
- # If async checkpointing is enabled, intermediate checkpoints are saved asynchronously
- # using the DistributedCheckpointer.
- # Therefore the recipe needs to load the distributed checkpoint to restore the training
- # progress.
- if self._enable_async_checkpointing:
- try:
- checkpoint_dict = (
- self._checkpoint_client.load_distributed_checkpoint(
- self._model,
- (
- self._optim_ckpt_wrapper
- if self._optimizer_in_bwd
- else self._optimizer
- ),
- )
- )
- except Exception as e:
- log.warning(
- f"Failed to load distributed checkpoint: {e}. Training will start from the base checkpoint."
- )
- # Update the recipe state from the checkpoint state dict.
- self._update_recipe_state(checkpoint_dict)
- # initialize loss
- self._loss_fn = config.instantiate(cfg.loss)
- if self._compile:
- training.compile_loss(self._loss_fn, verbose=self._is_rank_zero)
- if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss":
- # set num_output_chunks for model
- self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
- utils.log_rank_zero(log, "Loss is initialized.")
- # sampler and dataloader depend on the tokenizer and loss_fn and should be
- # setup after both of these are initialized
- collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft")
- self._sampler, self._dataloader = self._setup_data(
- cfg_dataset=cfg.dataset,
- shuffle=cfg.shuffle,
- batch_size=cfg.batch_size,
- collate_fn=collate_name,
- )
- # Finally update the recipe state which can only be correctly set after all of the
- # other components have been initialized and updated.
- #
- # Number of training steps in each epoch depends on the number of batches produced
- # by the dataloader, the max_steps_per_epoch param set by the user and the
- # gradient_accumulation_steps param. This value is used for logging and tracking
- # training state. The computation should happen after the dataloader has been setup
- self._steps_per_epoch = (
- len(self._dataloader) // self._gradient_accumulation_steps
- )
- if (
- self.max_steps_per_epoch is not None
- and self.max_steps_per_epoch < self._steps_per_epoch
- ):
- self._steps_per_epoch = self.max_steps_per_epoch
- self.global_step = self.epochs_run * self._steps_per_epoch
- # Setup lr scheduler
- self._lr_scheduler = self._setup_lr_scheduler(
- cfg_lr_scheduler=cfg.get("lr_scheduler", None),
- num_training_steps=self.total_epochs * self._steps_per_epoch,
- last_epoch=self.global_step - 1,
- )
- # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method)
- # if cfg is missing profiler key or if `cfg.profiler.enabled = False`
- self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None))
- # Used to ignore labels for loss computation
- self.ignore_labels_cache = torch.full(
- (cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device
- )
- def _setup_lr_scheduler(
- self,
- cfg_lr_scheduler: Optional[DictConfig],
- num_training_steps: int,
- last_epoch: int,
- ) -> Optional[Optimizer]:
- """
- Set up the learning rate scheduler based on the provided configuration.
- It supports both standard optimization and optimizer-in-backward cases.
- Args:
- cfg_lr_scheduler (Optional[DictConfig]): The learning rate scheduler configuration.
- num_training_steps (int): The total number of training steps.
- last_epoch (int): The index of the last epoch.
- Returns:
- lr_scheduler (Optional[Optimizer]): The learning rate scheduler.
- """
- if cfg_lr_scheduler is None:
- if self._is_rank_zero:
- log.info(
- "No learning rate scheduler configured. Using constant learning rate."
- )
- return None
- if self._optimizer_in_bwd:
- # Use the first optimizer from the wrapper to represent the learning rate
- optimizer = next(iter(self._optim_ckpt_wrapper.optim_map.values()))
- else:
- # Standard case: use the single optimizer
- optimizer = self._optimizer
- # Instantiate the learning rate scheduler
- lr_scheduler = config.instantiate(
- cfg_lr_scheduler,
- optimizer,
- num_training_steps=num_training_steps,
- last_epoch=last_epoch,
- )
- if self._optimizer_in_bwd:
- # Modify the scheduler for optimizer_in_bwd case
- self._optim_ckpt_wrapper.set_lr_scheduler(lr_scheduler)
- if self._is_rank_zero:
- log.info("Learning rate scheduler is initialized.")
- return lr_scheduler
- def _setup_profiler(
- self, cfg_profiler: Optional[DictConfig] = None
- ) -> Union[torch.profiler.profile, DummyProfiler]:
- """
- Parses the `profiler` section of top-level `cfg` and sets up profiler
- Args:
- cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to
- `recipe.main`). Default None.
- Returns:
- profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods
- for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such
- that the instrumented training loop does not need to be changed profiling is disabled.
- The profiler config can be provided in configs under the `profiler` key with the following layout:
- .. code-block:: yaml
- profiler:
- enabled: bool
- #Output directory of trace artifacts
- output_dir: str
- #`torch.profiler.ProfilerActivity` types to trace
- cpu: bool
- cuda: bool
- #Trace options
- profile_memory: bool
- with_stack: bool
- record_shapes: bool
- with_flops: bool
- # `torch.profiler.schedule` options:
- # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
- wait_steps: int
- warmup_steps: int
- active_steps: int
- num_cycles: int
- """
- # Missing profiler section in config, assume disabled
- if cfg_profiler is None:
- cfg_profiler = DictConfig({"enabled": False})
- # Check that component is included and set correctly
- if cfg_profiler.get("_component_", None) is None:
- cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler"
- else:
- assert (
- cfg_profiler.get("_component_")
- == "torchtune.training.setup_torch_profiler"
- ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`"
- profiler, profiler_cfg = config.instantiate(cfg_profiler)
- utils.log_rank_zero(
- log, f" Profiler config after instantiation: {profiler_cfg}"
- )
- if self._is_rank_zero:
- self.profiler_profile_memory = profiler_cfg.get("profile_memory", False)
- if profiler_cfg["enabled"]:
- self.profiler_wait_steps = profiler_cfg["wait_steps"]
- self.profiler_warmup_steps = profiler_cfg["warmup_steps"]
- self.profiler_active_steps = profiler_cfg["active_steps"]
- return profiler
- def _setup_model(
- self,
- cfg_model: DictConfig,
- enable_activation_checkpointing: bool,
- enable_activation_offloading: bool,
- fsdp_cpu_offload: bool,
- reshard_after_forward: bool,
- model_state_dict: Dict[str, Any],
- custom_sharded_layers: Optional[List[str]] = None,
- ac_mode: Optional[str] = None,
- ac_option: Optional[int] = None,
- ) -> nn.Module:
- """
- Model initialization has some important considerations:
- a. To minimize GPU peak memory, we initialize the model on meta device with
- the right dtype
- b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since
- full state dicts are loaded with ``torch.load(mmap=True)``
- """
- utils.log_rank_zero(
- log,
- "Distributed training is enabled. Instantiating model and loading checkpoint on Rank 0 ...",
- )
- init_start = time.perf_counter()
- with training.set_default_dtype(self._dtype), torch.device("meta"):
- model = config.instantiate(cfg_model)
- if self._compile:
- training.compile_model(model, verbose=self._is_rank_zero)
- device_mesh = init_device_mesh(
- self._device.type,
- mesh_shape=(self.data_parallel_dim, self.tensor_parallel_dim),
- mesh_dim_names=("dp", "tp"),
- )
- self.dp_size = device_mesh["dp"].size()
- self.dp_rank = device_mesh["dp"].get_local_rank()
- # Apply tensor parallelism to the model
- if self.tensor_parallel_dim > 1:
- # Use the local number (num_heads, num_kv_heads, embed_dim) to account for tensor parallel
- model = training.prepare_mha_for_tp(model, device_mesh["tp"])
- parallelize_module(
- model,
- device_mesh["tp"],
- parallelize_plan=self.tensor_parallel_plan,
- )
- # We currently have two versions of activation checkpointing in this recipe
- # for testing and BC purposes. ``enable_activation_checkpointing`` controls
- # the older version of AC and this behavior is unchanged
- # ac_mode and ac_option together control selective AC. This is only enabled
- # when these are set AND ``enable_activation_checkpointing`` is set to False
- # We'll clean this up as soon as testing of AC is complete
- if (not enable_activation_checkpointing) and (ac_mode is not None):
- apply_selective_activation_checkpointing(
- model,
- ac_mode,
- ac_option,
- )
- # original activation checkpointing (full) - flip the condition above
- if enable_activation_checkpointing and ac_mode is None:
- training.set_activation_checkpointing(
- model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
- )
- # Apply Fully Sharded Data Parallelism to the model
- if self.data_parallel_dim > 1:
- fsdp_shard_conditions = [
- partial(
- training.get_shard_conditions,
- names_to_match=custom_sharded_layers,
- )
- ]
- training.shard_model(
- model=model,
- shard_conditions=fsdp_shard_conditions,
- cpu_offload=fsdp_cpu_offload,
- reshard_after_forward=reshard_after_forward,
- dp_mesh=device_mesh["dp"],
- )
- with training.set_default_dtype(self._dtype), self._device:
- for m in model.modules():
- # RoPE is not covered in state dict
- if hasattr(m, "rope_init"):
- m.rope_init()
- # This method will convert the full model state dict into a sharded state
- # dict and load into the model
- training.load_from_full_model_state_dict(
- model,
- model_state_dict,
- self._device,
- strict=True,
- cpu_offload=fsdp_cpu_offload,
- )
- # activation offloading
- self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
- model, enable_activation_offloading
- )
- # Ensure no params and buffers are on meta device
- training.validate_no_params_on_meta_device(model)
- utils.log_rank_zero(
- log,
- f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs",
- )
- if self._is_rank_zero:
- memory_stats = training.get_memory_stats(device=self._device)
- training.log_memory_stats(memory_stats)
- # synchronize before training begins
- torch.distributed.barrier()
- return model
- def _setup_optimizer(
- self,
- cfg_optimizer: DictConfig,
- optimizer_in_bwd: bool = False,
- opt_state_dict: Optional[Dict[str, Any]] = None,
- ) -> Optional[Optimizer]:
- if optimizer_in_bwd:
- # Maintain a dict of optims for every parameter.
- optim_dict = {
- param: config.instantiate(cfg_optimizer, [param])
- for param in self._model.parameters()
- }
- # Register optimizer step hooks on the model to run optimizer in backward.
- training.register_optim_in_bwd_hooks(
- model=self._model, optim_dict=optim_dict
- )
- # Create a wrapper for checkpoint save/load of optimizer states when running in backward.
- self._optim_ckpt_wrapper = training.create_optim_in_bwd_wrapper(
- model=self._model, optim_dict=optim_dict
- )
- # Load optimizer states for each param. If optimizer states are being restored in an optimizer in
- # backward run, these need to have been saved with the same setting. Cannot restore from runs that
- # did not use optimizer in backward.
- if opt_state_dict is not None:
- for param in opt_state_dict.keys():
- try:
- training.load_from_full_optimizer_state_dict(
- self._model,
- self._optim_ckpt_wrapper.optim_map[param],
- opt_state_dict[param],
- self._device,
- )
- except BaseException as e:
- raise RuntimeError(
- "Failed loading in-backward optimizer checkpoints."
- "Please make sure run being restored from was using in-backward optimizer."
- ) from e
- utils.log_rank_zero(log, "In-backward optimizers are set up.")
- return None
- else:
- optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
- if opt_state_dict:
- training.load_from_full_optimizer_state_dict(
- self._model,
- optimizer,
- opt_state_dict,
- self._device,
- )
- utils.log_rank_zero(log, "Optimizer is initialized.")
- return optimizer
- def _setup_data(
- self,
- cfg_dataset: DictConfig,
- shuffle: bool,
- batch_size: int,
- collate_fn: str,
- ) -> Tuple[DistributedSampler, DataLoader]:
- """
- All data related setup happens here. Currently this recipe only supports the
- DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
- iterable datasets and streaming datasets are not supported.
- """
- if isinstance(cfg_dataset, ListConfig):
- datasets = [
- config.instantiate(single_cfg_dataset, self._tokenizer)
- for single_cfg_dataset in cfg_dataset
- ]
- ds = ConcatDataset(datasets=datasets)
- packed = getattr(ds, "packed", False)
- else:
- ds = config.instantiate(cfg_dataset, self._tokenizer)
- packed = cfg_dataset.get("packed", False)
- # Instantiate collate_fn
- if "left_pad_sequence" in collate_fn:
- raise RuntimeError("left_pad_sequence collator is only for inference.")
- collate_fn = _get_component_from_path(collate_fn)
- sampler = DistributedSampler(
- ds, num_replicas=self.dp_size, rank=self.dp_rank, shuffle=shuffle, seed=0
- )
- dataloader = DataLoader(
- dataset=ds,
- batch_size=batch_size,
- sampler=sampler,
- # dropping last avoids shape issues with compile + flex attention
- drop_last=True,
- collate_fn=(
- partial(
- collate_fn,
- padding_idx=self._tokenizer.pad_id,
- ignore_idx=self._loss_fn.ignore_index,
- )
- if not packed
- else padded_collate_packed
- ),
- )
- utils.log_rank_zero(log, "Dataset and Sampler are initialized.")
- return sampler, dataloader
- def train(self) -> None:
- """
- The core training loop.
- """
- # clean up before training begins
- training.cleanup_before_training()
- # zero out the gradients before starting training
- if not self._optimizer_in_bwd:
- self._optimizer.zero_grad()
- else:
- for opt in self._optim_ckpt_wrapper.optim_map.values():
- opt.zero_grad()
- # Initialize tokens count and running loss (for grad accumulation)
- t0 = time.perf_counter()
- running_loss = 0
- num_tokens = 0
- self._profiler.start()
- # self.epochs_run should be non-zero when we're resuming from a checkpoint
- for curr_epoch in range(self.epochs_run, self.total_epochs):
- # Update the sampler to ensure data is correctly shuffled across epochs
- # in case shuffle is True
- self._sampler.set_epoch(curr_epoch)
- pbar = tqdm(total=self._steps_per_epoch, disable=not self._is_rank_zero)
- for idx, batch in enumerate(self._dataloader):
- if (
- self.max_steps_per_epoch is not None
- and (idx // self._gradient_accumulation_steps)
- == self.max_steps_per_epoch
- ):
- break
- # Start tracking CUDA memory for active steps for just the first epoch
- if (
- self._is_rank_zero
- and curr_epoch == 0
- and self.profiler_profile_memory
- and idx == self.profiler_wait_steps + self.profiler_warmup_steps
- and self._device.type == "cuda"
- ):
- torch.cuda.memory._record_memory_history()
- utils.batch_to_device(batch, self._device)
- # Calculate the number of unmasked tokens in the current batch
- # and increment the total number of tokens seen in the step
- current_num_tokens = (
- batch["labels"] != self._loss_fn.ignore_index
- ).sum()
- num_tokens += current_num_tokens
- # Shape [b, s], needed for the loss not the model
- labels = batch.pop("labels")
- with self.activations_handling_ctx:
- logits = self._model(**batch)
- # Shift labels to compute loss
- # equivalent to doing labels[..., 1:] and logits[..., :-1, :]
- # But this way we dont need to slice the logits. We just add an ignore index to labels.
- labels = torch.hstack(
- (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]])
- )
- if not isinstance(logits, list):
- labels = labels.reshape(-1)
- logits = logits.reshape(-1, logits.size(-1))
- # Compute loss
- # Loss is normalized by default so we multiply by the number of tokens
- # This way we can normalize by the total number of tokens if we're accumulating gradients
- current_loss = self._loss_fn(logits, labels) * current_num_tokens
- # free logits otherwise it peaks backward memory
- del logits
- running_loss += current_loss
- # For optimizer in backward, we need to normalize before calling backward
- # This case and gradient accumulation are mutually exclusive
- if self._optimizer_in_bwd:
- torch.distributed.all_reduce(num_tokens)
- torch.distributed.all_reduce(running_loss)
- # We multiply by world_size to undo FSDP2 gradient normalization.
- current_loss = current_loss * (self.world_size / num_tokens)
- current_loss.backward()
- # Step with optimizer
- if (idx + 1) % self._gradient_accumulation_steps == 0:
- if not self._optimizer_in_bwd:
- # Get total number of tokens across all ranks to normalize gradients
- torch.distributed.all_reduce(num_tokens)
- # This will ensure that the logged loss matches what we're optimizing
- torch.distributed.all_reduce(running_loss)
- # Manually scale the gradients from unnormalized loss by total # of tokens
- # We multiply by world_size to undo FSDP2 gradient normalization.
- training.scale_grads(self._model, self.world_size / num_tokens)
- if self._clip_grad_norm is not None:
- grad_norm = torch.nn.utils.clip_grad_norm_(
- self._model.parameters(),
- max_norm=float(self._clip_grad_norm),
- )
- # If sharded, collect the DTensor here
- if isinstance(grad_norm, DTensor):
- grad_norm = grad_norm.full_tensor()
- self._optimizer.step()
- self._optimizer.zero_grad(set_to_none=True)
- # Update the number of steps when the weights are updated
- self.global_step += 1
- # Step the learning rate scheduler
- if self._lr_scheduler is not None:
- self._lr_scheduler.step()
- loss_to_log = running_loss.item() / num_tokens
- pbar.update(1)
- pbar.set_description(
- f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
- )
- # Log per-step metrics
- if (
- self.global_step % self._log_every_n_steps == 0
- and self._is_rank_zero
- ):
- time_per_step = time.perf_counter() - t0
- log_dict = {
- "loss": loss_to_log,
- "lr": get_lr(
- (
- self._optimizer
- if not self._optimizer_in_bwd
- else self._optim_ckpt_wrapper
- ),
- ),
- "tokens_per_second_per_gpu": num_tokens
- / (time_per_step * self.world_size),
- }
- if self._log_peak_memory_stats:
- log_dict.update(
- training.get_memory_stats(device=self._device)
- )
- if self._clip_grad_norm is not None:
- log_dict.update({"grad_norm": grad_norm})
- self._metric_logger.log_dict(
- log_dict,
- step=self.global_step,
- )
- # Reset running stats for the next step
- running_loss = 0
- num_tokens = 0
- t0 = time.perf_counter()
- # Stop tracking CUDA memory now that active steps are complete
- if (
- self._is_rank_zero
- and curr_epoch == 0
- and self.profiler_profile_memory
- and idx
- == self.profiler_wait_steps
- + self.profiler_warmup_steps
- + self.profiler_active_steps
- and self._device.type == "cuda"
- ):
- torch.cuda.memory._record_memory_history(enabled=None)
- # Step profiler
- # Note that this is called within gradient accumulation block, hence
- # will include multiple forward / backward passes if gradient accumulation > 1
- self._profiler.step()
- self.epochs_run += 1
- # self._checkpoint_client.save_checkpoint(
- # model=self._model,
- # optimizer=(
- # self._optimizer
- # if not self._optimizer_in_bwd
- # else self._optim_ckpt_wrapper
- # ),
- # training_progress=TrainingProgress(
- # seed=self.seed,
- # epochs_run=self.epochs_run,
- # total_epochs=self.total_epochs,
- # max_steps_per_epoch=self.max_steps_per_epoch,
- # ),
- # epoch=curr_epoch,
- # )
- self.epochs_run += 1
- if curr_epoch > 0 and curr_epoch % self.save_every_epochs == 0:
- utils.log_rank_zero(log, f"Saving checkpoint at epoch {curr_epoch}")
- self._checkpoint_client.save_checkpoint(
- model=self._model,
- optimizer=(
- self._optimizer
- if not self._optimizer_in_bwd
- else self._optim_ckpt_wrapper
- ),
- training_progress=TrainingProgress(
- seed=self.seed,
- epochs_run=self.epochs_run,
- total_epochs=self.total_epochs,
- max_steps_per_epoch=self.max_steps_per_epoch,
- ),
- epoch=curr_epoch,
- )
- self._profiler.stop()
- def cleanup(self) -> None:
- if self._is_rank_zero:
- self._metric_logger.close()
- destroy_process_group()
- @config.parse
- def recipe_main(cfg: DictConfig) -> None:
- """
- Entry point for the recipe.
- Configurable parameters are read in the following order:
- - Parameters specified in config (see available configs through ``tune ls``)
- - Overwritten by arguments from the command-line
- """
- config.log_config(recipe_name="FullFinetuneRecipeDistributed", cfg=cfg)
- recipe = FullFinetuneRecipeDistributed(cfg=cfg)
- recipe.setup(cfg=cfg)
- recipe.train()
- recipe.cleanup()
- if __name__ == "__main__":
- sys.exit(recipe_main())
|