123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331 |
- import os
- import torch
- import sys
- from megatron import get_args, print_rank_0
- from megatron.checkpointing import fix_query_key_value_ordering
- from megatron.checkpointing import get_checkpoint_tracker_filename
- from megatron.checkpointing import get_checkpoint_name
- from megatron import mpu, get_tokenizer
- from megatron.model.bert_model import bert_position_ids
- from megatron.model.enums import AttnMaskType
- from megatron.model.language_model import get_language_model
- from megatron.model.utils import get_linear_layer
- from megatron.model.utils import init_method_normal
- from megatron.model.utils import scaled_init_method_normal
- from .module import MegatronModule
- def get_model_provider(only_query_model=False, only_context_model=False,
- biencoder_shared_query_context_model=False):
- def model_provider(pre_process=True, post_process=True):
- """Build the model."""
- print_rank_0('building Bienoder model ...')
- model = biencoder_model_provider(only_query_model=only_query_model,
- only_context_model = only_context_model,
- biencoder_shared_query_context_model = \
- biencoder_shared_query_context_model,
- pre_process=pre_process, post_process=post_process)
- return model
- return model_provider
- def biencoder_model_provider(only_query_model=False,
- only_context_model=False,
- biencoder_shared_query_context_model=False,
- pre_process=True,
- post_process=True):
- """Build the model."""
- assert mpu.get_tensor_model_parallel_world_size() == 1 and \
- mpu.get_pipeline_model_parallel_world_size() == 1, \
- "Model parallel size > 1 not supported for ICT"
- print_rank_0('building BiEncoderModel...')
- # simpler to just keep using 2 tokentypes since
- # the LM we initialize with has 2 tokentypes
- model = BiEncoderModel(
- num_tokentypes=2,
- parallel_output=False,
- only_query_model=only_query_model,
- only_context_model=only_context_model,
- biencoder_shared_query_context_model=\
- biencoder_shared_query_context_model,
- pre_process=pre_process,
- post_process=post_process)
- return model
- class BiEncoderModel(MegatronModule):
- """Bert-based module for Biencoder model."""
- def __init__(self,
- num_tokentypes=1,
- parallel_output=True,
- only_query_model=False,
- only_context_model=False,
- biencoder_shared_query_context_model=False,
- pre_process=True,
- post_process=True):
- super(BiEncoderModel, self).__init__()
- args = get_args()
- bert_kwargs = dict(
- num_tokentypes=num_tokentypes,
- parallel_output=parallel_output,
- pre_process=pre_process,
- post_process=post_process)
- self.biencoder_shared_query_context_model = \
- biencoder_shared_query_context_model
- assert not (only_context_model and only_query_model)
- self.use_context_model = not only_query_model
- self.use_query_model = not only_context_model
- self.biencoder_projection_dim = args.biencoder_projection_dim
- if self.biencoder_shared_query_context_model:
- self.model = PretrainedBertModel(**bert_kwargs)
- self._model_key = 'shared_model'
- self.query_model, self.context_model = self.model, self.model
- else:
- if self.use_query_model:
- # this model embeds (pseudo-)queries - Embed_input in the paper
- self.query_model = PretrainedBertModel(**bert_kwargs)
- self._query_key = 'query_model'
- if self.use_context_model:
- # this model embeds evidence blocks - Embed_doc in the paper
- self.context_model = PretrainedBertModel(**bert_kwargs)
- self._context_key = 'context_model'
- def set_input_tensor(self, input_tensor):
- """See megatron.model.transformer.set_input_tensor()"""
- # this is just a placeholder and will be needed when model
- # parallelism will be used
- # self.language_model.set_input_tensor(input_tensor)
- return
- def forward(self, query_tokens, query_attention_mask, query_types,
- context_tokens, context_attention_mask, context_types):
- """Run a forward pass for each of the models and
- return the respective embeddings."""
- if self.use_query_model:
- query_logits = self.embed_text(self.query_model,
- query_tokens,
- query_attention_mask,
- query_types)
- else:
- raise ValueError("Cannot embed query without the query model.")
- if self.use_context_model:
- context_logits = self.embed_text(self.context_model,
- context_tokens,
- context_attention_mask,
- context_types)
- else:
- raise ValueError("Cannot embed block without the block model.")
- return query_logits, context_logits
- @staticmethod
- def embed_text(model, tokens, attention_mask, token_types):
- """Embed a batch of tokens using the model"""
- logits = model(tokens,
- attention_mask,
- token_types)
- return logits
- def state_dict_for_save_checkpoint(self, destination=None, \
- prefix='', keep_vars=False):
- """Save dict with state dicts of each of the models."""
- state_dict_ = {}
- if self.biencoder_shared_query_context_model:
- state_dict_[self._model_key] = \
- self.model.state_dict_for_save_checkpoint(destination,
- prefix,
- keep_vars)
- else:
- if self.use_query_model:
- state_dict_[self._query_key] = \
- self.query_model.state_dict_for_save_checkpoint(
- destination, prefix, keep_vars)
- if self.use_context_model:
- state_dict_[self._context_key] = \
- self.context_model.state_dict_for_save_checkpoint(
- destination, prefix, keep_vars)
- return state_dict_
- def load_state_dict(self, state_dict, strict=True):
- """Load the state dicts of each of the models"""
- if self.biencoder_shared_query_context_model:
- print_rank_0("Loading shared query-context model")
- self.model.load_state_dict(state_dict[self._model_key], \
- strict=strict)
- else:
- if self.use_query_model:
- print_rank_0("Loading query model")
- self.query_model.load_state_dict( \
- state_dict[self._query_key], strict=strict)
- if self.use_context_model:
- print_rank_0("Loading context model")
- self.context_model.load_state_dict( \
- state_dict[self._context_key], strict=strict)
- def init_state_dict_from_bert(self):
- """Initialize the state from a pretrained BERT model
- on iteration zero of ICT pretraining"""
- args = get_args()
- if args.bert_load is None:
- print_rank_0("bert-load argument is None")
- return
- tracker_filename = get_checkpoint_tracker_filename(args.bert_load)
- if not os.path.isfile(tracker_filename):
- raise FileNotFoundError("Could not find BERT checkpoint")
- with open(tracker_filename, 'r') as f:
- iteration = int(f.read().strip())
- assert iteration > 0
- checkpoint_name = get_checkpoint_name(args.bert_load, iteration, False)
- if mpu.get_data_parallel_rank() == 0:
- print('global rank {} is loading BERT checkpoint {}'.format(
- torch.distributed.get_rank(), checkpoint_name))
- # Load the checkpoint.
- try:
- state_dict = torch.load(checkpoint_name, map_location='cpu')
- except ModuleNotFoundError:
- from megatron.fp16_deprecated import loss_scaler
- # For backward compatibility.
- print_rank_0(' > deserializing using the old code structure ...')
- sys.modules['fp16.loss_scaler'] = sys.modules[
- 'megatron.fp16_deprecated.loss_scaler']
- sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
- 'megatron.fp16_deprecated.loss_scaler']
- state_dict = torch.load(checkpoint_name, map_location='cpu')
- sys.modules.pop('fp16.loss_scaler', None)
- sys.modules.pop('megatron.fp16.loss_scaler', None)
- except BaseException:
- print_rank_0('could not load the BERT checkpoint')
- sys.exit()
- checkpoint_version = state_dict.get('checkpoint_version', 0)
- # load the LM state dict into each model
- model_dict = state_dict['model']['language_model']
- if self.biencoder_shared_query_context_model:
- self.model.language_model.load_state_dict(model_dict)
- fix_query_key_value_ordering(self.model, checkpoint_version)
- else:
- if self.use_query_model:
- self.query_model.language_model.load_state_dict(model_dict)
- # give each model the same ict_head to begin with as well
- if self.biencoder_projection_dim > 0:
- query_proj_state_dict = \
- self.state_dict_for_save_checkpoint()\
- [self._query_key]['projection_enc']
- fix_query_key_value_ordering(self.query_model, checkpoint_version)
- if self.use_context_model:
- self.context_model.language_model.load_state_dict(model_dict)
- if self.query_model is not None and \
- self.biencoder_projection_dim > 0:
- self.context_model.projection_enc.load_state_dict\
- (query_proj_state_dict)
- fix_query_key_value_ordering(self.context_model, checkpoint_version)
- class PretrainedBertModel(MegatronModule):
- """BERT-based encoder for queries or contexts used for
- learned information retrieval."""
- def __init__(self, num_tokentypes=2,
- parallel_output=True, pre_process=True, post_process=True):
- super(PretrainedBertModel, self).__init__()
- args = get_args()
- tokenizer = get_tokenizer()
- self.pad_id = tokenizer.pad
- self.biencoder_projection_dim = args.biencoder_projection_dim
- self.parallel_output = parallel_output
- self.pre_process = pre_process
- self.post_process = post_process
- init_method = init_method_normal(args.init_method_std)
- scaled_init_method = scaled_init_method_normal(
- args.init_method_std, args.num_layers)
- self.language_model, self._language_model_key = get_language_model(
- num_tokentypes=num_tokentypes,
- add_pooler=False,
- encoder_attn_mask_type=AttnMaskType.padding,
- init_method=init_method,
- scaled_init_method=scaled_init_method,
- pre_process=self.pre_process,
- post_process=self.post_process)
- if args.biencoder_projection_dim > 0:
- self.projection_enc = get_linear_layer(args.hidden_size,
- args.biencoder_projection_dim,
- init_method)
- self._projection_enc_key = 'projection_enc'
- def forward(self, input_ids, attention_mask, tokentype_ids=None):
- extended_attention_mask = attention_mask.unsqueeze(1)
- #extended_attention_mask = bert_extended_attention_mask(attention_mask)
- position_ids = bert_position_ids(input_ids)
- lm_output = self.language_model(input_ids,
- position_ids,
- extended_attention_mask,
- tokentype_ids=tokentype_ids)
- # This mask will be used in average-pooling and max-pooling
- pool_mask = (input_ids == self.pad_id).unsqueeze(2)
- # Taking the representation of the [CLS] token of BERT
- pooled_output = lm_output[:, 0, :]
- # Converting to float16 dtype
- pooled_output = pooled_output.to(lm_output.dtype)
- # Output.
- if self.biencoder_projection_dim:
- pooled_output = self.projection_enc(pooled_output)
- return pooled_output
- def state_dict_for_save_checkpoint(self, destination=None, prefix='',
- keep_vars=False):
- """For easy load when model is combined with other heads,
- add an extra key."""
- state_dict_ = {}
- state_dict_[self._language_model_key] \
- = self.language_model.state_dict_for_save_checkpoint(
- destination, prefix, keep_vars)
- if self.biencoder_projection_dim > 0:
- state_dict_[self._projection_enc_key] = \
- self.projection_enc.state_dict(destination, prefix, keep_vars)
- return state_dict_
- def load_state_dict(self, state_dict, strict=True):
- """Customized load."""
- print_rank_0("loading pretrained weights")
- self.language_model.load_state_dict(
- state_dict[self._language_model_key], strict=strict)
- if self.biencoder_projection_dim > 0:
- print_rank_0("loading projection head weights")
- self.projection_enc.load_state_dict(
- state_dict[self._projection_enc_key], strict=strict)
|