|
@@ -228,7 +228,7 @@ class H2OLlamaAttention(nn.Module):
|
|
bsz, q_len, _ = hidden_states.size()
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
|
|
|
if self.layer_idx == 0:
|
|
if self.layer_idx == 0:
|
|
- import pdb; pdb.set_trace()
|
|
|
|
|
|
+ import pdb;pdb.set_trace()
|
|
|
|
|
|
if self.config.pretraining_tp > 1:
|
|
if self.config.pretraining_tp > 1:
|
|
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
|
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
|
@@ -369,6 +369,9 @@ def enable_h2ocache_forward(
|
|
all_self_attns = () if output_attentions else None
|
|
all_self_attns = () if output_attentions else None
|
|
next_decoder_cache = None
|
|
next_decoder_cache = None
|
|
|
|
|
|
|
|
+ import pdb;pdb.set_trace()
|
|
|
|
+
|
|
|
|
+
|
|
for decoder_layer in self.layers:
|
|
for decoder_layer in self.layers:
|
|
if output_hidden_states:
|
|
if output_hidden_states:
|
|
all_hidden_states += (hidden_states,)
|
|
all_hidden_states += (hidden_states,)
|