|
@@ -359,7 +359,7 @@ class HHCache(Cache):
|
|
|
self.value_cache: List[torch.Tensor] = []
|
|
|
self.window_length = window_length
|
|
|
self.num_hh_tokens = num_hh_tokens
|
|
|
- self.accumulated_attention_scores = {}
|
|
|
+ self.accumulated_attention_scores: List[torch.Tensor] = []
|
|
|
self.cos_sin_cache = {}
|
|
|
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
|
|
|
|
|
@@ -468,9 +468,36 @@ class HHCache(Cache):
|
|
|
A tuple containing the updated key and value states.
|
|
|
"""
|
|
|
|
|
|
- import pdb; pdb.set_trace()
|
|
|
+ if layer_idx == 0:
|
|
|
+ import pdb; pdb.set_trace()
|
|
|
+
|
|
|
+
|
|
|
+ # Update score metrics (Accumulated attention scores)
|
|
|
+ if len(self.accumulated_attention_scores) <= layer_idx:
|
|
|
+ self.accumulated_attention_scores.append(attention_scores.sum(2)[:,::num_kv_groups, :]) # [bs, num_heads, key_len]
|
|
|
+ else:
|
|
|
+ num_new_tokens = attention_scores.shape[2]
|
|
|
+ updated_attention_scores = attention_scores.sum(2)[:,::num_kv_groups, :] # [bs, num_heads, key_len]
|
|
|
+ updated_attention_scores[:, :, :-num_new_tokens] += self.accumulated_attention_scores[layer_idx]
|
|
|
+ self.accumulated_attention_scores[layer_idx] = updated_attention_scores
|
|
|
+
|
|
|
+ # Update KV Cache
|
|
|
+ if self.get_seq_length(layer_idx) > self.window_length:
|
|
|
+
|
|
|
+ seq_scores = self.accumulated_attention_scores[layer_idx][:, :, :-self.window_length + self.num_hh_tokens]
|
|
|
+ _, keep_hh_index = torch.topk(seq_scores, self.num_hh_tokens, dim=-1)
|
|
|
+ keep_hh_index = keep_hh_index.sort().values
|
|
|
+
|
|
|
+ keep_local_index = torch.arange(self.get_seq_length(layer_idx) - self.window_length + self.num_hh_tokens, self.get_seq_length(layer_idx), device=keep_hh_index.device).repeat(keep_hh_index.shape[0], keep_hh_index.shape[1], 1)
|
|
|
+ keep_index = torch.cat([keep_hh_index, keep_local_index], dim=-1)
|
|
|
|
|
|
+ mask = torch.zeros(self.accumulated_attention_scores[layer_idx].shape, dtype=torch.bool).to(keep_hh_index.device)
|
|
|
+ mask = mask.scatter(-1, keep_index, 1)
|
|
|
|
|
|
+ bsz, num_heads, _, head_dim = self.key_cache[layer_idx].shape
|
|
|
+ self.key_cache[layer_idx] = self.key_cache[layer_idx][mask].view(bsz, num_heads, -1, head_dim)
|
|
|
+ self.value_cache[layer_idx] = self.value_cache[layer_idx][mask].view(bsz, num_heads, -1, head_dim)
|
|
|
+ self.accumulated_attention_scores[layer_idx] = self.accumulated_attention_scores[layer_idx][mask].view(bsz, num_heads, -1)
|
|
|
|
|
|
|
|
|
def reorder_cache(self, beam_idx: torch.LongTensor):
|