module.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Megatron Module"""
  16. import torch
  17. from torch.autograd import Variable
  18. from torch.nn.parameter import Parameter
  19. from megatron import get_args
  20. from megatron import mpu
  21. _FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
  22. _HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
  23. _BF16_TYPES = (torch.BFloat16Tensor, torch.cuda.BFloat16Tensor)
  24. def param_is_not_shared(param):
  25. return not hasattr(param, 'shared') or not param.shared
  26. class MegatronModule(torch.nn.Module):
  27. """Megatron specific extensions of torch Module with support
  28. for pipelining."""
  29. def __init__(self, share_word_embeddings=True):
  30. super(MegatronModule, self).__init__()
  31. self.share_word_embeddings = share_word_embeddings
  32. def state_dict_for_save_checkpoint(self, destination=None, prefix='',
  33. keep_vars=False):
  34. """Use this function to override the state dict for
  35. saving checkpoints."""
  36. return self.state_dict(destination, prefix, keep_vars)
  37. def word_embeddings_weight(self):
  38. if mpu.is_pipeline_first_stage(ignore_virtual=True):
  39. return self.language_model.embedding.word_embeddings.weight
  40. if mpu.is_pipeline_last_stage(ignore_virtual=True):
  41. if not self.share_word_embeddings:
  42. raise Exception('word_embeddings_weight() called for last '
  43. 'stage, but share_word_embeddings is false')
  44. return self.word_embeddings.weight
  45. raise Exception('word_embeddings_weight() should be '
  46. 'called for first and last stage only')
  47. def initialize_word_embeddings(self, init_method_normal):
  48. args = get_args()
  49. if not self.share_word_embeddings:
  50. raise Exception('initialize_word_embeddings() was called but '
  51. 'share_word_embeddings is false')
  52. # This function just initializes the word embeddings in the final stage
  53. # when we are using pipeline parallelism. If we aren't using pipeline
  54. # parallelism there is nothing to do.
  55. if args.pipeline_model_parallel_size == 1:
  56. return
  57. # Parameters are shared between the word embeddings layer, and the
  58. # heads at the end of the model. In a pipelined setup with more than
  59. # one stage, the initial embedding layer and the head are on different
  60. # workers, so we do the following:
  61. # 1. Create a second copy of word_embeddings on the last stage, with
  62. # initial parameters of 0.0.
  63. # 2. Do an all-reduce between the first and last stage to ensure that
  64. # the two copies of word_embeddings start off with the same
  65. # parameter values.
  66. # 3. In the training loop, before an all-reduce between the grads of
  67. # the two word_embeddings layers to ensure that every applied weight
  68. # update is the same on both stages.
  69. if mpu.is_pipeline_last_stage():
  70. assert not mpu.is_pipeline_first_stage()
  71. self._word_embeddings_for_head_key = 'word_embeddings_for_head'
  72. # set word_embeddings weights to 0 here, then copy first
  73. # stage's weights using all_reduce below.
  74. self.word_embeddings = mpu.VocabParallelEmbedding(
  75. args.padded_vocab_size, args.hidden_size,
  76. init_method=init_method_normal(args.init_method_std))
  77. self.word_embeddings.weight.data.fill_(0)
  78. self.word_embeddings.weight.shared = True
  79. # Ensure that first and last stages have the same initial parameter
  80. # values.
  81. if torch.distributed.is_initialized():
  82. if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage():
  83. torch.distributed.all_reduce(self.word_embeddings_weight().data,
  84. group=mpu.get_embedding_group())
  85. else:
  86. print("WARNING! Distributed processes aren't initialized, so "
  87. "word embeddings in the last layer are not initialized. "
  88. "If you are just manipulating a model this is fine, but "
  89. "this needs to be handled manually. If you are training "
  90. "something is definitely wrong.")
  91. def conversion_helper(val, conversion):
  92. """Apply conversion to val. Recursively apply conversion if `val`
  93. #is a nested tuple/list structure."""
  94. if not isinstance(val, (tuple, list)):
  95. return conversion(val)
  96. rtn = [conversion_helper(v, conversion) for v in val]
  97. if isinstance(val, tuple):
  98. rtn = tuple(rtn)
  99. return rtn
  100. def fp32_to_float16(val, float16_convertor):
  101. """Convert fp32 `val` to fp16/bf16"""
  102. def half_conversion(val):
  103. val_typecheck = val
  104. if isinstance(val_typecheck, (Parameter, Variable)):
  105. val_typecheck = val.data
  106. if isinstance(val_typecheck, _FLOAT_TYPES):
  107. val = float16_convertor(val)
  108. return val
  109. return conversion_helper(val, half_conversion)
  110. def float16_to_fp32(val):
  111. """Convert fp16/bf16 `val` to fp32"""
  112. def float_conversion(val):
  113. val_typecheck = val
  114. if isinstance(val_typecheck, (Parameter, Variable)):
  115. val_typecheck = val.data
  116. if isinstance(val_typecheck, (_BF16_TYPES, _HALF_TYPES)):
  117. val = val.float()
  118. return val
  119. return conversion_helper(val, float_conversion)
  120. class Float16Module(MegatronModule):
  121. def __init__(self, module, args):
  122. super(Float16Module, self).__init__()
  123. if args.fp16:
  124. self.add_module('module', module.half())
  125. def float16_convertor(val):
  126. return val.half()
  127. elif args.bf16:
  128. self.add_module('module', module.bfloat16())
  129. def float16_convertor(val):
  130. return val.bfloat16()
  131. else:
  132. raise Exception('should not be here')
  133. self.float16_convertor = float16_convertor
  134. def forward(self, *inputs, **kwargs):
  135. if mpu.is_pipeline_first_stage():
  136. inputs = fp32_to_float16(inputs, self.float16_convertor)
  137. outputs = self.module(*inputs, **kwargs)
  138. if mpu.is_pipeline_last_stage():
  139. outputs = float16_to_fp32(outputs)
  140. return outputs
  141. def state_dict(self, destination=None, prefix='', keep_vars=False):
  142. return self.module.state_dict(destination, prefix, keep_vars)
  143. def state_dict_for_save_checkpoint(self, destination=None, prefix='',
  144. keep_vars=False):
  145. return self.module.state_dict_for_save_checkpoint(destination, prefix,
  146. keep_vars)
  147. def load_state_dict(self, state_dict, strict=True):
  148. self.module.load_state_dict(state_dict, strict=strict)