|
@@ -177,16 +177,23 @@ class H2OLlamaAttention(nn.Module):
|
|
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
|
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
|
|
|
|
|
|
- import pdb; pdb.set_trace()
|
|
|
-
|
|
|
kv_seq_len = past_key_value.get_seq_length() if past_key_value is not None else key_states.shape[-2]
|
|
|
+
|
|
|
+ if not position_ids.nelement() > 1:
|
|
|
+ # decoding stage
|
|
|
+ query_position_ids = kv_seq_len - 1
|
|
|
+ key_position_ids = torch.arange(kv_seq_len, device=hidden_states.device)
|
|
|
+ else:
|
|
|
+ query_position_ids = position_ids
|
|
|
+ key_position_ids = position_ids
|
|
|
+
|
|
|
cos, sin = self.rotary_emb(value_states, kv_seq_len)
|
|
|
|
|
|
if self.layer_idx == 0:
|
|
|
- print(kv_seq_len, position_ids)
|
|
|
-
|
|
|
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
|
+ print(kv_seq_len, query_position_ids, key_position_ids)
|
|
|
|
|
|
+ query_states = apply_rotary_pos_emb_single(query_states, cos, sin, query_position_ids)
|
|
|
+ key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids)
|
|
|
|
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|