|
@@ -18,7 +18,7 @@ from transformers.models.llama.modeling_llama import (
|
|
apply_rotary_pos_emb,
|
|
apply_rotary_pos_emb,
|
|
LlamaForCausalLM,
|
|
LlamaForCausalLM,
|
|
)
|
|
)
|
|
-from cache_utils import Cache
|
|
|
|
|
|
+from cache_utils import Cache, HHCache
|
|
from transformers.utils import logging
|
|
from transformers.utils import logging
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
logger = logging.get_logger(__name__)
|
|
@@ -186,6 +186,8 @@ class H2OLlamaAttention(nn.Module):
|
|
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
|
|
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
|
|
self._init_rope()
|
|
self._init_rope()
|
|
|
|
|
|
|
|
+ self.past_key_value = HHCache()
|
|
|
|
+
|
|
def _init_rope(self):
|
|
def _init_rope(self):
|
|
if self.config.rope_scaling is None:
|
|
if self.config.rope_scaling is None:
|
|
self.rotary_emb = LlamaRotaryEmbedding(
|
|
self.rotary_emb = LlamaRotaryEmbedding(
|
|
@@ -252,8 +254,6 @@ class H2OLlamaAttention(nn.Module):
|
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
|
|
- print(self.past_key_value)
|
|
|
|
-
|
|
|
|
past_key_value = getattr(self, "past_key_value", past_key_value)
|
|
past_key_value = getattr(self, "past_key_value", past_key_value)
|
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|