Allen 1 tahun lalu
induk
melakukan
89e576bded

+ 0 - 3
research/long-context-llama/H2O/cache_utils.py

@@ -502,9 +502,6 @@ class HHCache(Cache):
         if len(self.accumulated_attention_scores) <= layer_idx:
             self.accumulated_attention_scores.append(attention_scores.sum(2)[:,::num_kv_groups, :]) # [bs, num_heads, key_len]
         else:
-            if layer_idx == 0:
-                import pdb; pdb.set_trace()
-
             num_new_tokens = attention_scores.shape[2]
             updated_attention_scores = attention_scores.sum(2)[:,::num_kv_groups, :] # [bs, num_heads, key_len]
             updated_attention_scores[:, :, :-num_new_tokens] += self.accumulated_attention_scores[layer_idx]

+ 5 - 90
research/long-context-llama/H2O/utils_llama.py

@@ -49,96 +49,6 @@ def apply_rotary_pos_emb_single(x, cos, sin, position_ids):
     x_embed = (x * cos) + (rotate_half(x) * sin)
     return x_embed
 
-class H2OKVCache_LayerWise:
-    def __init__(
-        self,
-        hh_size=4,
-        recent_size=512,
-        k_seq_dim=2,
-        v_seq_dim=2,
-    ):
-        self.hh_size = hh_size
-        self.recent_size = recent_size
-        self.cache_size = hh_size + recent_size
-        self.k_seq_dim = k_seq_dim
-        self.v_seq_dim = v_seq_dim
-        self.hh_score = None
-
-    def __call__(self, past_key_values, attn_score_cache):
-
-        self._update_hh_score(attn_score_cache)
-
-        if past_key_values is None:
-            return None
-        seq_len = past_key_values[0].size(self.k_seq_dim)
-        if seq_len <= self.cache_size:
-            return past_key_values
-
-        # hh-selection
-        bsz, num_heads, _, head_dim = past_key_values[0].shape
-
-        select_hh_scores = self.hh_score[:, :seq_len - self.recent_size]
-        _, keep_topk = torch.topk(select_hh_scores, self.hh_size, dim=-1)
-        keep_topk = keep_topk.sort().values
-
-        # keep_recent = torch.arange(seq_len - self.recent_size, seq_len).expand(keep_topk.shape[0], 1).to(keep_topk.device)
-        keep_recent = torch.arange(seq_len - self.recent_size, seq_len, device=keep_topk.device).repeat(keep_topk.shape[0], 1)
-        keep_idx = torch.cat([keep_topk, keep_recent], dim=-1)
-
-        mask = torch.zeros(self.hh_score.shape, dtype=torch.bool).to(past_key_values[0].device)
-        mask = mask.scatter(-1, keep_idx, 1)
-
-        k_hh_recent = past_key_values[0].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
-        v_hh_recent = past_key_values[1].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
-
-        self.hh_score= self.hh_score[mask].view(num_heads, self.cache_size)
-
-        return (k_hh_recent, v_hh_recent)
-
-    def evict_for_space(self, past_key_values, num_coming):
-        if past_key_values is None:
-            return None
-        seq_len = past_key_values[0][0].size(self.k_seq_dim)
-        if seq_len + num_coming <= self.cache_size:
-            return past_key_values
-
-        # hh-selection
-        bsz, num_heads, _, head_dim = past_key_values[0].shape
-
-        select_hh_scores = self.hh_score[:, :seq_len - self.recent_size + num_coming]
-        _, keep_topk = torch.topk(select_hh_scores, self.hh_size, dim=-1)
-        keep_topk = keep_topk.sort().values
-
-        # keep_recent = torch.arange(seq_len - self.recent_size, seq_len).expand(keep_topk.shape[0], 1).to(keep_topk.device)
-        keep_recent = torch.arange(seq_len - self.recent_size + num_coming, seq_len, device=keep_topk.device).repeat(keep_topk.shape[0], 1)
-        keep_idx = torch.cat([keep_topk, keep_recent], dim=-1)
-
-        mask = torch.zeros(self.hh_score.shape, dtype=torch.bool).to(past_key_values[0].device)
-        mask = mask.scatter(-1, keep_idx, 1)
-
-        k_hh_recent = past_key_values[0].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
-        v_hh_recent = past_key_values[1].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
-
-        self.hh_score= self.hh_score[mask].view(num_heads, self.cache_size)
-
-        return (k_hh_recent, v_hh_recent)
-
-    def _update_hh_score(self, attn_score_cache):
-
-        num_new_tokens = attn_score_cache.shape[2]
-
-        if self.hh_score is None:
-            self.hh_score = attn_score_cache.sum(0).sum(1)
-        else:
-            attn_score_cache = attn_score_cache.sum(0).sum(1)
-            attn_score_cache[:, :-num_new_tokens] += self.hh_score
-            self.hh_score = attn_score_cache
-
-    def _clean_scores(self):
-        self.hh_score = None
-
-
-
 def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
     """
     This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
@@ -255,6 +165,11 @@ class H2OLlamaAttention(nn.Module):
 
         past_key_value = getattr(self, "past_key_value", past_key_value)
         cos, sin = self.rotary_emb(value_states, position_ids)
+
+        print(position_ids)
+
+
+
         query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
 
         if past_key_value is not None: