|
@@ -364,7 +364,6 @@ class HHCache(Cache):
|
|
|
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
|
|
|
sequence length.
|
|
|
"""
|
|
|
- print(layer_idx, len(self))
|
|
|
if layer_idx < len(self):
|
|
|
return (self.key_cache[layer_idx], self.value_cache[layer_idx])
|
|
|
else:
|
|
@@ -408,6 +407,7 @@ class HHCache(Cache):
|
|
|
Return:
|
|
|
A tuple containing the updated key and value states.
|
|
|
"""
|
|
|
+ import pdb; pdb.set_trace()
|
|
|
# Update the number of seen tokens
|
|
|
if layer_idx == 0:
|
|
|
self._seen_tokens += key_states.shape[-2]
|