123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524 |
- # coding=utf-8
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Megatron optimizer."""
- from abc import ABC
- from abc import abstractmethod
- import torch
- from apex.multi_tensor_apply import multi_tensor_applier
- import amp_C
- from megatron import get_timers
- from megatron import mpu
- from megatron import print_rank_0
- from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
- def _zero_grad_group_helper(group, set_to_none):
- """Zero out the gradient for a group of parameters.
- Note: copied from torch.optim.optimizer."""
- for param in group:
- if param.grad is not None:
- if set_to_none:
- param.grad = None
- else:
- if param.grad.grad_fn is not None:
- param.grad.detach_()
- else:
- param.grad.requires_grad_(False)
- param.grad.zero_()
- def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
- """Use multi-tensor-applier to copy values from one list to another.
- We don't have a blfoat16 implementation so for now if the overflow_buf
- is not provided, we default back to simple loop copy to be compatible
- with bfloat16."""
- if overflow_buf:
- overflow_buf.fill_(0)
- # Scaling with factor `1.0` is equivalent to copy.
- multi_tensor_applier(amp_C.multi_tensor_scale,
- overflow_buf,
- [this, that],
- 1.0)
- else:
- for this_, that_ in zip(this, that):
- that_.copy_(this_)
- class MegatronOptimizer(ABC):
- def __init__(self, optimizer, clip_grad,
- log_num_zeros_in_grad,
- params_have_main_grad):
- """Input optimizer is the base optimizer for example Adam."""
- self.optimizer = optimizer
- assert self.optimizer, 'no optimizer is provided.'
- # Set gradient clipping and logging params.
- self.clip_grad = clip_grad
- self.log_num_zeros_in_grad = log_num_zeros_in_grad
- self.params_have_main_grad = params_have_main_grad
- def get_parameters(self):
- params = []
- for param_group in self.optimizer.param_groups:
- for param in param_group['params']:
- params.append(param)
- return params
- def clip_grad_norm(self, clip_grad):
- params = self.get_parameters()
- return clip_grad_norm_fp32(params, clip_grad)
- def count_zeros(self):
- params = self.get_parameters()
- return count_zeros_fp32(params)
- @abstractmethod
- def zero_grad(self, set_to_none=True):
- pass
- @abstractmethod
- def get_loss_scale(self):
- """The output should be a cuda tensor of size 1."""
- pass
- def scale_loss(self, loss):
- """Simple scaling."""
- return self.get_loss_scale() * loss
- @abstractmethod
- def step(self):
- pass
- @abstractmethod
- def reload_model_params(self):
- """Refreshes any internal state from the current model parameters.
- Call whenever the parameters are changed outside of the optimizer.
- For example, when we load a model from a checkpoint without loading
- the optimizer, the model parameters are updated but for fp16 optimizer
- with main parameters, the main parameters need to also be updated."""
- pass
- @abstractmethod
- def state_dict(self):
- pass
- @abstractmethod
- def load_state_dict(self, state_dict):
- pass
- # Promote state so it can be retrieved or set via
- # "optimizer_instance.state"
- def _get_state(self):
- return self.optimizer.state
- def _set_state(self, value):
- self.optimizer.state = value
- state = property(_get_state, _set_state)
- # Promote param_groups so it can be retrieved or set via
- # "optimizer_instance.param_groups"
- # (for example, to adjust the learning rate)
- def _get_param_groups(self):
- return self.optimizer.param_groups
- def _set_param_groups(self, value):
- self.optimizer.param_groups = value
- param_groups = property(_get_param_groups, _set_param_groups)
- class Float16OptimizerWithFloat16Params(MegatronOptimizer):
- """Float16 optimizer for fp16 and bf16 data types.
- Arguments:
- optimizer: base optimizer such as Adam or SGD
- clip_grad: clip gradeints with this global L2 norm. Note
- that clipping is ignored if clip_grad == 0
- log_num_zeros_in_grad: return number of zeros in the gradients.
- params_have_main_grad: flag indicating if parameters have
- a `main_grad` field. If this is set, we are assuming
- that the model parameters are store in the `main_grad`
- field instead of the typical `grad` field. This happens
- for the DDP cases where there is a contihuous buffer
- holding the gradients. For example for bfloat16, we want
- to do gradient accumulation and all-reduces in float32
- and as a result we store those gradients in the main_grad.
- Note that main grad is not necessarily in float32.
- bf16: if true, the model is running in bfloat16.
- grad_scaler: used for scaling gradients. Note that this can be
- None. This case happens when `bf16 = True` and we don't
- use any loss scale. Note that for `bf16 = True`, we can have
- a constnat gradient scaler. Also for `bf16 = False`, we
- always require a grad scaler.
- """
- def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
- params_have_main_grad, bf16, grad_scaler):
- super(Float16OptimizerWithFloat16Params, self).__init__(
- optimizer, clip_grad, log_num_zeros_in_grad,
- params_have_main_grad)
- self.bf16 = bf16
- self.grad_scaler = grad_scaler
- # None grad scaler is only supported for bf16.
- if self.grad_scaler is None:
- assert self.bf16, 'fp16 expects a grad scaler.'
- # Tensor used to determine if a nan/if has happend.
- # Any non-zero value indicates inf/nan.
- # Note that we keep this for the cases that grad scaler is none.
- # We still record nan/inf if we have a bfloat16 with a grad scaler.
- if self.grad_scaler:
- self.found_inf = torch.cuda.FloatTensor([0.0])
- # Dummy tensor needed for apex multi-apply tensor.
- # For bfloat, we don't have multi-tensor apply and for now
- # we set it to none so the multi-tensor apply gets ignored.
- if bf16:
- self._dummy_overflow_buf = None
- else:
- self._dummy_overflow_buf = torch.cuda.IntTensor([0])
- # In case grad scaler is not passed, define the unity scale.
- if self.grad_scaler is None:
- self._scale_one = torch.cuda.FloatTensor([1.0])
- # ======================
- # main parameter stuff
- # ======================
- # Three groups of parameters:
- # float16_groups: original float16 parameters
- # fp32_from_float16_groups: fp32 copy of float16 parameters
- # fp32_from_fp32_groups: original fp32 parameters
- self.float16_groups = []
- self.fp32_from_float16_groups = []
- self.fp32_from_fp32_groups = []
- # For all the groups in the original optimizer:
- for param_group in self.optimizer.param_groups:
- float16_params_this_group = []
- fp32_params_this_group = []
- fp32_from_float16_params_this_group = []
- # For all the parameters in this group:
- for i, param in enumerate(param_group['params']):
- if param.requires_grad:
- # float16 params:
- if param.type() in ['torch.cuda.HalfTensor',
- 'torch.cuda.BFloat16Tensor']:
- float16_params_this_group.append(param)
- # Create a copy
- main_param = param.detach().clone().float()
- # Copy tensor model parallel attributes.
- mpu.copy_tensor_model_parallel_attributes(main_param,
- param)
- if hasattr(param, 'shared'):
- main_param.shared = param.shared
- # Replace the optimizer params with the new fp32 copy.
- param_group['params'][i] = main_param
- fp32_from_float16_params_this_group.append(main_param)
- # Reset existing state dict key to the new main param.
- if param in self.optimizer.state:
- self.optimizer.state[main_param] \
- = self.optimizer.state.pop(param)
- # fp32 params.
- elif param.type() == 'torch.cuda.FloatTensor':
- fp32_params_this_group.append(param)
- param_group['params'][i] = param
- else:
- raise TypeError('Wrapped parameters must be one of '
- 'torch.cuda.FloatTensor, '
- 'torch.cuda.HalfTensor, or '
- 'torch.cuda.BFloat16Tensor. '
- 'Received {}'.format(param.type()))
- self.float16_groups.append(float16_params_this_group)
- self.fp32_from_float16_groups.append(
- fp32_from_float16_params_this_group)
- self.fp32_from_fp32_groups.append(fp32_params_this_group)
- # Leverage state_dict() and load_state_dict() to
- # recast preexisting per-param state tensors
- self.optimizer.load_state_dict(self.optimizer.state_dict())
- def zero_grad(self, set_to_none=True):
- """We only need to zero the model related parameters, i.e.,
- float16_groups & fp32_from_fp32_groups."""
- for group in self.float16_groups:
- _zero_grad_group_helper(group, set_to_none)
- for group in self.fp32_from_fp32_groups:
- _zero_grad_group_helper(group, set_to_none)
- def get_loss_scale(self):
- if self.grad_scaler is None:
- return self._scale_one
- return self.grad_scaler.scale
- def _copy_model_grads_to_main_grads(self):
- # This only needs to be done for the float16 group.
- for model_group, main_group in zip(self.float16_groups,
- self.fp32_from_float16_groups):
- for model_param, main_param in zip(model_group, main_group):
- if self.params_have_main_grad:
- main_param.grad = model_param.main_grad.float()
- else:
- if model_param.grad is not None:
- main_param.grad = model_param.grad.float()
- # For fp32 grads, we need to reset the grads to main grad.
- if self.params_have_main_grad:
- for model_group in self.fp32_from_fp32_groups:
- for model_param in model_group:
- model_param.grad = model_param.main_grad
- def _unscale_main_grads_and_check_for_nan(self):
- main_grads = []
- # fp32 params fromm float16 ones.
- for main_group in self.fp32_from_float16_groups:
- for main_param in main_group:
- if main_param.grad is not None:
- main_grads.append(main_param.grad.data)
- # Append fp32 parameters.
- for main_group in self.fp32_from_fp32_groups:
- for main_param in main_group:
- if main_param.grad is not None:
- main_grads.append(main_param.grad.data)
- # Reset found inf.
- self.found_inf.fill_(0.0)
- # Unscale and set found inf/nan
- torch._amp_foreach_non_finite_check_and_unscale_(
- main_grads, self.found_inf, self.grad_scaler.inv_scale)
- # Update across all model parallel instances.
- torch.distributed.all_reduce(self.found_inf,
- op=torch.distributed.ReduceOp.MAX,
- group=mpu.get_model_parallel_group())
- # Check for nan.
- found_inf_flag = (self.found_inf.item() > 0)
- return found_inf_flag
- def _get_model_and_main_params_data_float16(self):
- model_data = []
- main_data = []
- for model_group, main_group in zip(self.float16_groups,
- self.fp32_from_float16_groups):
- for model_param, main_param in zip(model_group, main_group):
- model_data.append(model_param.data)
- main_data.append(main_param.data)
- return model_data, main_data
- def _copy_main_params_to_model_params(self):
- # Only needed for the float16 params.
- model_data, main_data = self._get_model_and_main_params_data_float16()
- _multi_tensor_copy_this_to_that(this=main_data, that=model_data,
- overflow_buf=self._dummy_overflow_buf)
- def _copy_model_params_to_main_params(self):
- # Only needed for the float16 params.
- model_data, main_data = self._get_model_and_main_params_data_float16()
- _multi_tensor_copy_this_to_that(this=model_data, that=main_data,
- overflow_buf=self._dummy_overflow_buf)
- def reload_model_params(self):
- self._copy_model_params_to_main_params()
- @torch.no_grad()
- def step(self):
- timers = get_timers()
- # Copy gradients from model params to main params.
- timers('optimizer-copy-to-main-grad').start()
- self._copy_model_grads_to_main_grads()
- timers('optimizer-copy-to-main-grad').stop()
- # Do unscale, check for inf, and update grad scaler only for
- # the case that grad scaler is provided.
- if self.grad_scaler:
- # Unscale and check for inf/nan.
- timers('optimizer-unscale-and-check-inf').start()
- found_inf_flag = self._unscale_main_grads_and_check_for_nan()
- timers('optimizer-unscale-and-check-inf').stop()
- # We are done with scaling gradients
- # so we can update the loss scale.
- self.grad_scaler.update(found_inf_flag)
- # If we found inf/nan, skip the update.
- if found_inf_flag:
- return False, None, None
- # Clip the main gradients.
- timers('optimizer-clip-main-grad').start()
- grad_norm = None
- if self.clip_grad > 0.0:
- grad_norm = self.clip_grad_norm(self.clip_grad)
- timers('optimizer-clip-main-grad').stop()
- # count the zeros in the grads
- num_zeros_in_grad = self.count_zeros() if \
- self.log_num_zeros_in_grad else None
- # Step the optimizer.
- self.optimizer.step()
- # Update params from main params.
- timers('optimizer-copy-main-to-model-params').start()
- self._copy_main_params_to_model_params()
- timers('optimizer-copy-main-to-model-params').stop()
- # Successful update.
- return True, grad_norm, num_zeros_in_grad
- def state_dict(self):
- state_dict = {}
- state_dict['optimizer'] = self.optimizer.state_dict()
- if self.grad_scaler:
- state_dict['grad_scaler'] = self.grad_scaler.state_dict()
- state_dict['fp32_from_fp16_params'] = self.fp32_from_float16_groups
- return state_dict
- def load_state_dict(self, state_dict):
- # Optimizer.
- optimizer_key = 'optimizer'
- if optimizer_key not in state_dict:
- optimizer_key = 'optimizer_state_dict'
- print_rank_0('***WARNING*** loading optimizer from '
- 'an old checkpoint ...')
- self.optimizer.load_state_dict(state_dict[optimizer_key])
- # Grad scaler.
- if 'grad_scaler' not in state_dict:
- print_rank_0('***WARNING*** found an old checkpoint, will not '
- 'load grad scaler ...')
- else:
- if self.grad_scaler:
- self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
- else:
- print_rank_0('***WARNING*** fould the grad scaler in the '
- 'checkpoint but it is None in the class. '
- 'Skipping loading grad scaler ...')
- # Copy data for the main params.
- fp32_from_float16_params_key = 'fp32_from_fp16_params'
- if fp32_from_float16_params_key not in state_dict:
- fp32_from_float16_params_key = 'fp32_from_fp16'
- for current_group, saved_group in zip(
- self.fp32_from_float16_groups,
- state_dict[fp32_from_float16_params_key]):
- for current_param, saved_param in zip(current_group, saved_group):
- current_param.data.copy_(saved_param.data)
- class FP32Optimizer(MegatronOptimizer):
- def __init__(self, optimizer, clip_grad,
- log_num_zeros_in_grad,
- params_have_main_grad):
- super(FP32Optimizer, self).__init__(
- optimizer, clip_grad, log_num_zeros_in_grad,
- params_have_main_grad)
- self._scale = torch.cuda.FloatTensor([1.0])
- def zero_grad(self, set_to_none=True):
- """Copied from torch.optim.optimizer"""
- for group in self.optimizer.param_groups:
- _zero_grad_group_helper(group['params'], set_to_none)
- def get_loss_scale(self):
- """FP32 optimizer does not do any scaling."""
- return self._scale
- @torch.no_grad()
- def step(self):
- """Clip gradients (if needed) and step the base optimizer.
- Always return successful since there is no overflow."""
- # Copy main_grads to grads.
- if self.params_have_main_grad:
- for param_group in self.optimizer.param_groups:
- for param in param_group['params']:
- param.grad = param.main_grad
- # Clip gradients.
- grad_norm = None
- if self.clip_grad > 0.0:
- grad_norm = self.clip_grad_norm(self.clip_grad)
- # count the zeros in the grads
- num_zeros_in_grad = self.count_zeros() if \
- self.log_num_zeros_in_grad else None
- # Update parameters.
- self.optimizer.step()
- # No overflow for FP32 optimizer.
- return True, grad_norm, num_zeros_in_grad
- def reload_model_params(self):
- pass
- def state_dict(self):
- return self.optimizer.state_dict()
- def load_state_dict(self, state_dict):
- self.optimizer.load_state_dict(state_dict)
|