|
@@ -49,96 +49,6 @@ def apply_rotary_pos_emb_single(x, cos, sin, position_ids):
|
|
|
x_embed = (x * cos) + (rotate_half(x) * sin)
|
|
|
return x_embed
|
|
|
|
|
|
-class H2OKVCache_LayerWise:
|
|
|
- def __init__(
|
|
|
- self,
|
|
|
- hh_size=4,
|
|
|
- recent_size=512,
|
|
|
- k_seq_dim=2,
|
|
|
- v_seq_dim=2,
|
|
|
- ):
|
|
|
- self.hh_size = hh_size
|
|
|
- self.recent_size = recent_size
|
|
|
- self.cache_size = hh_size + recent_size
|
|
|
- self.k_seq_dim = k_seq_dim
|
|
|
- self.v_seq_dim = v_seq_dim
|
|
|
- self.hh_score = None
|
|
|
-
|
|
|
- def __call__(self, past_key_values, attn_score_cache):
|
|
|
-
|
|
|
- self._update_hh_score(attn_score_cache)
|
|
|
-
|
|
|
- if past_key_values is None:
|
|
|
- return None
|
|
|
- seq_len = past_key_values[0].size(self.k_seq_dim)
|
|
|
- if seq_len <= self.cache_size:
|
|
|
- return past_key_values
|
|
|
-
|
|
|
- # hh-selection
|
|
|
- bsz, num_heads, _, head_dim = past_key_values[0].shape
|
|
|
-
|
|
|
- select_hh_scores = self.hh_score[:, :seq_len - self.recent_size]
|
|
|
- _, keep_topk = torch.topk(select_hh_scores, self.hh_size, dim=-1)
|
|
|
- keep_topk = keep_topk.sort().values
|
|
|
-
|
|
|
- # keep_recent = torch.arange(seq_len - self.recent_size, seq_len).expand(keep_topk.shape[0], 1).to(keep_topk.device)
|
|
|
- keep_recent = torch.arange(seq_len - self.recent_size, seq_len, device=keep_topk.device).repeat(keep_topk.shape[0], 1)
|
|
|
- keep_idx = torch.cat([keep_topk, keep_recent], dim=-1)
|
|
|
-
|
|
|
- mask = torch.zeros(self.hh_score.shape, dtype=torch.bool).to(past_key_values[0].device)
|
|
|
- mask = mask.scatter(-1, keep_idx, 1)
|
|
|
-
|
|
|
- k_hh_recent = past_key_values[0].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
|
|
|
- v_hh_recent = past_key_values[1].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
|
|
|
-
|
|
|
- self.hh_score= self.hh_score[mask].view(num_heads, self.cache_size)
|
|
|
-
|
|
|
- return (k_hh_recent, v_hh_recent)
|
|
|
-
|
|
|
- def evict_for_space(self, past_key_values, num_coming):
|
|
|
- if past_key_values is None:
|
|
|
- return None
|
|
|
- seq_len = past_key_values[0][0].size(self.k_seq_dim)
|
|
|
- if seq_len + num_coming <= self.cache_size:
|
|
|
- return past_key_values
|
|
|
-
|
|
|
- # hh-selection
|
|
|
- bsz, num_heads, _, head_dim = past_key_values[0].shape
|
|
|
-
|
|
|
- select_hh_scores = self.hh_score[:, :seq_len - self.recent_size + num_coming]
|
|
|
- _, keep_topk = torch.topk(select_hh_scores, self.hh_size, dim=-1)
|
|
|
- keep_topk = keep_topk.sort().values
|
|
|
-
|
|
|
- # keep_recent = torch.arange(seq_len - self.recent_size, seq_len).expand(keep_topk.shape[0], 1).to(keep_topk.device)
|
|
|
- keep_recent = torch.arange(seq_len - self.recent_size + num_coming, seq_len, device=keep_topk.device).repeat(keep_topk.shape[0], 1)
|
|
|
- keep_idx = torch.cat([keep_topk, keep_recent], dim=-1)
|
|
|
-
|
|
|
- mask = torch.zeros(self.hh_score.shape, dtype=torch.bool).to(past_key_values[0].device)
|
|
|
- mask = mask.scatter(-1, keep_idx, 1)
|
|
|
-
|
|
|
- k_hh_recent = past_key_values[0].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
|
|
|
- v_hh_recent = past_key_values[1].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
|
|
|
-
|
|
|
- self.hh_score= self.hh_score[mask].view(num_heads, self.cache_size)
|
|
|
-
|
|
|
- return (k_hh_recent, v_hh_recent)
|
|
|
-
|
|
|
- def _update_hh_score(self, attn_score_cache):
|
|
|
-
|
|
|
- num_new_tokens = attn_score_cache.shape[2]
|
|
|
-
|
|
|
- if self.hh_score is None:
|
|
|
- self.hh_score = attn_score_cache.sum(0).sum(1)
|
|
|
- else:
|
|
|
- attn_score_cache = attn_score_cache.sum(0).sum(1)
|
|
|
- attn_score_cache[:, :-num_new_tokens] += self.hh_score
|
|
|
- self.hh_score = attn_score_cache
|
|
|
-
|
|
|
- def _clean_scores(self):
|
|
|
- self.hh_score = None
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
|
"""
|
|
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
|
@@ -255,6 +165,11 @@ class H2OLlamaAttention(nn.Module):
|
|
|
|
|
|
past_key_value = getattr(self, "past_key_value", past_key_value)
|
|
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
|
+
|
|
|
+ print(position_ids)
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
|
|
|
|
if past_key_value is not None:
|