123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695 |
- # coding=utf-8
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Transformer."""
- import math
- import torch
- import torch.nn.functional as F
- from megatron import get_args
- from megatron import mpu
- from .module import MegatronModule
- from megatron.model.enums import AttnMaskType, LayerType, AttnType
- from megatron.model import LayerNorm
- from megatron.model.fused_softmax import FusedScaleMaskSoftmax
- from megatron.model.fused_bias_gelu import bias_gelu_impl
- from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
- # flags required to enable jit fusion kernels
- torch._C._jit_set_profiling_mode(False)
- torch._C._jit_set_profiling_executor(False)
- torch._C._jit_override_can_fuse_on_cpu(True)
- torch._C._jit_override_can_fuse_on_gpu(True)
- """ We use the following notation throughout this file:
- h: hidden size
- n: number of attention heads
- p: number of model parallel partitions
- np: n/p
- hp: h/p
- hn: h/n
- b: batch size
- s: sequence length
- l: number of layers
- Transformer takes input of size [s, b, h] and returns a
- tensor of the same size. We use the following arguments:
- hyperparameters: transformer hyperparameters
- """
- class ParallelMLP(MegatronModule):
- """MLP.
- MLP will take the input with h hidden state, project it to 4*h
- hidden dimension, perform nonlinear transformation, and project the
- state back into h hidden dimension. At the end, dropout is also
- applied.
- """
- def __init__(self, init_method, output_layer_init_method):
- super(ParallelMLP, self).__init__()
- args = get_args()
- # Project to 4h.
- self.dense_h_to_4h = mpu.ColumnParallelLinear(
- args.hidden_size,
- args.ffn_hidden_size,
- gather_output=False,
- init_method=init_method,
- skip_bias_add=True)
- self.bias_gelu_fusion = args.bias_gelu_fusion
- self.activation_func = F.gelu
- if args.openai_gelu:
- self.activation_func = openai_gelu
- elif args.onnx_safe:
- self.activation_func = erf_gelu
- # Project back to h.
- self.dense_4h_to_h = mpu.RowParallelLinear(
- args.ffn_hidden_size,
- args.hidden_size,
- input_is_parallel=True,
- init_method=output_layer_init_method,
- skip_bias_add=True)
- def forward(self, hidden_states):
- # [s, b, 4hp]
- intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
- if self.bias_gelu_fusion:
- intermediate_parallel = \
- bias_gelu_impl(intermediate_parallel, bias_parallel)
- else:
- intermediate_parallel = \
- self.activation_func(intermediate_parallel + bias_parallel)
- # [s, b, h]
- output, output_bias = self.dense_4h_to_h(intermediate_parallel)
- return output, output_bias
- class ParallelAttention(MegatronModule):
- """Parallel self-attention layer abstract class.
- Self-attention layer takes input with size [b, s, h]
- and returns output of the same size.
- """
- def __init__(self, init_method,
- output_layer_init_method, layer_number,
- attention_type=AttnType.self_attn,
- attn_mask_type=AttnMaskType.padding):
- super(ParallelAttention, self).__init__()
- args = get_args()
- self.fp16 = args.fp16
- self.bf16 = args.bf16
- self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
- self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
- if self.apply_query_key_layer_scaling:
- self.attention_softmax_in_fp32 = True
- self.layer_number = max(1, layer_number)
- self.attention_type = attention_type
- self.attn_mask_type = attn_mask_type
- projection_size = args.kv_channels * args.num_attention_heads
- # Per attention head and per partition values.
- world_size = mpu.get_tensor_model_parallel_world_size()
- self.hidden_size_per_partition = mpu.divide(projection_size,
- world_size)
- self.hidden_size_per_attention_head = mpu.divide(
- projection_size, args.num_attention_heads)
- self.num_attention_heads_per_partition = mpu.divide(
- args.num_attention_heads, world_size)
- # Strided linear layer.
- if attention_type == AttnType.self_attn:
- self.query_key_value = mpu.ColumnParallelLinear(
- args.hidden_size,
- 3 * projection_size,
- gather_output=False,
- init_method=init_method)
- else:
- assert attention_type == AttnType.cross_attn
- self.query = mpu.ColumnParallelLinear(
- args.hidden_size,
- projection_size,
- gather_output=False,
- init_method=init_method)
- self.key_value = mpu.ColumnParallelLinear(
- args.hidden_size,
- 2 * projection_size,
- gather_output=False,
- init_method=init_method)
- coeff = None
- self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
- if self.apply_query_key_layer_scaling:
- coeff = self.layer_number
- self.norm_factor *= coeff
- self.scale_mask_softmax = FusedScaleMaskSoftmax(
- self.fp16, self.bf16,
- self.attn_mask_type,
- args.masked_softmax_fusion,
- attention_mask_func,
- self.attention_softmax_in_fp32,
- coeff)
- # Dropout. Note that for a single iteration, this layer will generate
- # different outputs on different number of parallel partitions but
- # on average it should not be partition dependent.
- self.attention_dropout = torch.nn.Dropout(args.attention_dropout)
- # Output.
- self.dense = mpu.RowParallelLinear(
- projection_size,
- args.hidden_size,
- input_is_parallel=True,
- init_method=output_layer_init_method,
- skip_bias_add=True)
- def forward(self, hidden_states, attention_mask, layer_past=None,
- get_key_value=False, encoder_output=None):
- # hidden_states: [sq, b, h]
- # =====================
- # Query, Key, and Value
- # =====================
- if self.attention_type == AttnType.self_attn:
- # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
- mixed_x_layer, _ = self.query_key_value(hidden_states)
- # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
- new_tensor_shape = mixed_x_layer.size()[:-1] + \
- (self.num_attention_heads_per_partition,
- 3 * self.hidden_size_per_attention_head)
- mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
- # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
- (query_layer,
- key_layer,
- value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)
- else:
- # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
- mixed_kv_layer, _ = self.key_value(encoder_output)
- # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
- new_tensor_shape = mixed_kv_layer.size()[:-1] + \
- (self.num_attention_heads_per_partition,
- 2 * self.hidden_size_per_attention_head)
- mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
- # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
- (key_layer,
- value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2)
- # Attention head [sq, b, h] --> [sq, b, hp]
- query_layer, _ = self.query(hidden_states)
- # [sq, b, hp] --> [sq, b, np, hn]
- new_tensor_shape = query_layer.size()[:-1] + \
- (self.num_attention_heads_per_partition,
- self.hidden_size_per_attention_head)
- query_layer = query_layer.view(*new_tensor_shape)
- # ==================================
- # Adjust key and value for inference
- # ==================================
- if layer_past is not None:
- past_key, past_value = layer_past
- key_layer = torch.cat((past_key.type_as(key_layer),
- key_layer), dim=0)
- value_layer = torch.cat((past_value.type_as(value_layer),
- value_layer), dim=0)
- if get_key_value:
- present = (key_layer, value_layer)
- # ===================================
- # Raw attention scores. [b, np, s, s]
- # ===================================
- # [b, np, sq, sk]
- output_size = (query_layer.size(1),
- query_layer.size(2),
- query_layer.size(0),
- key_layer.size(0))
- # [sq, b, np, hn] -> [sq, b * np, hn]
- query_layer = query_layer.view(output_size[2],
- output_size[0] * output_size[1], -1)
- # [sk, b, np, hn] -> [sk, b * np, hn]
- key_layer = key_layer.view(output_size[3],
- output_size[0] * output_size[1], -1)
- # preallocting result tensor: [b * np, sq, sk]
- matmul_result = torch.empty(
- output_size[0]*output_size[1],
- output_size[2],
- output_size[3],
- dtype=query_layer.dtype,
- device=torch.cuda.current_device())
- # Raw attention scores. [b * np, sq, sk]
- matmul_result = torch.baddbmm(
- matmul_result,
- query_layer.transpose(0, 1), # [b * np, sq, hn]
- key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
- beta=0.0, alpha=(1.0/self.norm_factor))
- # change view to [b, np, sq, sk]
- attention_scores = matmul_result.view(*output_size)
- # ==================================================
- # Update attention mask for inference. [b, np, sq, sk]
- # ==================================================
- if get_key_value:
- with torch.no_grad():
- if layer_past is not None:
- attention_mask = attention_mask[
- ...,
- attention_scores.size(3) - 1,
- :attention_scores.size(3)].unsqueeze(2)
- else:
- attention_mask = attention_mask[
- ...,
- :attention_scores.size(3),
- :attention_scores.size(3)]
- # ===========================
- # Attention probs and dropout
- # ===========================
- # attention scores and attention mask [b, np, sq, sk]
- attention_probs = self.scale_mask_softmax(attention_scores,
- attention_mask)
- # This is actually dropping out entire tokens to attend to, which might
- # seem a bit unusual, but is taken from the original Transformer paper.
- with mpu.get_cuda_rng_tracker().fork():
- attention_probs = self.attention_dropout(attention_probs)
- # =========================
- # Context layer. [sq, b, hp]
- # =========================
- # value_layer -> context layer.
- # [sk, b, np, hn] --> [b, np, sq, hn]
- # context layer shape: [b, np, sq, hn]
- output_size = (value_layer.size(1),
- value_layer.size(2),
- query_layer.size(0),
- value_layer.size(3))
- # change view [sk, b * np, hn]
- value_layer = value_layer.view(value_layer.size(0),
- output_size[0] * output_size[1], -1)
- # change view [b * np, sq, sk]
- attention_probs = attention_probs.view(output_size[0] * output_size[1],
- output_size[2], -1)
- # matmul: [b * np, sq, hn]
- context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
- # change view [b, np, sq, hn]
- context_layer = context_layer.view(*output_size)
- # [b, np, sq, hn] --> [sq, b, np, hn]
- context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
- # [sq, b, np, hn] --> [sq, b, hp]
- new_context_layer_shape = context_layer.size()[:-2] + \
- (self.hidden_size_per_partition,)
- context_layer = context_layer.view(*new_context_layer_shape)
- # =================
- # Output. [sq, b, h]
- # =================
- output, bias = self.dense(context_layer)
- if get_key_value:
- output = [output, present]
- return output, bias
- def bias_dropout_add(x, bias, residual, prob, training):
- # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
- out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
- out = residual + out
- return out
- def get_bias_dropout_add(training):
- def _bias_dropout_add(x, bias, residual, prob):
- return bias_dropout_add(x, bias, residual, prob, training)
- return _bias_dropout_add
- @torch.jit.script
- def bias_dropout_add_fused_train(x, bias, residual, prob):
- # type: (Tensor, Tensor, Tensor, float) -> Tensor
- return bias_dropout_add(x, bias, residual, prob, True)
- @torch.jit.script
- def bias_dropout_add_fused_inference(x, bias, residual, prob):
- # type: (Tensor, Tensor, Tensor, float) -> Tensor
- return bias_dropout_add(x, bias, residual, prob, False)
- class ParallelTransformerLayer(MegatronModule):
- """A single transformer layer.
- Transformer layer takes input with size [b, s, h] and returns an
- output of the same size.
- """
- def __init__(self, init_method, output_layer_init_method,
- layer_number, layer_type=LayerType.encoder,
- self_attn_mask_type=AttnMaskType.padding):
- args = get_args()
- super(ParallelTransformerLayer, self).__init__()
- self.layer_number = layer_number
- self.layer_type = layer_type
- self.apply_residual_connection_post_layernorm \
- = args.apply_residual_connection_post_layernorm
- self.bf16 = args.bf16
- self.fp32_residual_connection = args.fp32_residual_connection
- # Layernorm on the input data.
- self.input_layernorm = LayerNorm(
- args.hidden_size,
- eps=args.layernorm_epsilon)
- # Self attention.
- self.self_attention = ParallelAttention(
- init_method,
- output_layer_init_method,
- layer_number,
- attention_type=AttnType.self_attn,
- attn_mask_type=self_attn_mask_type)
- self.hidden_dropout = args.hidden_dropout
- self.bias_dropout_fusion = args.bias_dropout_fusion
- # Layernorm on the attention output
- self.post_attention_layernorm = LayerNorm(
- args.hidden_size,
- eps=args.layernorm_epsilon)
- if self.layer_type == LayerType.decoder:
- self.inter_attention = ParallelAttention(
- init_method,
- output_layer_init_method,
- layer_number,
- attention_type=AttnType.cross_attn)
- # Layernorm on the attention output.
- self.post_inter_attention_layernorm = LayerNorm(
- args.hidden_size,
- eps=args.layernorm_epsilon)
- # MLP
- self.mlp = ParallelMLP(init_method,
- output_layer_init_method)
- def forward(self, hidden_states, attention_mask,
- encoder_output=None, enc_dec_attn_mask=None,
- layer_past=None, get_key_value=False):
- # hidden_states: [b, s, h]
- # Layer norm at the beginning of the transformer layer.
- layernorm_output = self.input_layernorm(hidden_states)
- # Self attention.
- attention_output, attention_bias = \
- self.self_attention(layernorm_output,
- attention_mask,
- layer_past=layer_past,
- get_key_value=get_key_value)
- if get_key_value:
- attention_output, presents = attention_output
- # Residual connection.
- if self.apply_residual_connection_post_layernorm:
- residual = layernorm_output
- else:
- residual = hidden_states
- # jit scripting for a nn.module (with dropout) is not
- # trigerring the fusion kernel. For now, we use two
- # different nn.functional routines to account for varying
- # dropout semantics during training and inference phases.
- if self.bias_dropout_fusion:
- if self.training:
- bias_dropout_add_func = bias_dropout_add_fused_train
- else:
- bias_dropout_add_func = bias_dropout_add_fused_inference
- else:
- bias_dropout_add_func = get_bias_dropout_add(self.training)
- # re-enable torch grad to enable fused optimization.
- with torch.enable_grad():
- layernorm_input = bias_dropout_add_func(
- attention_output,
- attention_bias.expand_as(residual),
- residual,
- self.hidden_dropout)
- # Layer norm post the self attention.
- layernorm_output = self.post_attention_layernorm(layernorm_input)
- if self.layer_type == LayerType.decoder:
- attention_output, attention_bias = \
- self.inter_attention(layernorm_output,
- enc_dec_attn_mask,
- encoder_output=encoder_output)
- # residual connection
- if self.apply_residual_connection_post_layernorm:
- residual = layernorm_output
- else:
- residual = layernorm_input
- # re-enable torch grad to enable fused optimization.
- with torch.enable_grad():
- layernorm_input = bias_dropout_add_func(
- attention_output,
- attention_bias.expand_as(residual),
- residual,
- self.hidden_dropout)
- # Layer norm post the decoder attention
- layernorm_output = self.post_inter_attention_layernorm(layernorm_input)
- # MLP.
- mlp_output, mlp_bias = self.mlp(layernorm_output)
- # Second residual connection.
- if self.apply_residual_connection_post_layernorm:
- residual = layernorm_output
- else:
- residual = layernorm_input
- # re-enable torch grad to enable fused optimization.
- with torch.enable_grad():
- output = bias_dropout_add_func(
- mlp_output,
- mlp_bias.expand_as(residual),
- residual,
- self.hidden_dropout)
- if get_key_value:
- output = [output, presents]
- return output
- class ParallelTransformer(MegatronModule):
- """Transformer class."""
- def __init__(self, init_method, output_layer_init_method,
- layer_type=LayerType.encoder,
- self_attn_mask_type=AttnMaskType.padding,
- pre_process=True, post_process=True):
- super(ParallelTransformer, self).__init__()
- args = get_args()
- self.bf16 = args.bf16
- self.fp32_residual_connection = args.fp32_residual_connection
- self.pre_process = pre_process
- self.post_process = post_process
- self.input_tensor = None
- # Store activation checkpoiting flag.
- self.checkpoint_activations = args.checkpoint_activations
- self.checkpoint_num_layers = args.checkpoint_num_layers
- # Number of layers.
- assert args.num_layers % mpu.get_pipeline_model_parallel_world_size() == 0, \
- 'num_layers must be divisible by pipeline_model_parallel_size'
- self.num_layers = args.num_layers // mpu.get_pipeline_model_parallel_world_size()
- # Transformer layers.
- def build_layer(layer_number):
- return ParallelTransformerLayer(
- init_method,
- output_layer_init_method,
- layer_number,
- layer_type=layer_type,
- self_attn_mask_type=self_attn_mask_type)
- if args.virtual_pipeline_model_parallel_size is not None:
- assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
- 'num_layers_per_stage must be divisible by ' \
- 'virtual_pipeline_model_parallel_size'
- # Number of layers in each model chunk is the number of layers in the stage,
- # divided by the number of model chunks in a stage.
- self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
- # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
- # layers to stages like (each list is a model chunk):
- # Stage 0: [0] [2] [4] [6]
- # Stage 1: [1] [3] [5] [7]
- # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
- # layers to stages like (each list is a model chunk):
- # Stage 0: [0, 1] [4, 5]
- # Stage 1: [2, 3] [6, 7]
- offset = mpu.get_virtual_pipeline_model_parallel_rank() * (
- args.num_layers // args.virtual_pipeline_model_parallel_size) + \
- (mpu.get_pipeline_model_parallel_rank() * self.num_layers)
- else:
- # Each stage gets a contiguous set of layers.
- offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers
- self.layers = torch.nn.ModuleList(
- [build_layer(i + 1 + offset) for i in range(self.num_layers)])
- if self.post_process:
- # Final layer norm before output.
- self.final_layernorm = LayerNorm(
- args.hidden_size,
- eps=args.layernorm_epsilon)
- def _get_layer(self, layer_number):
- return self.layers[layer_number]
- def _checkpointed_forward(self, hidden_states, attention_mask,
- encoder_output, enc_dec_attn_mask):
- """Forward method with activation checkpointing."""
- def custom(start, end):
- def custom_forward(*inputs):
- x_ = inputs[0]
- attention_mask = inputs[1]
- encoder_output = inputs[2]
- enc_dec_attn_mask = inputs[3]
- for index in range(start, end):
- layer = self._get_layer(index)
- x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
- return x_
- return custom_forward
- # Make sure memory is freed.
- mpu.reset_checkpointed_activations_memory_buffer()
- l = 0
- while l < self.num_layers:
- hidden_states = mpu.checkpoint(
- custom(l, l + self.checkpoint_num_layers),
- hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
- l += self.checkpoint_num_layers
- return hidden_states
- def set_input_tensor(self, input_tensor):
- """Set input tensor to be used instead of forward()'s input.
- When doing pipeline parallelism the input from the previous
- stage comes from communication, not from the input, so the
- model's forward_step_func won't have it. This function is thus
- used by internal code to bypass the input provided by the
- forward_step_func"""
- self.input_tensor = input_tensor
- def forward(self, hidden_states, attention_mask, layer_past=None,
- get_key_value=False, encoder_output=None, enc_dec_attn_mask=None):
- # Checks.
- if layer_past is not None:
- assert get_key_value, \
- 'for not None values in layer_past, ' \
- 'expected get_key_value to be set'
- if get_key_value:
- assert not self.checkpoint_activations, \
- 'get_key_value does not work with ' \
- 'activation checkpointing'
- if self.pre_process:
- # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
- # If the input flag for fp32 residual connection is set, convert for float.
- if self.fp32_residual_connection:
- hidden_states = hidden_states.transpose(0, 1).contiguous().float()
- # Otherwise, leave it as is.
- else:
- hidden_states = hidden_states.transpose(0, 1).contiguous()
- else:
- # See set_input_tensor()
- hidden_states = self.input_tensor
- if encoder_output is not None:
- encoder_output = encoder_output.transpose(0, 1).contiguous()
- if self.checkpoint_activations:
- hidden_states = self._checkpointed_forward(hidden_states,
- attention_mask,
- encoder_output,
- enc_dec_attn_mask)
- else:
- if get_key_value:
- presents = []
- for index in range(self.num_layers):
- layer = self._get_layer(index)
- past = None
- if layer_past is not None:
- past = layer_past[index]
- hidden_states = layer(hidden_states,
- attention_mask,
- encoder_output=encoder_output,
- enc_dec_attn_mask=enc_dec_attn_mask,
- layer_past=past,
- get_key_value=get_key_value)
- if get_key_value:
- hidden_states, present = hidden_states
- presents.append(present)
- # Final layer norm.
- if self.post_process:
- # Reverting data format change [s b h] --> [b s h].
- hidden_states = hidden_states.transpose(0, 1).contiguous()
- output = self.final_layernorm(hidden_states)
- else:
- output = hidden_states
- if get_key_value:
- output = [output, presents]
- return output
|