소스 검색

Update utils_llama.py

Allen 1 년 전
부모
커밋
ab07cdcf6c
1개의 변경된 파일2개의 추가작업 그리고 2개의 파일을 삭제
  1. 2 2
      research/long-context-llama/H2O/utils_llama.py

+ 2 - 2
research/long-context-llama/H2O/utils_llama.py

@@ -182,8 +182,8 @@ class H2OLlamaAttention(nn.Module):
                 import pdb; pdb.set_trace()
             if not position_ids.nelement() > 1:
                 # decoding stage
-                key_position_ids = torch.arange(kv_seq_len, device=hidden_states.device)
-                query_position_ids = key_position_ids[:, -1]
+                key_position_ids = torch.arange(kv_seq_len, device=hidden_states.device).unsqueeze(0)
+                query_position_ids = key_position_ids[:, -1].unsqueeze(0)
             else:
                 query_position_ids = position_ids
                 key_position_ids = position_ids