Allen преди 1 година
родител
ревизия
008238b0e9
променени са 1 файла, в които са добавени 12 реда и са изтрити 5 реда
  1. 12 5
      research/long-context-llama/H2O/utils_llama.py

+ 12 - 5
research/long-context-llama/H2O/utils_llama.py

@@ -177,16 +177,23 @@ class H2OLlamaAttention(nn.Module):
                 # sin and cos are specific to RoPE models; cache_position needed for the static cache
                 key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
 
-            import pdb; pdb.set_trace()
-
             kv_seq_len = past_key_value.get_seq_length() if past_key_value is not None else key_states.shape[-2]
+
+            if not position_ids.nelement() > 1:
+                # decoding stage
+                query_position_ids = kv_seq_len - 1
+                key_position_ids = torch.arange(kv_seq_len, device=hidden_states.device)
+            else:
+                query_position_ids = position_ids
+                key_position_ids = position_ids
+
             cos, sin = self.rotary_emb(value_states, kv_seq_len)
 
             if self.layer_idx == 0:
-                print(kv_seq_len, position_ids)
-
-            query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+                print(kv_seq_len, query_position_ids, key_position_ids)
 
+            query_states = apply_rotary_pos_emb_single(query_states, cos, sin, query_position_ids)
+            key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids)
 
         key_states = repeat_kv(key_states, self.num_key_value_groups)
         value_states = repeat_kv(value_states, self.num_key_value_groups)