فهرست منبع

Update utils_llama.py

Allen 1 سال پیش
والد
کامیت
07efbea605
1فایلهای تغییر یافته به همراه3 افزوده شده و 3 حذف شده
  1. 3 3
      research/long-context-llama/H2O/utils_llama.py

+ 3 - 3
research/long-context-llama/H2O/utils_llama.py

@@ -18,7 +18,7 @@ from transformers.models.llama.modeling_llama import (
     apply_rotary_pos_emb,
     LlamaForCausalLM,
 )
-from cache_utils import Cache
+from cache_utils import Cache, HHCache
 from transformers.utils import logging
 
 logger = logging.get_logger(__name__)
@@ -186,6 +186,8 @@ class H2OLlamaAttention(nn.Module):
         self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
         self._init_rope()
 
+        self.past_key_value = HHCache()
+
     def _init_rope(self):
         if self.config.rope_scaling is None:
             self.rotary_emb = LlamaRotaryEmbedding(
@@ -252,8 +254,6 @@ class H2OLlamaAttention(nn.Module):
         key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
         value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
 
-        print(self.past_key_value)
-
         past_key_value = getattr(self, "past_key_value", past_key_value)
         cos, sin = self.rotary_emb(value_states, position_ids)
         query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)