Browse Source

Update cache_utils.py

Allen 1 năm trước cách đây
mục cha
commit
fb9373bc33
1 tập tin đã thay đổi với 29 bổ sung2 xóa
  1. 29 2
      research/long-context-llama/H2O/cache_utils.py

+ 29 - 2
research/long-context-llama/H2O/cache_utils.py

@@ -359,7 +359,7 @@ class HHCache(Cache):
         self.value_cache: List[torch.Tensor] = []
         self.window_length = window_length
         self.num_hh_tokens = num_hh_tokens
-        self.accumulated_attention_scores = {}
+        self.accumulated_attention_scores: List[torch.Tensor] = []
         self.cos_sin_cache = {}
         self._seen_tokens = 0  # Used in `generate` to keep tally of how many tokens the cache has seen
 
@@ -468,9 +468,36 @@ class HHCache(Cache):
             A tuple containing the updated key and value states.
         """
 
-        import pdb; pdb.set_trace()
+        if layer_idx == 0:
+            import pdb; pdb.set_trace()
+
+
+        # Update score metrics (Accumulated attention scores)
+        if len(self.accumulated_attention_scores) <= layer_idx:
+            self.accumulated_attention_scores.append(attention_scores.sum(2)[:,::num_kv_groups, :]) # [bs, num_heads, key_len]
+        else:
+            num_new_tokens = attention_scores.shape[2]
+            updated_attention_scores = attention_scores.sum(2)[:,::num_kv_groups, :] # [bs, num_heads, key_len]
+            updated_attention_scores[:, :, :-num_new_tokens] += self.accumulated_attention_scores[layer_idx]
+            self.accumulated_attention_scores[layer_idx] = updated_attention_scores
+
+        # Update KV Cache
+        if self.get_seq_length(layer_idx) > self.window_length:
+
+            seq_scores = self.accumulated_attention_scores[layer_idx][:, :, :-self.window_length + self.num_hh_tokens]
+            _, keep_hh_index = torch.topk(seq_scores, self.num_hh_tokens, dim=-1)
+            keep_hh_index = keep_hh_index.sort().values
+
+            keep_local_index = torch.arange(self.get_seq_length(layer_idx) - self.window_length + self.num_hh_tokens, self.get_seq_length(layer_idx), device=keep_hh_index.device).repeat(keep_hh_index.shape[0], keep_hh_index.shape[1], 1)
+            keep_index = torch.cat([keep_hh_index, keep_local_index], dim=-1)
 
+            mask = torch.zeros(self.accumulated_attention_scores[layer_idx].shape, dtype=torch.bool).to(keep_hh_index.device)
+            mask = mask.scatter(-1, keep_index, 1)
 
+            bsz, num_heads, _, head_dim = self.key_cache[layer_idx].shape
+            self.key_cache[layer_idx] = self.key_cache[layer_idx][mask].view(bsz, num_heads, -1, head_dim)
+            self.value_cache[layer_idx] = self.value_cache[layer_idx][mask].view(bsz, num_heads, -1, head_dim)
+            self.accumulated_attention_scores[layer_idx] = self.accumulated_attention_scores[layer_idx][mask].view(bsz, num_heads, -1)
 
 
     def reorder_cache(self, beam_idx: torch.LongTensor):