|
@@ -369,7 +369,7 @@ class HHCache(Cache):
|
|
|
sequence length.
|
|
|
"""
|
|
|
if layer_idx < len(self):
|
|
|
- return (self.key_cache[layer_idx], self.value_cache[layer_idx])
|
|
|
+ return (self.key_cache[layer_idx], self.value_cache[layer_idx], self.accumulated_attention_scores[layer_idx])
|
|
|
else:
|
|
|
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
|
|
|
|
|
@@ -379,7 +379,7 @@ class HHCache(Cache):
|
|
|
keys and values
|
|
|
"""
|
|
|
for layer_idx in range(len(self)):
|
|
|
- yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
|
|
|
+ yield (self.key_cache[layer_idx], self.value_cache[layer_idx], self.accumulated_attention_scores[layer_idx])
|
|
|
|
|
|
def __len__(self):
|
|
|
"""
|