|
@@ -187,7 +187,7 @@ class H2OLlamaAttention(nn.Module):
|
|
|
query_position_ids = position_ids
|
|
|
key_position_ids = position_ids
|
|
|
|
|
|
- cos, sin = self.rotary_emb(value_states, kv_seq_len)
|
|
|
+ cos, sin = self.rotary_emb(value_states, key_position_ids)
|
|
|
|
|
|
if self.layer_idx == 0:
|
|
|
print(kv_seq_len, query_position_ids, key_position_ids)
|