|
@@ -193,8 +193,8 @@ class H2OLlamaAttention(nn.Module):
|
|
if self.layer_idx == 0:
|
|
if self.layer_idx == 0:
|
|
print(kv_seq_len, query_position_ids, key_position_ids)
|
|
print(kv_seq_len, query_position_ids, key_position_ids)
|
|
|
|
|
|
- query_states = apply_rotary_pos_emb_single(query_states, cos, sin, query_position_ids)
|
|
|
|
- key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids)
|
|
|
|
|
|
+ query_states = apply_rotary_pos_emb_single(query_states, query_cos, query_sin, query_position_ids)
|
|
|
|
+ key_states = apply_rotary_pos_emb_single(key_states, key_cos, key_sin, key_position_ids)
|
|
|
|
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|