Selaa lähdekoodia

Update utils_llama.py

Allen 1 vuosi sitten
vanhempi
commit
ab07cdcf6c
1 muutettua tiedostoa jossa 2 lisäystä ja 2 poistoa
  1. 2 2
      research/long-context-llama/H2O/utils_llama.py

+ 2 - 2
research/long-context-llama/H2O/utils_llama.py

@@ -182,8 +182,8 @@ class H2OLlamaAttention(nn.Module):
                 import pdb; pdb.set_trace()
             if not position_ids.nelement() > 1:
                 # decoding stage
-                key_position_ids = torch.arange(kv_seq_len, device=hidden_states.device)
-                query_position_ids = key_position_ids[:, -1]
+                key_position_ids = torch.arange(kv_seq_len, device=hidden_states.device).unsqueeze(0)
+                query_position_ids = key_position_ids[:, -1].unsqueeze(0)
             else:
                 query_position_ids = position_ids
                 key_position_ids = position_ids