|
@@ -227,6 +227,9 @@ class H2OLlamaAttention(nn.Module):
|
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
|
|
|
+ if self.layer_idx == 0:
|
|
|
+ import pdb; pdb.set_trace()
|
|
|
+
|
|
|
if self.config.pretraining_tp > 1:
|
|
|
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
|
|
query_slices = self.q_proj.weight.split(
|