|
@@ -464,7 +464,6 @@ class HHCache(Cache):
|
|
|
|
|
|
if layer_idx == 0:
|
|
|
self._seen_tokens += key_states.shape[-2]
|
|
|
- import pdb; pdb.set_trace()
|
|
|
|
|
|
# Update the cache
|
|
|
if len(self.key_cache) <= layer_idx:
|
|
@@ -499,10 +498,6 @@ class HHCache(Cache):
|
|
|
A tuple containing the updated key and value states.
|
|
|
"""
|
|
|
|
|
|
- if layer_idx == 0:
|
|
|
- import pdb; pdb.set_trace()
|
|
|
-
|
|
|
-
|
|
|
# Update score metrics (Accumulated attention scores)
|
|
|
if len(self.accumulated_attention_scores) <= layer_idx:
|
|
|
self.accumulated_attention_scores.append(attention_scores.sum(2)[:,::num_kv_groups, :]) # [bs, num_heads, key_len]
|
|
@@ -529,9 +524,6 @@ class HHCache(Cache):
|
|
|
self.key_cache[layer_idx] = self.key_cache[layer_idx][mask].view(bsz, num_heads, -1, head_dim)
|
|
|
self.value_cache[layer_idx] = self.value_cache[layer_idx][mask].view(bsz, num_heads, -1, head_dim)
|
|
|
self.accumulated_attention_scores[layer_idx] = self.accumulated_attention_scores[layer_idx][mask].view(bsz, num_heads, -1)
|
|
|
-
|
|
|
- if layer_idx == 0:
|
|
|
- import pdb; pdb.set_trace()
|
|
|
|
|
|
|
|
|
def reorder_cache(self, beam_idx: torch.LongTensor):
|