123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205 |
- import os
- import torch
- from megatron import get_args, print_rank_0
- from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
- from megatron.model import BertModel
- from .module import MegatronModule
- from megatron import mpu
- from megatron.model.enums import AttnMaskType
- from megatron.model.utils import get_linear_layer
- from megatron.model.utils import init_method_normal
- from megatron.model.language_model import get_language_model
- from megatron.model.utils import scaled_init_method_normal
- from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
- def general_ict_model_provider(only_query_model=False, only_block_model=False):
- """Build the model."""
- args = get_args()
- assert args.ict_head_size is not None, \
- "Need to specify --ict-head-size to provide an ICTBertModel"
- assert mpu.get_tensor_model_parallel_world_size() == 1 and mpu.get_pipeline_model_parallel_world_size() == 1, \
- "Model parallel size > 1 not supported for ICT"
- print_rank_0('building ICTBertModel...')
- # simpler to just keep using 2 tokentypes since the LM we initialize with has 2 tokentypes
- model = ICTBertModel(
- ict_head_size=args.ict_head_size,
- num_tokentypes=2,
- parallel_output=True,
- only_query_model=only_query_model,
- only_block_model=only_block_model)
- return model
- class ICTBertModel(MegatronModule):
- """Bert-based module for Inverse Cloze task."""
- def __init__(self,
- ict_head_size,
- num_tokentypes=1,
- parallel_output=True,
- only_query_model=False,
- only_block_model=False):
- super(ICTBertModel, self).__init__()
- bert_kwargs = dict(
- ict_head_size=ict_head_size,
- num_tokentypes=num_tokentypes,
- parallel_output=parallel_output
- )
- assert not (only_block_model and only_query_model)
- self.use_block_model = not only_query_model
- self.use_query_model = not only_block_model
- if self.use_query_model:
- # this model embeds (pseudo-)queries - Embed_input in the paper
- self.query_model = IREncoderBertModel(**bert_kwargs)
- self._query_key = 'question_model'
- if self.use_block_model:
- # this model embeds evidence blocks - Embed_doc in the paper
- self.block_model = IREncoderBertModel(**bert_kwargs)
- self._block_key = 'context_model'
- def forward(self, query_tokens, query_attention_mask, block_tokens, block_attention_mask):
- """Run a forward pass for each of the models and return the respective embeddings."""
- query_logits = self.embed_query(query_tokens, query_attention_mask)
- block_logits = self.embed_block(block_tokens, block_attention_mask)
- return query_logits, block_logits
- def embed_query(self, query_tokens, query_attention_mask):
- """Embed a batch of tokens using the query model"""
- if self.use_query_model:
- query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
- query_ict_logits, _ = self.query_model.forward(query_tokens, query_attention_mask, query_types)
- return query_ict_logits
- else:
- raise ValueError("Cannot embed query without query model.")
- def embed_block(self, block_tokens, block_attention_mask):
- """Embed a batch of tokens using the block model"""
- if self.use_block_model:
- block_types = torch.cuda.LongTensor(*block_tokens.shape).fill_(0)
- block_ict_logits, _ = self.block_model.forward(block_tokens, block_attention_mask, block_types)
- return block_ict_logits
- else:
- raise ValueError("Cannot embed block without block model.")
- def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
- """Save dict with state dicts of each of the models."""
- state_dict_ = {}
- if self.use_query_model:
- state_dict_[self._query_key] \
- = self.query_model.state_dict_for_save_checkpoint(
- destination, prefix, keep_vars)
- if self.use_block_model:
- state_dict_[self._block_key] \
- = self.block_model.state_dict_for_save_checkpoint(
- destination, prefix, keep_vars)
- return state_dict_
- def load_state_dict(self, state_dict, strict=True):
- """Load the state dicts of each of the models"""
- if self.use_query_model:
- print("Loading ICT query model", flush=True)
- self.query_model.load_state_dict(
- state_dict[self._query_key], strict=strict)
- if self.use_block_model:
- print("Loading ICT block model", flush=True)
- self.block_model.load_state_dict(
- state_dict[self._block_key], strict=strict)
- def init_state_dict_from_bert(self):
- """Initialize the state from a pretrained BERT model on iteration zero of ICT pretraining"""
- args = get_args()
- tracker_filename = get_checkpoint_tracker_filename(args.bert_load)
- if not os.path.isfile(tracker_filename):
- raise FileNotFoundError("Could not find BERT load for ICT")
- with open(tracker_filename, 'r') as f:
- iteration = int(f.read().strip())
- assert iteration > 0
- checkpoint_name = get_checkpoint_name(args.bert_load, iteration, False)
- if mpu.get_data_parallel_rank() == 0:
- print('global rank {} is loading checkpoint {}'.format(
- torch.distributed.get_rank(), checkpoint_name))
- try:
- state_dict = torch.load(checkpoint_name, map_location='cpu')
- except BaseException:
- raise ValueError("Could not load checkpoint")
- # load the LM state dict into each model
- model_dict = state_dict['model']['language_model']
- self.query_model.language_model.load_state_dict(model_dict)
- self.block_model.language_model.load_state_dict(model_dict)
- # give each model the same ict_head to begin with as well
- query_ict_head_state_dict = self.state_dict_for_save_checkpoint()[self._query_key]['ict_head']
- self.block_model.ict_head.load_state_dict(query_ict_head_state_dict)
- class IREncoderBertModel(MegatronModule):
- """BERT-based encoder for queries or blocks used for learned information retrieval."""
- def __init__(self, ict_head_size, num_tokentypes=2, parallel_output=True):
- super(IREncoderBertModel, self).__init__()
- args = get_args()
- self.ict_head_size = ict_head_size
- self.parallel_output = parallel_output
- init_method = init_method_normal(args.init_method_std)
- scaled_init_method = scaled_init_method_normal(args.init_method_std,
- args.num_layers)
- self.language_model, self._language_model_key = get_language_model(
- num_tokentypes=num_tokentypes,
- add_pooler=True,
- encoder_attn_mask_type=AttnMaskType.padding,
- init_method=init_method,
- scaled_init_method=scaled_init_method)
- self.ict_head = get_linear_layer(args.hidden_size, ict_head_size, init_method)
- self._ict_head_key = 'ict_head'
- def forward(self, input_ids, attention_mask, tokentype_ids=None):
- extended_attention_mask = bert_extended_attention_mask(
- attention_mask, next(self.language_model.parameters()).dtype)
- position_ids = bert_position_ids(input_ids)
- lm_output, pooled_output = self.language_model(
- input_ids,
- position_ids,
- extended_attention_mask,
- tokentype_ids=tokentype_ids)
- # Output.
- ict_logits = self.ict_head(pooled_output)
- return ict_logits, None
- def state_dict_for_save_checkpoint(self, destination=None, prefix='',
- keep_vars=False):
- """For easy load when model is combined with other heads,
- add an extra key."""
- state_dict_ = {}
- state_dict_[self._language_model_key] \
- = self.language_model.state_dict_for_save_checkpoint(
- destination, prefix, keep_vars)
- state_dict_[self._ict_head_key] \
- = self.ict_head.state_dict(destination, prefix, keep_vars)
- return state_dict_
- def load_state_dict(self, state_dict, strict=True):
- """Customized load."""
- self.language_model.load_state_dict(
- state_dict[self._language_model_key], strict=strict)
- self.ict_head.load_state_dict(
- state_dict[self._ict_head_key], strict=strict)
|