|
@@ -209,7 +209,7 @@ class H2OLlamaAttention(nn.Module):
|
|
|
|
|
|
# Update KV Cache based on Heavy-Hitter Oracle
|
|
# Update KV Cache based on Heavy-Hitter Oracle
|
|
if past_key_value is not None:
|
|
if past_key_value is not None:
|
|
- past_key_value.update_slimming(attn_weights, self.num_key_value_groups, self.layer_idx, cache_kwargs)
|
|
|
|
|
|
+ past_key_value.update_slimming(attn_weights, self.num_key_value_groups, self.layer_idx)
|
|
|
|
|
|
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
|
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
|
attn_output = torch.matmul(attn_weights, value_states)
|
|
attn_output = torch.matmul(attn_weights, value_states)
|