|
@@ -254,6 +254,9 @@ class H2OLlamaAttention(nn.Module):
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
+
|
|
|
|
+ import pdb; pdb.set_trace()
|
|
|
|
+
|
|
|
|
|
|
past_key_value = getattr(self, "past_key_value", past_key_value)
|
|
past_key_value = getattr(self, "past_key_value", past_key_value)
|
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
cos, sin = self.rotary_emb(value_states, position_ids)
|