clip_grads.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Gradient clipping."""
  16. import torch
  17. from torch._six import inf
  18. from apex.multi_tensor_apply import multi_tensor_applier
  19. import amp_C
  20. from megatron import mpu
  21. from megatron.model.module import param_is_not_shared
  22. from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
  23. def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
  24. """Clips gradient norm of an iterable of parameters whose gradients
  25. are in fp32.
  26. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
  27. added functionality to handle model parallel parameters. Note that
  28. the gradients are modified in place.
  29. Arguments:
  30. parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
  31. single Tensor that will have gradients normalized
  32. max_norm (float or int): max norm of the gradients
  33. norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
  34. infinity norm.
  35. Returns:
  36. Total norm of the parameters (viewed as a single vector).
  37. """
  38. if isinstance(parameters, torch.Tensor):
  39. parameters = [parameters]
  40. # Filter parameters based on:
  41. # - grad should not be none
  42. # - parameter should not be shared
  43. # - should not be a replica due to tensor model parallelism
  44. grads = []
  45. grads_for_norm = []
  46. for param in parameters:
  47. grad_not_none = param.grad is not None
  48. is_not_shared = param_is_not_shared(param)
  49. is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
  50. grad = param.grad.detach()
  51. if grad_not_none:
  52. # Make sure the grads are in fp32
  53. assert param.grad.type() == 'torch.cuda.FloatTensor'
  54. grads.append(grad)
  55. if grad_not_none and is_not_shared and is_not_tp_duplicate:
  56. grads_for_norm.append(grad)
  57. # Norm parameters.
  58. max_norm = float(max_norm)
  59. norm_type = float(norm_type)
  60. total_norm = 0.0
  61. # Calculate norm.
  62. if norm_type == inf:
  63. total_norm = max(grad.abs().max() for grad in grads_for_norm)
  64. total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
  65. # Take max across all model-parallel GPUs.
  66. torch.distributed.all_reduce(total_norm_cuda,
  67. op=torch.distributed.ReduceOp.MAX,
  68. group=mpu.get_model_parallel_group())
  69. total_norm = total_norm_cuda[0].item()
  70. else:
  71. if norm_type == 2.0:
  72. dummy_overflow_buf = torch.cuda.IntTensor([0])
  73. # Use apex's multi-tensor applier for efficiency reasons.
  74. # Multi-tensor applier takes a function and a list of list
  75. # and performs the operation on that list all in one kernel.
  76. grad_norm, _ = multi_tensor_applier(
  77. amp_C.multi_tensor_l2norm,
  78. dummy_overflow_buf,
  79. [grads_for_norm],
  80. False # no per-parameter norm
  81. )
  82. # Since we will be summing across data parallel groups,
  83. # we need the pow(norm-type).
  84. total_norm = grad_norm ** norm_type
  85. else:
  86. for grad in grads_for_norm:
  87. grad_norm = torch.norm(grad, norm_type)
  88. total_norm += grad_norm ** norm_type
  89. # Sum across all model-parallel GPUs.
  90. torch.distributed.all_reduce(total_norm,
  91. op=torch.distributed.ReduceOp.SUM,
  92. group=mpu.get_model_parallel_group())
  93. total_norm = total_norm.item() ** (1.0 / norm_type)
  94. # Scale.
  95. clip_coeff = max_norm / (total_norm + 1.0e-6)
  96. if clip_coeff < 1.0:
  97. dummy_overflow_buf = torch.cuda.IntTensor([0])
  98. multi_tensor_applier(amp_C.multi_tensor_scale,
  99. dummy_overflow_buf,
  100. [grads, grads],
  101. clip_coeff)
  102. return total_norm
  103. def count_zeros_fp32(parameters):
  104. if isinstance(parameters, torch.Tensor):
  105. parameters = [parameters]
  106. # Filter parameters based on:
  107. # - grad should not be none
  108. # - parameter should not be shared
  109. # - should not be a replica due to tensor model parallelism
  110. total_num_zeros = 0.0
  111. for param in parameters:
  112. grad_not_none = param.grad is not None
  113. is_not_shared = param_is_not_shared(param)
  114. is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
  115. if grad_not_none and is_not_shared and is_not_tp_duplicate:
  116. grad = param.grad.detach()
  117. num_zeros = grad.numel() - torch.count_nonzero(grad)
  118. total_num_zeros = num_zeros + total_num_zeros
  119. # Sum across all model-parallel GPUs.
  120. torch.distributed.all_reduce(total_num_zeros,
  121. op=torch.distributed.ReduceOp.SUM,
  122. group=mpu.get_model_parallel_group())
  123. total_num_zeros = total_num_zeros.item()
  124. return total_num_zeros