123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465 |
- # coding=utf-8
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Transformer based language model."""
- import torch
- import torch.nn.functional as F
- from megatron import get_args
- from megatron import mpu
- from .module import MegatronModule
- from megatron.model.enums import LayerType, AttnMaskType
- from megatron.model.transformer import ParallelTransformer
- from megatron.model.utils import get_linear_layer
- from megatron.model.utils import init_method_normal, scaled_init_method_normal
- def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
- bias=None):
- """LM logits using word embedding weights."""
- # Parallel logits.
- input_parallel = mpu.copy_to_tensor_model_parallel_region(input_)
- # Matrix multiply.
- if bias is None:
- logits_parallel = F.linear(input_parallel, word_embeddings_weight)
- else:
- logits_parallel = F.linear(input_parallel, word_embeddings_weight, bias)
- # Gather if needed.
- if parallel_output:
- return logits_parallel
- return mpu.gather_from_tensor_model_parallel_region(logits_parallel)
- def get_language_model(num_tokentypes, add_pooler,
- encoder_attn_mask_type, init_method=None,
- scaled_init_method=None, add_decoder=False,
- decoder_attn_mask_type=AttnMaskType.causal,
- pre_process=True, post_process=True):
- """Build language model and return along with the key to save."""
- args = get_args()
- if init_method is None:
- init_method = init_method_normal(args.init_method_std)
- if scaled_init_method is None:
- scaled_init_method = scaled_init_method_normal(args.init_method_std,
- args.num_layers)
- # Language model.
- language_model = TransformerLanguageModel(
- init_method,
- scaled_init_method,
- encoder_attn_mask_type,
- num_tokentypes=num_tokentypes,
- add_decoder=add_decoder,
- decoder_attn_mask_type=decoder_attn_mask_type,
- add_pooler=add_pooler,
- pre_process=pre_process,
- post_process=post_process
- )
- # key used for checkpoints.
- language_model_key = 'language_model'
- return language_model, language_model_key
- class Pooler(MegatronModule):
- """Pooler layer.
- Pool hidden states of a specific token (for example start of the
- sequence) and add a linear transformation followed by a tanh.
- Arguments:
- hidden_size: hidden size
- init_method: weight initialization method for the linear layer.
- bias is set to zero.
- """
- def __init__(self, hidden_size, init_method):
- super(Pooler, self).__init__()
- self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
- def forward(self, hidden_states, sequence_index=0):
- # hidden_states: [b, s, h]
- # sequence_index: index of the token to pool.
- pooled = hidden_states[:, sequence_index, :]
- pooled = self.dense(pooled)
- pooled = torch.tanh(pooled)
- return pooled
- class Embedding(MegatronModule):
- """Language model embeddings.
- Arguments:
- hidden_size: hidden size
- vocab_size: vocabulary size
- max_sequence_length: maximum size of sequence. This
- is used for positional embedding
- embedding_dropout_prob: dropout probability for embeddings
- init_method: weight initialization method
- num_tokentypes: size of the token-type embeddings. 0 value
- will ignore this embedding
- """
- def __init__(self,
- hidden_size,
- vocab_size,
- max_sequence_length,
- embedding_dropout_prob,
- init_method,
- num_tokentypes=0):
- super(Embedding, self).__init__()
- self.hidden_size = hidden_size
- self.init_method = init_method
- self.num_tokentypes = num_tokentypes
- args = get_args()
- # Word embeddings (parallel).
- self.word_embeddings = mpu.VocabParallelEmbedding(
- vocab_size, self.hidden_size,
- init_method=self.init_method)
- self._word_embeddings_key = 'word_embeddings'
- # Position embedding (serial).
- self.position_embeddings = torch.nn.Embedding(
- max_sequence_length, self.hidden_size)
- self._position_embeddings_key = 'position_embeddings'
- # Initialize the position embeddings.
- self.init_method(self.position_embeddings.weight)
- # Token type embedding.
- # Add this as an optional field that can be added through
- # method call so we can load a pretrain model without
- # token types and add them as needed.
- self._tokentype_embeddings_key = 'tokentype_embeddings'
- if self.num_tokentypes > 0:
- self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes,
- self.hidden_size)
- # Initialize the token-type embeddings.
- self.init_method(self.tokentype_embeddings.weight)
- else:
- self.tokentype_embeddings = None
- # Embeddings dropout
- self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
- def add_tokentype_embeddings(self, num_tokentypes):
- """Add token-type embedding. This function is provided so we can add
- token-type embeddings in case the pretrained model does not have it.
- This allows us to load the model normally and then add this embedding.
- """
- if self.tokentype_embeddings is not None:
- raise Exception('tokentype embeddings is already initialized')
- if torch.distributed.get_rank() == 0:
- print('adding embedding for {} tokentypes'.format(num_tokentypes),
- flush=True)
- self.num_tokentypes = num_tokentypes
- self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes,
- self.hidden_size)
- # Initialize the token-type embeddings.
- args = get_args()
- self.init_method(self.tokentype_embeddings.weight)
- def forward(self, input_ids, position_ids, tokentype_ids=None):
- # Embeddings.
- words_embeddings = self.word_embeddings(input_ids)
- position_embeddings = self.position_embeddings(position_ids)
- embeddings = words_embeddings + position_embeddings
- if tokentype_ids is not None:
- assert self.tokentype_embeddings is not None
- embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
- else:
- assert self.tokentype_embeddings is None
- # Dropout.
- embeddings = self.embedding_dropout(embeddings)
- return embeddings
- def state_dict_for_save_checkpoint(self, destination=None, prefix='',
- keep_vars=False):
- """For easy load."""
- state_dict_ = {}
- state_dict_[self._word_embeddings_key] \
- = self.word_embeddings.state_dict(destination, prefix, keep_vars)
- state_dict_[self._position_embeddings_key] \
- = self.position_embeddings.state_dict(
- destination, prefix, keep_vars)
- if self.num_tokentypes > 0:
- state_dict_[self._tokentype_embeddings_key] \
- = self.tokentype_embeddings.state_dict(
- destination, prefix, keep_vars)
- return state_dict_
- def load_state_dict(self, state_dict, strict=True):
- """Customized load."""
- # Word embedding.
- if self._word_embeddings_key in state_dict:
- state_dict_ = state_dict[self._word_embeddings_key]
- else:
- # for backward compatibility.
- state_dict_ = {}
- for key in state_dict.keys():
- if 'word_embeddings' in key:
- state_dict_[key.split('word_embeddings.')[1]] \
- = state_dict[key]
- self.word_embeddings.load_state_dict(state_dict_, strict=strict)
- # Position embedding.
- if self._position_embeddings_key in state_dict:
- state_dict_ = state_dict[self._position_embeddings_key]
- else:
- # for backward compatibility.
- state_dict_ = {}
- for key in state_dict.keys():
- if 'position_embeddings' in key:
- state_dict_[key.split('position_embeddings.')[1]] \
- = state_dict[key]
- self.position_embeddings.load_state_dict(state_dict_, strict=strict)
- # Tokentype embedding.
- if self.num_tokentypes > 0:
- state_dict_ = {}
- if self._tokentype_embeddings_key in state_dict:
- state_dict_ = state_dict[self._tokentype_embeddings_key]
- else:
- # for backward compatibility.
- for key in state_dict.keys():
- if 'tokentype_embeddings' in key:
- state_dict_[key.split('tokentype_embeddings.')[1]] \
- = state_dict[key]
- if len(state_dict_.keys()) > 0:
- self.tokentype_embeddings.load_state_dict(state_dict_,
- strict=strict)
- else:
- print('***WARNING*** expected tokentype embeddings in the '
- 'checkpoint but could not find it', flush=True)
- class TransformerLanguageModel(MegatronModule):
- """Transformer language model.
- Arguments:
- transformer_hparams: transformer hyperparameters
- vocab_size: vocabulary size
- max_sequence_length: maximum size of sequence. This
- is used for positional embedding
- embedding_dropout_prob: dropout probability for embeddings
- num_tokentypes: size of the token-type embeddings. 0 value
- will ignore this embedding
- """
- def __init__(self,
- init_method,
- output_layer_init_method,
- encoder_attn_mask_type,
- num_tokentypes=0,
- add_decoder=False,
- decoder_attn_mask_type=AttnMaskType.causal,
- add_pooler=False,
- pre_process=True,
- post_process=True):
- super(TransformerLanguageModel, self).__init__()
- args = get_args()
- self.pre_process = pre_process
- self.post_process = post_process
- self.hidden_size = args.hidden_size
- self.num_tokentypes = num_tokentypes
- self.init_method = init_method
- self.encoder_attn_mask_type = encoder_attn_mask_type
- self.add_decoder = add_decoder
- self.decoder_attn_mask_type = decoder_attn_mask_type
- self.add_pooler = add_pooler
- # Embeddings.
- if self.pre_process:
- self.embedding = Embedding(self.hidden_size,
- args.padded_vocab_size,
- args.max_position_embeddings,
- args.hidden_dropout,
- self.init_method,
- self.num_tokentypes)
- self._embedding_key = 'embedding'
- # Transformer.
- self.encoder = ParallelTransformer(
- self.init_method,
- output_layer_init_method,
- self_attn_mask_type=self.encoder_attn_mask_type,
- pre_process=self.pre_process,
- post_process=self.post_process
- )
- self._encoder_key = 'encoder'
- # Decoder
- if self.add_decoder:
- assert args.pipeline_model_parallel_size == 1, \
- 'pipeline parallelism is not supported in the presence of decoder'
- self.decoder = ParallelTransformer(
- self.init_method,
- output_layer_init_method,
- layer_type=LayerType.decoder,
- self_attn_mask_type=self.decoder_attn_mask_type)
- self._decoder_key = 'decoder'
- if self.post_process:
- # Pooler.
- if self.add_pooler:
- self.pooler = Pooler(self.hidden_size, self.init_method)
- self._pooler_key = 'pooler'
- def set_input_tensor(self, input_tensor):
- """ See megatron.model.transformer.set_input_tensor()"""
- self.encoder.set_input_tensor(input_tensor)
- def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
- dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
- enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None,
- get_key_value=False, pooling_sequence_index=0,
- enc_hidden_states=None, output_enc_hidden=False):
- # Embeddings.
- if self.pre_process:
- embedding_output = self.embedding(enc_input_ids, enc_position_ids,
- tokentype_ids=tokentype_ids)
- encoder_input = embedding_output
- else:
- encoder_input = None
- # encoder.
- if enc_hidden_states is None:
- encoder_output = self.encoder(encoder_input,
- enc_attn_mask,
- layer_past=layer_past,
- get_key_value=get_key_value)
- else:
- encoder_output = enc_hidden_states.to(encoder_input.dtype)
- if self.post_process:
- if self.add_pooler:
- pooled_output = self.pooler(encoder_output,
- pooling_sequence_index)
- # output_enc_hidden refers to when we just need the encoder's
- # output. For example, it is helpful to compute
- # similarity between two sequences by average pooling
- if not self.add_decoder or output_enc_hidden:
- if self.add_pooler and self.post_process:
- return encoder_output, pooled_output
- else:
- return encoder_output
- # Decoder Embedding
- dec_embedding_output = self.embedding(dec_input_ids,
- dec_position_ids)
- # decoder
- decoder_output = self.decoder(dec_embedding_output,
- dec_attn_mask,
- layer_past=layer_past,
- get_key_value=get_key_value,
- encoder_output=encoder_output,
- enc_dec_attn_mask=enc_dec_attn_mask)
- if self.add_pooler and self.post_process:
- return decoder_output, encoder_output, pooled_output
- else:
- return decoder_output, encoder_output
- def state_dict_for_save_checkpoint(self, destination=None, prefix='',
- keep_vars=False):
- """For easy load."""
- state_dict_ = {}
- if self.pre_process:
- state_dict_[self._embedding_key] \
- = self.embedding.state_dict_for_save_checkpoint(
- destination, prefix, keep_vars)
- state_dict_[self._encoder_key] \
- = self.encoder.state_dict_for_save_checkpoint(
- destination, prefix, keep_vars)
- if self.post_process:
- if self.add_pooler:
- state_dict_[self._pooler_key] \
- = self.pooler.state_dict_for_save_checkpoint(
- destination, prefix, keep_vars)
- if self.add_decoder:
- state_dict_[self._decoder_key] \
- = self.decoder.state_dict_for_save_checkpoint(
- destination, prefix, keep_vars)
- return state_dict_
- def load_state_dict(self, state_dict, strict=True):
- """Customized load."""
- # Embedding.
- if self.pre_process:
- if self._embedding_key in state_dict:
- state_dict_ = state_dict[self._embedding_key]
- else:
- # for backward compatibility.
- state_dict_ = {}
- for key in state_dict.keys():
- if '_embeddings' in key:
- state_dict_[key] = state_dict[key]
- self.embedding.load_state_dict(state_dict_, strict=strict)
- # Encoder.
- if self._encoder_key in state_dict:
- state_dict_ = state_dict[self._encoder_key]
- # for backward compatibility.
- elif 'transformer' in state_dict:
- state_dict_ = state_dict['transformer']
- else:
- # for backward compatibility.
- state_dict_ = {}
- for key in state_dict.keys():
- if 'transformer.' in key:
- state_dict_[key.split('transformer.')[1]] = state_dict[key]
- # for backward compatibility.
- state_dict_self_attention = {}
- for key in state_dict_.keys():
- if '.attention.' in key:
- state_dict_self_attention[key.replace(".attention.",
- ".self_attention.")] = state_dict_[key]
- else:
- state_dict_self_attention[key] = state_dict_[key]
- state_dict_ = state_dict_self_attention
- self.encoder.load_state_dict(state_dict_, strict=strict)
- if self.post_process:
- # pooler
- if self.add_pooler:
- assert 'pooler' in state_dict, \
- 'could not find data for pooler in the checkpoint'
- self.pooler.load_state_dict(state_dict[self._pooler_key],
- strict=strict)
- # decoder
- if self.add_decoder:
- assert 'decoder' in state_dict, \
- 'could not find data for pooler in the checkpoint'
- self.decoder.load_state_dict(state_dict[self._decoder_key],
- strict=strict)
|