|
@@ -15,7 +15,6 @@ from transformers.models.llama.modeling_llama import (
|
|
|
apply_rotary_pos_emb,
|
|
|
repeat_kv,
|
|
|
LlamaRotaryEmbedding,
|
|
|
- apply_rotary_pos_emb,
|
|
|
LlamaForCausalLM,
|
|
|
)
|
|
|
from cache_utils import Cache, HHCache, StaticCache
|
|
@@ -84,6 +83,7 @@ class H2OLlamaAttention(nn.Module):
|
|
|
self.max_position_embeddings = config.max_position_embeddings
|
|
|
self.rope_theta = config.rope_theta
|
|
|
self.is_causal = True
|
|
|
+ self.positional_rolling = True
|
|
|
|
|
|
if (self.head_dim * self.num_heads) != self.hidden_size:
|
|
|
raise ValueError(
|
|
@@ -164,19 +164,28 @@ class H2OLlamaAttention(nn.Module):
|
|
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
|
|
past_key_value = getattr(self, "past_key_value", past_key_value)
|
|
|
- cos, sin = self.rotary_emb(value_states, position_ids)
|
|
|
|
|
|
- if self.layer_idx == 0:
|
|
|
- print(position_ids)
|
|
|
+ if not self.positional_rolling:
|
|
|
+ cos, sin = self.rotary_emb(value_states, position_ids)
|
|
|
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
|
+ if past_key_value is not None:
|
|
|
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
|
|
|
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
|
|
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
|
+ else:
|
|
|
+ if past_key_value is not None:
|
|
|
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
|
|
|
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
|
|
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
|
|
|
|
+ kv_seq_len = past_key_value.get_seq_length() if past_key_value is not None else key_states.shape[-2]
|
|
|
+ cos, sin = self.rotary_emb(value_states, kv_seq_len)
|
|
|
|
|
|
+ if self.layer_idx == 0:
|
|
|
+ print(kv_seq_len, position_ids)
|
|
|
|
|
|
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
|
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
|
|
|
|
- if past_key_value is not None:
|
|
|
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
|
|
|
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
|
|
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
|
|
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|