|
@@ -178,7 +178,8 @@ class H2OLlamaAttention(nn.Module):
|
|
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
|
|
|
|
|
|
kv_seq_len = past_key_value.get_seq_length() if past_key_value is not None else key_states.shape[-2]
|
|
|
-
|
|
|
+ if self.layer_idx == 0:
|
|
|
+ import pdb; pdb.set_trace()
|
|
|
if not position_ids.nelement() > 1:
|
|
|
# decoding stage
|
|
|
query_position_ids = kv_seq_len - 1
|