|
@@ -0,0 +1,972 @@
|
|
|
+# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
|
+# All rights reserved.
|
|
|
+#
|
|
|
+# This source code is licensed under the BSD-style license found in the
|
|
|
+# LICENSE file in the root directory of this source tree.
|
|
|
+
|
|
|
+import sys
|
|
|
+import time
|
|
|
+
|
|
|
+from functools import partial
|
|
|
+from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
+from warnings import warn
|
|
|
+
|
|
|
+import torch
|
|
|
+from omegaconf import DictConfig, ListConfig
|
|
|
+
|
|
|
+from torch import nn
|
|
|
+from torch.distributed import (
|
|
|
+ destroy_process_group,
|
|
|
+ init_device_mesh,
|
|
|
+ init_process_group,
|
|
|
+)
|
|
|
+from torch.distributed._tensor import DTensor
|
|
|
+from torch.distributed.tensor.parallel import parallelize_module
|
|
|
+from torch.optim import Optimizer
|
|
|
+from torch.utils.data import DataLoader, DistributedSampler
|
|
|
+from torchtune import config, modules, training, utils
|
|
|
+from torchtune.config._utils import _get_component_from_path
|
|
|
+from torchtune.data import padded_collate_packed
|
|
|
+from torchtune.datasets import ConcatDataset
|
|
|
+from torchtune.recipe_interfaces import FTRecipeInterface
|
|
|
+from torchtune.training import DummyProfiler, PROFILER_KEY
|
|
|
+from torchtune.training.activations import apply_selective_activation_checkpointing
|
|
|
+from torchtune.training.checkpointing._checkpoint_client import (
|
|
|
+ CheckpointClient,
|
|
|
+ TrainingProgress,
|
|
|
+)
|
|
|
+from torchtune.training.lr_schedulers import get_lr
|
|
|
+
|
|
|
+from tqdm import tqdm
|
|
|
+
|
|
|
+log = utils.get_logger("DEBUG")
|
|
|
+
|
|
|
+
|
|
|
+class FullFinetuneRecipeDistributed(FTRecipeInterface):
|
|
|
+ """
|
|
|
+ Full finetuning recipe for dense transformer-based LLMs such as Llama2. This recipe supports
|
|
|
+ distributed training and can be run on a single node (1 to 8 GPUs).
|
|
|
+
|
|
|
+ Features:
|
|
|
+ - FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states
|
|
|
+ is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is
|
|
|
+ done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config
|
|
|
+ ``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy).
|
|
|
+ DDP is currently not supported. Training on CPU is not supported.
|
|
|
+
|
|
|
+ - Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing``
|
|
|
+ flag. Activation checkpointing helps reduce the memory footprint since we no longer keep
|
|
|
+ activations in memory and instead recompute them during the backward pass. This is especially
|
|
|
+ helpful for larger batch sizes when you're memory constrained. But these savings in memory
|
|
|
+ come at the cost of training performance. In most cases training can slow-down quite a bit as
|
|
|
+ a result of this activation recomputation.
|
|
|
+
|
|
|
+ - Activation Offloading. This can be controlled using the ``enable_activation_offloading``
|
|
|
+ flag. Activation offloading is a technique similar to activations checkpointing that helps
|
|
|
+ reduce the memory footprint to prevent OOMs on CUDA and enable bigger batches. Where activations
|
|
|
+ checkpointing drops the activation in the forward to recompute it later in the backward,
|
|
|
+ activations offloading will drop the activation in the forward to the CPU and bring it
|
|
|
+ back during the backward pass. As always, there is a tradeoff--these savings in memory can
|
|
|
+ come at the cost of training performance and CPU resources. To recover some runtime cost,
|
|
|
+ we've added an option to enable offloading on a different stream to permit overlapping with
|
|
|
+ the computation. This option is currently only available on PyTorch 2.5 or later and will
|
|
|
+ be enabled by default if an acceptable torch version is found. Activation offloading can be
|
|
|
+ used in conjunction with activation checkpointing.
|
|
|
+
|
|
|
+ - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype``
|
|
|
+ flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In
|
|
|
+ most cases this should halve the memory footprint of full precision (fp32) training, without
|
|
|
+ loss in model quality (will depend on the model, training data and other settings). For
|
|
|
+ GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16
|
|
|
+ precision are currently not supported.
|
|
|
+
|
|
|
+ - Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is
|
|
|
+ controlled using the ``gradient_accumulation_steps`` flag.
|
|
|
+
|
|
|
+ Total Batch Size = batch_size * number of GPUs * gradient accumulation steps.
|
|
|
+
|
|
|
+ For example: with batch_size=1, nproc_per_node=2 and gradient_accumulation_steps=32 we get a
|
|
|
+ total batch size of 64.
|
|
|
+
|
|
|
+ Gradient accumulation is especially useful when you are memory constrained. In this case,
|
|
|
+ accumulating gradients might give you better training speed than enabling activation
|
|
|
+ checkpointing.
|
|
|
+
|
|
|
+ - Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of
|
|
|
+ training. Optimizer state and recipe state (seed, total_epochs, number of epochs run etc) are
|
|
|
+ only saved at the end of a given epoch and used in case of resuming training.
|
|
|
+
|
|
|
+ Resuming training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is
|
|
|
+ currently not supported.
|
|
|
+
|
|
|
+ For more details on the checkpointer, please take a look at
|
|
|
+ our checkpointer deepdive (https://pytorch.org/torchtune/main/deep_dives/checkpointer.html).
|
|
|
+
|
|
|
+ - Logging. Terminal, Disk, WandB and TensorBoard are all supported.
|
|
|
+
|
|
|
+ - Gradient Clipping. Gradient clipping is supported using the ``clip_grad_norm`` flag. By default,
|
|
|
+ ``clip_grad_norm`` is set to ``None``. If you only want to log the grad norm, you can set
|
|
|
+ ``clip_grad_norm='inf'``.
|
|
|
+
|
|
|
+ For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config
|
|
|
+ has example commands for how to kick-off training.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ cfg (DictConfig): OmegaConf object parsed from yaml file
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ ValueError: If ``dtype`` is set to fp16.
|
|
|
+ RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
|
|
|
+ RuntimeError: If ``left_pad_sequence`` is set as the data collator.
|
|
|
+ RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA.
|
|
|
+ RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False.
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, cfg: DictConfig) -> None:
|
|
|
+ device_type = cfg.device
|
|
|
+ self._device = utils.get_device(device=device_type)
|
|
|
+ self._dtype = training.get_dtype(cfg.dtype, device=self._device)
|
|
|
+
|
|
|
+ if self._dtype == torch.float16:
|
|
|
+ raise ValueError(
|
|
|
+ "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
|
|
|
+ )
|
|
|
+
|
|
|
+ # Set up the backend for distributed training (NCCL, GLOO, etc.)
|
|
|
+ self._enable_async_checkpointing = cfg.get("enable_async_checkpointing", False)
|
|
|
+ self.fsdp_cpu_offload = cfg.get("fsdp_cpu_offload", False)
|
|
|
+ self.distributed_backend = training.get_distributed_backend(
|
|
|
+ device_type,
|
|
|
+ offload_ops_to_cpu=self.fsdp_cpu_offload
|
|
|
+ or self._enable_async_checkpointing,
|
|
|
+ )
|
|
|
+ init_process_group(self.distributed_backend)
|
|
|
+
|
|
|
+ # Initialize distributed variables
|
|
|
+ self.world_size, self.rank = utils.get_world_size_and_rank()
|
|
|
+ self._is_rank_zero = self.rank == 0
|
|
|
+ self.tensor_parallel_plan = config.instantiate(
|
|
|
+ cfg.get("tensor_parallel_plan", None)
|
|
|
+ )
|
|
|
+ self.tensor_parallel_dim = cfg.get("tensor_parallel_dim", 1)
|
|
|
+ if self.tensor_parallel_dim > 1 and self.tensor_parallel_plan is None:
|
|
|
+ raise ValueError(
|
|
|
+ "Tensor Parallel plan needs to be provided when tensor parallel is enabled."
|
|
|
+ )
|
|
|
+ if self.world_size % self.tensor_parallel_dim != 0:
|
|
|
+ raise ValueError(
|
|
|
+ f"world_size {self.world_size} must be divisible by tensor_parallel_dim {self.tensor_parallel_dim}"
|
|
|
+ )
|
|
|
+ self.data_parallel_dim = self.world_size // self.tensor_parallel_dim
|
|
|
+
|
|
|
+ # Logging attributes
|
|
|
+ self._output_dir = cfg.output_dir
|
|
|
+ self._log_every_n_steps = cfg.get("log_every_n_steps", 1)
|
|
|
+ self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False)
|
|
|
+ if self._log_peak_memory_stats and device_type != "cuda":
|
|
|
+ log.info(
|
|
|
+ "log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False."
|
|
|
+ )
|
|
|
+ self._log_peak_memory_stats = False
|
|
|
+
|
|
|
+ # Training cfg
|
|
|
+ self._resume_from_checkpoint = cfg.resume_from_checkpoint
|
|
|
+ self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
|
|
|
+ self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False)
|
|
|
+ self._clip_grad_norm = cfg.get("clip_grad_norm", None)
|
|
|
+ self._checkpoint_client = CheckpointClient(cfg)
|
|
|
+ self.save_every_epochs = cfg.get("save_every_epochs", 1)
|
|
|
+
|
|
|
+ # Optimizer in backward is not compatible with gradient accumulation or gradient clipping
|
|
|
+ if self._optimizer_in_bwd:
|
|
|
+ if self._clip_grad_norm is not None:
|
|
|
+ raise RuntimeError(
|
|
|
+ "Gradient clipping is not supported with optimizer in bwd."
|
|
|
+ "Please set clip_grad_norm=None, or optimizer_in_bwd=False."
|
|
|
+ )
|
|
|
+ if self._gradient_accumulation_steps > 1:
|
|
|
+ raise RuntimeError(
|
|
|
+ "Gradient accumulation is not supported with optimizer in bwd."
|
|
|
+ "Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False."
|
|
|
+ )
|
|
|
+
|
|
|
+ # activation checkpointing/offloading
|
|
|
+ self._enable_activation_checkpointing = cfg.get(
|
|
|
+ "enable_activation_checkpointing", False
|
|
|
+ )
|
|
|
+ self._enable_activation_offloading = cfg.get(
|
|
|
+ "enable_activation_offloading", False
|
|
|
+ )
|
|
|
+ if self._enable_activation_offloading:
|
|
|
+ if device_type != "cuda":
|
|
|
+ raise RuntimeError(
|
|
|
+ "enable_activation_offloading should only be True when training on CUDA"
|
|
|
+ )
|
|
|
+ if not self._enable_activation_checkpointing:
|
|
|
+ raise RuntimeError(
|
|
|
+ "enable_activation_offloading should only be True when enable_activation_checkpointing is True"
|
|
|
+ )
|
|
|
+ elif (
|
|
|
+ self._enable_activation_checkpointing
|
|
|
+ and cfg.checkpointer.model_type != "LLAMA3_VISION"
|
|
|
+ ):
|
|
|
+ utils.log_rank_zero(
|
|
|
+ log,
|
|
|
+ "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. "
|
|
|
+ "Enabling activation offloading should reduce memory further.",
|
|
|
+ )
|
|
|
+
|
|
|
+ # These are public properties which are updated by the checkpoint loader
|
|
|
+ # when ``resume_from_checkpoint`` is `True` or validated in tests
|
|
|
+ self.seed = training.set_seed(
|
|
|
+ seed=cfg.seed, debug_mode=cfg.get("cudnn_deterministic_mode", None)
|
|
|
+ )
|
|
|
+ self.epochs_run = 0
|
|
|
+ self.total_epochs = cfg.epochs
|
|
|
+ self.max_steps_per_epoch = cfg.max_steps_per_epoch
|
|
|
+ self.global_step = 0
|
|
|
+
|
|
|
+ def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None:
|
|
|
+ """
|
|
|
+ Updates the recipe state from checkpoint.
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ self.epochs_run = ckpt_dict[training.EPOCHS_KEY]
|
|
|
+
|
|
|
+ # on mismatch, warn the user and prevent the override
|
|
|
+ if self.seed != ckpt_dict[training.SEED_KEY]:
|
|
|
+ warn(
|
|
|
+ message=(
|
|
|
+ "Config value for seed does not match the checkpoint value, "
|
|
|
+ f"using the checkpoint value: {ckpt_dict[training.SEED_KEY]}"
|
|
|
+ )
|
|
|
+ )
|
|
|
+ self.seed = ckpt_dict[training.SEED_KEY]
|
|
|
+ if self.max_steps_per_epoch != ckpt_dict[training.MAX_STEPS_KEY]:
|
|
|
+ warn(
|
|
|
+ message=(
|
|
|
+ "Config value for max_steps_per_epoch does not match the checkpoint value, "
|
|
|
+ f"using the checkpoint value: {ckpt_dict[training.MAX_STEPS_KEY]}"
|
|
|
+ )
|
|
|
+ )
|
|
|
+ self.max_steps_per_epoch = ckpt_dict[training.MAX_STEPS_KEY]
|
|
|
+
|
|
|
+ # on mismatch, warn the user but allow the override
|
|
|
+ if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]:
|
|
|
+ warn(
|
|
|
+ message=(
|
|
|
+ "Config value for total_epochs does not match the checkpoint value, "
|
|
|
+ f"using the config value: {self.total_epochs}"
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
+ except KeyError as e:
|
|
|
+ raise KeyError(
|
|
|
+ "Checkpoint does not contain the required keys needed for updating recipe state. "
|
|
|
+ "Are you sure you passed in the right recipe checkpoint?"
|
|
|
+ ) from e
|
|
|
+
|
|
|
+ def setup(self, cfg: DictConfig) -> None:
|
|
|
+ """
|
|
|
+ Setup the recipe. This includes training state (if resume_from_checkpoint is True),
|
|
|
+ model, tokenizer, loss, optimizer, lr scheduler, sampler, and dataloader.
|
|
|
+ """
|
|
|
+ if self.fsdp_cpu_offload:
|
|
|
+ # Utilize all available CPU cores for intra-op parallelism. This provides ~2x
|
|
|
+ # speed up when benchmarking fused AdamW on CPU
|
|
|
+ training.set_torch_num_threads()
|
|
|
+
|
|
|
+ if self._is_rank_zero:
|
|
|
+ self._metric_logger = config.instantiate(cfg.metric_logger)
|
|
|
+ # log config with parameter override
|
|
|
+ self._metric_logger.log_config(cfg)
|
|
|
+
|
|
|
+ # Load the base model
|
|
|
+ checkpoint_dict = self._checkpoint_client.load_base_checkpoint()
|
|
|
+
|
|
|
+ self._compile = cfg.get("compile", False)
|
|
|
+ self._model = self._setup_model(
|
|
|
+ cfg_model=cfg.model,
|
|
|
+ enable_activation_checkpointing=self._enable_activation_checkpointing,
|
|
|
+ enable_activation_offloading=self._enable_activation_offloading,
|
|
|
+ custom_sharded_layers=cfg.get("custom_sharded_layers", None),
|
|
|
+ fsdp_cpu_offload=self.fsdp_cpu_offload,
|
|
|
+ reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True),
|
|
|
+ model_state_dict=checkpoint_dict[training.MODEL_KEY],
|
|
|
+ ac_mode=cfg.get("ac_mode", None),
|
|
|
+ ac_option=cfg.get("ac_option", None),
|
|
|
+ )
|
|
|
+ self._tokenizer = config.instantiate(cfg.tokenizer)
|
|
|
+
|
|
|
+ self._optimizer = self._setup_optimizer(
|
|
|
+ cfg_optimizer=cfg.optimizer,
|
|
|
+ optimizer_in_bwd=self._optimizer_in_bwd,
|
|
|
+ opt_state_dict=(
|
|
|
+ checkpoint_dict[training.OPT_KEY]
|
|
|
+ if training.OPT_KEY in checkpoint_dict
|
|
|
+ else None
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ if self._resume_from_checkpoint:
|
|
|
+ # If async checkpointing is enabled, intermediate checkpoints are saved asynchronously
|
|
|
+ # using the DistributedCheckpointer.
|
|
|
+ # Therefore the recipe needs to load the distributed checkpoint to restore the training
|
|
|
+ # progress.
|
|
|
+ if self._enable_async_checkpointing:
|
|
|
+ try:
|
|
|
+ checkpoint_dict = (
|
|
|
+ self._checkpoint_client.load_distributed_checkpoint(
|
|
|
+ self._model,
|
|
|
+ (
|
|
|
+ self._optim_ckpt_wrapper
|
|
|
+ if self._optimizer_in_bwd
|
|
|
+ else self._optimizer
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ log.warning(
|
|
|
+ f"Failed to load distributed checkpoint: {e}. Training will start from the base checkpoint."
|
|
|
+ )
|
|
|
+
|
|
|
+ # Update the recipe state from the checkpoint state dict.
|
|
|
+ self._update_recipe_state(checkpoint_dict)
|
|
|
+
|
|
|
+ # initialize loss
|
|
|
+ self._loss_fn = config.instantiate(cfg.loss)
|
|
|
+
|
|
|
+ if self._compile:
|
|
|
+ training.compile_loss(self._loss_fn, verbose=self._is_rank_zero)
|
|
|
+
|
|
|
+ if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss":
|
|
|
+ # set num_output_chunks for model
|
|
|
+ self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
|
|
|
+
|
|
|
+ utils.log_rank_zero(log, "Loss is initialized.")
|
|
|
+
|
|
|
+ # sampler and dataloader depend on the tokenizer and loss_fn and should be
|
|
|
+ # setup after both of these are initialized
|
|
|
+ collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft")
|
|
|
+ self._sampler, self._dataloader = self._setup_data(
|
|
|
+ cfg_dataset=cfg.dataset,
|
|
|
+ shuffle=cfg.shuffle,
|
|
|
+ batch_size=cfg.batch_size,
|
|
|
+ collate_fn=collate_name,
|
|
|
+ )
|
|
|
+
|
|
|
+ # Finally update the recipe state which can only be correctly set after all of the
|
|
|
+ # other components have been initialized and updated.
|
|
|
+ #
|
|
|
+ # Number of training steps in each epoch depends on the number of batches produced
|
|
|
+ # by the dataloader, the max_steps_per_epoch param set by the user and the
|
|
|
+ # gradient_accumulation_steps param. This value is used for logging and tracking
|
|
|
+ # training state. The computation should happen after the dataloader has been setup
|
|
|
+ self._steps_per_epoch = (
|
|
|
+ len(self._dataloader) // self._gradient_accumulation_steps
|
|
|
+ )
|
|
|
+ if (
|
|
|
+ self.max_steps_per_epoch is not None
|
|
|
+ and self.max_steps_per_epoch < self._steps_per_epoch
|
|
|
+ ):
|
|
|
+ self._steps_per_epoch = self.max_steps_per_epoch
|
|
|
+ self.global_step = self.epochs_run * self._steps_per_epoch
|
|
|
+
|
|
|
+ # Setup lr scheduler
|
|
|
+ self._lr_scheduler = self._setup_lr_scheduler(
|
|
|
+ cfg_lr_scheduler=cfg.get("lr_scheduler", None),
|
|
|
+ num_training_steps=self.total_epochs * self._steps_per_epoch,
|
|
|
+ last_epoch=self.global_step - 1,
|
|
|
+ )
|
|
|
+
|
|
|
+ # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method)
|
|
|
+ # if cfg is missing profiler key or if `cfg.profiler.enabled = False`
|
|
|
+ self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None))
|
|
|
+
|
|
|
+ # Used to ignore labels for loss computation
|
|
|
+ self.ignore_labels_cache = torch.full(
|
|
|
+ (cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device
|
|
|
+ )
|
|
|
+
|
|
|
+ def _setup_lr_scheduler(
|
|
|
+ self,
|
|
|
+ cfg_lr_scheduler: Optional[DictConfig],
|
|
|
+ num_training_steps: int,
|
|
|
+ last_epoch: int,
|
|
|
+ ) -> Optional[Optimizer]:
|
|
|
+ """
|
|
|
+ Set up the learning rate scheduler based on the provided configuration.
|
|
|
+ It supports both standard optimization and optimizer-in-backward cases.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ cfg_lr_scheduler (Optional[DictConfig]): The learning rate scheduler configuration.
|
|
|
+ num_training_steps (int): The total number of training steps.
|
|
|
+ last_epoch (int): The index of the last epoch.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ lr_scheduler (Optional[Optimizer]): The learning rate scheduler.
|
|
|
+ """
|
|
|
+ if cfg_lr_scheduler is None:
|
|
|
+ if self._is_rank_zero:
|
|
|
+ log.info(
|
|
|
+ "No learning rate scheduler configured. Using constant learning rate."
|
|
|
+ )
|
|
|
+ return None
|
|
|
+
|
|
|
+ if self._optimizer_in_bwd:
|
|
|
+ # Use the first optimizer from the wrapper to represent the learning rate
|
|
|
+ optimizer = next(iter(self._optim_ckpt_wrapper.optim_map.values()))
|
|
|
+ else:
|
|
|
+ # Standard case: use the single optimizer
|
|
|
+ optimizer = self._optimizer
|
|
|
+
|
|
|
+ # Instantiate the learning rate scheduler
|
|
|
+ lr_scheduler = config.instantiate(
|
|
|
+ cfg_lr_scheduler,
|
|
|
+ optimizer,
|
|
|
+ num_training_steps=num_training_steps,
|
|
|
+ last_epoch=last_epoch,
|
|
|
+ )
|
|
|
+
|
|
|
+ if self._optimizer_in_bwd:
|
|
|
+ # Modify the scheduler for optimizer_in_bwd case
|
|
|
+ self._optim_ckpt_wrapper.set_lr_scheduler(lr_scheduler)
|
|
|
+
|
|
|
+ if self._is_rank_zero:
|
|
|
+ log.info("Learning rate scheduler is initialized.")
|
|
|
+
|
|
|
+ return lr_scheduler
|
|
|
+
|
|
|
+ def _setup_profiler(
|
|
|
+ self, cfg_profiler: Optional[DictConfig] = None
|
|
|
+ ) -> Union[torch.profiler.profile, DummyProfiler]:
|
|
|
+ """
|
|
|
+ Parses the `profiler` section of top-level `cfg` and sets up profiler
|
|
|
+
|
|
|
+ Args:
|
|
|
+ cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to
|
|
|
+ `recipe.main`). Default None.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods
|
|
|
+ for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such
|
|
|
+ that the instrumented training loop does not need to be changed profiling is disabled.
|
|
|
+
|
|
|
+ The profiler config can be provided in configs under the `profiler` key with the following layout:
|
|
|
+
|
|
|
+ .. code-block:: yaml
|
|
|
+ profiler:
|
|
|
+ enabled: bool
|
|
|
+
|
|
|
+ #Output directory of trace artifacts
|
|
|
+ output_dir: str
|
|
|
+
|
|
|
+ #`torch.profiler.ProfilerActivity` types to trace
|
|
|
+ cpu: bool
|
|
|
+ cuda: bool
|
|
|
+
|
|
|
+ #Trace options
|
|
|
+ profile_memory: bool
|
|
|
+ with_stack: bool
|
|
|
+ record_shapes: bool
|
|
|
+ with_flops: bool
|
|
|
+
|
|
|
+ # `torch.profiler.schedule` options:
|
|
|
+ # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
|
|
|
+ wait_steps: int
|
|
|
+ warmup_steps: int
|
|
|
+ active_steps: int
|
|
|
+ num_cycles: int
|
|
|
+ """
|
|
|
+ # Missing profiler section in config, assume disabled
|
|
|
+ if cfg_profiler is None:
|
|
|
+ cfg_profiler = DictConfig({"enabled": False})
|
|
|
+
|
|
|
+ # Check that component is included and set correctly
|
|
|
+ if cfg_profiler.get("_component_", None) is None:
|
|
|
+ cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler"
|
|
|
+ else:
|
|
|
+ assert (
|
|
|
+ cfg_profiler.get("_component_")
|
|
|
+ == "torchtune.training.setup_torch_profiler"
|
|
|
+ ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`"
|
|
|
+
|
|
|
+ profiler, profiler_cfg = config.instantiate(cfg_profiler)
|
|
|
+
|
|
|
+ utils.log_rank_zero(
|
|
|
+ log, f" Profiler config after instantiation: {profiler_cfg}"
|
|
|
+ )
|
|
|
+ if self._is_rank_zero:
|
|
|
+ self.profiler_profile_memory = profiler_cfg.get("profile_memory", False)
|
|
|
+ if profiler_cfg["enabled"]:
|
|
|
+ self.profiler_wait_steps = profiler_cfg["wait_steps"]
|
|
|
+ self.profiler_warmup_steps = profiler_cfg["warmup_steps"]
|
|
|
+ self.profiler_active_steps = profiler_cfg["active_steps"]
|
|
|
+
|
|
|
+ return profiler
|
|
|
+
|
|
|
+ def _setup_model(
|
|
|
+ self,
|
|
|
+ cfg_model: DictConfig,
|
|
|
+ enable_activation_checkpointing: bool,
|
|
|
+ enable_activation_offloading: bool,
|
|
|
+ fsdp_cpu_offload: bool,
|
|
|
+ reshard_after_forward: bool,
|
|
|
+ model_state_dict: Dict[str, Any],
|
|
|
+ custom_sharded_layers: Optional[List[str]] = None,
|
|
|
+ ac_mode: Optional[str] = None,
|
|
|
+ ac_option: Optional[int] = None,
|
|
|
+ ) -> nn.Module:
|
|
|
+ """
|
|
|
+ Model initialization has some important considerations:
|
|
|
+ a. To minimize GPU peak memory, we initialize the model on meta device with
|
|
|
+ the right dtype
|
|
|
+ b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since
|
|
|
+ full state dicts are loaded with ``torch.load(mmap=True)``
|
|
|
+ """
|
|
|
+
|
|
|
+ utils.log_rank_zero(
|
|
|
+ log,
|
|
|
+ "Distributed training is enabled. Instantiating model and loading checkpoint on Rank 0 ...",
|
|
|
+ )
|
|
|
+ init_start = time.perf_counter()
|
|
|
+
|
|
|
+ with training.set_default_dtype(self._dtype), torch.device("meta"):
|
|
|
+ model = config.instantiate(cfg_model)
|
|
|
+
|
|
|
+ if self._compile:
|
|
|
+ training.compile_model(model, verbose=self._is_rank_zero)
|
|
|
+
|
|
|
+ device_mesh = init_device_mesh(
|
|
|
+ self._device.type,
|
|
|
+ mesh_shape=(self.data_parallel_dim, self.tensor_parallel_dim),
|
|
|
+ mesh_dim_names=("dp", "tp"),
|
|
|
+ )
|
|
|
+ self.dp_size = device_mesh["dp"].size()
|
|
|
+ self.dp_rank = device_mesh["dp"].get_local_rank()
|
|
|
+
|
|
|
+ # Apply tensor parallelism to the model
|
|
|
+ if self.tensor_parallel_dim > 1:
|
|
|
+ # Use the local number (num_heads, num_kv_heads, embed_dim) to account for tensor parallel
|
|
|
+ model = training.prepare_mha_for_tp(model, device_mesh["tp"])
|
|
|
+ parallelize_module(
|
|
|
+ model,
|
|
|
+ device_mesh["tp"],
|
|
|
+ parallelize_plan=self.tensor_parallel_plan,
|
|
|
+ )
|
|
|
+
|
|
|
+ # We currently have two versions of activation checkpointing in this recipe
|
|
|
+ # for testing and BC purposes. ``enable_activation_checkpointing`` controls
|
|
|
+ # the older version of AC and this behavior is unchanged
|
|
|
+ # ac_mode and ac_option together control selective AC. This is only enabled
|
|
|
+ # when these are set AND ``enable_activation_checkpointing`` is set to False
|
|
|
+ # We'll clean this up as soon as testing of AC is complete
|
|
|
+ if (not enable_activation_checkpointing) and (ac_mode is not None):
|
|
|
+ apply_selective_activation_checkpointing(
|
|
|
+ model,
|
|
|
+ ac_mode,
|
|
|
+ ac_option,
|
|
|
+ )
|
|
|
+
|
|
|
+ # original activation checkpointing (full) - flip the condition above
|
|
|
+ if enable_activation_checkpointing and ac_mode is None:
|
|
|
+ training.set_activation_checkpointing(
|
|
|
+ model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
|
|
|
+ )
|
|
|
+
|
|
|
+ # Apply Fully Sharded Data Parallelism to the model
|
|
|
+ if self.data_parallel_dim > 1:
|
|
|
+ fsdp_shard_conditions = [
|
|
|
+ partial(
|
|
|
+ training.get_shard_conditions,
|
|
|
+ names_to_match=custom_sharded_layers,
|
|
|
+ )
|
|
|
+ ]
|
|
|
+ training.shard_model(
|
|
|
+ model=model,
|
|
|
+ shard_conditions=fsdp_shard_conditions,
|
|
|
+ cpu_offload=fsdp_cpu_offload,
|
|
|
+ reshard_after_forward=reshard_after_forward,
|
|
|
+ dp_mesh=device_mesh["dp"],
|
|
|
+ )
|
|
|
+
|
|
|
+ with training.set_default_dtype(self._dtype), self._device:
|
|
|
+ for m in model.modules():
|
|
|
+ # RoPE is not covered in state dict
|
|
|
+ if hasattr(m, "rope_init"):
|
|
|
+ m.rope_init()
|
|
|
+
|
|
|
+ # This method will convert the full model state dict into a sharded state
|
|
|
+ # dict and load into the model
|
|
|
+ training.load_from_full_model_state_dict(
|
|
|
+ model,
|
|
|
+ model_state_dict,
|
|
|
+ self._device,
|
|
|
+ strict=True,
|
|
|
+ cpu_offload=fsdp_cpu_offload,
|
|
|
+ )
|
|
|
+
|
|
|
+ # activation offloading
|
|
|
+ self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
|
|
|
+ model, enable_activation_offloading
|
|
|
+ )
|
|
|
+
|
|
|
+ # Ensure no params and buffers are on meta device
|
|
|
+ training.validate_no_params_on_meta_device(model)
|
|
|
+
|
|
|
+ utils.log_rank_zero(
|
|
|
+ log,
|
|
|
+ f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs",
|
|
|
+ )
|
|
|
+
|
|
|
+ if self._is_rank_zero:
|
|
|
+ memory_stats = training.get_memory_stats(device=self._device)
|
|
|
+ training.log_memory_stats(memory_stats)
|
|
|
+
|
|
|
+ # synchronize before training begins
|
|
|
+ torch.distributed.barrier()
|
|
|
+
|
|
|
+ return model
|
|
|
+
|
|
|
+ def _setup_optimizer(
|
|
|
+ self,
|
|
|
+ cfg_optimizer: DictConfig,
|
|
|
+ optimizer_in_bwd: bool = False,
|
|
|
+ opt_state_dict: Optional[Dict[str, Any]] = None,
|
|
|
+ ) -> Optional[Optimizer]:
|
|
|
+ if optimizer_in_bwd:
|
|
|
+ # Maintain a dict of optims for every parameter.
|
|
|
+ optim_dict = {
|
|
|
+ param: config.instantiate(cfg_optimizer, [param])
|
|
|
+ for param in self._model.parameters()
|
|
|
+ }
|
|
|
+
|
|
|
+ # Register optimizer step hooks on the model to run optimizer in backward.
|
|
|
+ training.register_optim_in_bwd_hooks(
|
|
|
+ model=self._model, optim_dict=optim_dict
|
|
|
+ )
|
|
|
+ # Create a wrapper for checkpoint save/load of optimizer states when running in backward.
|
|
|
+ self._optim_ckpt_wrapper = training.create_optim_in_bwd_wrapper(
|
|
|
+ model=self._model, optim_dict=optim_dict
|
|
|
+ )
|
|
|
+ # Load optimizer states for each param. If optimizer states are being restored in an optimizer in
|
|
|
+ # backward run, these need to have been saved with the same setting. Cannot restore from runs that
|
|
|
+ # did not use optimizer in backward.
|
|
|
+ if opt_state_dict is not None:
|
|
|
+ for param in opt_state_dict.keys():
|
|
|
+ try:
|
|
|
+ training.load_from_full_optimizer_state_dict(
|
|
|
+ self._model,
|
|
|
+ self._optim_ckpt_wrapper.optim_map[param],
|
|
|
+ opt_state_dict[param],
|
|
|
+ self._device,
|
|
|
+ )
|
|
|
+ except BaseException as e:
|
|
|
+ raise RuntimeError(
|
|
|
+ "Failed loading in-backward optimizer checkpoints."
|
|
|
+ "Please make sure run being restored from was using in-backward optimizer."
|
|
|
+ ) from e
|
|
|
+ utils.log_rank_zero(log, "In-backward optimizers are set up.")
|
|
|
+ return None
|
|
|
+ else:
|
|
|
+ optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
|
|
|
+ if opt_state_dict:
|
|
|
+ training.load_from_full_optimizer_state_dict(
|
|
|
+ self._model,
|
|
|
+ optimizer,
|
|
|
+ opt_state_dict,
|
|
|
+ self._device,
|
|
|
+ )
|
|
|
+
|
|
|
+ utils.log_rank_zero(log, "Optimizer is initialized.")
|
|
|
+ return optimizer
|
|
|
+
|
|
|
+ def _setup_data(
|
|
|
+ self,
|
|
|
+ cfg_dataset: DictConfig,
|
|
|
+ shuffle: bool,
|
|
|
+ batch_size: int,
|
|
|
+ collate_fn: str,
|
|
|
+ ) -> Tuple[DistributedSampler, DataLoader]:
|
|
|
+ """
|
|
|
+ All data related setup happens here. Currently this recipe only supports the
|
|
|
+ DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
|
|
|
+ iterable datasets and streaming datasets are not supported.
|
|
|
+ """
|
|
|
+ if isinstance(cfg_dataset, ListConfig):
|
|
|
+ datasets = [
|
|
|
+ config.instantiate(single_cfg_dataset, self._tokenizer)
|
|
|
+ for single_cfg_dataset in cfg_dataset
|
|
|
+ ]
|
|
|
+ ds = ConcatDataset(datasets=datasets)
|
|
|
+ packed = getattr(ds, "packed", False)
|
|
|
+ else:
|
|
|
+ ds = config.instantiate(cfg_dataset, self._tokenizer)
|
|
|
+ packed = cfg_dataset.get("packed", False)
|
|
|
+
|
|
|
+ # Instantiate collate_fn
|
|
|
+ if "left_pad_sequence" in collate_fn:
|
|
|
+ raise RuntimeError("left_pad_sequence collator is only for inference.")
|
|
|
+ collate_fn = _get_component_from_path(collate_fn)
|
|
|
+
|
|
|
+ sampler = DistributedSampler(
|
|
|
+ ds, num_replicas=self.dp_size, rank=self.dp_rank, shuffle=shuffle, seed=0
|
|
|
+ )
|
|
|
+ dataloader = DataLoader(
|
|
|
+ dataset=ds,
|
|
|
+ batch_size=batch_size,
|
|
|
+ sampler=sampler,
|
|
|
+ # dropping last avoids shape issues with compile + flex attention
|
|
|
+ drop_last=True,
|
|
|
+ collate_fn=(
|
|
|
+ partial(
|
|
|
+ collate_fn,
|
|
|
+ padding_idx=self._tokenizer.pad_id,
|
|
|
+ ignore_idx=self._loss_fn.ignore_index,
|
|
|
+ )
|
|
|
+ if not packed
|
|
|
+ else padded_collate_packed
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ utils.log_rank_zero(log, "Dataset and Sampler are initialized.")
|
|
|
+
|
|
|
+ return sampler, dataloader
|
|
|
+
|
|
|
+ def train(self) -> None:
|
|
|
+ """
|
|
|
+ The core training loop.
|
|
|
+ """
|
|
|
+ # clean up before training begins
|
|
|
+ training.cleanup_before_training()
|
|
|
+
|
|
|
+ # zero out the gradients before starting training
|
|
|
+ if not self._optimizer_in_bwd:
|
|
|
+ self._optimizer.zero_grad()
|
|
|
+ else:
|
|
|
+ for opt in self._optim_ckpt_wrapper.optim_map.values():
|
|
|
+ opt.zero_grad()
|
|
|
+
|
|
|
+ # Initialize tokens count and running loss (for grad accumulation)
|
|
|
+ t0 = time.perf_counter()
|
|
|
+ running_loss = 0
|
|
|
+ num_tokens = 0
|
|
|
+
|
|
|
+ self._profiler.start()
|
|
|
+ # self.epochs_run should be non-zero when we're resuming from a checkpoint
|
|
|
+ for curr_epoch in range(self.epochs_run, self.total_epochs):
|
|
|
+ # Update the sampler to ensure data is correctly shuffled across epochs
|
|
|
+ # in case shuffle is True
|
|
|
+ self._sampler.set_epoch(curr_epoch)
|
|
|
+
|
|
|
+ pbar = tqdm(total=self._steps_per_epoch, disable=not self._is_rank_zero)
|
|
|
+ for idx, batch in enumerate(self._dataloader):
|
|
|
+ if (
|
|
|
+ self.max_steps_per_epoch is not None
|
|
|
+ and (idx // self._gradient_accumulation_steps)
|
|
|
+ == self.max_steps_per_epoch
|
|
|
+ ):
|
|
|
+ break
|
|
|
+
|
|
|
+ # Start tracking CUDA memory for active steps for just the first epoch
|
|
|
+ if (
|
|
|
+ self._is_rank_zero
|
|
|
+ and curr_epoch == 0
|
|
|
+ and self.profiler_profile_memory
|
|
|
+ and idx == self.profiler_wait_steps + self.profiler_warmup_steps
|
|
|
+ and self._device.type == "cuda"
|
|
|
+ ):
|
|
|
+ torch.cuda.memory._record_memory_history()
|
|
|
+ utils.batch_to_device(batch, self._device)
|
|
|
+
|
|
|
+ # Calculate the number of unmasked tokens in the current batch
|
|
|
+ # and increment the total number of tokens seen in the step
|
|
|
+ current_num_tokens = (
|
|
|
+ batch["labels"] != self._loss_fn.ignore_index
|
|
|
+ ).sum()
|
|
|
+ num_tokens += current_num_tokens
|
|
|
+
|
|
|
+ # Shape [b, s], needed for the loss not the model
|
|
|
+ labels = batch.pop("labels")
|
|
|
+
|
|
|
+ with self.activations_handling_ctx:
|
|
|
+ logits = self._model(**batch)
|
|
|
+
|
|
|
+ # Shift labels to compute loss
|
|
|
+ # equivalent to doing labels[..., 1:] and logits[..., :-1, :]
|
|
|
+ # But this way we dont need to slice the logits. We just add an ignore index to labels.
|
|
|
+ labels = torch.hstack(
|
|
|
+ (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]])
|
|
|
+ )
|
|
|
+ if not isinstance(logits, list):
|
|
|
+ labels = labels.reshape(-1)
|
|
|
+ logits = logits.reshape(-1, logits.size(-1))
|
|
|
+
|
|
|
+ # Compute loss
|
|
|
+ # Loss is normalized by default so we multiply by the number of tokens
|
|
|
+ # This way we can normalize by the total number of tokens if we're accumulating gradients
|
|
|
+ current_loss = self._loss_fn(logits, labels) * current_num_tokens
|
|
|
+
|
|
|
+ # free logits otherwise it peaks backward memory
|
|
|
+ del logits
|
|
|
+
|
|
|
+ running_loss += current_loss
|
|
|
+
|
|
|
+ # For optimizer in backward, we need to normalize before calling backward
|
|
|
+ # This case and gradient accumulation are mutually exclusive
|
|
|
+ if self._optimizer_in_bwd:
|
|
|
+ torch.distributed.all_reduce(num_tokens)
|
|
|
+ torch.distributed.all_reduce(running_loss)
|
|
|
+
|
|
|
+ # We multiply by world_size to undo FSDP2 gradient normalization.
|
|
|
+ current_loss = current_loss * (self.world_size / num_tokens)
|
|
|
+
|
|
|
+ current_loss.backward()
|
|
|
+
|
|
|
+ # Step with optimizer
|
|
|
+ if (idx + 1) % self._gradient_accumulation_steps == 0:
|
|
|
+ if not self._optimizer_in_bwd:
|
|
|
+ # Get total number of tokens across all ranks to normalize gradients
|
|
|
+ torch.distributed.all_reduce(num_tokens)
|
|
|
+ # This will ensure that the logged loss matches what we're optimizing
|
|
|
+ torch.distributed.all_reduce(running_loss)
|
|
|
+ # Manually scale the gradients from unnormalized loss by total # of tokens
|
|
|
+ # We multiply by world_size to undo FSDP2 gradient normalization.
|
|
|
+ training.scale_grads(self._model, self.world_size / num_tokens)
|
|
|
+ if self._clip_grad_norm is not None:
|
|
|
+ grad_norm = torch.nn.utils.clip_grad_norm_(
|
|
|
+ self._model.parameters(),
|
|
|
+ max_norm=float(self._clip_grad_norm),
|
|
|
+ )
|
|
|
+ # If sharded, collect the DTensor here
|
|
|
+ if isinstance(grad_norm, DTensor):
|
|
|
+ grad_norm = grad_norm.full_tensor()
|
|
|
+ self._optimizer.step()
|
|
|
+ self._optimizer.zero_grad(set_to_none=True)
|
|
|
+
|
|
|
+ # Update the number of steps when the weights are updated
|
|
|
+ self.global_step += 1
|
|
|
+
|
|
|
+ # Step the learning rate scheduler
|
|
|
+ if self._lr_scheduler is not None:
|
|
|
+ self._lr_scheduler.step()
|
|
|
+
|
|
|
+ loss_to_log = running_loss.item() / num_tokens
|
|
|
+ pbar.update(1)
|
|
|
+ pbar.set_description(
|
|
|
+ f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
|
|
|
+ )
|
|
|
+
|
|
|
+ # Log per-step metrics
|
|
|
+ if (
|
|
|
+ self.global_step % self._log_every_n_steps == 0
|
|
|
+ and self._is_rank_zero
|
|
|
+ ):
|
|
|
+ time_per_step = time.perf_counter() - t0
|
|
|
+ log_dict = {
|
|
|
+ "loss": loss_to_log,
|
|
|
+ "lr": get_lr(
|
|
|
+ (
|
|
|
+ self._optimizer
|
|
|
+ if not self._optimizer_in_bwd
|
|
|
+ else self._optim_ckpt_wrapper
|
|
|
+ ),
|
|
|
+ ),
|
|
|
+ "tokens_per_second_per_gpu": num_tokens
|
|
|
+ / (time_per_step * self.world_size),
|
|
|
+ }
|
|
|
+ if self._log_peak_memory_stats:
|
|
|
+ log_dict.update(
|
|
|
+ training.get_memory_stats(device=self._device)
|
|
|
+ )
|
|
|
+ if self._clip_grad_norm is not None:
|
|
|
+ log_dict.update({"grad_norm": grad_norm})
|
|
|
+ self._metric_logger.log_dict(
|
|
|
+ log_dict,
|
|
|
+ step=self.global_step,
|
|
|
+ )
|
|
|
+
|
|
|
+ # Reset running stats for the next step
|
|
|
+ running_loss = 0
|
|
|
+ num_tokens = 0
|
|
|
+ t0 = time.perf_counter()
|
|
|
+
|
|
|
+ # Stop tracking CUDA memory now that active steps are complete
|
|
|
+ if (
|
|
|
+ self._is_rank_zero
|
|
|
+ and curr_epoch == 0
|
|
|
+ and self.profiler_profile_memory
|
|
|
+ and idx
|
|
|
+ == self.profiler_wait_steps
|
|
|
+ + self.profiler_warmup_steps
|
|
|
+ + self.profiler_active_steps
|
|
|
+ and self._device.type == "cuda"
|
|
|
+ ):
|
|
|
+ torch.cuda.memory._record_memory_history(enabled=None)
|
|
|
+
|
|
|
+ # Step profiler
|
|
|
+ # Note that this is called within gradient accumulation block, hence
|
|
|
+ # will include multiple forward / backward passes if gradient accumulation > 1
|
|
|
+ self._profiler.step()
|
|
|
+
|
|
|
+ self.epochs_run += 1
|
|
|
+ # self._checkpoint_client.save_checkpoint(
|
|
|
+ # model=self._model,
|
|
|
+ # optimizer=(
|
|
|
+ # self._optimizer
|
|
|
+ # if not self._optimizer_in_bwd
|
|
|
+ # else self._optim_ckpt_wrapper
|
|
|
+ # ),
|
|
|
+ # training_progress=TrainingProgress(
|
|
|
+ # seed=self.seed,
|
|
|
+ # epochs_run=self.epochs_run,
|
|
|
+ # total_epochs=self.total_epochs,
|
|
|
+ # max_steps_per_epoch=self.max_steps_per_epoch,
|
|
|
+ # ),
|
|
|
+ # epoch=curr_epoch,
|
|
|
+ # )
|
|
|
+ self.epochs_run += 1
|
|
|
+ if curr_epoch > 0 and curr_epoch % self.save_every_epochs == 0:
|
|
|
+ utils.log_rank_zero(log, f"Saving checkpoint at epoch {curr_epoch}")
|
|
|
+ self._checkpoint_client.save_checkpoint(
|
|
|
+ model=self._model,
|
|
|
+ optimizer=(
|
|
|
+ self._optimizer
|
|
|
+ if not self._optimizer_in_bwd
|
|
|
+ else self._optim_ckpt_wrapper
|
|
|
+ ),
|
|
|
+ training_progress=TrainingProgress(
|
|
|
+ seed=self.seed,
|
|
|
+ epochs_run=self.epochs_run,
|
|
|
+ total_epochs=self.total_epochs,
|
|
|
+ max_steps_per_epoch=self.max_steps_per_epoch,
|
|
|
+ ),
|
|
|
+ epoch=curr_epoch,
|
|
|
+ )
|
|
|
+
|
|
|
+ self._profiler.stop()
|
|
|
+
|
|
|
+ def cleanup(self) -> None:
|
|
|
+ if self._is_rank_zero:
|
|
|
+ self._metric_logger.close()
|
|
|
+ destroy_process_group()
|
|
|
+
|
|
|
+
|
|
|
+@config.parse
|
|
|
+def recipe_main(cfg: DictConfig) -> None:
|
|
|
+ """
|
|
|
+ Entry point for the recipe.
|
|
|
+
|
|
|
+ Configurable parameters are read in the following order:
|
|
|
+ - Parameters specified in config (see available configs through ``tune ls``)
|
|
|
+ - Overwritten by arguments from the command-line
|
|
|
+ """
|
|
|
+ config.log_config(recipe_name="FullFinetuneRecipeDistributed", cfg=cfg)
|
|
|
+ recipe = FullFinetuneRecipeDistributed(cfg=cfg)
|
|
|
+ recipe.setup(cfg=cfg)
|
|
|
+ recipe.train()
|
|
|
+ recipe.cleanup()
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ sys.exit(recipe_main())
|