biencoder_model.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. import os
  2. import torch
  3. import sys
  4. from megatron import get_args, print_rank_0
  5. from megatron.checkpointing import fix_query_key_value_ordering
  6. from megatron.checkpointing import get_checkpoint_tracker_filename
  7. from megatron.checkpointing import get_checkpoint_name
  8. from megatron import mpu, get_tokenizer
  9. from megatron.model.bert_model import bert_position_ids
  10. from megatron.model.enums import AttnMaskType
  11. from megatron.model.language_model import get_language_model
  12. from megatron.model.utils import get_linear_layer
  13. from megatron.model.utils import init_method_normal
  14. from megatron.model.utils import scaled_init_method_normal
  15. from .module import MegatronModule
  16. def get_model_provider(only_query_model=False, only_context_model=False,
  17. biencoder_shared_query_context_model=False):
  18. def model_provider(pre_process=True, post_process=True):
  19. """Build the model."""
  20. print_rank_0('building Bienoder model ...')
  21. model = biencoder_model_provider(only_query_model=only_query_model,
  22. only_context_model = only_context_model,
  23. biencoder_shared_query_context_model = \
  24. biencoder_shared_query_context_model,
  25. pre_process=pre_process, post_process=post_process)
  26. return model
  27. return model_provider
  28. def biencoder_model_provider(only_query_model=False,
  29. only_context_model=False,
  30. biencoder_shared_query_context_model=False,
  31. pre_process=True,
  32. post_process=True):
  33. """Build the model."""
  34. assert mpu.get_tensor_model_parallel_world_size() == 1 and \
  35. mpu.get_pipeline_model_parallel_world_size() == 1, \
  36. "Model parallel size > 1 not supported for ICT"
  37. print_rank_0('building BiEncoderModel...')
  38. # simpler to just keep using 2 tokentypes since
  39. # the LM we initialize with has 2 tokentypes
  40. model = BiEncoderModel(
  41. num_tokentypes=2,
  42. parallel_output=False,
  43. only_query_model=only_query_model,
  44. only_context_model=only_context_model,
  45. biencoder_shared_query_context_model=\
  46. biencoder_shared_query_context_model,
  47. pre_process=pre_process,
  48. post_process=post_process)
  49. return model
  50. class BiEncoderModel(MegatronModule):
  51. """Bert-based module for Biencoder model."""
  52. def __init__(self,
  53. num_tokentypes=1,
  54. parallel_output=True,
  55. only_query_model=False,
  56. only_context_model=False,
  57. biencoder_shared_query_context_model=False,
  58. pre_process=True,
  59. post_process=True):
  60. super(BiEncoderModel, self).__init__()
  61. args = get_args()
  62. bert_kwargs = dict(
  63. num_tokentypes=num_tokentypes,
  64. parallel_output=parallel_output,
  65. pre_process=pre_process,
  66. post_process=post_process)
  67. self.biencoder_shared_query_context_model = \
  68. biencoder_shared_query_context_model
  69. assert not (only_context_model and only_query_model)
  70. self.use_context_model = not only_query_model
  71. self.use_query_model = not only_context_model
  72. self.biencoder_projection_dim = args.biencoder_projection_dim
  73. if self.biencoder_shared_query_context_model:
  74. self.model = PretrainedBertModel(**bert_kwargs)
  75. self._model_key = 'shared_model'
  76. self.query_model, self.context_model = self.model, self.model
  77. else:
  78. if self.use_query_model:
  79. # this model embeds (pseudo-)queries - Embed_input in the paper
  80. self.query_model = PretrainedBertModel(**bert_kwargs)
  81. self._query_key = 'query_model'
  82. if self.use_context_model:
  83. # this model embeds evidence blocks - Embed_doc in the paper
  84. self.context_model = PretrainedBertModel(**bert_kwargs)
  85. self._context_key = 'context_model'
  86. def set_input_tensor(self, input_tensor):
  87. """See megatron.model.transformer.set_input_tensor()"""
  88. # this is just a placeholder and will be needed when model
  89. # parallelism will be used
  90. # self.language_model.set_input_tensor(input_tensor)
  91. return
  92. def forward(self, query_tokens, query_attention_mask, query_types,
  93. context_tokens, context_attention_mask, context_types):
  94. """Run a forward pass for each of the models and
  95. return the respective embeddings."""
  96. if self.use_query_model:
  97. query_logits = self.embed_text(self.query_model,
  98. query_tokens,
  99. query_attention_mask,
  100. query_types)
  101. else:
  102. raise ValueError("Cannot embed query without the query model.")
  103. if self.use_context_model:
  104. context_logits = self.embed_text(self.context_model,
  105. context_tokens,
  106. context_attention_mask,
  107. context_types)
  108. else:
  109. raise ValueError("Cannot embed block without the block model.")
  110. return query_logits, context_logits
  111. @staticmethod
  112. def embed_text(model, tokens, attention_mask, token_types):
  113. """Embed a batch of tokens using the model"""
  114. logits = model(tokens,
  115. attention_mask,
  116. token_types)
  117. return logits
  118. def state_dict_for_save_checkpoint(self, destination=None, \
  119. prefix='', keep_vars=False):
  120. """Save dict with state dicts of each of the models."""
  121. state_dict_ = {}
  122. if self.biencoder_shared_query_context_model:
  123. state_dict_[self._model_key] = \
  124. self.model.state_dict_for_save_checkpoint(destination,
  125. prefix,
  126. keep_vars)
  127. else:
  128. if self.use_query_model:
  129. state_dict_[self._query_key] = \
  130. self.query_model.state_dict_for_save_checkpoint(
  131. destination, prefix, keep_vars)
  132. if self.use_context_model:
  133. state_dict_[self._context_key] = \
  134. self.context_model.state_dict_for_save_checkpoint(
  135. destination, prefix, keep_vars)
  136. return state_dict_
  137. def load_state_dict(self, state_dict, strict=True):
  138. """Load the state dicts of each of the models"""
  139. if self.biencoder_shared_query_context_model:
  140. print_rank_0("Loading shared query-context model")
  141. self.model.load_state_dict(state_dict[self._model_key], \
  142. strict=strict)
  143. else:
  144. if self.use_query_model:
  145. print_rank_0("Loading query model")
  146. self.query_model.load_state_dict( \
  147. state_dict[self._query_key], strict=strict)
  148. if self.use_context_model:
  149. print_rank_0("Loading context model")
  150. self.context_model.load_state_dict( \
  151. state_dict[self._context_key], strict=strict)
  152. def init_state_dict_from_bert(self):
  153. """Initialize the state from a pretrained BERT model
  154. on iteration zero of ICT pretraining"""
  155. args = get_args()
  156. if args.bert_load is None:
  157. print_rank_0("bert-load argument is None")
  158. return
  159. tracker_filename = get_checkpoint_tracker_filename(args.bert_load)
  160. if not os.path.isfile(tracker_filename):
  161. raise FileNotFoundError("Could not find BERT checkpoint")
  162. with open(tracker_filename, 'r') as f:
  163. iteration = int(f.read().strip())
  164. assert iteration > 0
  165. checkpoint_name = get_checkpoint_name(args.bert_load, iteration, False)
  166. if mpu.get_data_parallel_rank() == 0:
  167. print('global rank {} is loading BERT checkpoint {}'.format(
  168. torch.distributed.get_rank(), checkpoint_name))
  169. # Load the checkpoint.
  170. try:
  171. state_dict = torch.load(checkpoint_name, map_location='cpu')
  172. except ModuleNotFoundError:
  173. from megatron.fp16_deprecated import loss_scaler
  174. # For backward compatibility.
  175. print_rank_0(' > deserializing using the old code structure ...')
  176. sys.modules['fp16.loss_scaler'] = sys.modules[
  177. 'megatron.fp16_deprecated.loss_scaler']
  178. sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
  179. 'megatron.fp16_deprecated.loss_scaler']
  180. state_dict = torch.load(checkpoint_name, map_location='cpu')
  181. sys.modules.pop('fp16.loss_scaler', None)
  182. sys.modules.pop('megatron.fp16.loss_scaler', None)
  183. except BaseException:
  184. print_rank_0('could not load the BERT checkpoint')
  185. sys.exit()
  186. checkpoint_version = state_dict.get('checkpoint_version', 0)
  187. # load the LM state dict into each model
  188. model_dict = state_dict['model']['language_model']
  189. if self.biencoder_shared_query_context_model:
  190. self.model.language_model.load_state_dict(model_dict)
  191. fix_query_key_value_ordering(self.model, checkpoint_version)
  192. else:
  193. if self.use_query_model:
  194. self.query_model.language_model.load_state_dict(model_dict)
  195. # give each model the same ict_head to begin with as well
  196. if self.biencoder_projection_dim > 0:
  197. query_proj_state_dict = \
  198. self.state_dict_for_save_checkpoint()\
  199. [self._query_key]['projection_enc']
  200. fix_query_key_value_ordering(self.query_model, checkpoint_version)
  201. if self.use_context_model:
  202. self.context_model.language_model.load_state_dict(model_dict)
  203. if self.query_model is not None and \
  204. self.biencoder_projection_dim > 0:
  205. self.context_model.projection_enc.load_state_dict\
  206. (query_proj_state_dict)
  207. fix_query_key_value_ordering(self.context_model, checkpoint_version)
  208. class PretrainedBertModel(MegatronModule):
  209. """BERT-based encoder for queries or contexts used for
  210. learned information retrieval."""
  211. def __init__(self, num_tokentypes=2,
  212. parallel_output=True, pre_process=True, post_process=True):
  213. super(PretrainedBertModel, self).__init__()
  214. args = get_args()
  215. tokenizer = get_tokenizer()
  216. self.pad_id = tokenizer.pad
  217. self.biencoder_projection_dim = args.biencoder_projection_dim
  218. self.parallel_output = parallel_output
  219. self.pre_process = pre_process
  220. self.post_process = post_process
  221. init_method = init_method_normal(args.init_method_std)
  222. scaled_init_method = scaled_init_method_normal(
  223. args.init_method_std, args.num_layers)
  224. self.language_model, self._language_model_key = get_language_model(
  225. num_tokentypes=num_tokentypes,
  226. add_pooler=False,
  227. encoder_attn_mask_type=AttnMaskType.padding,
  228. init_method=init_method,
  229. scaled_init_method=scaled_init_method,
  230. pre_process=self.pre_process,
  231. post_process=self.post_process)
  232. if args.biencoder_projection_dim > 0:
  233. self.projection_enc = get_linear_layer(args.hidden_size,
  234. args.biencoder_projection_dim,
  235. init_method)
  236. self._projection_enc_key = 'projection_enc'
  237. def forward(self, input_ids, attention_mask, tokentype_ids=None):
  238. extended_attention_mask = attention_mask.unsqueeze(1)
  239. #extended_attention_mask = bert_extended_attention_mask(attention_mask)
  240. position_ids = bert_position_ids(input_ids)
  241. lm_output = self.language_model(input_ids,
  242. position_ids,
  243. extended_attention_mask,
  244. tokentype_ids=tokentype_ids)
  245. # This mask will be used in average-pooling and max-pooling
  246. pool_mask = (input_ids == self.pad_id).unsqueeze(2)
  247. # Taking the representation of the [CLS] token of BERT
  248. pooled_output = lm_output[:, 0, :]
  249. # Converting to float16 dtype
  250. pooled_output = pooled_output.to(lm_output.dtype)
  251. # Output.
  252. if self.biencoder_projection_dim:
  253. pooled_output = self.projection_enc(pooled_output)
  254. return pooled_output
  255. def state_dict_for_save_checkpoint(self, destination=None, prefix='',
  256. keep_vars=False):
  257. """For easy load when model is combined with other heads,
  258. add an extra key."""
  259. state_dict_ = {}
  260. state_dict_[self._language_model_key] \
  261. = self.language_model.state_dict_for_save_checkpoint(
  262. destination, prefix, keep_vars)
  263. if self.biencoder_projection_dim > 0:
  264. state_dict_[self._projection_enc_key] = \
  265. self.projection_enc.state_dict(destination, prefix, keep_vars)
  266. return state_dict_
  267. def load_state_dict(self, state_dict, strict=True):
  268. """Customized load."""
  269. print_rank_0("loading pretrained weights")
  270. self.language_model.load_state_dict(
  271. state_dict[self._language_model_key], strict=strict)
  272. if self.biencoder_projection_dim > 0:
  273. print_rank_0("loading projection head weights")
  274. self.projection_enc.load_state_dict(
  275. state_dict[self._projection_enc_key], strict=strict)