|
@@ -555,8 +555,8 @@ class HHCache(Cache):
|
|
|
cache = cls(window_length, num_hh_tokens)
|
|
|
if past_key_values is not None:
|
|
|
for layer_idx in range(len(past_key_values)):
|
|
|
- key_states, value_states, accumulated_attention_scores = past_key_values[layer_idx]
|
|
|
import pdb; pdb.set_trace()
|
|
|
+ key_states, value_states, accumulated_attention_scores = past_key_values[layer_idx]
|
|
|
cache.update(key_states, value_states, layer_idx, accumulated_attention_scores=accumulated_attention_scores)
|
|
|
return cache
|
|
|
|