fft.py 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import sys
  7. import time
  8. from functools import partial
  9. from typing import Any, Dict, List, Optional, Tuple, Union
  10. from warnings import warn
  11. import torch
  12. from omegaconf import DictConfig, ListConfig
  13. from torch import nn
  14. from torch.distributed import (
  15. destroy_process_group,
  16. init_device_mesh,
  17. init_process_group,
  18. )
  19. from torch.distributed._tensor import DTensor
  20. from torch.distributed.tensor.parallel import parallelize_module
  21. from torch.optim import Optimizer
  22. from torch.utils.data import DataLoader, DistributedSampler
  23. from torchtune import config, modules, training, utils
  24. from torchtune.config._utils import _get_component_from_path
  25. from torchtune.data import padded_collate_packed
  26. from torchtune.datasets import ConcatDataset
  27. from torchtune.recipe_interfaces import FTRecipeInterface
  28. from torchtune.training import DummyProfiler, PROFILER_KEY
  29. from torchtune.training.activations import apply_selective_activation_checkpointing
  30. from torchtune.training.checkpointing._checkpoint_client import (
  31. CheckpointClient,
  32. TrainingProgress,
  33. )
  34. from torchtune.training.lr_schedulers import get_lr
  35. from tqdm import tqdm
  36. log = utils.get_logger("DEBUG")
  37. class FullFinetuneRecipeDistributed(FTRecipeInterface):
  38. """
  39. Full finetuning recipe for dense transformer-based LLMs such as Llama2. This recipe supports
  40. distributed training and can be run on a single node (1 to 8 GPUs).
  41. Features:
  42. - FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states
  43. is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is
  44. done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config
  45. ``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy).
  46. DDP is currently not supported. Training on CPU is not supported.
  47. - Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing``
  48. flag. Activation checkpointing helps reduce the memory footprint since we no longer keep
  49. activations in memory and instead recompute them during the backward pass. This is especially
  50. helpful for larger batch sizes when you're memory constrained. But these savings in memory
  51. come at the cost of training performance. In most cases training can slow-down quite a bit as
  52. a result of this activation recomputation.
  53. - Activation Offloading. This can be controlled using the ``enable_activation_offloading``
  54. flag. Activation offloading is a technique similar to activations checkpointing that helps
  55. reduce the memory footprint to prevent OOMs on CUDA and enable bigger batches. Where activations
  56. checkpointing drops the activation in the forward to recompute it later in the backward,
  57. activations offloading will drop the activation in the forward to the CPU and bring it
  58. back during the backward pass. As always, there is a tradeoff--these savings in memory can
  59. come at the cost of training performance and CPU resources. To recover some runtime cost,
  60. we've added an option to enable offloading on a different stream to permit overlapping with
  61. the computation. This option is currently only available on PyTorch 2.5 or later and will
  62. be enabled by default if an acceptable torch version is found. Activation offloading can be
  63. used in conjunction with activation checkpointing.
  64. - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype``
  65. flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In
  66. most cases this should halve the memory footprint of full precision (fp32) training, without
  67. loss in model quality (will depend on the model, training data and other settings). For
  68. GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16
  69. precision are currently not supported.
  70. - Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is
  71. controlled using the ``gradient_accumulation_steps`` flag.
  72. Total Batch Size = batch_size * number of GPUs * gradient accumulation steps.
  73. For example: with batch_size=1, nproc_per_node=2 and gradient_accumulation_steps=32 we get a
  74. total batch size of 64.
  75. Gradient accumulation is especially useful when you are memory constrained. In this case,
  76. accumulating gradients might give you better training speed than enabling activation
  77. checkpointing.
  78. - Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of
  79. training. Optimizer state and recipe state (seed, total_epochs, number of epochs run etc) are
  80. only saved at the end of a given epoch and used in case of resuming training.
  81. Resuming training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is
  82. currently not supported.
  83. For more details on the checkpointer, please take a look at
  84. our checkpointer deepdive (https://pytorch.org/torchtune/main/deep_dives/checkpointer.html).
  85. - Logging. Terminal, Disk, WandB and TensorBoard are all supported.
  86. - Gradient Clipping. Gradient clipping is supported using the ``clip_grad_norm`` flag. By default,
  87. ``clip_grad_norm`` is set to ``None``. If you only want to log the grad norm, you can set
  88. ``clip_grad_norm='inf'``.
  89. For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config
  90. has example commands for how to kick-off training.
  91. Args:
  92. cfg (DictConfig): OmegaConf object parsed from yaml file
  93. Raises:
  94. ValueError: If ``dtype`` is set to fp16.
  95. RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
  96. RuntimeError: If ``left_pad_sequence`` is set as the data collator.
  97. RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA.
  98. RuntimeError: If ``enable_activation_offloading`` is True and ``enable_activation_checkpointing`` is False.
  99. """
  100. def __init__(self, cfg: DictConfig) -> None:
  101. device_type = cfg.device
  102. self._device = utils.get_device(device=device_type)
  103. self._dtype = training.get_dtype(cfg.dtype, device=self._device)
  104. if self._dtype == torch.float16:
  105. raise ValueError(
  106. "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
  107. )
  108. # Set up the backend for distributed training (NCCL, GLOO, etc.)
  109. self._enable_async_checkpointing = cfg.get("enable_async_checkpointing", False)
  110. self.fsdp_cpu_offload = cfg.get("fsdp_cpu_offload", False)
  111. self.distributed_backend = training.get_distributed_backend(
  112. device_type,
  113. offload_ops_to_cpu=self.fsdp_cpu_offload
  114. or self._enable_async_checkpointing,
  115. )
  116. init_process_group(self.distributed_backend)
  117. # Initialize distributed variables
  118. self.world_size, self.rank = utils.get_world_size_and_rank()
  119. self._is_rank_zero = self.rank == 0
  120. self.tensor_parallel_plan = config.instantiate(
  121. cfg.get("tensor_parallel_plan", None)
  122. )
  123. self.tensor_parallel_dim = cfg.get("tensor_parallel_dim", 1)
  124. if self.tensor_parallel_dim > 1 and self.tensor_parallel_plan is None:
  125. raise ValueError(
  126. "Tensor Parallel plan needs to be provided when tensor parallel is enabled."
  127. )
  128. if self.world_size % self.tensor_parallel_dim != 0:
  129. raise ValueError(
  130. f"world_size {self.world_size} must be divisible by tensor_parallel_dim {self.tensor_parallel_dim}"
  131. )
  132. self.data_parallel_dim = self.world_size // self.tensor_parallel_dim
  133. # Logging attributes
  134. self._output_dir = cfg.output_dir
  135. self._log_every_n_steps = cfg.get("log_every_n_steps", 1)
  136. self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False)
  137. if self._log_peak_memory_stats and device_type != "cuda":
  138. log.info(
  139. "log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False."
  140. )
  141. self._log_peak_memory_stats = False
  142. # Training cfg
  143. self._resume_from_checkpoint = cfg.resume_from_checkpoint
  144. self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
  145. self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False)
  146. self._clip_grad_norm = cfg.get("clip_grad_norm", None)
  147. self._checkpoint_client = CheckpointClient(cfg)
  148. self.save_every_epochs = cfg.get("save_every_epochs", 1)
  149. # Optimizer in backward is not compatible with gradient accumulation or gradient clipping
  150. if self._optimizer_in_bwd:
  151. if self._clip_grad_norm is not None:
  152. raise RuntimeError(
  153. "Gradient clipping is not supported with optimizer in bwd."
  154. "Please set clip_grad_norm=None, or optimizer_in_bwd=False."
  155. )
  156. if self._gradient_accumulation_steps > 1:
  157. raise RuntimeError(
  158. "Gradient accumulation is not supported with optimizer in bwd."
  159. "Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False."
  160. )
  161. # activation checkpointing/offloading
  162. self._enable_activation_checkpointing = cfg.get(
  163. "enable_activation_checkpointing", False
  164. )
  165. self._enable_activation_offloading = cfg.get(
  166. "enable_activation_offloading", False
  167. )
  168. if self._enable_activation_offloading:
  169. if device_type != "cuda":
  170. raise RuntimeError(
  171. "enable_activation_offloading should only be True when training on CUDA"
  172. )
  173. if not self._enable_activation_checkpointing:
  174. raise RuntimeError(
  175. "enable_activation_offloading should only be True when enable_activation_checkpointing is True"
  176. )
  177. elif (
  178. self._enable_activation_checkpointing
  179. and cfg.checkpointer.model_type != "LLAMA3_VISION"
  180. ):
  181. utils.log_rank_zero(
  182. log,
  183. "Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. "
  184. "Enabling activation offloading should reduce memory further.",
  185. )
  186. # These are public properties which are updated by the checkpoint loader
  187. # when ``resume_from_checkpoint`` is `True` or validated in tests
  188. self.seed = training.set_seed(
  189. seed=cfg.seed, debug_mode=cfg.get("cudnn_deterministic_mode", None)
  190. )
  191. self.epochs_run = 0
  192. self.total_epochs = cfg.epochs
  193. self.max_steps_per_epoch = cfg.max_steps_per_epoch
  194. self.global_step = 0
  195. def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None:
  196. """
  197. Updates the recipe state from checkpoint.
  198. """
  199. try:
  200. self.epochs_run = ckpt_dict[training.EPOCHS_KEY]
  201. # on mismatch, warn the user and prevent the override
  202. if self.seed != ckpt_dict[training.SEED_KEY]:
  203. warn(
  204. message=(
  205. "Config value for seed does not match the checkpoint value, "
  206. f"using the checkpoint value: {ckpt_dict[training.SEED_KEY]}"
  207. )
  208. )
  209. self.seed = ckpt_dict[training.SEED_KEY]
  210. if self.max_steps_per_epoch != ckpt_dict[training.MAX_STEPS_KEY]:
  211. warn(
  212. message=(
  213. "Config value for max_steps_per_epoch does not match the checkpoint value, "
  214. f"using the checkpoint value: {ckpt_dict[training.MAX_STEPS_KEY]}"
  215. )
  216. )
  217. self.max_steps_per_epoch = ckpt_dict[training.MAX_STEPS_KEY]
  218. # on mismatch, warn the user but allow the override
  219. if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]:
  220. warn(
  221. message=(
  222. "Config value for total_epochs does not match the checkpoint value, "
  223. f"using the config value: {self.total_epochs}"
  224. )
  225. )
  226. except KeyError as e:
  227. raise KeyError(
  228. "Checkpoint does not contain the required keys needed for updating recipe state. "
  229. "Are you sure you passed in the right recipe checkpoint?"
  230. ) from e
  231. def setup(self, cfg: DictConfig) -> None:
  232. """
  233. Setup the recipe. This includes training state (if resume_from_checkpoint is True),
  234. model, tokenizer, loss, optimizer, lr scheduler, sampler, and dataloader.
  235. """
  236. if self.fsdp_cpu_offload:
  237. # Utilize all available CPU cores for intra-op parallelism. This provides ~2x
  238. # speed up when benchmarking fused AdamW on CPU
  239. training.set_torch_num_threads()
  240. if self._is_rank_zero:
  241. self._metric_logger = config.instantiate(cfg.metric_logger)
  242. # log config with parameter override
  243. self._metric_logger.log_config(cfg)
  244. # Load the base model
  245. checkpoint_dict = self._checkpoint_client.load_base_checkpoint()
  246. self._compile = cfg.get("compile", False)
  247. self._model = self._setup_model(
  248. cfg_model=cfg.model,
  249. enable_activation_checkpointing=self._enable_activation_checkpointing,
  250. enable_activation_offloading=self._enable_activation_offloading,
  251. custom_sharded_layers=cfg.get("custom_sharded_layers", None),
  252. fsdp_cpu_offload=self.fsdp_cpu_offload,
  253. reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True),
  254. model_state_dict=checkpoint_dict[training.MODEL_KEY],
  255. ac_mode=cfg.get("ac_mode", None),
  256. ac_option=cfg.get("ac_option", None),
  257. )
  258. self._tokenizer = config.instantiate(cfg.tokenizer)
  259. self._optimizer = self._setup_optimizer(
  260. cfg_optimizer=cfg.optimizer,
  261. optimizer_in_bwd=self._optimizer_in_bwd,
  262. opt_state_dict=(
  263. checkpoint_dict[training.OPT_KEY]
  264. if training.OPT_KEY in checkpoint_dict
  265. else None
  266. ),
  267. )
  268. if self._resume_from_checkpoint:
  269. # If async checkpointing is enabled, intermediate checkpoints are saved asynchronously
  270. # using the DistributedCheckpointer.
  271. # Therefore the recipe needs to load the distributed checkpoint to restore the training
  272. # progress.
  273. if self._enable_async_checkpointing:
  274. try:
  275. checkpoint_dict = (
  276. self._checkpoint_client.load_distributed_checkpoint(
  277. self._model,
  278. (
  279. self._optim_ckpt_wrapper
  280. if self._optimizer_in_bwd
  281. else self._optimizer
  282. ),
  283. )
  284. )
  285. except Exception as e:
  286. log.warning(
  287. f"Failed to load distributed checkpoint: {e}. Training will start from the base checkpoint."
  288. )
  289. # Update the recipe state from the checkpoint state dict.
  290. self._update_recipe_state(checkpoint_dict)
  291. # initialize loss
  292. self._loss_fn = config.instantiate(cfg.loss)
  293. if self._compile:
  294. training.compile_loss(self._loss_fn, verbose=self._is_rank_zero)
  295. if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss":
  296. # set num_output_chunks for model
  297. self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
  298. utils.log_rank_zero(log, "Loss is initialized.")
  299. # sampler and dataloader depend on the tokenizer and loss_fn and should be
  300. # setup after both of these are initialized
  301. collate_name = cfg.get("collate_fn", "torchtune.data.padded_collate_sft")
  302. self._sampler, self._dataloader = self._setup_data(
  303. cfg_dataset=cfg.dataset,
  304. shuffle=cfg.shuffle,
  305. batch_size=cfg.batch_size,
  306. collate_fn=collate_name,
  307. )
  308. # Finally update the recipe state which can only be correctly set after all of the
  309. # other components have been initialized and updated.
  310. #
  311. # Number of training steps in each epoch depends on the number of batches produced
  312. # by the dataloader, the max_steps_per_epoch param set by the user and the
  313. # gradient_accumulation_steps param. This value is used for logging and tracking
  314. # training state. The computation should happen after the dataloader has been setup
  315. self._steps_per_epoch = (
  316. len(self._dataloader) // self._gradient_accumulation_steps
  317. )
  318. if (
  319. self.max_steps_per_epoch is not None
  320. and self.max_steps_per_epoch < self._steps_per_epoch
  321. ):
  322. self._steps_per_epoch = self.max_steps_per_epoch
  323. self.global_step = self.epochs_run * self._steps_per_epoch
  324. # Setup lr scheduler
  325. self._lr_scheduler = self._setup_lr_scheduler(
  326. cfg_lr_scheduler=cfg.get("lr_scheduler", None),
  327. num_training_steps=self.total_epochs * self._steps_per_epoch,
  328. last_epoch=self.global_step - 1,
  329. )
  330. # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method)
  331. # if cfg is missing profiler key or if `cfg.profiler.enabled = False`
  332. self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None))
  333. # Used to ignore labels for loss computation
  334. self.ignore_labels_cache = torch.full(
  335. (cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device
  336. )
  337. def _setup_lr_scheduler(
  338. self,
  339. cfg_lr_scheduler: Optional[DictConfig],
  340. num_training_steps: int,
  341. last_epoch: int,
  342. ) -> Optional[Optimizer]:
  343. """
  344. Set up the learning rate scheduler based on the provided configuration.
  345. It supports both standard optimization and optimizer-in-backward cases.
  346. Args:
  347. cfg_lr_scheduler (Optional[DictConfig]): The learning rate scheduler configuration.
  348. num_training_steps (int): The total number of training steps.
  349. last_epoch (int): The index of the last epoch.
  350. Returns:
  351. lr_scheduler (Optional[Optimizer]): The learning rate scheduler.
  352. """
  353. if cfg_lr_scheduler is None:
  354. if self._is_rank_zero:
  355. log.info(
  356. "No learning rate scheduler configured. Using constant learning rate."
  357. )
  358. return None
  359. if self._optimizer_in_bwd:
  360. # Use the first optimizer from the wrapper to represent the learning rate
  361. optimizer = next(iter(self._optim_ckpt_wrapper.optim_map.values()))
  362. else:
  363. # Standard case: use the single optimizer
  364. optimizer = self._optimizer
  365. # Instantiate the learning rate scheduler
  366. lr_scheduler = config.instantiate(
  367. cfg_lr_scheduler,
  368. optimizer,
  369. num_training_steps=num_training_steps,
  370. last_epoch=last_epoch,
  371. )
  372. if self._optimizer_in_bwd:
  373. # Modify the scheduler for optimizer_in_bwd case
  374. self._optim_ckpt_wrapper.set_lr_scheduler(lr_scheduler)
  375. if self._is_rank_zero:
  376. log.info("Learning rate scheduler is initialized.")
  377. return lr_scheduler
  378. def _setup_profiler(
  379. self, cfg_profiler: Optional[DictConfig] = None
  380. ) -> Union[torch.profiler.profile, DummyProfiler]:
  381. """
  382. Parses the `profiler` section of top-level `cfg` and sets up profiler
  383. Args:
  384. cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to
  385. `recipe.main`). Default None.
  386. Returns:
  387. profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods
  388. for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such
  389. that the instrumented training loop does not need to be changed profiling is disabled.
  390. The profiler config can be provided in configs under the `profiler` key with the following layout:
  391. .. code-block:: yaml
  392. profiler:
  393. enabled: bool
  394. #Output directory of trace artifacts
  395. output_dir: str
  396. #`torch.profiler.ProfilerActivity` types to trace
  397. cpu: bool
  398. cuda: bool
  399. #Trace options
  400. profile_memory: bool
  401. with_stack: bool
  402. record_shapes: bool
  403. with_flops: bool
  404. # `torch.profiler.schedule` options:
  405. # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
  406. wait_steps: int
  407. warmup_steps: int
  408. active_steps: int
  409. num_cycles: int
  410. """
  411. # Missing profiler section in config, assume disabled
  412. if cfg_profiler is None:
  413. cfg_profiler = DictConfig({"enabled": False})
  414. # Check that component is included and set correctly
  415. if cfg_profiler.get("_component_", None) is None:
  416. cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler"
  417. else:
  418. assert (
  419. cfg_profiler.get("_component_")
  420. == "torchtune.training.setup_torch_profiler"
  421. ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`"
  422. profiler, profiler_cfg = config.instantiate(cfg_profiler)
  423. utils.log_rank_zero(
  424. log, f" Profiler config after instantiation: {profiler_cfg}"
  425. )
  426. if self._is_rank_zero:
  427. self.profiler_profile_memory = profiler_cfg.get("profile_memory", False)
  428. if profiler_cfg["enabled"]:
  429. self.profiler_wait_steps = profiler_cfg["wait_steps"]
  430. self.profiler_warmup_steps = profiler_cfg["warmup_steps"]
  431. self.profiler_active_steps = profiler_cfg["active_steps"]
  432. return profiler
  433. def _setup_model(
  434. self,
  435. cfg_model: DictConfig,
  436. enable_activation_checkpointing: bool,
  437. enable_activation_offloading: bool,
  438. fsdp_cpu_offload: bool,
  439. reshard_after_forward: bool,
  440. model_state_dict: Dict[str, Any],
  441. custom_sharded_layers: Optional[List[str]] = None,
  442. ac_mode: Optional[str] = None,
  443. ac_option: Optional[int] = None,
  444. ) -> nn.Module:
  445. """
  446. Model initialization has some important considerations:
  447. a. To minimize GPU peak memory, we initialize the model on meta device with
  448. the right dtype
  449. b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since
  450. full state dicts are loaded with ``torch.load(mmap=True)``
  451. """
  452. utils.log_rank_zero(
  453. log,
  454. "Distributed training is enabled. Instantiating model and loading checkpoint on Rank 0 ...",
  455. )
  456. init_start = time.perf_counter()
  457. with training.set_default_dtype(self._dtype), torch.device("meta"):
  458. model = config.instantiate(cfg_model)
  459. if self._compile:
  460. training.compile_model(model, verbose=self._is_rank_zero)
  461. device_mesh = init_device_mesh(
  462. self._device.type,
  463. mesh_shape=(self.data_parallel_dim, self.tensor_parallel_dim),
  464. mesh_dim_names=("dp", "tp"),
  465. )
  466. self.dp_size = device_mesh["dp"].size()
  467. self.dp_rank = device_mesh["dp"].get_local_rank()
  468. # Apply tensor parallelism to the model
  469. if self.tensor_parallel_dim > 1:
  470. # Use the local number (num_heads, num_kv_heads, embed_dim) to account for tensor parallel
  471. model = training.prepare_mha_for_tp(model, device_mesh["tp"])
  472. parallelize_module(
  473. model,
  474. device_mesh["tp"],
  475. parallelize_plan=self.tensor_parallel_plan,
  476. )
  477. # We currently have two versions of activation checkpointing in this recipe
  478. # for testing and BC purposes. ``enable_activation_checkpointing`` controls
  479. # the older version of AC and this behavior is unchanged
  480. # ac_mode and ac_option together control selective AC. This is only enabled
  481. # when these are set AND ``enable_activation_checkpointing`` is set to False
  482. # We'll clean this up as soon as testing of AC is complete
  483. if (not enable_activation_checkpointing) and (ac_mode is not None):
  484. apply_selective_activation_checkpointing(
  485. model,
  486. ac_mode,
  487. ac_option,
  488. )
  489. # original activation checkpointing (full) - flip the condition above
  490. if enable_activation_checkpointing and ac_mode is None:
  491. training.set_activation_checkpointing(
  492. model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
  493. )
  494. # Apply Fully Sharded Data Parallelism to the model
  495. if self.data_parallel_dim > 1:
  496. fsdp_shard_conditions = [
  497. partial(
  498. training.get_shard_conditions,
  499. names_to_match=custom_sharded_layers,
  500. )
  501. ]
  502. training.shard_model(
  503. model=model,
  504. shard_conditions=fsdp_shard_conditions,
  505. cpu_offload=fsdp_cpu_offload,
  506. reshard_after_forward=reshard_after_forward,
  507. dp_mesh=device_mesh["dp"],
  508. )
  509. with training.set_default_dtype(self._dtype), self._device:
  510. for m in model.modules():
  511. # RoPE is not covered in state dict
  512. if hasattr(m, "rope_init"):
  513. m.rope_init()
  514. # This method will convert the full model state dict into a sharded state
  515. # dict and load into the model
  516. training.load_from_full_model_state_dict(
  517. model,
  518. model_state_dict,
  519. self._device,
  520. strict=True,
  521. cpu_offload=fsdp_cpu_offload,
  522. )
  523. # activation offloading
  524. self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
  525. model, enable_activation_offloading
  526. )
  527. # Ensure no params and buffers are on meta device
  528. training.validate_no_params_on_meta_device(model)
  529. utils.log_rank_zero(
  530. log,
  531. f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs",
  532. )
  533. if self._is_rank_zero:
  534. memory_stats = training.get_memory_stats(device=self._device)
  535. training.log_memory_stats(memory_stats)
  536. # synchronize before training begins
  537. torch.distributed.barrier()
  538. return model
  539. def _setup_optimizer(
  540. self,
  541. cfg_optimizer: DictConfig,
  542. optimizer_in_bwd: bool = False,
  543. opt_state_dict: Optional[Dict[str, Any]] = None,
  544. ) -> Optional[Optimizer]:
  545. if optimizer_in_bwd:
  546. # Maintain a dict of optims for every parameter.
  547. optim_dict = {
  548. param: config.instantiate(cfg_optimizer, [param])
  549. for param in self._model.parameters()
  550. }
  551. # Register optimizer step hooks on the model to run optimizer in backward.
  552. training.register_optim_in_bwd_hooks(
  553. model=self._model, optim_dict=optim_dict
  554. )
  555. # Create a wrapper for checkpoint save/load of optimizer states when running in backward.
  556. self._optim_ckpt_wrapper = training.create_optim_in_bwd_wrapper(
  557. model=self._model, optim_dict=optim_dict
  558. )
  559. # Load optimizer states for each param. If optimizer states are being restored in an optimizer in
  560. # backward run, these need to have been saved with the same setting. Cannot restore from runs that
  561. # did not use optimizer in backward.
  562. if opt_state_dict is not None:
  563. for param in opt_state_dict.keys():
  564. try:
  565. training.load_from_full_optimizer_state_dict(
  566. self._model,
  567. self._optim_ckpt_wrapper.optim_map[param],
  568. opt_state_dict[param],
  569. self._device,
  570. )
  571. except BaseException as e:
  572. raise RuntimeError(
  573. "Failed loading in-backward optimizer checkpoints."
  574. "Please make sure run being restored from was using in-backward optimizer."
  575. ) from e
  576. utils.log_rank_zero(log, "In-backward optimizers are set up.")
  577. return None
  578. else:
  579. optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
  580. if opt_state_dict:
  581. training.load_from_full_optimizer_state_dict(
  582. self._model,
  583. optimizer,
  584. opt_state_dict,
  585. self._device,
  586. )
  587. utils.log_rank_zero(log, "Optimizer is initialized.")
  588. return optimizer
  589. def _setup_data(
  590. self,
  591. cfg_dataset: DictConfig,
  592. shuffle: bool,
  593. batch_size: int,
  594. collate_fn: str,
  595. ) -> Tuple[DistributedSampler, DataLoader]:
  596. """
  597. All data related setup happens here. Currently this recipe only supports the
  598. DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
  599. iterable datasets and streaming datasets are not supported.
  600. """
  601. if isinstance(cfg_dataset, ListConfig):
  602. datasets = [
  603. config.instantiate(single_cfg_dataset, self._tokenizer)
  604. for single_cfg_dataset in cfg_dataset
  605. ]
  606. ds = ConcatDataset(datasets=datasets)
  607. packed = getattr(ds, "packed", False)
  608. else:
  609. ds = config.instantiate(cfg_dataset, self._tokenizer)
  610. packed = cfg_dataset.get("packed", False)
  611. # Instantiate collate_fn
  612. if "left_pad_sequence" in collate_fn:
  613. raise RuntimeError("left_pad_sequence collator is only for inference.")
  614. collate_fn = _get_component_from_path(collate_fn)
  615. sampler = DistributedSampler(
  616. ds, num_replicas=self.dp_size, rank=self.dp_rank, shuffle=shuffle, seed=0
  617. )
  618. dataloader = DataLoader(
  619. dataset=ds,
  620. batch_size=batch_size,
  621. sampler=sampler,
  622. # dropping last avoids shape issues with compile + flex attention
  623. drop_last=True,
  624. collate_fn=(
  625. partial(
  626. collate_fn,
  627. padding_idx=self._tokenizer.pad_id,
  628. ignore_idx=self._loss_fn.ignore_index,
  629. )
  630. if not packed
  631. else padded_collate_packed
  632. ),
  633. )
  634. utils.log_rank_zero(log, "Dataset and Sampler are initialized.")
  635. return sampler, dataloader
  636. def train(self) -> None:
  637. """
  638. The core training loop.
  639. """
  640. # clean up before training begins
  641. training.cleanup_before_training()
  642. # zero out the gradients before starting training
  643. if not self._optimizer_in_bwd:
  644. self._optimizer.zero_grad()
  645. else:
  646. for opt in self._optim_ckpt_wrapper.optim_map.values():
  647. opt.zero_grad()
  648. # Initialize tokens count and running loss (for grad accumulation)
  649. t0 = time.perf_counter()
  650. running_loss = 0
  651. num_tokens = 0
  652. self._profiler.start()
  653. # self.epochs_run should be non-zero when we're resuming from a checkpoint
  654. for curr_epoch in range(self.epochs_run, self.total_epochs):
  655. # Update the sampler to ensure data is correctly shuffled across epochs
  656. # in case shuffle is True
  657. self._sampler.set_epoch(curr_epoch)
  658. pbar = tqdm(total=self._steps_per_epoch, disable=not self._is_rank_zero)
  659. for idx, batch in enumerate(self._dataloader):
  660. if (
  661. self.max_steps_per_epoch is not None
  662. and (idx // self._gradient_accumulation_steps)
  663. == self.max_steps_per_epoch
  664. ):
  665. break
  666. # Start tracking CUDA memory for active steps for just the first epoch
  667. if (
  668. self._is_rank_zero
  669. and curr_epoch == 0
  670. and self.profiler_profile_memory
  671. and idx == self.profiler_wait_steps + self.profiler_warmup_steps
  672. and self._device.type == "cuda"
  673. ):
  674. torch.cuda.memory._record_memory_history()
  675. utils.batch_to_device(batch, self._device)
  676. # Calculate the number of unmasked tokens in the current batch
  677. # and increment the total number of tokens seen in the step
  678. current_num_tokens = (
  679. batch["labels"] != self._loss_fn.ignore_index
  680. ).sum()
  681. num_tokens += current_num_tokens
  682. # Shape [b, s], needed for the loss not the model
  683. labels = batch.pop("labels")
  684. with self.activations_handling_ctx:
  685. logits = self._model(**batch)
  686. # Shift labels to compute loss
  687. # equivalent to doing labels[..., 1:] and logits[..., :-1, :]
  688. # But this way we dont need to slice the logits. We just add an ignore index to labels.
  689. labels = torch.hstack(
  690. (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]])
  691. )
  692. if not isinstance(logits, list):
  693. labels = labels.reshape(-1)
  694. logits = logits.reshape(-1, logits.size(-1))
  695. # Compute loss
  696. # Loss is normalized by default so we multiply by the number of tokens
  697. # This way we can normalize by the total number of tokens if we're accumulating gradients
  698. current_loss = self._loss_fn(logits, labels) * current_num_tokens
  699. # free logits otherwise it peaks backward memory
  700. del logits
  701. running_loss += current_loss
  702. # For optimizer in backward, we need to normalize before calling backward
  703. # This case and gradient accumulation are mutually exclusive
  704. if self._optimizer_in_bwd:
  705. torch.distributed.all_reduce(num_tokens)
  706. torch.distributed.all_reduce(running_loss)
  707. # We multiply by world_size to undo FSDP2 gradient normalization.
  708. current_loss = current_loss * (self.world_size / num_tokens)
  709. current_loss.backward()
  710. # Step with optimizer
  711. if (idx + 1) % self._gradient_accumulation_steps == 0:
  712. if not self._optimizer_in_bwd:
  713. # Get total number of tokens across all ranks to normalize gradients
  714. torch.distributed.all_reduce(num_tokens)
  715. # This will ensure that the logged loss matches what we're optimizing
  716. torch.distributed.all_reduce(running_loss)
  717. # Manually scale the gradients from unnormalized loss by total # of tokens
  718. # We multiply by world_size to undo FSDP2 gradient normalization.
  719. training.scale_grads(self._model, self.world_size / num_tokens)
  720. if self._clip_grad_norm is not None:
  721. grad_norm = torch.nn.utils.clip_grad_norm_(
  722. self._model.parameters(),
  723. max_norm=float(self._clip_grad_norm),
  724. )
  725. # If sharded, collect the DTensor here
  726. if isinstance(grad_norm, DTensor):
  727. grad_norm = grad_norm.full_tensor()
  728. self._optimizer.step()
  729. self._optimizer.zero_grad(set_to_none=True)
  730. # Update the number of steps when the weights are updated
  731. self.global_step += 1
  732. # Step the learning rate scheduler
  733. if self._lr_scheduler is not None:
  734. self._lr_scheduler.step()
  735. loss_to_log = running_loss.item() / num_tokens
  736. pbar.update(1)
  737. pbar.set_description(
  738. f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
  739. )
  740. # Log per-step metrics
  741. if (
  742. self.global_step % self._log_every_n_steps == 0
  743. and self._is_rank_zero
  744. ):
  745. time_per_step = time.perf_counter() - t0
  746. log_dict = {
  747. "loss": loss_to_log,
  748. "lr": get_lr(
  749. (
  750. self._optimizer
  751. if not self._optimizer_in_bwd
  752. else self._optim_ckpt_wrapper
  753. ),
  754. ),
  755. "tokens_per_second_per_gpu": num_tokens
  756. / (time_per_step * self.world_size),
  757. }
  758. if self._log_peak_memory_stats:
  759. log_dict.update(
  760. training.get_memory_stats(device=self._device)
  761. )
  762. if self._clip_grad_norm is not None:
  763. log_dict.update({"grad_norm": grad_norm})
  764. self._metric_logger.log_dict(
  765. log_dict,
  766. step=self.global_step,
  767. )
  768. # Reset running stats for the next step
  769. running_loss = 0
  770. num_tokens = 0
  771. t0 = time.perf_counter()
  772. # Stop tracking CUDA memory now that active steps are complete
  773. if (
  774. self._is_rank_zero
  775. and curr_epoch == 0
  776. and self.profiler_profile_memory
  777. and idx
  778. == self.profiler_wait_steps
  779. + self.profiler_warmup_steps
  780. + self.profiler_active_steps
  781. and self._device.type == "cuda"
  782. ):
  783. torch.cuda.memory._record_memory_history(enabled=None)
  784. # Step profiler
  785. # Note that this is called within gradient accumulation block, hence
  786. # will include multiple forward / backward passes if gradient accumulation > 1
  787. self._profiler.step()
  788. self.epochs_run += 1
  789. # self._checkpoint_client.save_checkpoint(
  790. # model=self._model,
  791. # optimizer=(
  792. # self._optimizer
  793. # if not self._optimizer_in_bwd
  794. # else self._optim_ckpt_wrapper
  795. # ),
  796. # training_progress=TrainingProgress(
  797. # seed=self.seed,
  798. # epochs_run=self.epochs_run,
  799. # total_epochs=self.total_epochs,
  800. # max_steps_per_epoch=self.max_steps_per_epoch,
  801. # ),
  802. # epoch=curr_epoch,
  803. # )
  804. self.epochs_run += 1
  805. if curr_epoch > 0 and curr_epoch % self.save_every_epochs == 0:
  806. utils.log_rank_zero(log, f"Saving checkpoint at epoch {curr_epoch}")
  807. self._checkpoint_client.save_checkpoint(
  808. model=self._model,
  809. optimizer=(
  810. self._optimizer
  811. if not self._optimizer_in_bwd
  812. else self._optim_ckpt_wrapper
  813. ),
  814. training_progress=TrainingProgress(
  815. seed=self.seed,
  816. epochs_run=self.epochs_run,
  817. total_epochs=self.total_epochs,
  818. max_steps_per_epoch=self.max_steps_per_epoch,
  819. ),
  820. epoch=curr_epoch,
  821. )
  822. self._profiler.stop()
  823. def cleanup(self) -> None:
  824. if self._is_rank_zero:
  825. self._metric_logger.close()
  826. destroy_process_group()
  827. @config.parse
  828. def recipe_main(cfg: DictConfig) -> None:
  829. """
  830. Entry point for the recipe.
  831. Configurable parameters are read in the following order:
  832. - Parameters specified in config (see available configs through ``tune ls``)
  833. - Overwritten by arguments from the command-line
  834. """
  835. config.log_config(recipe_name="FullFinetuneRecipeDistributed", cfg=cfg)
  836. recipe = FullFinetuneRecipeDistributed(cfg=cfg)
  837. recipe.setup(cfg=cfg)
  838. recipe.train()
  839. recipe.cleanup()
  840. if __name__ == "__main__":
  841. sys.exit(recipe_main())