__init__.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from apex.optimizers import FusedAdam as Adam
  16. from apex.optimizers import FusedSGD as SGD
  17. from megatron import get_args
  18. from megatron.model import LayerNorm
  19. from .grad_scaler import ConstantGradScaler, DynamicGradScaler
  20. from .optimizer import Float16OptimizerWithFloat16Params, FP32Optimizer
  21. def _get_params_for_weight_decay_optimization(modules):
  22. """Divide params into with-weight-decay and without-weight-decay groups.
  23. Layernorms and baises will have no weight decay but the rest will.
  24. """
  25. weight_decay_params = {'params': []}
  26. no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
  27. for module in modules:
  28. for module_ in module.modules():
  29. if isinstance(module_, LayerNorm):
  30. no_weight_decay_params['params'].extend(
  31. [p for p in list(module_._parameters.values())
  32. if p is not None])
  33. else:
  34. weight_decay_params['params'].extend(
  35. [p for n, p in list(module_._parameters.items())
  36. if p is not None and n != 'bias'])
  37. no_weight_decay_params['params'].extend(
  38. [p for n, p in list(module_._parameters.items())
  39. if p is not None and n == 'bias'])
  40. return weight_decay_params, no_weight_decay_params
  41. def get_megatron_optimizer(model):
  42. args = get_args()
  43. # Base optimizer.
  44. param_groups = _get_params_for_weight_decay_optimization(model)
  45. if args.optimizer == 'adam':
  46. optimizer = Adam(param_groups,
  47. lr=args.lr,
  48. weight_decay=args.weight_decay,
  49. betas=(args.adam_beta1, args.adam_beta2),
  50. eps=args.adam_eps)
  51. elif args.optimizer == 'sgd':
  52. optimizer = SGD(param_groups,
  53. lr=args.lr,
  54. weight_decay=args.weight_decay,
  55. momentum=args.sgd_momentum)
  56. else:
  57. raise Exception('{} optimizer is not supported.'.format(
  58. args.optimizer))
  59. # Determine whether the params have main-grad field.
  60. params_have_main_grad = False
  61. if args.DDP_impl == 'local':
  62. params_have_main_grad = True
  63. if args.fp16 or args.bf16:
  64. # Grad scaler:
  65. # if loss-scale is provided, instantiate the constant scaler.
  66. # if we are using fp16 and loss-scale is not present, use a
  67. # dynamic scaler.
  68. # otherwise we are running in bf16 with no loss-scale so
  69. # leave it as None.
  70. grad_scaler = None
  71. # Constant loss scale.
  72. if args.loss_scale:
  73. grad_scaler = ConstantGradScaler(args.loss_scale)
  74. # Dynamic loss scale.
  75. else:
  76. if args.fp16:
  77. grad_scaler = DynamicGradScaler(
  78. initial_scale=args.initial_loss_scale,
  79. min_scale=args.min_loss_scale,
  80. growth_factor=2.0,
  81. backoff_factor=0.5,
  82. growth_interval=args.loss_scale_window,
  83. hysteresis=args.hysteresis)
  84. # Megatron optimizer.
  85. return Float16OptimizerWithFloat16Params(optimizer,
  86. args.clip_grad,
  87. args.log_num_zeros_in_grad,
  88. params_have_main_grad,
  89. args.bf16,
  90. grad_scaler)
  91. # FP32.
  92. return FP32Optimizer(optimizer, args.clip_grad,
  93. args.log_num_zeros_in_grad,
  94. params_have_main_grad)