|
@@ -369,6 +369,7 @@ class HHCache(Cache):
|
|
sequence length.
|
|
sequence length.
|
|
"""
|
|
"""
|
|
if layer_idx < len(self):
|
|
if layer_idx < len(self):
|
|
|
|
+ print(layer_idx, self.key_cache[layer_idx].shape)
|
|
return (self.key_cache[layer_idx], self.value_cache[layer_idx])
|
|
return (self.key_cache[layer_idx], self.value_cache[layer_idx])
|
|
else:
|
|
else:
|
|
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
|
|
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
|
|
@@ -379,6 +380,7 @@ class HHCache(Cache):
|
|
keys and values
|
|
keys and values
|
|
"""
|
|
"""
|
|
for layer_idx in range(len(self)):
|
|
for layer_idx in range(len(self)):
|
|
|
|
+ print(layer_idx, self.key_cache[layer_idx].shape)
|
|
yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
|
|
yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
|
|
|
|
|
|
def __len__(self):
|
|
def __len__(self):
|