|
@@ -39,13 +39,12 @@ def _make_causal_mask(
|
|
|
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
|
|
|
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
|
|
|
|
|
-def apply_rotary_pos_emb_single(x, cos, sin, position_ids):
|
|
|
- # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
|
|
- cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
|
|
- sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
|
|
- cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
|
|
- sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
|
|
+def apply_rotary_pos_emb_single(x, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
|
+
|
|
|
+ cos = cos.unsqueeze(unsqueeze_dim)
|
|
|
+ sin = sin.unsqueeze(unsqueeze_dim)
|
|
|
x_embed = (x * cos) + (rotate_half(x) * sin)
|
|
|
+
|
|
|
return x_embed
|
|
|
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
@@ -178,8 +177,7 @@ class H2OLlamaAttention(nn.Module):
|
|
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
|
|
|
|
|
|
kv_seq_len = past_key_value.get_seq_length() if past_key_value is not None else key_states.shape[-2]
|
|
|
- if self.layer_idx == 0:
|
|
|
- import pdb; pdb.set_trace()
|
|
|
+
|
|
|
if not position_ids.nelement() > 1:
|
|
|
# decoding stage
|
|
|
key_position_ids = torch.arange(kv_seq_len, device=hidden_states.device).unsqueeze(0)
|
|
@@ -194,8 +192,8 @@ class H2OLlamaAttention(nn.Module):
|
|
|
if self.layer_idx == 0:
|
|
|
print(kv_seq_len, query_position_ids, key_position_ids)
|
|
|
|
|
|
- query_states = apply_rotary_pos_emb_single(query_states, query_cos, query_sin, query_position_ids)
|
|
|
- key_states = apply_rotary_pos_emb_single(key_states, key_cos, key_sin, key_position_ids)
|
|
|
+ query_states = apply_rotary_pos_emb_single(query_states, query_cos, query_sin)
|
|
|
+ key_states = apply_rotary_pos_emb_single(key_states, key_cos, key_sin)
|
|
|
|
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|