anyprecision_optimizer.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. # AnyPrecisionAdamW: a flexible precision AdamW optimizer
  4. # with optional Kahan summation for high precision weight updates.
  5. # Allows direct control over momentum, variance and auxiliary compensation
  6. # buffer dtypes.
  7. # Optional Kahan summation is used to offset precision reduction for
  8. # the weight updates. This allows full training in BFloat16 (equal or
  9. # better than FP32 results in many cases) due to high precision weight upates.
  10. import torch
  11. from torch.optim.optimizer import Optimizer
  12. class AnyPrecisionAdamW(Optimizer):
  13. def __init__(
  14. self,
  15. params,
  16. lr=1e-3,
  17. betas=(0.9, 0.999),
  18. eps=1e-8,
  19. weight_decay=0.0,
  20. use_kahan_summation=False,
  21. momentum_dtype=torch.bfloat16,
  22. variance_dtype=torch.bfloat16,
  23. compensation_buffer_dtype=torch.bfloat16,
  24. ):
  25. """
  26. Args:
  27. params (iterable): iterable of parameters to optimize or dicts defining
  28. parameter groups
  29. lr (float, optional): learning rate (default: 1e-3)
  30. betas (Tuple[float, float], optional): coefficients used for computing
  31. running averages of gradient and its square (default: (0.9, 0.999))
  32. eps (float, optional): term added to the denominator to improve
  33. numerical stability (default: 1e-8)
  34. weight_decay (float, optional): weight decay coefficient (default: 1e-2)
  35. # Any Precision specific
  36. use_kahan_summation = creates auxiliary buffer to ensure high precision
  37. model param updates (default: False)
  38. momentum_dtype = dtype for momentum (default: BFloat32)
  39. variance_dtype = dtype for uncentered variance (default: BFloat16)
  40. compensation_buffer_dtype = dtype for Kahan summation
  41. buffer (default: BFloat16)
  42. # Usage
  43. This optimizer implements optimizer states, and Kahan summation
  44. for high precision updates, all in user controlled dtypes.
  45. Defaults are variance in BF16, Momentum in FP32.
  46. This can be run in FSDP mixed precision, amp, or full precision,
  47. depending on what training pipeline you wish to work with.
  48. Setting to use_kahan_summation = False, and changing momentum and
  49. variance dtypes to FP32, reverts this to a standard AdamW optimizer.
  50. """
  51. defaults = dict(
  52. lr=lr,
  53. betas=betas,
  54. eps=eps,
  55. weight_decay=weight_decay,
  56. use_kahan_summation=use_kahan_summation,
  57. momentum_dtype=momentum_dtype,
  58. variance_dtype=variance_dtype,
  59. compensation_buffer_dtype=compensation_buffer_dtype,
  60. )
  61. super().__init__(params, defaults)
  62. @torch.no_grad()
  63. def step(self, closure=None):
  64. """Performs a single optimization step.
  65. Args:
  66. closure (callable, optional): A closure that reevaluates the model
  67. and returns the loss.
  68. """
  69. if closure is not None:
  70. with torch.enable_grad():
  71. # to fix linter, we do not keep the returned loss for use atm.
  72. closure()
  73. for group in self.param_groups:
  74. beta1, beta2 = group["betas"]
  75. lr = group["lr"]
  76. weight_decay = group["weight_decay"]
  77. eps = group["eps"]
  78. use_kahan_summation = group["use_kahan_summation"]
  79. momentum_dtype = group["momentum_dtype"]
  80. variance_dtype = group["variance_dtype"]
  81. compensation_buffer_dtype = group["compensation_buffer_dtype"]
  82. for p in group["params"]:
  83. if p.grad is None:
  84. continue
  85. if p.grad.is_sparse:
  86. raise RuntimeError(
  87. "AnyPrecisionAdamW does not support sparse gradients"
  88. )
  89. state = self.state[p]
  90. # State initialization
  91. if len(state) == 0:
  92. state["step"] = torch.tensor(0.0)
  93. # momentum - EMA of gradient values
  94. state["exp_avg"] = torch.zeros_like(
  95. p,
  96. dtype=momentum_dtype,
  97. )
  98. # variance uncentered - EMA of squared gradient values
  99. state["exp_avg_sq"] = torch.zeros_like(
  100. p,
  101. dtype=variance_dtype,
  102. )
  103. # optional Kahan summation - accumulated error tracker
  104. if use_kahan_summation:
  105. state["compensation"] = torch.zeros_like(
  106. p,
  107. dtype=compensation_buffer_dtype,
  108. )
  109. # main processing -------------------------
  110. # update the steps for each param group update
  111. state["step"] += 1
  112. step = state["step"]
  113. exp_avg = state["exp_avg"]
  114. exp_avg_sq = state["exp_avg_sq"]
  115. grad = p.grad
  116. # weight decay, AdamW style
  117. if weight_decay:
  118. p.data.mul_(1 - lr * weight_decay)
  119. # update momentum
  120. exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
  121. # update uncentered variance
  122. exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
  123. # adjust using bias1
  124. bias_correction1 = 1 - beta1**step
  125. step_size = lr / bias_correction1
  126. # adjust using bias2
  127. denom_correction = (1 - beta2**step) ** 0.5 # avoids math import
  128. centered_variance = (exp_avg_sq.sqrt() / denom_correction).add_(
  129. eps, alpha=1
  130. )
  131. # lr update to compensation
  132. if use_kahan_summation:
  133. compensation = state["compensation"]
  134. compensation.addcdiv_(exp_avg, centered_variance, value=-step_size)
  135. # update weights with compensation (Kahan summation)
  136. # save error back to compensation for next iteration
  137. temp_buffer = p.detach().clone()
  138. p.data.add_(compensation)
  139. compensation.add_(temp_buffer.sub_(p.data))
  140. else:
  141. # usual AdamW updates
  142. p.data.addcdiv_(exp_avg, centered_variance, value=-step_size)