bert_model.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """BERT model."""
  16. import torch
  17. from megatron import get_args
  18. from megatron import mpu
  19. from megatron.model.enums import AttnMaskType
  20. from megatron.model.language_model import parallel_lm_logits
  21. from megatron.model.language_model import get_language_model
  22. from megatron.model import LayerNorm
  23. from megatron.model.utils import openai_gelu, erf_gelu
  24. from megatron.model.utils import get_linear_layer
  25. from megatron.model.utils import init_method_normal
  26. from megatron.model.utils import scaled_init_method_normal
  27. from .module import MegatronModule
  28. def bert_extended_attention_mask(attention_mask):
  29. # We create a 3D attention mask from a 2D tensor mask.
  30. # [b, 1, s]
  31. attention_mask_b1s = attention_mask.unsqueeze(1)
  32. # [b, s, 1]
  33. attention_mask_bs1 = attention_mask.unsqueeze(2)
  34. # [b, s, s]
  35. attention_mask_bss = attention_mask_b1s * attention_mask_bs1
  36. # [b, 1, s, s]
  37. extended_attention_mask = attention_mask_bss.unsqueeze(1)
  38. # Convert attention mask to binary:
  39. extended_attention_mask = (extended_attention_mask < 0.5)
  40. return extended_attention_mask
  41. def bert_position_ids(token_ids):
  42. # Create position ids
  43. seq_length = token_ids.size(1)
  44. position_ids = torch.arange(seq_length, dtype=torch.long,
  45. device=token_ids.device)
  46. position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
  47. return position_ids
  48. class BertLMHead(MegatronModule):
  49. """Masked LM head for Bert
  50. Arguments:
  51. mpu_vocab_size: model parallel size of vocabulary.
  52. hidden_size: hidden size
  53. init_method: init method for weight initialization
  54. layernorm_epsilon: tolerance for layer norm divisions
  55. parallel_output: whether output logits being distributed or not.
  56. """
  57. def __init__(self, mpu_vocab_size, hidden_size, init_method,
  58. layernorm_epsilon, parallel_output):
  59. super(BertLMHead, self).__init__()
  60. args = get_args()
  61. self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
  62. mpu.set_tensor_model_parallel_attributes(self.bias, True, 0, 1)
  63. self.parallel_output = parallel_output
  64. self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
  65. self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
  66. self.gelu = torch.nn.functional.gelu
  67. if args.openai_gelu:
  68. self.gelu = openai_gelu
  69. elif args.onnx_safe:
  70. self.gelu = erf_gelu
  71. def forward(self, hidden_states, word_embeddings_weight):
  72. hidden_states = self.dense(hidden_states)
  73. hidden_states = self.gelu(hidden_states)
  74. hidden_states = self.layernorm(hidden_states)
  75. output = parallel_lm_logits(hidden_states,
  76. word_embeddings_weight,
  77. self.parallel_output,
  78. bias=self.bias)
  79. return output
  80. def post_language_model_processing(lm_output, pooled_output,
  81. lm_head, binary_head,
  82. lm_labels,
  83. logit_weights,
  84. fp16_lm_cross_entropy):
  85. # Output.
  86. lm_logits = lm_head(
  87. lm_output, logit_weights)
  88. binary_logits = None
  89. if binary_head is not None:
  90. binary_logits = binary_head(pooled_output)
  91. if lm_labels is None:
  92. return lm_logits, binary_logits
  93. else:
  94. if fp16_lm_cross_entropy:
  95. assert lm_logits.dtype == torch.half
  96. lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
  97. else:
  98. lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(),
  99. lm_labels)
  100. return lm_loss, binary_logits
  101. class BertModel(MegatronModule):
  102. """Bert Language model."""
  103. def __init__(self,
  104. num_tokentypes=2,
  105. add_binary_head=True,
  106. parallel_output=True,
  107. pre_process=True,
  108. post_process=True):
  109. super(BertModel, self).__init__()
  110. args = get_args()
  111. self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
  112. self.add_binary_head = add_binary_head
  113. self.parallel_output = parallel_output
  114. self.pre_process = pre_process
  115. self.post_process = post_process
  116. init_method = init_method_normal(args.init_method_std)
  117. scaled_init_method = scaled_init_method_normal(args.init_method_std,
  118. args.num_layers)
  119. self.language_model, self._language_model_key = get_language_model(
  120. num_tokentypes=num_tokentypes,
  121. add_pooler=self.add_binary_head,
  122. encoder_attn_mask_type=AttnMaskType.padding,
  123. init_method=init_method,
  124. scaled_init_method=scaled_init_method,
  125. pre_process=self.pre_process,
  126. post_process=self.post_process)
  127. self.initialize_word_embeddings(init_method_normal)
  128. if self.post_process:
  129. self.lm_head = BertLMHead(
  130. self.word_embeddings_weight().size(0),
  131. args.hidden_size, init_method, args.layernorm_epsilon, parallel_output)
  132. self._lm_head_key = 'lm_head'
  133. self.binary_head = None
  134. if self.add_binary_head:
  135. self.binary_head = get_linear_layer(args.hidden_size, 2,
  136. init_method)
  137. self._binary_head_key = 'binary_head'
  138. def set_input_tensor(self, input_tensor):
  139. """See megatron.model.transformer.set_input_tensor()"""
  140. self.language_model.set_input_tensor(input_tensor)
  141. def forward(self, bert_model_input, attention_mask,
  142. tokentype_ids=None, lm_labels=None):
  143. extended_attention_mask = bert_extended_attention_mask(attention_mask)
  144. input_ids = bert_model_input
  145. position_ids = bert_position_ids(input_ids)
  146. lm_output = self.language_model(
  147. input_ids,
  148. position_ids,
  149. extended_attention_mask,
  150. tokentype_ids=tokentype_ids
  151. )
  152. if self.post_process and self.add_binary_head:
  153. lm_output, pooled_output = lm_output
  154. else:
  155. pooled_output = None
  156. if self.post_process:
  157. return post_language_model_processing(lm_output, pooled_output,
  158. self.lm_head, self.binary_head,
  159. lm_labels,
  160. self.word_embeddings_weight(),
  161. self.fp16_lm_cross_entropy)
  162. else:
  163. return lm_output
  164. def state_dict_for_save_checkpoint(self, destination=None, prefix='',
  165. keep_vars=False):
  166. """For easy load when model is combined with other heads,
  167. add an extra key."""
  168. state_dict_ = {}
  169. state_dict_[self._language_model_key] \
  170. = self.language_model.state_dict_for_save_checkpoint(
  171. destination, prefix, keep_vars)
  172. if self.post_process:
  173. state_dict_[self._lm_head_key] \
  174. = self.lm_head.state_dict_for_save_checkpoint(
  175. destination, prefix, keep_vars)
  176. if self.post_process and self.add_binary_head:
  177. state_dict_[self._binary_head_key] \
  178. = self.binary_head.state_dict(destination, prefix, keep_vars)
  179. # Save word_embeddings.
  180. if self.post_process and not self.pre_process:
  181. state_dict_[self._word_embeddings_for_head_key] \
  182. = self.word_embeddings.state_dict(destination, prefix, keep_vars)
  183. return state_dict_
  184. def load_state_dict(self, state_dict, strict=True):
  185. """Customized load."""
  186. self.language_model.load_state_dict(
  187. state_dict[self._language_model_key], strict=strict)
  188. if self.post_process:
  189. self.lm_head.load_state_dict(
  190. state_dict[self._lm_head_key], strict=strict)
  191. if self.post_process and self.add_binary_head:
  192. self.binary_head.load_state_dict(
  193. state_dict[self._binary_head_key], strict=strict)
  194. # Load word_embeddings.
  195. if self.post_process and not self.pre_process:
  196. self.word_embeddings.load_state_dict(
  197. state_dict[self._word_embeddings_for_head_key], strict=strict)