|
@@ -401,8 +401,6 @@ def enable_h2ocache_forward(
|
|
|
all_self_attns += (layer_outputs[1],)
|
|
|
|
|
|
hidden_states = self.norm(hidden_states)
|
|
|
-
|
|
|
- import pdb; pdb.set_trace()
|
|
|
|
|
|
# add hidden states from the last decoder layer
|
|
|
if output_hidden_states:
|
|
@@ -422,70 +420,6 @@ def enable_h2ocache_forward(
|
|
|
attentions=all_self_attns,
|
|
|
)
|
|
|
|
|
|
-# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
|
|
-# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
|
|
-# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
|
|
-# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
|
|
-def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
|
|
|
- if self.config._attn_implementation == "flash_attention_2":
|
|
|
- if attention_mask is not None and 0.0 in attention_mask:
|
|
|
- return attention_mask
|
|
|
- return None
|
|
|
-
|
|
|
- dtype, device = input_tensor.dtype, input_tensor.device
|
|
|
- min_dtype = torch.finfo(dtype).min
|
|
|
- sequence_length = input_tensor.shape[1]
|
|
|
- if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache
|
|
|
- target_length = self.config.max_position_embeddings
|
|
|
- else: # dynamic cache
|
|
|
- target_length = (
|
|
|
- attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1
|
|
|
- )
|
|
|
-
|
|
|
- causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
|
|
- if sequence_length != 1:
|
|
|
- causal_mask = torch.triu(causal_mask, diagonal=1)
|
|
|
- causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
|
|
- causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
|
|
|
- if attention_mask is not None:
|
|
|
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
|
|
- if attention_mask.dim() == 2:
|
|
|
- mask_length = attention_mask.shape[-1]
|
|
|
- padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
|
|
- causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
|
|
|
- elif attention_mask.dim() == 4:
|
|
|
- # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
|
|
|
- # cache. In that case, the 4D attention mask attends to the newest tokens only.
|
|
|
- if attention_mask.shape[-2] < cache_position[0] + sequence_length:
|
|
|
- offset = cache_position[0]
|
|
|
- else:
|
|
|
- offset = 0
|
|
|
- mask_shape = attention_mask.shape
|
|
|
- mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
|
|
|
- causal_mask[
|
|
|
- : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
|
|
|
- ] = mask_slice
|
|
|
-
|
|
|
- if (
|
|
|
- self.config._attn_implementation == "sdpa"
|
|
|
- and attention_mask is not None
|
|
|
- and attention_mask.device.type == "cuda"
|
|
|
- ):
|
|
|
- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
|
|
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
|
|
- # Details: https://github.com/pytorch/pytorch/issues/110213
|
|
|
- causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
|
|
-
|
|
|
- return causal_mask
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
|
|
|
|
|
|
|