Pārlūkot izejas kodu

Update utils_llama.py

Allen 1 gadu atpakaļ
vecāks
revīzija
ab07cdcf6c
1 mainītis faili ar 2 papildinājumiem un 2 dzēšanām
  1. 2 2
      research/long-context-llama/H2O/utils_llama.py

+ 2 - 2
research/long-context-llama/H2O/utils_llama.py

@@ -182,8 +182,8 @@ class H2OLlamaAttention(nn.Module):
                 import pdb; pdb.set_trace()
             if not position_ids.nelement() > 1:
                 # decoding stage
-                key_position_ids = torch.arange(kv_seq_len, device=hidden_states.device)
-                query_position_ids = key_position_ids[:, -1]
+                key_position_ids = torch.arange(kv_seq_len, device=hidden_states.device).unsqueeze(0)
+                query_position_ids = key_position_ids[:, -1].unsqueeze(0)
             else:
                 query_position_ids = position_ids
                 key_position_ids = position_ids