|
@@ -182,8 +182,8 @@ class H2OLlamaAttention(nn.Module):
|
|
|
import pdb; pdb.set_trace()
|
|
|
if not position_ids.nelement() > 1:
|
|
|
# decoding stage
|
|
|
- query_position_ids = kv_seq_len - 1
|
|
|
key_position_ids = torch.arange(kv_seq_len, device=hidden_states.device)
|
|
|
+ query_position_ids = key_position_ids[:, -1]
|
|
|
else:
|
|
|
query_position_ids = position_ids
|
|
|
key_position_ids = position_ids
|