123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
- # AnyPrecisionAdamW: a flexible precision AdamW optimizer
- # with optional Kahan summation for high precision weight updates.
- # Allows direct control over momentum, variance and auxiliary compensation
- # buffer dtypes.
- # Optional Kahan summation is used to offset precision reduction for
- # the weight updates. This allows full training in BFloat16 (equal or
- # better than FP32 results in many cases) due to high precision weight upates.
- import torch
- from torch.optim.optimizer import Optimizer
- class AnyPrecisionAdamW(Optimizer):
- def __init__(
- self,
- params,
- lr=1e-3,
- betas=(0.9, 0.999),
- eps=1e-8,
- weight_decay=0.0,
- use_kahan_summation=False,
- momentum_dtype=torch.bfloat16,
- variance_dtype=torch.bfloat16,
- compensation_buffer_dtype=torch.bfloat16,
- ):
- """
- Args:
- params (iterable): iterable of parameters to optimize or dicts defining
- parameter groups
- lr (float, optional): learning rate (default: 1e-3)
- betas (Tuple[float, float], optional): coefficients used for computing
- running averages of gradient and its square (default: (0.9, 0.999))
- eps (float, optional): term added to the denominator to improve
- numerical stability (default: 1e-8)
- weight_decay (float, optional): weight decay coefficient (default: 1e-2)
- # Any Precision specific
- use_kahan_summation = creates auxiliary buffer to ensure high precision
- model param updates (default: False)
- momentum_dtype = dtype for momentum (default: BFloat32)
- variance_dtype = dtype for uncentered variance (default: BFloat16)
- compensation_buffer_dtype = dtype for Kahan summation
- buffer (default: BFloat16)
- # Usage
- This optimizer implements optimizer states, and Kahan summation
- for high precision updates, all in user controlled dtypes.
- Defaults are variance in BF16, Momentum in FP32.
- This can be run in FSDP mixed precision, amp, or full precision,
- depending on what training pipeline you wish to work with.
- Setting to use_kahan_summation = False, and changing momentum and
- variance dtypes to FP32, reverts this to a standard AdamW optimizer.
- """
- defaults = dict(
- lr=lr,
- betas=betas,
- eps=eps,
- weight_decay=weight_decay,
- use_kahan_summation=use_kahan_summation,
- momentum_dtype=momentum_dtype,
- variance_dtype=variance_dtype,
- compensation_buffer_dtype=compensation_buffer_dtype,
- )
- super().__init__(params, defaults)
- @torch.no_grad()
- def step(self, closure=None):
- """Performs a single optimization step.
- Args:
- closure (callable, optional): A closure that reevaluates the model
- and returns the loss.
- """
- if closure is not None:
- with torch.enable_grad():
- # to fix linter, we do not keep the returned loss for use atm.
- closure()
- for group in self.param_groups:
- beta1, beta2 = group["betas"]
- lr = group["lr"]
- weight_decay = group["weight_decay"]
- eps = group["eps"]
- use_kahan_summation = group["use_kahan_summation"]
- momentum_dtype = group["momentum_dtype"]
- variance_dtype = group["variance_dtype"]
- compensation_buffer_dtype = group["compensation_buffer_dtype"]
- for p in group["params"]:
- if p.grad is None:
- continue
- if p.grad.is_sparse:
- raise RuntimeError(
- "AnyPrecisionAdamW does not support sparse gradients"
- )
- state = self.state[p]
- # State initialization
- if len(state) == 0:
- state["step"] = torch.tensor(0.0)
- # momentum - EMA of gradient values
- state["exp_avg"] = torch.zeros_like(
- p,
- dtype=momentum_dtype,
- )
- # variance uncentered - EMA of squared gradient values
- state["exp_avg_sq"] = torch.zeros_like(
- p,
- dtype=variance_dtype,
- )
- # optional Kahan summation - accumulated error tracker
- if use_kahan_summation:
- state["compensation"] = torch.zeros_like(
- p,
- dtype=compensation_buffer_dtype,
- )
- # main processing -------------------------
- # update the steps for each param group update
- state["step"] += 1
- step = state["step"]
- exp_avg = state["exp_avg"]
- exp_avg_sq = state["exp_avg_sq"]
- grad = p.grad
- # weight decay, AdamW style
- if weight_decay:
- p.data.mul_(1 - lr * weight_decay)
- # update momentum
- exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
- # update uncentered variance
- exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
- # adjust using bias1
- bias_correction1 = 1 - beta1**step
- step_size = lr / bias_correction1
- # adjust using bias2
- denom_correction = (1 - beta2**step) ** 0.5 # avoids math import
- centered_variance = (exp_avg_sq.sqrt() / denom_correction).add_(
- eps, alpha=1
- )
- # lr update to compensation
- if use_kahan_summation:
- compensation = state["compensation"]
- compensation.addcdiv_(exp_avg, centered_variance, value=-step_size)
- # update weights with compensation (Kahan summation)
- # save error back to compensation for next iteration
- temp_buffer = p.detach().clone()
- p.data.add_(compensation)
- compensation.add_(temp_buffer.sub_(p.data))
- else:
- # usual AdamW updates
- p.data.addcdiv_(exp_avg, centered_variance, value=-step_size)
|