| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179 | 
							- # Copyright (c) Meta Platforms, Inc. and affiliates.
 
- # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
- # AnyPrecisionAdamW: a flexible precision AdamW optimizer
 
- # with optional Kahan summation for high precision weight updates.
 
- # Allows direct control over momentum, variance and auxiliary compensation
 
- # buffer dtypes.
 
- # Optional Kahan summation is used to offset precision reduction for
 
- # the weight updates. This allows full training in BFloat16 (equal or
 
- # better than FP32 results in many cases) due to high precision weight upates.
 
- import torch
 
- from torch.optim.optimizer import Optimizer
 
- class AnyPrecisionAdamW(Optimizer):
 
-     def __init__(
 
-         self,
 
-         params,
 
-         lr=1e-3,
 
-         betas=(0.9, 0.999),
 
-         eps=1e-8,
 
-         weight_decay=0.0,
 
-         use_kahan_summation=False,
 
-         momentum_dtype=torch.bfloat16,
 
-         variance_dtype=torch.bfloat16,
 
-         compensation_buffer_dtype=torch.bfloat16,
 
-     ):
 
-         """
 
-         Args:
 
-                 params (iterable): iterable of parameters to optimize or dicts defining
 
-                     parameter groups
 
-                 lr (float, optional): learning rate (default: 1e-3)
 
-                 betas (Tuple[float, float], optional): coefficients used for computing
 
-                     running averages of gradient and its square (default: (0.9, 0.999))
 
-                 eps (float, optional): term added to the denominator to improve
 
-                     numerical stability (default: 1e-8)
 
-                 weight_decay (float, optional): weight decay coefficient (default: 1e-2)
 
-                 # Any Precision specific
 
-                 use_kahan_summation = creates auxiliary buffer to ensure high precision
 
-                 model param updates (default: False)
 
-                 momentum_dtype = dtype for momentum  (default: BFloat32)
 
-                 variance_dtype = dtype for uncentered variance (default: BFloat16)
 
-                 compensation_buffer_dtype  = dtype for Kahan summation
 
-                                              buffer (default: BFloat16)
 
-                 # Usage
 
-                 This optimizer implements optimizer states, and Kahan summation
 
-                 for high precision updates, all in user controlled dtypes.
 
-                 Defaults are variance in BF16, Momentum in FP32.
 
-                 This can be run in FSDP mixed precision, amp, or full precision,
 
-                 depending on what training pipeline you wish to work with.
 
-                 Setting to use_kahan_summation = False, and changing momentum and
 
-                 variance dtypes to FP32, reverts this to a standard AdamW optimizer.
 
-         """
 
-         defaults = dict(
 
-             lr=lr,
 
-             betas=betas,
 
-             eps=eps,
 
-             weight_decay=weight_decay,
 
-             use_kahan_summation=use_kahan_summation,
 
-             momentum_dtype=momentum_dtype,
 
-             variance_dtype=variance_dtype,
 
-             compensation_buffer_dtype=compensation_buffer_dtype,
 
-         )
 
-         super().__init__(params, defaults)
 
-     @torch.no_grad()
 
-     def step(self, closure=None):
 
-         """Performs a single optimization step.
 
-         Args:
 
-             closure (callable, optional): A closure that reevaluates the model
 
-                 and returns the loss.
 
-         """
 
-         if closure is not None:
 
-             with torch.enable_grad():
 
-                 # to fix linter, we do not keep the returned loss for use atm.
 
-                 closure()
 
-         for group in self.param_groups:
 
-             beta1, beta2 = group["betas"]
 
-             lr = group["lr"]
 
-             weight_decay = group["weight_decay"]
 
-             eps = group["eps"]
 
-             use_kahan_summation = group["use_kahan_summation"]
 
-             momentum_dtype = group["momentum_dtype"]
 
-             variance_dtype = group["variance_dtype"]
 
-             compensation_buffer_dtype = group["compensation_buffer_dtype"]
 
-             for p in group["params"]:
 
-                 if p.grad is None:
 
-                     continue
 
-                 if p.grad.is_sparse:
 
-                     raise RuntimeError(
 
-                         "AnyPrecisionAdamW does not support sparse gradients"
 
-                     )
 
-                 state = self.state[p]
 
-                 # State initialization
 
-                 if len(state) == 0:
 
-                     state["step"] = torch.tensor(0.0)
 
-                     # momentum - EMA of gradient values
 
-                     state["exp_avg"] = torch.zeros_like(
 
-                         p,
 
-                         dtype=momentum_dtype,
 
-                     )
 
-                     # variance uncentered - EMA of squared gradient values
 
-                     state["exp_avg_sq"] = torch.zeros_like(
 
-                         p,
 
-                         dtype=variance_dtype,
 
-                     )
 
-                     # optional Kahan summation - accumulated error tracker
 
-                     if use_kahan_summation:
 
-                         state["compensation"] = torch.zeros_like(
 
-                             p,
 
-                             dtype=compensation_buffer_dtype,
 
-                         )
 
-                 # main processing -------------------------
 
-                 # update the steps for each param group update
 
-                 state["step"] += 1
 
-                 step = state["step"]
 
-                 exp_avg = state["exp_avg"]
 
-                 exp_avg_sq = state["exp_avg_sq"]
 
-                 grad = p.grad
 
-                 # weight decay, AdamW style
 
-                 if weight_decay:
 
-                     p.data.mul_(1 - lr * weight_decay)
 
-                 # update momentum
 
-                 exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
 
-                 # update uncentered variance
 
-                 exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
 
-                 # adjust using bias1
 
-                 bias_correction1 = 1 - beta1**step
 
-                 step_size = lr / bias_correction1
 
-                 # adjust using bias2
 
-                 denom_correction = (1 - beta2**step) ** 0.5  # avoids math import
 
-                 centered_variance = (exp_avg_sq.sqrt() / denom_correction).add_(
 
-                     eps, alpha=1
 
-                 )
 
-                 # lr update to compensation
 
-                 if use_kahan_summation:
 
-                     compensation = state["compensation"]
 
-                     compensation.addcdiv_(exp_avg, centered_variance, value=-step_size)
 
-                     # update weights with compensation (Kahan summation)
 
-                     # save error back to compensation for next iteration
 
-                     temp_buffer = p.detach().clone()
 
-                     p.data.add_(compensation)
 
-                     compensation.add_(temp_buffer.sub_(p.data))
 
-                 else:
 
-                     # usual AdamW updates
 
-                     p.data.addcdiv_(exp_avg, centered_variance, value=-step_size)
 
 
  |