|
@@ -546,8 +546,7 @@ class HHCache(Cache):
|
|
|
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
|
|
|
legacy_cache = ()
|
|
|
for layer_idx in range(len(self)):
|
|
|
- legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx], self.accumulated_attention_scores[layer_idx], ))
|
|
|
- import pdb; pdb.set_trace()
|
|
|
+ legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx], self.accumulated_attention_scores[layer_idx],))
|
|
|
return legacy_cache
|
|
|
|
|
|
@classmethod
|
|
@@ -556,8 +555,9 @@ class HHCache(Cache):
|
|
|
cache = cls(window_length, num_hh_tokens)
|
|
|
if past_key_values is not None:
|
|
|
for layer_idx in range(len(past_key_values)):
|
|
|
- import pdb; pdb.set_trace()
|
|
|
- key_states, value_states, accumulated_attention_scores = past_key_values[layer_idx]
|
|
|
+ key_states = past_key_values[layer_idx * 3]
|
|
|
+ value_states = past_key_values[layer_idx * 3 + 1]
|
|
|
+ accumulated_attention_scores = past_key_values[layer_idx * 3 + 2]
|
|
|
cache.update(key_states, value_states, layer_idx, accumulated_attention_scores=accumulated_attention_scores)
|
|
|
return cache
|
|
|
|