optimizer.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Megatron optimizer."""
  16. from abc import ABC
  17. from abc import abstractmethod
  18. import torch
  19. from apex.multi_tensor_apply import multi_tensor_applier
  20. import amp_C
  21. from megatron import get_timers
  22. from megatron import mpu
  23. from megatron import print_rank_0
  24. from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
  25. def _zero_grad_group_helper(group, set_to_none):
  26. """Zero out the gradient for a group of parameters.
  27. Note: copied from torch.optim.optimizer."""
  28. for param in group:
  29. if param.grad is not None:
  30. if set_to_none:
  31. param.grad = None
  32. else:
  33. if param.grad.grad_fn is not None:
  34. param.grad.detach_()
  35. else:
  36. param.grad.requires_grad_(False)
  37. param.grad.zero_()
  38. def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
  39. """Use multi-tensor-applier to copy values from one list to another.
  40. We don't have a blfoat16 implementation so for now if the overflow_buf
  41. is not provided, we default back to simple loop copy to be compatible
  42. with bfloat16."""
  43. if overflow_buf:
  44. overflow_buf.fill_(0)
  45. # Scaling with factor `1.0` is equivalent to copy.
  46. multi_tensor_applier(amp_C.multi_tensor_scale,
  47. overflow_buf,
  48. [this, that],
  49. 1.0)
  50. else:
  51. for this_, that_ in zip(this, that):
  52. that_.copy_(this_)
  53. class MegatronOptimizer(ABC):
  54. def __init__(self, optimizer, clip_grad,
  55. log_num_zeros_in_grad,
  56. params_have_main_grad):
  57. """Input optimizer is the base optimizer for example Adam."""
  58. self.optimizer = optimizer
  59. assert self.optimizer, 'no optimizer is provided.'
  60. # Set gradient clipping and logging params.
  61. self.clip_grad = clip_grad
  62. self.log_num_zeros_in_grad = log_num_zeros_in_grad
  63. self.params_have_main_grad = params_have_main_grad
  64. def get_parameters(self):
  65. params = []
  66. for param_group in self.optimizer.param_groups:
  67. for param in param_group['params']:
  68. params.append(param)
  69. return params
  70. def clip_grad_norm(self, clip_grad):
  71. params = self.get_parameters()
  72. return clip_grad_norm_fp32(params, clip_grad)
  73. def count_zeros(self):
  74. params = self.get_parameters()
  75. return count_zeros_fp32(params)
  76. @abstractmethod
  77. def zero_grad(self, set_to_none=True):
  78. pass
  79. @abstractmethod
  80. def get_loss_scale(self):
  81. """The output should be a cuda tensor of size 1."""
  82. pass
  83. def scale_loss(self, loss):
  84. """Simple scaling."""
  85. return self.get_loss_scale() * loss
  86. @abstractmethod
  87. def step(self):
  88. pass
  89. @abstractmethod
  90. def reload_model_params(self):
  91. """Refreshes any internal state from the current model parameters.
  92. Call whenever the parameters are changed outside of the optimizer.
  93. For example, when we load a model from a checkpoint without loading
  94. the optimizer, the model parameters are updated but for fp16 optimizer
  95. with main parameters, the main parameters need to also be updated."""
  96. pass
  97. @abstractmethod
  98. def state_dict(self):
  99. pass
  100. @abstractmethod
  101. def load_state_dict(self, state_dict):
  102. pass
  103. # Promote state so it can be retrieved or set via
  104. # "optimizer_instance.state"
  105. def _get_state(self):
  106. return self.optimizer.state
  107. def _set_state(self, value):
  108. self.optimizer.state = value
  109. state = property(_get_state, _set_state)
  110. # Promote param_groups so it can be retrieved or set via
  111. # "optimizer_instance.param_groups"
  112. # (for example, to adjust the learning rate)
  113. def _get_param_groups(self):
  114. return self.optimizer.param_groups
  115. def _set_param_groups(self, value):
  116. self.optimizer.param_groups = value
  117. param_groups = property(_get_param_groups, _set_param_groups)
  118. class Float16OptimizerWithFloat16Params(MegatronOptimizer):
  119. """Float16 optimizer for fp16 and bf16 data types.
  120. Arguments:
  121. optimizer: base optimizer such as Adam or SGD
  122. clip_grad: clip gradeints with this global L2 norm. Note
  123. that clipping is ignored if clip_grad == 0
  124. log_num_zeros_in_grad: return number of zeros in the gradients.
  125. params_have_main_grad: flag indicating if parameters have
  126. a `main_grad` field. If this is set, we are assuming
  127. that the model parameters are store in the `main_grad`
  128. field instead of the typical `grad` field. This happens
  129. for the DDP cases where there is a contihuous buffer
  130. holding the gradients. For example for bfloat16, we want
  131. to do gradient accumulation and all-reduces in float32
  132. and as a result we store those gradients in the main_grad.
  133. Note that main grad is not necessarily in float32.
  134. bf16: if true, the model is running in bfloat16.
  135. grad_scaler: used for scaling gradients. Note that this can be
  136. None. This case happens when `bf16 = True` and we don't
  137. use any loss scale. Note that for `bf16 = True`, we can have
  138. a constnat gradient scaler. Also for `bf16 = False`, we
  139. always require a grad scaler.
  140. """
  141. def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
  142. params_have_main_grad, bf16, grad_scaler):
  143. super(Float16OptimizerWithFloat16Params, self).__init__(
  144. optimizer, clip_grad, log_num_zeros_in_grad,
  145. params_have_main_grad)
  146. self.bf16 = bf16
  147. self.grad_scaler = grad_scaler
  148. # None grad scaler is only supported for bf16.
  149. if self.grad_scaler is None:
  150. assert self.bf16, 'fp16 expects a grad scaler.'
  151. # Tensor used to determine if a nan/if has happend.
  152. # Any non-zero value indicates inf/nan.
  153. # Note that we keep this for the cases that grad scaler is none.
  154. # We still record nan/inf if we have a bfloat16 with a grad scaler.
  155. if self.grad_scaler:
  156. self.found_inf = torch.cuda.FloatTensor([0.0])
  157. # Dummy tensor needed for apex multi-apply tensor.
  158. # For bfloat, we don't have multi-tensor apply and for now
  159. # we set it to none so the multi-tensor apply gets ignored.
  160. if bf16:
  161. self._dummy_overflow_buf = None
  162. else:
  163. self._dummy_overflow_buf = torch.cuda.IntTensor([0])
  164. # In case grad scaler is not passed, define the unity scale.
  165. if self.grad_scaler is None:
  166. self._scale_one = torch.cuda.FloatTensor([1.0])
  167. # ======================
  168. # main parameter stuff
  169. # ======================
  170. # Three groups of parameters:
  171. # float16_groups: original float16 parameters
  172. # fp32_from_float16_groups: fp32 copy of float16 parameters
  173. # fp32_from_fp32_groups: original fp32 parameters
  174. self.float16_groups = []
  175. self.fp32_from_float16_groups = []
  176. self.fp32_from_fp32_groups = []
  177. # For all the groups in the original optimizer:
  178. for param_group in self.optimizer.param_groups:
  179. float16_params_this_group = []
  180. fp32_params_this_group = []
  181. fp32_from_float16_params_this_group = []
  182. # For all the parameters in this group:
  183. for i, param in enumerate(param_group['params']):
  184. if param.requires_grad:
  185. # float16 params:
  186. if param.type() in ['torch.cuda.HalfTensor',
  187. 'torch.cuda.BFloat16Tensor']:
  188. float16_params_this_group.append(param)
  189. # Create a copy
  190. main_param = param.detach().clone().float()
  191. # Copy tensor model parallel attributes.
  192. mpu.copy_tensor_model_parallel_attributes(main_param,
  193. param)
  194. if hasattr(param, 'shared'):
  195. main_param.shared = param.shared
  196. # Replace the optimizer params with the new fp32 copy.
  197. param_group['params'][i] = main_param
  198. fp32_from_float16_params_this_group.append(main_param)
  199. # Reset existing state dict key to the new main param.
  200. if param in self.optimizer.state:
  201. self.optimizer.state[main_param] \
  202. = self.optimizer.state.pop(param)
  203. # fp32 params.
  204. elif param.type() == 'torch.cuda.FloatTensor':
  205. fp32_params_this_group.append(param)
  206. param_group['params'][i] = param
  207. else:
  208. raise TypeError('Wrapped parameters must be one of '
  209. 'torch.cuda.FloatTensor, '
  210. 'torch.cuda.HalfTensor, or '
  211. 'torch.cuda.BFloat16Tensor. '
  212. 'Received {}'.format(param.type()))
  213. self.float16_groups.append(float16_params_this_group)
  214. self.fp32_from_float16_groups.append(
  215. fp32_from_float16_params_this_group)
  216. self.fp32_from_fp32_groups.append(fp32_params_this_group)
  217. # Leverage state_dict() and load_state_dict() to
  218. # recast preexisting per-param state tensors
  219. self.optimizer.load_state_dict(self.optimizer.state_dict())
  220. def zero_grad(self, set_to_none=True):
  221. """We only need to zero the model related parameters, i.e.,
  222. float16_groups & fp32_from_fp32_groups."""
  223. for group in self.float16_groups:
  224. _zero_grad_group_helper(group, set_to_none)
  225. for group in self.fp32_from_fp32_groups:
  226. _zero_grad_group_helper(group, set_to_none)
  227. def get_loss_scale(self):
  228. if self.grad_scaler is None:
  229. return self._scale_one
  230. return self.grad_scaler.scale
  231. def _copy_model_grads_to_main_grads(self):
  232. # This only needs to be done for the float16 group.
  233. for model_group, main_group in zip(self.float16_groups,
  234. self.fp32_from_float16_groups):
  235. for model_param, main_param in zip(model_group, main_group):
  236. if self.params_have_main_grad:
  237. main_param.grad = model_param.main_grad.float()
  238. else:
  239. if model_param.grad is not None:
  240. main_param.grad = model_param.grad.float()
  241. # For fp32 grads, we need to reset the grads to main grad.
  242. if self.params_have_main_grad:
  243. for model_group in self.fp32_from_fp32_groups:
  244. for model_param in model_group:
  245. model_param.grad = model_param.main_grad
  246. def _unscale_main_grads_and_check_for_nan(self):
  247. main_grads = []
  248. # fp32 params fromm float16 ones.
  249. for main_group in self.fp32_from_float16_groups:
  250. for main_param in main_group:
  251. if main_param.grad is not None:
  252. main_grads.append(main_param.grad.data)
  253. # Append fp32 parameters.
  254. for main_group in self.fp32_from_fp32_groups:
  255. for main_param in main_group:
  256. if main_param.grad is not None:
  257. main_grads.append(main_param.grad.data)
  258. # Reset found inf.
  259. self.found_inf.fill_(0.0)
  260. # Unscale and set found inf/nan
  261. torch._amp_foreach_non_finite_check_and_unscale_(
  262. main_grads, self.found_inf, self.grad_scaler.inv_scale)
  263. # Update across all model parallel instances.
  264. torch.distributed.all_reduce(self.found_inf,
  265. op=torch.distributed.ReduceOp.MAX,
  266. group=mpu.get_model_parallel_group())
  267. # Check for nan.
  268. found_inf_flag = (self.found_inf.item() > 0)
  269. return found_inf_flag
  270. def _get_model_and_main_params_data_float16(self):
  271. model_data = []
  272. main_data = []
  273. for model_group, main_group in zip(self.float16_groups,
  274. self.fp32_from_float16_groups):
  275. for model_param, main_param in zip(model_group, main_group):
  276. model_data.append(model_param.data)
  277. main_data.append(main_param.data)
  278. return model_data, main_data
  279. def _copy_main_params_to_model_params(self):
  280. # Only needed for the float16 params.
  281. model_data, main_data = self._get_model_and_main_params_data_float16()
  282. _multi_tensor_copy_this_to_that(this=main_data, that=model_data,
  283. overflow_buf=self._dummy_overflow_buf)
  284. def _copy_model_params_to_main_params(self):
  285. # Only needed for the float16 params.
  286. model_data, main_data = self._get_model_and_main_params_data_float16()
  287. _multi_tensor_copy_this_to_that(this=model_data, that=main_data,
  288. overflow_buf=self._dummy_overflow_buf)
  289. def reload_model_params(self):
  290. self._copy_model_params_to_main_params()
  291. @torch.no_grad()
  292. def step(self):
  293. timers = get_timers()
  294. # Copy gradients from model params to main params.
  295. timers('optimizer-copy-to-main-grad').start()
  296. self._copy_model_grads_to_main_grads()
  297. timers('optimizer-copy-to-main-grad').stop()
  298. # Do unscale, check for inf, and update grad scaler only for
  299. # the case that grad scaler is provided.
  300. if self.grad_scaler:
  301. # Unscale and check for inf/nan.
  302. timers('optimizer-unscale-and-check-inf').start()
  303. found_inf_flag = self._unscale_main_grads_and_check_for_nan()
  304. timers('optimizer-unscale-and-check-inf').stop()
  305. # We are done with scaling gradients
  306. # so we can update the loss scale.
  307. self.grad_scaler.update(found_inf_flag)
  308. # If we found inf/nan, skip the update.
  309. if found_inf_flag:
  310. return False, None, None
  311. # Clip the main gradients.
  312. timers('optimizer-clip-main-grad').start()
  313. grad_norm = None
  314. if self.clip_grad > 0.0:
  315. grad_norm = self.clip_grad_norm(self.clip_grad)
  316. timers('optimizer-clip-main-grad').stop()
  317. # count the zeros in the grads
  318. num_zeros_in_grad = self.count_zeros() if \
  319. self.log_num_zeros_in_grad else None
  320. # Step the optimizer.
  321. self.optimizer.step()
  322. # Update params from main params.
  323. timers('optimizer-copy-main-to-model-params').start()
  324. self._copy_main_params_to_model_params()
  325. timers('optimizer-copy-main-to-model-params').stop()
  326. # Successful update.
  327. return True, grad_norm, num_zeros_in_grad
  328. def state_dict(self):
  329. state_dict = {}
  330. state_dict['optimizer'] = self.optimizer.state_dict()
  331. if self.grad_scaler:
  332. state_dict['grad_scaler'] = self.grad_scaler.state_dict()
  333. state_dict['fp32_from_fp16_params'] = self.fp32_from_float16_groups
  334. return state_dict
  335. def load_state_dict(self, state_dict):
  336. # Optimizer.
  337. optimizer_key = 'optimizer'
  338. if optimizer_key not in state_dict:
  339. optimizer_key = 'optimizer_state_dict'
  340. print_rank_0('***WARNING*** loading optimizer from '
  341. 'an old checkpoint ...')
  342. self.optimizer.load_state_dict(state_dict[optimizer_key])
  343. # Grad scaler.
  344. if 'grad_scaler' not in state_dict:
  345. print_rank_0('***WARNING*** found an old checkpoint, will not '
  346. 'load grad scaler ...')
  347. else:
  348. if self.grad_scaler:
  349. self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
  350. else:
  351. print_rank_0('***WARNING*** fould the grad scaler in the '
  352. 'checkpoint but it is None in the class. '
  353. 'Skipping loading grad scaler ...')
  354. # Copy data for the main params.
  355. fp32_from_float16_params_key = 'fp32_from_fp16_params'
  356. if fp32_from_float16_params_key not in state_dict:
  357. fp32_from_float16_params_key = 'fp32_from_fp16'
  358. for current_group, saved_group in zip(
  359. self.fp32_from_float16_groups,
  360. state_dict[fp32_from_float16_params_key]):
  361. for current_param, saved_param in zip(current_group, saved_group):
  362. current_param.data.copy_(saved_param.data)
  363. class FP32Optimizer(MegatronOptimizer):
  364. def __init__(self, optimizer, clip_grad,
  365. log_num_zeros_in_grad,
  366. params_have_main_grad):
  367. super(FP32Optimizer, self).__init__(
  368. optimizer, clip_grad, log_num_zeros_in_grad,
  369. params_have_main_grad)
  370. self._scale = torch.cuda.FloatTensor([1.0])
  371. def zero_grad(self, set_to_none=True):
  372. """Copied from torch.optim.optimizer"""
  373. for group in self.optimizer.param_groups:
  374. _zero_grad_group_helper(group['params'], set_to_none)
  375. def get_loss_scale(self):
  376. """FP32 optimizer does not do any scaling."""
  377. return self._scale
  378. @torch.no_grad()
  379. def step(self):
  380. """Clip gradients (if needed) and step the base optimizer.
  381. Always return successful since there is no overflow."""
  382. # Copy main_grads to grads.
  383. if self.params_have_main_grad:
  384. for param_group in self.optimizer.param_groups:
  385. for param in param_group['params']:
  386. param.grad = param.main_grad
  387. # Clip gradients.
  388. grad_norm = None
  389. if self.clip_grad > 0.0:
  390. grad_norm = self.clip_grad_norm(self.clip_grad)
  391. # count the zeros in the grads
  392. num_zeros_in_grad = self.count_zeros() if \
  393. self.log_num_zeros_in_grad else None
  394. # Update parameters.
  395. self.optimizer.step()
  396. # No overflow for FP32 optimizer.
  397. return True, grad_norm, num_zeros_in_grad
  398. def reload_model_params(self):
  399. pass
  400. def state_dict(self):
  401. return self.optimizer.state_dict()
  402. def load_state_dict(self, state_dict):
  403. self.optimizer.load_state_dict(state_dict)