@@ -177,7 +177,7 @@ class H2OLlamaAttention(nn.Module):
# sin and cos are specific to RoPE models; cache_position needed for the static cache
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
- kv_seq_len = past_key_value.get_seq_length() if past_key_value is not None else key_states.shape[-2]
+ kv_seq_len = past_key_value.get_seq_length(self.layer_idx) if past_key_value is not None else key_states.shape[-2]
if not position_ids.nelement() > 1:
# decoding stage