|
@@ -278,8 +278,6 @@ class H2OLlamaAttention(nn.Module):
|
|
if past_key_value is not None:
|
|
if past_key_value is not None:
|
|
past_key_value.update_slimming(attn_weights, self.num_key_value_groups, self.layer_idx, cache_kwargs)
|
|
past_key_value.update_slimming(attn_weights, self.num_key_value_groups, self.layer_idx, cache_kwargs)
|
|
|
|
|
|
-
|
|
|
|
-
|
|
|
|
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
|
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
|
attn_output = torch.matmul(attn_weights, value_states)
|
|
attn_output = torch.matmul(attn_weights, value_states)
|
|
|
|
|
|
@@ -303,6 +301,8 @@ class H2OLlamaAttention(nn.Module):
|
|
if not output_attentions:
|
|
if not output_attentions:
|
|
attn_weights = None
|
|
attn_weights = None
|
|
|
|
|
|
|
|
+ pdb.set_trace()
|
|
|
|
+
|
|
return attn_output, attn_weights, past_key_value
|
|
return attn_output, attn_weights, past_key_value
|
|
|
|
|
|
|
|
|