|
@@ -439,6 +439,7 @@ class HHCache(Cache):
|
|
|
value_states: torch.Tensor,
|
|
|
layer_idx: int,
|
|
|
cache_kwargs: Optional[Dict[str, Any]] = None,
|
|
|
+ accumulated_attention_scores: Optional[torch.Tensor] = None,
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
"""
|
|
|
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
|
@@ -457,7 +458,10 @@ class HHCache(Cache):
|
|
|
A tuple containing the updated key and value states.
|
|
|
"""
|
|
|
# Update the number of seen tokens
|
|
|
-
|
|
|
+
|
|
|
+ if accumulated_attention_scores is not None:
|
|
|
+ self.accumulated_attention_scores.append(accumulated_attention_scores)
|
|
|
+
|
|
|
if layer_idx == 0:
|
|
|
self._seen_tokens += key_states.shape[-2]
|
|
|
import pdb; pdb.set_trace()
|
|
@@ -542,7 +546,7 @@ class HHCache(Cache):
|
|
|
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
|
|
|
legacy_cache = ()
|
|
|
for layer_idx in range(len(self)):
|
|
|
- legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
|
|
|
+ legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]), self.accumulated_attention_scores[layer_idx], )
|
|
|
return legacy_cache
|
|
|
|
|
|
@classmethod
|
|
@@ -551,8 +555,8 @@ class HHCache(Cache):
|
|
|
cache = cls(window_length, num_hh_tokens)
|
|
|
if past_key_values is not None:
|
|
|
for layer_idx in range(len(past_key_values)):
|
|
|
- key_states, value_states = past_key_values[layer_idx]
|
|
|
- cache.update(key_states, value_states, layer_idx)
|
|
|
+ key_states, value_states, accumulated_attention_scores = past_key_values[layer_idx]
|
|
|
+ cache.update(key_states, value_states, layer_idx, accumulated_attention_scores=accumulated_attention_scores)
|
|
|
return cache
|
|
|
|
|
|
|