Allen пре 1 година
родитељ
комит
b33b68c3f7
1 измењених фајлова са 8 додато и 10 уклоњено
  1. 8 10
      research/long-context-llama/H2O/utils_llama.py

+ 8 - 10
research/long-context-llama/H2O/utils_llama.py

@@ -39,13 +39,12 @@ def _make_causal_mask(
         mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
     return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
 
-def apply_rotary_pos_emb_single(x, cos, sin, position_ids):
-    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
-    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
-    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
-    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
-    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
+def apply_rotary_pos_emb_single(x, cos, sin, position_ids=None, unsqueeze_dim=1):
+
+    cos = cos.unsqueeze(unsqueeze_dim)
+    sin = sin.unsqueeze(unsqueeze_dim)
     x_embed = (x * cos) + (rotate_half(x) * sin)
+
     return x_embed
 
 def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
@@ -178,8 +177,7 @@ class H2OLlamaAttention(nn.Module):
                 key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
 
             kv_seq_len = past_key_value.get_seq_length() if past_key_value is not None else key_states.shape[-2]
-            if self.layer_idx == 0:
-                import pdb; pdb.set_trace()
+
             if not position_ids.nelement() > 1:
                 # decoding stage
                 key_position_ids = torch.arange(kv_seq_len, device=hidden_states.device).unsqueeze(0)
@@ -194,8 +192,8 @@ class H2OLlamaAttention(nn.Module):
             if self.layer_idx == 0:
                 print(kv_seq_len, query_position_ids, key_position_ids)
 
-            query_states = apply_rotary_pos_emb_single(query_states, query_cos, query_sin, query_position_ids)
-            key_states = apply_rotary_pos_emb_single(key_states, key_cos, key_sin, key_position_ids)
+            query_states = apply_rotary_pos_emb_single(query_states, query_cos, query_sin)
+            key_states = apply_rotary_pos_emb_single(key_states, key_cos, key_sin)
 
         key_states = repeat_kv(key_states, self.num_key_value_groups)
         value_states = repeat_kv(value_states, self.num_key_value_groups)