t5_model.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """T5 model."""
  16. import torch
  17. from megatron import (
  18. get_args,
  19. mpu
  20. )
  21. from megatron.model.enums import AttnMaskType
  22. from megatron.model.language_model import parallel_lm_logits, get_language_model
  23. from megatron.model.transformer import LayerNorm
  24. from megatron.model.utils import (
  25. openai_gelu,
  26. get_linear_layer,
  27. init_method_normal,
  28. scaled_init_method_normal
  29. )
  30. from .module import MegatronModule
  31. def t5_extended_attention_mask(attention_mask_list):
  32. def attn_mask_postprocess(attn_mask):
  33. # [b, 1, s, s]
  34. extended_attention_mask = attn_mask.unsqueeze(1)
  35. return extended_attention_mask
  36. return [attn_mask_postprocess(attn_mask) for attn_mask in attention_mask_list]
  37. def t5_position_ids(token_ids):
  38. # Create position ids
  39. seq_length = token_ids.size(1)
  40. position_ids = torch.arange(seq_length, dtype=torch.long,
  41. device=token_ids.device)
  42. position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
  43. return position_ids
  44. class T5LMHead(MegatronModule):
  45. """Masked LM head for T5
  46. Arguments:
  47. mpu_vocab_size: model parallel size of vocabulary.
  48. hidden_size: hidden size
  49. init_method: init method for weight initialization
  50. layernorm_epsilon: tolerance for layer norm divisions
  51. parallel_output: wether output logits being distributed or not.
  52. """
  53. def __init__(self, mpu_vocab_size, parallel_output):
  54. super(T5LMHead, self).__init__()
  55. args = get_args()
  56. self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
  57. self.bias.model_parallel = True
  58. self.bias.partition_dim = 0
  59. self.bias.stride = 1
  60. self.parallel_output = parallel_output
  61. def forward(self, hidden_states, word_embeddings_weight):
  62. output = parallel_lm_logits(hidden_states,
  63. word_embeddings_weight,
  64. self.parallel_output,
  65. bias=self.bias)
  66. return output
  67. class T5Model(MegatronModule):
  68. """T5 Language model."""
  69. def __init__(self, num_tokentypes=0, parallel_output=True):
  70. super(T5Model, self).__init__()
  71. args = get_args()
  72. self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
  73. self.parallel_output = parallel_output
  74. init_method = init_method_normal(args.init_method_std)
  75. scaled_init_method = scaled_init_method_normal(args.init_method_std,
  76. args.num_layers)
  77. self.language_model, self._language_model_key = get_language_model(
  78. num_tokentypes=num_tokentypes,
  79. add_pooler=False,
  80. add_decoder=True,
  81. encoder_attn_mask_type=AttnMaskType.padding,
  82. init_method=init_method,
  83. scaled_init_method=scaled_init_method)
  84. self.lm_head = T5LMHead(
  85. self.language_model.embedding.word_embeddings.weight.size(0),
  86. parallel_output)
  87. self._lm_head_key = 'lm_head'
  88. def set_input_tensor(self, input_tensor):
  89. """See megatron.model.transformer.set_input_tensor()"""
  90. self.language_model.set_input_tensor(input_tensor)
  91. def forward(self, encoder_input_ids, decoder_input_ids, encoder_attn_mask,
  92. decoder_attn_mask, encoder_decoder_attn_mask,
  93. tokentype_ids=None, lm_labels=None, enc_hidden_states=None):
  94. # Converting the attention masks to proper parameter settings
  95. encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask = t5_extended_attention_mask(
  96. [encoder_attn_mask, decoder_attn_mask, encoder_decoder_attn_mask])
  97. encoder_position_ids = t5_position_ids(encoder_input_ids)
  98. decoder_position_ids = t5_position_ids(decoder_input_ids)
  99. lm_output = self.language_model(encoder_input_ids,
  100. encoder_position_ids,
  101. encoder_attn_mask,
  102. decoder_input_ids,
  103. decoder_position_ids,
  104. decoder_attn_mask,
  105. encoder_decoder_attn_mask,
  106. tokentype_ids=tokentype_ids,
  107. enc_hidden_states=enc_hidden_states)
  108. decoder_output, encoder_output = lm_output
  109. # Output.
  110. lm_logits = self.lm_head(decoder_output,
  111. self.language_model.embedding.word_embeddings.weight)
  112. if lm_labels is None:
  113. return lm_logits, encoder_output
  114. else:
  115. if self.fp16_lm_cross_entropy:
  116. assert lm_logits.dtype == torch.half
  117. lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
  118. else:
  119. lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(),
  120. lm_labels)
  121. return lm_loss, encoder_output
  122. def state_dict_for_save_checkpoint(self, destination=None, prefix='',
  123. keep_vars=False):
  124. """For easy load when model is combined with other heads,
  125. add an extra key."""
  126. state_dict_ = {}
  127. state_dict_[self._language_model_key] \
  128. = self.language_model.state_dict_for_save_checkpoint(
  129. destination, prefix, keep_vars)
  130. state_dict_[self._lm_head_key] \
  131. = self.lm_head.state_dict_for_save_checkpoint(
  132. destination, prefix, keep_vars)
  133. return state_dict_
  134. def load_state_dict(self, state_dict, strict=True):
  135. """Customized load."""
  136. self.language_model.load_state_dict(
  137. state_dict[self._language_model_key], strict=strict)
  138. self.lm_head.load_state_dict(state_dict[self._lm_head_key],
  139. strict=strict)