transformer.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Transformer."""
  16. import math
  17. import torch
  18. import torch.nn.functional as F
  19. from megatron import get_args
  20. from megatron import mpu
  21. from .module import MegatronModule
  22. from megatron.model.enums import AttnMaskType, LayerType, AttnType
  23. from megatron.model import LayerNorm
  24. from megatron.model.fused_softmax import FusedScaleMaskSoftmax
  25. from megatron.model.fused_bias_gelu import bias_gelu_impl
  26. from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
  27. # flags required to enable jit fusion kernels
  28. torch._C._jit_set_profiling_mode(False)
  29. torch._C._jit_set_profiling_executor(False)
  30. torch._C._jit_override_can_fuse_on_cpu(True)
  31. torch._C._jit_override_can_fuse_on_gpu(True)
  32. """ We use the following notation throughout this file:
  33. h: hidden size
  34. n: number of attention heads
  35. p: number of model parallel partitions
  36. np: n/p
  37. hp: h/p
  38. hn: h/n
  39. b: batch size
  40. s: sequence length
  41. l: number of layers
  42. Transformer takes input of size [s, b, h] and returns a
  43. tensor of the same size. We use the following arguments:
  44. hyperparameters: transformer hyperparameters
  45. """
  46. class ParallelMLP(MegatronModule):
  47. """MLP.
  48. MLP will take the input with h hidden state, project it to 4*h
  49. hidden dimension, perform nonlinear transformation, and project the
  50. state back into h hidden dimension. At the end, dropout is also
  51. applied.
  52. """
  53. def __init__(self, init_method, output_layer_init_method):
  54. super(ParallelMLP, self).__init__()
  55. args = get_args()
  56. # Project to 4h.
  57. self.dense_h_to_4h = mpu.ColumnParallelLinear(
  58. args.hidden_size,
  59. args.ffn_hidden_size,
  60. gather_output=False,
  61. init_method=init_method,
  62. skip_bias_add=True)
  63. self.bias_gelu_fusion = args.bias_gelu_fusion
  64. self.activation_func = F.gelu
  65. if args.openai_gelu:
  66. self.activation_func = openai_gelu
  67. elif args.onnx_safe:
  68. self.activation_func = erf_gelu
  69. # Project back to h.
  70. self.dense_4h_to_h = mpu.RowParallelLinear(
  71. args.ffn_hidden_size,
  72. args.hidden_size,
  73. input_is_parallel=True,
  74. init_method=output_layer_init_method,
  75. skip_bias_add=True)
  76. def forward(self, hidden_states):
  77. # [s, b, 4hp]
  78. intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
  79. if self.bias_gelu_fusion:
  80. intermediate_parallel = \
  81. bias_gelu_impl(intermediate_parallel, bias_parallel)
  82. else:
  83. intermediate_parallel = \
  84. self.activation_func(intermediate_parallel + bias_parallel)
  85. # [s, b, h]
  86. output, output_bias = self.dense_4h_to_h(intermediate_parallel)
  87. return output, output_bias
  88. class ParallelAttention(MegatronModule):
  89. """Parallel self-attention layer abstract class.
  90. Self-attention layer takes input with size [b, s, h]
  91. and returns output of the same size.
  92. """
  93. def __init__(self, init_method,
  94. output_layer_init_method, layer_number,
  95. attention_type=AttnType.self_attn,
  96. attn_mask_type=AttnMaskType.padding):
  97. super(ParallelAttention, self).__init__()
  98. args = get_args()
  99. self.fp16 = args.fp16
  100. self.bf16 = args.bf16
  101. self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
  102. self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
  103. if self.apply_query_key_layer_scaling:
  104. self.attention_softmax_in_fp32 = True
  105. self.layer_number = max(1, layer_number)
  106. self.attention_type = attention_type
  107. self.attn_mask_type = attn_mask_type
  108. projection_size = args.kv_channels * args.num_attention_heads
  109. # Per attention head and per partition values.
  110. world_size = mpu.get_tensor_model_parallel_world_size()
  111. self.hidden_size_per_partition = mpu.divide(projection_size,
  112. world_size)
  113. self.hidden_size_per_attention_head = mpu.divide(
  114. projection_size, args.num_attention_heads)
  115. self.num_attention_heads_per_partition = mpu.divide(
  116. args.num_attention_heads, world_size)
  117. # Strided linear layer.
  118. if attention_type == AttnType.self_attn:
  119. self.query_key_value = mpu.ColumnParallelLinear(
  120. args.hidden_size,
  121. 3 * projection_size,
  122. gather_output=False,
  123. init_method=init_method)
  124. else:
  125. assert attention_type == AttnType.cross_attn
  126. self.query = mpu.ColumnParallelLinear(
  127. args.hidden_size,
  128. projection_size,
  129. gather_output=False,
  130. init_method=init_method)
  131. self.key_value = mpu.ColumnParallelLinear(
  132. args.hidden_size,
  133. 2 * projection_size,
  134. gather_output=False,
  135. init_method=init_method)
  136. coeff = None
  137. self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
  138. if self.apply_query_key_layer_scaling:
  139. coeff = self.layer_number
  140. self.norm_factor *= coeff
  141. self.scale_mask_softmax = FusedScaleMaskSoftmax(
  142. self.fp16, self.bf16,
  143. self.attn_mask_type,
  144. args.masked_softmax_fusion,
  145. attention_mask_func,
  146. self.attention_softmax_in_fp32,
  147. coeff)
  148. # Dropout. Note that for a single iteration, this layer will generate
  149. # different outputs on different number of parallel partitions but
  150. # on average it should not be partition dependent.
  151. self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
  152. # Output.
  153. self.dense = mpu.RowParallelLinear(
  154. projection_size,
  155. args.hidden_size,
  156. input_is_parallel=True,
  157. init_method=output_layer_init_method,
  158. skip_bias_add=True)
  159. def forward(self, hidden_states, attention_mask, layer_past=None,
  160. get_key_value=False, encoder_output=None):
  161. # hidden_states: [sq, b, h]
  162. # =====================
  163. # Query, Key, and Value
  164. # =====================
  165. if self.attention_type == AttnType.self_attn:
  166. # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
  167. mixed_x_layer, _ = self.query_key_value(hidden_states)
  168. # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
  169. new_tensor_shape = mixed_x_layer.size()[:-1] + \
  170. (self.num_attention_heads_per_partition,
  171. 3 * self.hidden_size_per_attention_head)
  172. mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
  173. # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
  174. (query_layer,
  175. key_layer,
  176. value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
  177. else:
  178. # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
  179. mixed_kv_layer, _ = self.key_value(encoder_output)
  180. # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
  181. new_tensor_shape = mixed_kv_layer.size()[:-1] + \
  182. (self.num_attention_heads_per_partition,
  183. 2 * self.hidden_size_per_attention_head)
  184. mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
  185. # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
  186. (key_layer,
  187. value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2)
  188. # Attention head [sq, b, h] --> [sq, b, hp]
  189. query_layer, _ = self.query(hidden_states)
  190. # [sq, b, hp] --> [sq, b, np, hn]
  191. new_tensor_shape = query_layer.size()[:-1] + \
  192. (self.num_attention_heads_per_partition,
  193. self.hidden_size_per_attention_head)
  194. query_layer = query_layer.view(*new_tensor_shape)
  195. # ==================================
  196. # Adjust key and value for inference
  197. # ==================================
  198. if layer_past is not None:
  199. past_key, past_value = layer_past
  200. key_layer = torch.cat((past_key.type_as(key_layer),
  201. key_layer), dim=0)
  202. value_layer = torch.cat((past_value.type_as(value_layer),
  203. value_layer), dim=0)
  204. if get_key_value:
  205. present = (key_layer, value_layer)
  206. # ===================================
  207. # Raw attention scores. [b, np, s, s]
  208. # ===================================
  209. # [b, np, sq, sk]
  210. output_size = (query_layer.size(1),
  211. query_layer.size(2),
  212. query_layer.size(0),
  213. key_layer.size(0))
  214. # [sq, b, np, hn] -> [sq, b * np, hn]
  215. query_layer = query_layer.view(output_size[2],
  216. output_size[0] * output_size[1], -1)
  217. # [sk, b, np, hn] -> [sk, b * np, hn]
  218. key_layer = key_layer.view(output_size[3],
  219. output_size[0] * output_size[1], -1)
  220. # preallocting result tensor: [b * np, sq, sk]
  221. matmul_result = torch.empty(
  222. output_size[0]*output_size[1],
  223. output_size[2],
  224. output_size[3],
  225. dtype=query_layer.dtype,
  226. device=torch.cuda.current_device())
  227. # Raw attention scores. [b * np, sq, sk]
  228. matmul_result = torch.baddbmm(
  229. matmul_result,
  230. query_layer.transpose(0, 1), # [b * np, sq, hn]
  231. key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
  232. beta=0.0, alpha=(1.0/self.norm_factor))
  233. # change view to [b, np, sq, sk]
  234. attention_scores = matmul_result.view(*output_size)
  235. # ==================================================
  236. # Update attention mask for inference. [b, np, sq, sk]
  237. # ==================================================
  238. if get_key_value:
  239. with torch.no_grad():
  240. if layer_past is not None:
  241. attention_mask = attention_mask[
  242. ...,
  243. attention_scores.size(3) - 1,
  244. :attention_scores.size(3)].unsqueeze(2)
  245. else:
  246. attention_mask = attention_mask[
  247. ...,
  248. :attention_scores.size(3),
  249. :attention_scores.size(3)]
  250. # ===========================
  251. # Attention probs and dropout
  252. # ===========================
  253. # attention scores and attention mask [b, np, sq, sk]
  254. attention_probs = self.scale_mask_softmax(attention_scores,
  255. attention_mask)
  256. # This is actually dropping out entire tokens to attend to, which might
  257. # seem a bit unusual, but is taken from the original Transformer paper.
  258. with mpu.get_cuda_rng_tracker().fork():
  259. attention_probs = self.attention_dropout(attention_probs)
  260. # =========================
  261. # Context layer. [sq, b, hp]
  262. # =========================
  263. # value_layer -> context layer.
  264. # [sk, b, np, hn] --> [b, np, sq, hn]
  265. # context layer shape: [b, np, sq, hn]
  266. output_size = (value_layer.size(1),
  267. value_layer.size(2),
  268. query_layer.size(0),
  269. value_layer.size(3))
  270. # change view [sk, b * np, hn]
  271. value_layer = value_layer.view(value_layer.size(0),
  272. output_size[0] * output_size[1], -1)
  273. # change view [b * np, sq, sk]
  274. attention_probs = attention_probs.view(output_size[0] * output_size[1],
  275. output_size[2], -1)
  276. # matmul: [b * np, sq, hn]
  277. context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
  278. # change view [b, np, sq, hn]
  279. context_layer = context_layer.view(*output_size)
  280. # [b, np, sq, hn] --> [sq, b, np, hn]
  281. context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
  282. # [sq, b, np, hn] --> [sq, b, hp]
  283. new_context_layer_shape = context_layer.size()[:-2] + \
  284. (self.hidden_size_per_partition,)
  285. context_layer = context_layer.view(*new_context_layer_shape)
  286. # =================
  287. # Output. [sq, b, h]
  288. # =================
  289. output, bias = self.dense(context_layer)
  290. if get_key_value:
  291. output = [output, present]
  292. return output, bias
  293. def bias_dropout_add(x, bias, residual, prob, training):
  294. # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
  295. out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
  296. out = residual + out
  297. return out
  298. def get_bias_dropout_add(training):
  299. def _bias_dropout_add(x, bias, residual, prob):
  300. return bias_dropout_add(x, bias, residual, prob, training)
  301. return _bias_dropout_add
  302. @torch.jit.script
  303. def bias_dropout_add_fused_train(x, bias, residual, prob):
  304. # type: (Tensor, Tensor, Tensor, float) -> Tensor
  305. return bias_dropout_add(x, bias, residual, prob, True)
  306. @torch.jit.script
  307. def bias_dropout_add_fused_inference(x, bias, residual, prob):
  308. # type: (Tensor, Tensor, Tensor, float) -> Tensor
  309. return bias_dropout_add(x, bias, residual, prob, False)
  310. class ParallelTransformerLayer(MegatronModule):
  311. """A single transformer layer.
  312. Transformer layer takes input with size [b, s, h] and returns an
  313. output of the same size.
  314. """
  315. def __init__(self, init_method, output_layer_init_method,
  316. layer_number, layer_type=LayerType.encoder,
  317. self_attn_mask_type=AttnMaskType.padding):
  318. args = get_args()
  319. super(ParallelTransformerLayer, self).__init__()
  320. self.layer_number = layer_number
  321. self.layer_type = layer_type
  322. self.apply_residual_connection_post_layernorm \
  323. = args.apply_residual_connection_post_layernorm
  324. self.bf16 = args.bf16
  325. self.fp32_residual_connection = args.fp32_residual_connection
  326. # Layernorm on the input data.
  327. self.input_layernorm = LayerNorm(
  328. args.hidden_size,
  329. eps=args.layernorm_epsilon)
  330. # Self attention.
  331. self.self_attention = ParallelAttention(
  332. init_method,
  333. output_layer_init_method,
  334. layer_number,
  335. attention_type=AttnType.self_attn,
  336. attn_mask_type=self_attn_mask_type)
  337. self.hidden_dropout = args.hidden_dropout
  338. self.bias_dropout_fusion = args.bias_dropout_fusion
  339. # Layernorm on the attention output
  340. self.post_attention_layernorm = LayerNorm(
  341. args.hidden_size,
  342. eps=args.layernorm_epsilon)
  343. if self.layer_type == LayerType.decoder:
  344. self.inter_attention = ParallelAttention(
  345. init_method,
  346. output_layer_init_method,
  347. layer_number,
  348. attention_type=AttnType.cross_attn)
  349. # Layernorm on the attention output.
  350. self.post_inter_attention_layernorm = LayerNorm(
  351. args.hidden_size,
  352. eps=args.layernorm_epsilon)
  353. # MLP
  354. self.mlp = ParallelMLP(init_method,
  355. output_layer_init_method)
  356. def forward(self, hidden_states, attention_mask,
  357. encoder_output=None, enc_dec_attn_mask=None,
  358. layer_past=None, get_key_value=False):
  359. # hidden_states: [b, s, h]
  360. # Layer norm at the beginning of the transformer layer.
  361. layernorm_output = self.input_layernorm(hidden_states)
  362. # Self attention.
  363. attention_output, attention_bias = \
  364. self.self_attention(layernorm_output,
  365. attention_mask,
  366. layer_past=layer_past,
  367. get_key_value=get_key_value)
  368. if get_key_value:
  369. attention_output, presents = attention_output
  370. # Residual connection.
  371. if self.apply_residual_connection_post_layernorm:
  372. residual = layernorm_output
  373. else:
  374. residual = hidden_states
  375. # jit scripting for a nn.module (with dropout) is not
  376. # trigerring the fusion kernel. For now, we use two
  377. # different nn.functional routines to account for varying
  378. # dropout semantics during training and inference phases.
  379. if self.bias_dropout_fusion:
  380. if self.training:
  381. bias_dropout_add_func = bias_dropout_add_fused_train
  382. else:
  383. bias_dropout_add_func = bias_dropout_add_fused_inference
  384. else:
  385. bias_dropout_add_func = get_bias_dropout_add(self.training)
  386. # re-enable torch grad to enable fused optimization.
  387. with torch.enable_grad():
  388. layernorm_input = bias_dropout_add_func(
  389. attention_output,
  390. attention_bias.expand_as(residual),
  391. residual,
  392. self.hidden_dropout)
  393. # Layer norm post the self attention.
  394. layernorm_output = self.post_attention_layernorm(layernorm_input)
  395. if self.layer_type == LayerType.decoder:
  396. attention_output, attention_bias = \
  397. self.inter_attention(layernorm_output,
  398. enc_dec_attn_mask,
  399. encoder_output=encoder_output)
  400. # residual connection
  401. if self.apply_residual_connection_post_layernorm:
  402. residual = layernorm_output
  403. else:
  404. residual = layernorm_input
  405. # re-enable torch grad to enable fused optimization.
  406. with torch.enable_grad():
  407. layernorm_input = bias_dropout_add_func(
  408. attention_output,
  409. attention_bias.expand_as(residual),
  410. residual,
  411. self.hidden_dropout)
  412. # Layer norm post the decoder attention
  413. layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
  414. # MLP.
  415. mlp_output, mlp_bias = self.mlp(layernorm_output)
  416. # Second residual connection.
  417. if self.apply_residual_connection_post_layernorm:
  418. residual = layernorm_output
  419. else:
  420. residual = layernorm_input
  421. # re-enable torch grad to enable fused optimization.
  422. with torch.enable_grad():
  423. output = bias_dropout_add_func(
  424. mlp_output,
  425. mlp_bias.expand_as(residual),
  426. residual,
  427. self.hidden_dropout)
  428. if get_key_value:
  429. output = [output, presents]
  430. return output
  431. class ParallelTransformer(MegatronModule):
  432. """Transformer class."""
  433. def __init__(self, init_method, output_layer_init_method,
  434. layer_type=LayerType.encoder,
  435. self_attn_mask_type=AttnMaskType.padding,
  436. pre_process=True, post_process=True):
  437. super(ParallelTransformer, self).__init__()
  438. args = get_args()
  439. self.bf16 = args.bf16
  440. self.fp32_residual_connection = args.fp32_residual_connection
  441. self.pre_process = pre_process
  442. self.post_process = post_process
  443. self.input_tensor = None
  444. # Store activation checkpoiting flag.
  445. self.checkpoint_activations = args.checkpoint_activations
  446. self.checkpoint_num_layers = args.checkpoint_num_layers
  447. # Number of layers.
  448. assert args.num_layers % mpu.get_pipeline_model_parallel_world_size() == 0, \
  449. 'num_layers must be divisible by pipeline_model_parallel_size'
  450. self.num_layers = args.num_layers // mpu.get_pipeline_model_parallel_world_size()
  451. # Transformer layers.
  452. def build_layer(layer_number):
  453. return ParallelTransformerLayer(
  454. init_method,
  455. output_layer_init_method,
  456. layer_number,
  457. layer_type=layer_type,
  458. self_attn_mask_type=self_attn_mask_type)
  459. if args.virtual_pipeline_model_parallel_size is not None:
  460. assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
  461. 'num_layers_per_stage must be divisible by ' \
  462. 'virtual_pipeline_model_parallel_size'
  463. # Number of layers in each model chunk is the number of layers in the stage,
  464. # divided by the number of model chunks in a stage.
  465. self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
  466. # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
  467. # layers to stages like (each list is a model chunk):
  468. # Stage 0: [0] [2] [4] [6]
  469. # Stage 1: [1] [3] [5] [7]
  470. # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
  471. # layers to stages like (each list is a model chunk):
  472. # Stage 0: [0, 1] [4, 5]
  473. # Stage 1: [2, 3] [6, 7]
  474. offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
  475. args.num_layers // args.virtual_pipeline_model_parallel_size) + \
  476. (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
  477. else:
  478. # Each stage gets a contiguous set of layers.
  479. offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
  480. self.layers = torch.nn.ModuleList(
  481. [build_layer(i + 1 + offset) for i in range(self.num_layers)])
  482. if self.post_process:
  483. # Final layer norm before output.
  484. self.final_layernorm = LayerNorm(
  485. args.hidden_size,
  486. eps=args.layernorm_epsilon)
  487. def _get_layer(self, layer_number):
  488. return self.layers[layer_number]
  489. def _checkpointed_forward(self, hidden_states, attention_mask,
  490. encoder_output, enc_dec_attn_mask):
  491. """Forward method with activation checkpointing."""
  492. def custom(start, end):
  493. def custom_forward(*inputs):
  494. x_ = inputs[0]
  495. attention_mask = inputs[1]
  496. encoder_output = inputs[2]
  497. enc_dec_attn_mask = inputs[3]
  498. for index in range(start, end):
  499. layer = self._get_layer(index)
  500. x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
  501. return x_
  502. return custom_forward
  503. # Make sure memory is freed.
  504. mpu.reset_checkpointed_activations_memory_buffer()
  505. l = 0
  506. while l < self.num_layers:
  507. hidden_states = mpu.checkpoint(
  508. custom(l, l + self.checkpoint_num_layers),
  509. hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
  510. l += self.checkpoint_num_layers
  511. return hidden_states
  512. def set_input_tensor(self, input_tensor):
  513. """Set input tensor to be used instead of forward()'s input.
  514. When doing pipeline parallelism the input from the previous
  515. stage comes from communication, not from the input, so the
  516. model's forward_step_func won't have it. This function is thus
  517. used by internal code to bypass the input provided by the
  518. forward_step_func"""
  519. self.input_tensor = input_tensor
  520. def forward(self, hidden_states, attention_mask, layer_past=None,
  521. get_key_value=False, encoder_output=None, enc_dec_attn_mask=None):
  522. # Checks.
  523. if layer_past is not None:
  524. assert get_key_value, \
  525. 'for not None values in layer_past, ' \
  526. 'expected get_key_value to be set'
  527. if get_key_value:
  528. assert not self.checkpoint_activations, \
  529. 'get_key_value does not work with ' \
  530. 'activation checkpointing'
  531. if self.pre_process:
  532. # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
  533. # If the input flag for fp32 residual connection is set, convert for float.
  534. if self.fp32_residual_connection:
  535. hidden_states = hidden_states.transpose(0, 1).contiguous().float()
  536. # Otherwise, leave it as is.
  537. else:
  538. hidden_states = hidden_states.transpose(0, 1).contiguous()
  539. else:
  540. # See set_input_tensor()
  541. hidden_states = self.input_tensor
  542. if encoder_output is not None:
  543. encoder_output = encoder_output.transpose(0, 1).contiguous()
  544. if self.checkpoint_activations:
  545. hidden_states = self._checkpointed_forward(hidden_states,
  546. attention_mask,
  547. encoder_output,
  548. enc_dec_attn_mask)
  549. else:
  550. if get_key_value:
  551. presents = []
  552. for index in range(self.num_layers):
  553. layer = self._get_layer(index)
  554. past = None
  555. if layer_past is not None:
  556. past = layer_past[index]
  557. hidden_states = layer(hidden_states,
  558. attention_mask,
  559. encoder_output=encoder_output,
  560. enc_dec_attn_mask=enc_dec_attn_mask,
  561. layer_past=past,
  562. get_key_value=get_key_value)
  563. if get_key_value:
  564. hidden_states, present = hidden_states
  565. presents.append(present)
  566. # Final layer norm.
  567. if self.post_process:
  568. # Reverting data format change [s b h] --> [b s h].
  569. hidden_states = hidden_states.transpose(0, 1).contiguous()
  570. output = self.final_layernorm(hidden_states)
  571. else:
  572. output = hidden_states
  573. if get_key_value:
  574. output = [output, presents]
  575. return output