gpt_model.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """GPT-2 model."""
  16. import torch
  17. from megatron import get_args
  18. from megatron import mpu
  19. from .module import MegatronModule
  20. from .enums import AttnMaskType
  21. from .language_model import parallel_lm_logits
  22. from .language_model import get_language_model
  23. from .utils import init_method_normal
  24. from .utils import scaled_init_method_normal
  25. def post_language_model_processing(lm_output, labels, logit_weights,
  26. get_key_value, parallel_output,
  27. forward_method_parallel_output,
  28. fp16_lm_cross_entropy):
  29. if get_key_value:
  30. lm_output, presents = lm_output
  31. # Output.
  32. if forward_method_parallel_output is not None:
  33. parallel_output = forward_method_parallel_output
  34. output = parallel_lm_logits(
  35. lm_output,
  36. logit_weights,
  37. parallel_output)
  38. if get_key_value:
  39. output = [output, presents]
  40. if labels is None:
  41. return output
  42. else:
  43. if fp16_lm_cross_entropy:
  44. assert output.dtype == torch.half
  45. loss = mpu.vocab_parallel_cross_entropy(output, labels)
  46. else:
  47. loss = mpu.vocab_parallel_cross_entropy(output.float(), labels)
  48. return loss
  49. class GPTModel(MegatronModule):
  50. """GPT-2 Language model."""
  51. def __init__(self,
  52. num_tokentypes=0,
  53. parallel_output=True,
  54. pre_process=True,
  55. post_process=True):
  56. super(GPTModel, self).__init__()
  57. args = get_args()
  58. self.parallel_output = parallel_output
  59. self.pre_process = pre_process
  60. self.post_process = post_process
  61. self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
  62. self.language_model, self._language_model_key = get_language_model(
  63. num_tokentypes=num_tokentypes,
  64. add_pooler=False,
  65. encoder_attn_mask_type=AttnMaskType.causal,
  66. init_method=init_method_normal(args.init_method_std),
  67. scaled_init_method=scaled_init_method_normal(args.init_method_std,
  68. args.num_layers),
  69. pre_process=self.pre_process,
  70. post_process=self.post_process)
  71. self.initialize_word_embeddings(init_method_normal)
  72. def set_input_tensor(self, input_tensor):
  73. """See megatron.model.transformer.set_input_tensor()"""
  74. self.language_model.set_input_tensor(input_tensor)
  75. def forward(self, input_ids, position_ids, attention_mask, labels=None,
  76. tokentype_ids=None, layer_past=None, get_key_value=False,
  77. forward_method_parallel_output=None):
  78. lm_output = self.language_model(
  79. input_ids,
  80. position_ids,
  81. attention_mask,
  82. layer_past=layer_past,
  83. get_key_value=get_key_value)
  84. if self.post_process:
  85. return post_language_model_processing(
  86. lm_output, labels,
  87. self.word_embeddings_weight(),
  88. get_key_value,
  89. self.parallel_output,
  90. forward_method_parallel_output,
  91. self.fp16_lm_cross_entropy)
  92. else:
  93. return lm_output
  94. def state_dict_for_save_checkpoint(self, destination=None, prefix='',
  95. keep_vars=False):
  96. state_dict_ = {}
  97. state_dict_[self._language_model_key] \
  98. = self.language_model.state_dict_for_save_checkpoint(
  99. destination, prefix, keep_vars)
  100. # Save word_embeddings.
  101. if self.post_process and not self.pre_process:
  102. state_dict_[self._word_embeddings_for_head_key] \
  103. = self.word_embeddings.state_dict(destination, prefix, keep_vars)
  104. return state_dict_
  105. def load_state_dict(self, state_dict, strict=True):
  106. """Customized load."""
  107. # Load word_embeddings.
  108. if self.post_process and not self.pre_process:
  109. self.word_embeddings.load_state_dict(
  110. state_dict[self._word_embeddings_for_head_key], strict=strict)
  111. if self._language_model_key in state_dict:
  112. state_dict = state_dict[self._language_model_key]
  113. self.language_model.load_state_dict(state_dict, strict=strict)