language_model.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Transformer based language model."""
  16. import torch
  17. import torch.nn.functional as F
  18. from megatron import get_args
  19. from megatron import mpu
  20. from .module import MegatronModule
  21. from megatron.model.enums import LayerType, AttnMaskType
  22. from megatron.model.transformer import ParallelTransformer
  23. from megatron.model.utils import get_linear_layer
  24. from megatron.model.utils import init_method_normal, scaled_init_method_normal
  25. def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
  26. bias=None):
  27. """LM logits using word embedding weights."""
  28. # Parallel logits.
  29. input_parallel = mpu.copy_to_tensor_model_parallel_region(input_)
  30. # Matrix multiply.
  31. if bias is None:
  32. logits_parallel = F.linear(input_parallel, word_embeddings_weight)
  33. else:
  34. logits_parallel = F.linear(input_parallel, word_embeddings_weight, bias)
  35. # Gather if needed.
  36. if parallel_output:
  37. return logits_parallel
  38. return mpu.gather_from_tensor_model_parallel_region(logits_parallel)
  39. def get_language_model(num_tokentypes, add_pooler,
  40. encoder_attn_mask_type, init_method=None,
  41. scaled_init_method=None, add_decoder=False,
  42. decoder_attn_mask_type=AttnMaskType.causal,
  43. pre_process=True, post_process=True):
  44. """Build language model and return along with the key to save."""
  45. args = get_args()
  46. if init_method is None:
  47. init_method = init_method_normal(args.init_method_std)
  48. if scaled_init_method is None:
  49. scaled_init_method = scaled_init_method_normal(args.init_method_std,
  50. args.num_layers)
  51. # Language model.
  52. language_model = TransformerLanguageModel(
  53. init_method,
  54. scaled_init_method,
  55. encoder_attn_mask_type,
  56. num_tokentypes=num_tokentypes,
  57. add_decoder=add_decoder,
  58. decoder_attn_mask_type=decoder_attn_mask_type,
  59. add_pooler=add_pooler,
  60. pre_process=pre_process,
  61. post_process=post_process
  62. )
  63. # key used for checkpoints.
  64. language_model_key = 'language_model'
  65. return language_model, language_model_key
  66. class Pooler(MegatronModule):
  67. """Pooler layer.
  68. Pool hidden states of a specific token (for example start of the
  69. sequence) and add a linear transformation followed by a tanh.
  70. Arguments:
  71. hidden_size: hidden size
  72. init_method: weight initialization method for the linear layer.
  73. bias is set to zero.
  74. """
  75. def __init__(self, hidden_size, init_method):
  76. super(Pooler, self).__init__()
  77. self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
  78. def forward(self, hidden_states, sequence_index=0):
  79. # hidden_states: [b, s, h]
  80. # sequence_index: index of the token to pool.
  81. pooled = hidden_states[:, sequence_index, :]
  82. pooled = self.dense(pooled)
  83. pooled = torch.tanh(pooled)
  84. return pooled
  85. class Embedding(MegatronModule):
  86. """Language model embeddings.
  87. Arguments:
  88. hidden_size: hidden size
  89. vocab_size: vocabulary size
  90. max_sequence_length: maximum size of sequence. This
  91. is used for positional embedding
  92. embedding_dropout_prob: dropout probability for embeddings
  93. init_method: weight initialization method
  94. num_tokentypes: size of the token-type embeddings. 0 value
  95. will ignore this embedding
  96. """
  97. def __init__(self,
  98. hidden_size,
  99. vocab_size,
  100. max_sequence_length,
  101. embedding_dropout_prob,
  102. init_method,
  103. num_tokentypes=0):
  104. super(Embedding, self).__init__()
  105. self.hidden_size = hidden_size
  106. self.init_method = init_method
  107. self.num_tokentypes = num_tokentypes
  108. args = get_args()
  109. # Word embeddings (parallel).
  110. self.word_embeddings = mpu.VocabParallelEmbedding(
  111. vocab_size, self.hidden_size,
  112. init_method=self.init_method)
  113. self._word_embeddings_key = 'word_embeddings'
  114. # Position embedding (serial).
  115. self.position_embeddings = torch.nn.Embedding(
  116. max_sequence_length, self.hidden_size)
  117. self._position_embeddings_key = 'position_embeddings'
  118. # Initialize the position embeddings.
  119. self.init_method(self.position_embeddings.weight)
  120. # Token type embedding.
  121. # Add this as an optional field that can be added through
  122. # method call so we can load a pretrain model without
  123. # token types and add them as needed.
  124. self._tokentype_embeddings_key = 'tokentype_embeddings'
  125. if self.num_tokentypes > 0:
  126. self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes,
  127. self.hidden_size)
  128. # Initialize the token-type embeddings.
  129. self.init_method(self.tokentype_embeddings.weight)
  130. else:
  131. self.tokentype_embeddings = None
  132. # Embeddings dropout
  133. self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
  134. def add_tokentype_embeddings(self, num_tokentypes):
  135. """Add token-type embedding. This function is provided so we can add
  136. token-type embeddings in case the pretrained model does not have it.
  137. This allows us to load the model normally and then add this embedding.
  138. """
  139. if self.tokentype_embeddings is not None:
  140. raise Exception('tokentype embeddings is already initialized')
  141. if torch.distributed.get_rank() == 0:
  142. print('adding embedding for {} tokentypes'.format(num_tokentypes),
  143. flush=True)
  144. self.num_tokentypes = num_tokentypes
  145. self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes,
  146. self.hidden_size)
  147. # Initialize the token-type embeddings.
  148. args = get_args()
  149. self.init_method(self.tokentype_embeddings.weight)
  150. def forward(self, input_ids, position_ids, tokentype_ids=None):
  151. # Embeddings.
  152. words_embeddings = self.word_embeddings(input_ids)
  153. position_embeddings = self.position_embeddings(position_ids)
  154. embeddings = words_embeddings + position_embeddings
  155. if tokentype_ids is not None:
  156. assert self.tokentype_embeddings is not None
  157. embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
  158. else:
  159. assert self.tokentype_embeddings is None
  160. # Dropout.
  161. embeddings = self.embedding_dropout(embeddings)
  162. return embeddings
  163. def state_dict_for_save_checkpoint(self, destination=None, prefix='',
  164. keep_vars=False):
  165. """For easy load."""
  166. state_dict_ = {}
  167. state_dict_[self._word_embeddings_key] \
  168. = self.word_embeddings.state_dict(destination, prefix, keep_vars)
  169. state_dict_[self._position_embeddings_key] \
  170. = self.position_embeddings.state_dict(
  171. destination, prefix, keep_vars)
  172. if self.num_tokentypes > 0:
  173. state_dict_[self._tokentype_embeddings_key] \
  174. = self.tokentype_embeddings.state_dict(
  175. destination, prefix, keep_vars)
  176. return state_dict_
  177. def load_state_dict(self, state_dict, strict=True):
  178. """Customized load."""
  179. # Word embedding.
  180. if self._word_embeddings_key in state_dict:
  181. state_dict_ = state_dict[self._word_embeddings_key]
  182. else:
  183. # for backward compatibility.
  184. state_dict_ = {}
  185. for key in state_dict.keys():
  186. if 'word_embeddings' in key:
  187. state_dict_[key.split('word_embeddings.')[1]] \
  188. = state_dict[key]
  189. self.word_embeddings.load_state_dict(state_dict_, strict=strict)
  190. # Position embedding.
  191. if self._position_embeddings_key in state_dict:
  192. state_dict_ = state_dict[self._position_embeddings_key]
  193. else:
  194. # for backward compatibility.
  195. state_dict_ = {}
  196. for key in state_dict.keys():
  197. if 'position_embeddings' in key:
  198. state_dict_[key.split('position_embeddings.')[1]] \
  199. = state_dict[key]
  200. self.position_embeddings.load_state_dict(state_dict_, strict=strict)
  201. # Tokentype embedding.
  202. if self.num_tokentypes > 0:
  203. state_dict_ = {}
  204. if self._tokentype_embeddings_key in state_dict:
  205. state_dict_ = state_dict[self._tokentype_embeddings_key]
  206. else:
  207. # for backward compatibility.
  208. for key in state_dict.keys():
  209. if 'tokentype_embeddings' in key:
  210. state_dict_[key.split('tokentype_embeddings.')[1]] \
  211. = state_dict[key]
  212. if len(state_dict_.keys()) > 0:
  213. self.tokentype_embeddings.load_state_dict(state_dict_,
  214. strict=strict)
  215. else:
  216. print('***WARNING*** expected tokentype embeddings in the '
  217. 'checkpoint but could not find it', flush=True)
  218. class TransformerLanguageModel(MegatronModule):
  219. """Transformer language model.
  220. Arguments:
  221. transformer_hparams: transformer hyperparameters
  222. vocab_size: vocabulary size
  223. max_sequence_length: maximum size of sequence. This
  224. is used for positional embedding
  225. embedding_dropout_prob: dropout probability for embeddings
  226. num_tokentypes: size of the token-type embeddings. 0 value
  227. will ignore this embedding
  228. """
  229. def __init__(self,
  230. init_method,
  231. output_layer_init_method,
  232. encoder_attn_mask_type,
  233. num_tokentypes=0,
  234. add_decoder=False,
  235. decoder_attn_mask_type=AttnMaskType.causal,
  236. add_pooler=False,
  237. pre_process=True,
  238. post_process=True):
  239. super(TransformerLanguageModel, self).__init__()
  240. args = get_args()
  241. self.pre_process = pre_process
  242. self.post_process = post_process
  243. self.hidden_size = args.hidden_size
  244. self.num_tokentypes = num_tokentypes
  245. self.init_method = init_method
  246. self.encoder_attn_mask_type = encoder_attn_mask_type
  247. self.add_decoder = add_decoder
  248. self.decoder_attn_mask_type = decoder_attn_mask_type
  249. self.add_pooler = add_pooler
  250. # Embeddings.
  251. if self.pre_process:
  252. self.embedding = Embedding(self.hidden_size,
  253. args.padded_vocab_size,
  254. args.max_position_embeddings,
  255. args.hidden_dropout,
  256. self.init_method,
  257. self.num_tokentypes)
  258. self._embedding_key = 'embedding'
  259. # Transformer.
  260. self.encoder = ParallelTransformer(
  261. self.init_method,
  262. output_layer_init_method,
  263. self_attn_mask_type=self.encoder_attn_mask_type,
  264. pre_process=self.pre_process,
  265. post_process=self.post_process
  266. )
  267. self._encoder_key = 'encoder'
  268. # Decoder
  269. if self.add_decoder:
  270. assert args.pipeline_model_parallel_size == 1, \
  271. 'pipeline parallelism is not supported in the presence of decoder'
  272. self.decoder = ParallelTransformer(
  273. self.init_method,
  274. output_layer_init_method,
  275. layer_type=LayerType.decoder,
  276. self_attn_mask_type=self.decoder_attn_mask_type)
  277. self._decoder_key = 'decoder'
  278. if self.post_process:
  279. # Pooler.
  280. if self.add_pooler:
  281. self.pooler = Pooler(self.hidden_size, self.init_method)
  282. self._pooler_key = 'pooler'
  283. def set_input_tensor(self, input_tensor):
  284. """ See megatron.model.transformer.set_input_tensor()"""
  285. self.encoder.set_input_tensor(input_tensor)
  286. def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
  287. dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
  288. enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None,
  289. get_key_value=False, pooling_sequence_index=0,
  290. enc_hidden_states=None, output_enc_hidden=False):
  291. # Embeddings.
  292. if self.pre_process:
  293. embedding_output = self.embedding(enc_input_ids, enc_position_ids,
  294. tokentype_ids=tokentype_ids)
  295. encoder_input = embedding_output
  296. else:
  297. encoder_input = None
  298. # encoder.
  299. if enc_hidden_states is None:
  300. encoder_output = self.encoder(encoder_input,
  301. enc_attn_mask,
  302. layer_past=layer_past,
  303. get_key_value=get_key_value)
  304. else:
  305. encoder_output = enc_hidden_states.to(encoder_input.dtype)
  306. if self.post_process:
  307. if self.add_pooler:
  308. pooled_output = self.pooler(encoder_output,
  309. pooling_sequence_index)
  310. # output_enc_hidden refers to when we just need the encoder's
  311. # output. For example, it is helpful to compute
  312. # similarity between two sequences by average pooling
  313. if not self.add_decoder or output_enc_hidden:
  314. if self.add_pooler and self.post_process:
  315. return encoder_output, pooled_output
  316. else:
  317. return encoder_output
  318. # Decoder Embedding
  319. dec_embedding_output = self.embedding(dec_input_ids,
  320. dec_position_ids)
  321. # decoder
  322. decoder_output = self.decoder(dec_embedding_output,
  323. dec_attn_mask,
  324. layer_past=layer_past,
  325. get_key_value=get_key_value,
  326. encoder_output=encoder_output,
  327. enc_dec_attn_mask=enc_dec_attn_mask)
  328. if self.add_pooler and self.post_process:
  329. return decoder_output, encoder_output, pooled_output
  330. else:
  331. return decoder_output, encoder_output
  332. def state_dict_for_save_checkpoint(self, destination=None, prefix='',
  333. keep_vars=False):
  334. """For easy load."""
  335. state_dict_ = {}
  336. if self.pre_process:
  337. state_dict_[self._embedding_key] \
  338. = self.embedding.state_dict_for_save_checkpoint(
  339. destination, prefix, keep_vars)
  340. state_dict_[self._encoder_key] \
  341. = self.encoder.state_dict_for_save_checkpoint(
  342. destination, prefix, keep_vars)
  343. if self.post_process:
  344. if self.add_pooler:
  345. state_dict_[self._pooler_key] \
  346. = self.pooler.state_dict_for_save_checkpoint(
  347. destination, prefix, keep_vars)
  348. if self.add_decoder:
  349. state_dict_[self._decoder_key] \
  350. = self.decoder.state_dict_for_save_checkpoint(
  351. destination, prefix, keep_vars)
  352. return state_dict_
  353. def load_state_dict(self, state_dict, strict=True):
  354. """Customized load."""
  355. # Embedding.
  356. if self.pre_process:
  357. if self._embedding_key in state_dict:
  358. state_dict_ = state_dict[self._embedding_key]
  359. else:
  360. # for backward compatibility.
  361. state_dict_ = {}
  362. for key in state_dict.keys():
  363. if '_embeddings' in key:
  364. state_dict_[key] = state_dict[key]
  365. self.embedding.load_state_dict(state_dict_, strict=strict)
  366. # Encoder.
  367. if self._encoder_key in state_dict:
  368. state_dict_ = state_dict[self._encoder_key]
  369. # for backward compatibility.
  370. elif 'transformer' in state_dict:
  371. state_dict_ = state_dict['transformer']
  372. else:
  373. # for backward compatibility.
  374. state_dict_ = {}
  375. for key in state_dict.keys():
  376. if 'transformer.' in key:
  377. state_dict_[key.split('transformer.')[1]] = state_dict[key]
  378. # for backward compatibility.
  379. state_dict_self_attention = {}
  380. for key in state_dict_.keys():
  381. if '.attention.' in key:
  382. state_dict_self_attention[key.replace(".attention.",
  383. ".self_attention.")] = state_dict_[key]
  384. else:
  385. state_dict_self_attention[key] = state_dict_[key]
  386. state_dict_ = state_dict_self_attention
  387. self.encoder.load_state_dict(state_dict_, strict=strict)
  388. if self.post_process:
  389. # pooler
  390. if self.add_pooler:
  391. assert 'pooler' in state_dict, \
  392. 'could not find data for pooler in the checkpoint'
  393. self.pooler.load_state_dict(state_dict[self._pooler_key],
  394. strict=strict)
  395. # decoder
  396. if self.add_decoder:
  397. assert 'decoder' in state_dict, \
  398. 'could not find data for pooler in the checkpoint'
  399. self.decoder.load_state_dict(state_dict[self._decoder_key],
  400. strict=strict)