Forráskód Böngészése

Update utils_llama.py

Allen 1 éve
szülő
commit
0dc84a4653
1 módosított fájl, 18 hozzáadás és 9 törlés
  1. 18 9
      research/long-context-llama/H2O/utils_llama.py

+ 18 - 9
research/long-context-llama/H2O/utils_llama.py

@@ -15,7 +15,6 @@ from transformers.models.llama.modeling_llama import (
     apply_rotary_pos_emb,
     repeat_kv,
     LlamaRotaryEmbedding,
-    apply_rotary_pos_emb,
     LlamaForCausalLM,
 )
 from cache_utils import Cache, HHCache, StaticCache
@@ -84,6 +83,7 @@ class H2OLlamaAttention(nn.Module):
         self.max_position_embeddings = config.max_position_embeddings
         self.rope_theta = config.rope_theta
         self.is_causal = True
+        self.positional_rolling = True
 
         if (self.head_dim * self.num_heads) != self.hidden_size:
             raise ValueError(
@@ -164,19 +164,28 @@ class H2OLlamaAttention(nn.Module):
         value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
 
         past_key_value = getattr(self, "past_key_value", past_key_value)
-        cos, sin = self.rotary_emb(value_states, position_ids)
 
-        if self.layer_idx == 0:
-            print(position_ids)
+        if not self.positional_rolling:
+            cos, sin = self.rotary_emb(value_states, position_ids)
+            query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+            if past_key_value is not None:
+                # sin and cos are specific to RoPE models; cache_position needed for the static cache
+                cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+                key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+        else:
+            if past_key_value is not None:
+                # sin and cos are specific to RoPE models; cache_position needed for the static cache
+                cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+                key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
 
+            kv_seq_len = past_key_value.get_seq_length() if past_key_value is not None else key_states.shape[-2]
+            cos, sin = self.rotary_emb(value_states, kv_seq_len)
 
+            if self.layer_idx == 0:
+                print(kv_seq_len, position_ids)
 
-        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+            query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
 
-        if past_key_value is not None:
-            # sin and cos are specific to RoPE models; cache_position needed for the static cache
-            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
-            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
 
         key_states = repeat_kv(key_states, self.num_key_value_groups)
         value_states = repeat_kv(value_states, self.num_key_value_groups)