realm_model.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. import os
  2. import torch
  3. from megatron import get_args, print_rank_0
  4. from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
  5. from megatron.model import BertModel
  6. from .module import MegatronModule
  7. from megatron import mpu
  8. from megatron.model.enums import AttnMaskType
  9. from megatron.model.utils import get_linear_layer
  10. from megatron.model.utils import init_method_normal
  11. from megatron.model.language_model import get_language_model
  12. from megatron.model.utils import scaled_init_method_normal
  13. from megatron.model.bert_model import bert_extended_attention_mask, bert_position_ids
  14. def general_ict_model_provider(only_query_model=False, only_block_model=False):
  15. """Build the model."""
  16. args = get_args()
  17. assert args.ict_head_size is not None, \
  18. "Need to specify --ict-head-size to provide an ICTBertModel"
  19. assert mpu.get_tensor_model_parallel_world_size() == 1 and mpu.get_pipeline_model_parallel_world_size() == 1, \
  20. "Model parallel size > 1 not supported for ICT"
  21. print_rank_0('building ICTBertModel...')
  22. # simpler to just keep using 2 tokentypes since the LM we initialize with has 2 tokentypes
  23. model = ICTBertModel(
  24. ict_head_size=args.ict_head_size,
  25. num_tokentypes=2,
  26. parallel_output=True,
  27. only_query_model=only_query_model,
  28. only_block_model=only_block_model)
  29. return model
  30. class ICTBertModel(MegatronModule):
  31. """Bert-based module for Inverse Cloze task."""
  32. def __init__(self,
  33. ict_head_size,
  34. num_tokentypes=1,
  35. parallel_output=True,
  36. only_query_model=False,
  37. only_block_model=False):
  38. super(ICTBertModel, self).__init__()
  39. bert_kwargs = dict(
  40. ict_head_size=ict_head_size,
  41. num_tokentypes=num_tokentypes,
  42. parallel_output=parallel_output
  43. )
  44. assert not (only_block_model and only_query_model)
  45. self.use_block_model = not only_query_model
  46. self.use_query_model = not only_block_model
  47. if self.use_query_model:
  48. # this model embeds (pseudo-)queries - Embed_input in the paper
  49. self.query_model = IREncoderBertModel(**bert_kwargs)
  50. self._query_key = 'question_model'
  51. if self.use_block_model:
  52. # this model embeds evidence blocks - Embed_doc in the paper
  53. self.block_model = IREncoderBertModel(**bert_kwargs)
  54. self._block_key = 'context_model'
  55. def forward(self, query_tokens, query_attention_mask, block_tokens, block_attention_mask):
  56. """Run a forward pass for each of the models and return the respective embeddings."""
  57. query_logits = self.embed_query(query_tokens, query_attention_mask)
  58. block_logits = self.embed_block(block_tokens, block_attention_mask)
  59. return query_logits, block_logits
  60. def embed_query(self, query_tokens, query_attention_mask):
  61. """Embed a batch of tokens using the query model"""
  62. if self.use_query_model:
  63. query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
  64. query_ict_logits, _ = self.query_model.forward(query_tokens, query_attention_mask, query_types)
  65. return query_ict_logits
  66. else:
  67. raise ValueError("Cannot embed query without query model.")
  68. def embed_block(self, block_tokens, block_attention_mask):
  69. """Embed a batch of tokens using the block model"""
  70. if self.use_block_model:
  71. block_types = torch.cuda.LongTensor(*block_tokens.shape).fill_(0)
  72. block_ict_logits, _ = self.block_model.forward(block_tokens, block_attention_mask, block_types)
  73. return block_ict_logits
  74. else:
  75. raise ValueError("Cannot embed block without block model.")
  76. def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
  77. """Save dict with state dicts of each of the models."""
  78. state_dict_ = {}
  79. if self.use_query_model:
  80. state_dict_[self._query_key] \
  81. = self.query_model.state_dict_for_save_checkpoint(
  82. destination, prefix, keep_vars)
  83. if self.use_block_model:
  84. state_dict_[self._block_key] \
  85. = self.block_model.state_dict_for_save_checkpoint(
  86. destination, prefix, keep_vars)
  87. return state_dict_
  88. def load_state_dict(self, state_dict, strict=True):
  89. """Load the state dicts of each of the models"""
  90. if self.use_query_model:
  91. print("Loading ICT query model", flush=True)
  92. self.query_model.load_state_dict(
  93. state_dict[self._query_key], strict=strict)
  94. if self.use_block_model:
  95. print("Loading ICT block model", flush=True)
  96. self.block_model.load_state_dict(
  97. state_dict[self._block_key], strict=strict)
  98. def init_state_dict_from_bert(self):
  99. """Initialize the state from a pretrained BERT model on iteration zero of ICT pretraining"""
  100. args = get_args()
  101. tracker_filename = get_checkpoint_tracker_filename(args.bert_load)
  102. if not os.path.isfile(tracker_filename):
  103. raise FileNotFoundError("Could not find BERT load for ICT")
  104. with open(tracker_filename, 'r') as f:
  105. iteration = int(f.read().strip())
  106. assert iteration > 0
  107. checkpoint_name = get_checkpoint_name(args.bert_load, iteration, False)
  108. if mpu.get_data_parallel_rank() == 0:
  109. print('global rank {} is loading checkpoint {}'.format(
  110. torch.distributed.get_rank(), checkpoint_name))
  111. try:
  112. state_dict = torch.load(checkpoint_name, map_location='cpu')
  113. except BaseException:
  114. raise ValueError("Could not load checkpoint")
  115. # load the LM state dict into each model
  116. model_dict = state_dict['model']['language_model']
  117. self.query_model.language_model.load_state_dict(model_dict)
  118. self.block_model.language_model.load_state_dict(model_dict)
  119. # give each model the same ict_head to begin with as well
  120. query_ict_head_state_dict = self.state_dict_for_save_checkpoint()[self._query_key]['ict_head']
  121. self.block_model.ict_head.load_state_dict(query_ict_head_state_dict)
  122. class IREncoderBertModel(MegatronModule):
  123. """BERT-based encoder for queries or blocks used for learned information retrieval."""
  124. def __init__(self, ict_head_size, num_tokentypes=2, parallel_output=True):
  125. super(IREncoderBertModel, self).__init__()
  126. args = get_args()
  127. self.ict_head_size = ict_head_size
  128. self.parallel_output = parallel_output
  129. init_method = init_method_normal(args.init_method_std)
  130. scaled_init_method = scaled_init_method_normal(args.init_method_std,
  131. args.num_layers)
  132. self.language_model, self._language_model_key = get_language_model(
  133. num_tokentypes=num_tokentypes,
  134. add_pooler=True,
  135. encoder_attn_mask_type=AttnMaskType.padding,
  136. init_method=init_method,
  137. scaled_init_method=scaled_init_method)
  138. self.ict_head = get_linear_layer(args.hidden_size, ict_head_size, init_method)
  139. self._ict_head_key = 'ict_head'
  140. def forward(self, input_ids, attention_mask, tokentype_ids=None):
  141. extended_attention_mask = bert_extended_attention_mask(
  142. attention_mask, next(self.language_model.parameters()).dtype)
  143. position_ids = bert_position_ids(input_ids)
  144. lm_output, pooled_output = self.language_model(
  145. input_ids,
  146. position_ids,
  147. extended_attention_mask,
  148. tokentype_ids=tokentype_ids)
  149. # Output.
  150. ict_logits = self.ict_head(pooled_output)
  151. return ict_logits, None
  152. def state_dict_for_save_checkpoint(self, destination=None, prefix='',
  153. keep_vars=False):
  154. """For easy load when model is combined with other heads,
  155. add an extra key."""
  156. state_dict_ = {}
  157. state_dict_[self._language_model_key] \
  158. = self.language_model.state_dict_for_save_checkpoint(
  159. destination, prefix, keep_vars)
  160. state_dict_[self._ict_head_key] \
  161. = self.ict_head.state_dict(destination, prefix, keep_vars)
  162. return state_dict_
  163. def load_state_dict(self, state_dict, strict=True):
  164. """Customized load."""
  165. self.language_model.load_state_dict(
  166. state_dict[self._language_model_key], strict=strict)
  167. self.ict_head.load_state_dict(
  168. state_dict[self._ict_head_key], strict=strict)