|
@@ -302,7 +302,7 @@ class H2OLlamaAttention(nn.Module):
|
|
attn_weights = None
|
|
attn_weights = None
|
|
|
|
|
|
if self.layer_idx == 0:
|
|
if self.layer_idx == 0:
|
|
- print(past_key_value.key_cache[0].shape, past_key_value.value_cache[0].shape, past_key_value.accumulated_attention_scores[0][0,0,0].item())
|
|
|
|
|
|
+ print(past_key_value.key_cache[0].shape, hidden_states.shape, past_key_value.value_cache[0].shape, past_key_value.accumulated_attention_scores[0][0,0,0].item())
|
|
|
|
|
|
return attn_output, attn_weights, past_key_value
|
|
return attn_output, attn_weights, past_key_value
|
|
|
|
|