Parcourir la source

Update utils_llama.py

Allen il y a 1 an
Parent
commit
84ddf520a7
1 fichiers modifiés avec 1 ajouts et 1 suppressions
  1. 1 1
      research/long-context-llama/H2O/utils_llama.py

+ 1 - 1
research/long-context-llama/H2O/utils_llama.py

@@ -182,8 +182,8 @@ class H2OLlamaAttention(nn.Module):
                 import pdb; pdb.set_trace()
             if not position_ids.nelement() > 1:
                 # decoding stage
-                query_position_ids = kv_seq_len - 1
                 key_position_ids = torch.arange(kv_seq_len, device=hidden_states.device)
+                query_position_ids = key_position_ids[:, -1]
             else:
                 query_position_ids = position_ids
                 key_position_ids = position_ids