|
@@ -186,9 +186,6 @@ class H2OLlamaAttention(nn.Module):
|
|
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
|
|
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
|
|
self._init_rope()
|
|
self._init_rope()
|
|
|
|
|
|
- # self.past_key_value = HHCache()
|
|
|
|
- # pdb.set_trace()
|
|
|
|
-
|
|
|
|
def _init_rope(self):
|
|
def _init_rope(self):
|
|
if self.config.rope_scaling is None:
|
|
if self.config.rope_scaling is None:
|
|
self.rotary_emb = LlamaRotaryEmbedding(
|
|
self.rotary_emb = LlamaRotaryEmbedding(
|
|
@@ -254,9 +251,6 @@ class H2OLlamaAttention(nn.Module):
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
-
|
|
|
|
- import pdb; pdb.set_trace()
|
|
|
|
-
|
|
|
|
|
|
|
|
past_key_value = getattr(self, "past_key_value", past_key_value)
|
|
past_key_value = getattr(self, "past_key_value", past_key_value)
|
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
@@ -304,9 +298,175 @@ class H2OLlamaAttention(nn.Module):
|
|
return attn_output, attn_weights, past_key_value
|
|
return attn_output, attn_weights, past_key_value
|
|
|
|
|
|
|
|
|
|
|
|
+def enable_h2ocache_forward(
|
|
|
|
+ self,
|
|
|
|
+ input_ids: torch.LongTensor = None,
|
|
|
|
+ attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
+ position_ids: Optional[torch.LongTensor] = None,
|
|
|
|
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
|
|
+ inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
|
|
+ use_cache: Optional[bool] = None,
|
|
|
|
+ output_attentions: Optional[bool] = None,
|
|
|
|
+ output_hidden_states: Optional[bool] = None,
|
|
|
|
+ return_dict: Optional[bool] = None,
|
|
|
|
+ cache_position: Optional[torch.LongTensor] = None,
|
|
|
|
+) -> Union[Tuple, BaseModelOutputWithPast]:
|
|
|
|
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
|
|
+ output_hidden_states = (
|
|
|
|
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
|
+ )
|
|
|
|
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
|
|
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
+
|
|
|
|
+ if (input_ids is None) ^ (inputs_embeds is not None):
|
|
|
|
+ raise ValueError(
|
|
|
|
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ if self.gradient_checkpointing and self.training and use_cache:
|
|
|
|
+ logger.warning_once(
|
|
|
|
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
|
|
|
+ )
|
|
|
|
+ use_cache = False
|
|
|
|
+
|
|
|
|
+ if inputs_embeds is None:
|
|
|
|
+ inputs_embeds = self.embed_tokens(input_ids)
|
|
|
|
+
|
|
|
|
+ past_seen_tokens = 0
|
|
|
|
+ if use_cache: # kept for BC (cache positions)
|
|
|
|
+ if not isinstance(past_key_values, StaticCache):
|
|
|
|
+ past_key_values = HHCache.from_legacy_cache(past_key_values)
|
|
|
|
+ past_seen_tokens = past_key_values.get_seq_length()
|
|
|
|
+
|
|
|
|
+ if cache_position is None:
|
|
|
|
+ if isinstance(past_key_values, StaticCache):
|
|
|
|
+ raise ValueError("cache_position is a required argument when using StaticCache.")
|
|
|
|
+ cache_position = torch.arange(
|
|
|
|
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ if position_ids is None:
|
|
|
|
+ position_ids = cache_position.unsqueeze(0)
|
|
|
|
+
|
|
|
|
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
|
|
|
+
|
|
|
|
+ # embed positions
|
|
|
|
+ hidden_states = inputs_embeds
|
|
|
|
+
|
|
|
|
+ # decoder layers
|
|
|
|
+ all_hidden_states = () if output_hidden_states else None
|
|
|
|
+ all_self_attns = () if output_attentions else None
|
|
|
|
+ next_decoder_cache = None
|
|
|
|
+
|
|
|
|
+ for decoder_layer in self.layers:
|
|
|
|
+ if output_hidden_states:
|
|
|
|
+ all_hidden_states += (hidden_states,)
|
|
|
|
+
|
|
|
|
+ if self.gradient_checkpointing and self.training:
|
|
|
|
+ layer_outputs = self._gradient_checkpointing_func(
|
|
|
|
+ decoder_layer.__call__,
|
|
|
|
+ hidden_states,
|
|
|
|
+ causal_mask,
|
|
|
|
+ position_ids,
|
|
|
|
+ past_key_values,
|
|
|
|
+ output_attentions,
|
|
|
|
+ use_cache,
|
|
|
|
+ cache_position,
|
|
|
|
+ )
|
|
|
|
+ else:
|
|
|
|
+ layer_outputs = decoder_layer(
|
|
|
|
+ hidden_states,
|
|
|
|
+ attention_mask=causal_mask,
|
|
|
|
+ position_ids=position_ids,
|
|
|
|
+ past_key_value=past_key_values,
|
|
|
|
+ output_attentions=output_attentions,
|
|
|
|
+ use_cache=use_cache,
|
|
|
|
+ cache_position=cache_position,
|
|
|
|
+ )
|
|
|
|
|
|
|
|
+ hidden_states = layer_outputs[0]
|
|
|
|
+
|
|
|
|
+ if use_cache:
|
|
|
|
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
|
|
|
+
|
|
|
|
+ if output_attentions:
|
|
|
|
+ all_self_attns += (layer_outputs[1],)
|
|
|
|
+
|
|
|
|
+ hidden_states = self.norm(hidden_states)
|
|
|
|
+
|
|
|
|
+ # add hidden states from the last decoder layer
|
|
|
|
+ if output_hidden_states:
|
|
|
|
+ all_hidden_states += (hidden_states,)
|
|
|
|
+
|
|
|
|
+ next_cache = None
|
|
|
|
+ if use_cache:
|
|
|
|
+ next_cache = (
|
|
|
|
+ next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
|
|
|
|
+ )
|
|
|
|
+ if not return_dict:
|
|
|
|
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
|
|
|
+ return BaseModelOutputWithPast(
|
|
|
|
+ last_hidden_state=hidden_states,
|
|
|
|
+ past_key_values=next_cache,
|
|
|
|
+ hidden_states=all_hidden_states,
|
|
|
|
+ attentions=all_self_attns,
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
|
|
|
+# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
|
|
|
+# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
|
|
|
+# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
|
|
|
+def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
|
|
|
|
+ if self.config._attn_implementation == "flash_attention_2":
|
|
|
|
+ if attention_mask is not None and 0.0 in attention_mask:
|
|
|
|
+ return attention_mask
|
|
|
|
+ return None
|
|
|
|
+
|
|
|
|
+ dtype, device = input_tensor.dtype, input_tensor.device
|
|
|
|
+ min_dtype = torch.finfo(dtype).min
|
|
|
|
+ sequence_length = input_tensor.shape[1]
|
|
|
|
+ if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache
|
|
|
|
+ target_length = self.config.max_position_embeddings
|
|
|
|
+ else: # dynamic cache
|
|
|
|
+ target_length = (
|
|
|
|
+ attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
|
|
|
|
+ if sequence_length != 1:
|
|
|
|
+ causal_mask = torch.triu(causal_mask, diagonal=1)
|
|
|
|
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
|
|
|
+ causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
|
|
|
|
+ if attention_mask is not None:
|
|
|
|
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
|
|
|
+ if attention_mask.dim() == 2:
|
|
|
|
+ mask_length = attention_mask.shape[-1]
|
|
|
|
+ padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
|
|
|
+ causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
|
|
|
|
+ elif attention_mask.dim() == 4:
|
|
|
|
+ # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
|
|
|
|
+ # cache. In that case, the 4D attention mask attends to the newest tokens only.
|
|
|
|
+ if attention_mask.shape[-2] < cache_position[0] + sequence_length:
|
|
|
|
+ offset = cache_position[0]
|
|
|
|
+ else:
|
|
|
|
+ offset = 0
|
|
|
|
+ mask_shape = attention_mask.shape
|
|
|
|
+ mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
|
|
|
|
+ causal_mask[
|
|
|
|
+ : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
|
|
|
|
+ ] = mask_slice
|
|
|
|
+
|
|
|
|
+ if (
|
|
|
|
+ self.config._attn_implementation == "sdpa"
|
|
|
|
+ and attention_mask is not None
|
|
|
|
+ and attention_mask.device.type == "cuda"
|
|
|
|
+ ):
|
|
|
|
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
|
|
|
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
|
|
|
+ # Details: https://github.com/pytorch/pytorch/issues/110213
|
|
|
|
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
|
|
|
|
|
-
|
|
|
|
|
|
+ return causal_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -545,3 +705,5 @@ class H2OLlamaForCausalLM(LlamaForCausalLM):
|
|
num_layers = len(self.model.layers)
|
|
num_layers = len(self.model.layers)
|
|
for layer_idx in range(num_layers):
|
|
for layer_idx in range(num_layers):
|
|
self.model.layers[layer_idx].self_attn = H2OLlamaAttention(config, layer_idx)
|
|
self.model.layers[layer_idx].self_attn = H2OLlamaAttention(config, layer_idx)
|
|
|
|
+
|
|
|
|
+ self.model.forward = types.MethodType(enable_h2ocache_forward, self.model)
|