|
@@ -363,6 +363,31 @@ class HHCache(Cache):
|
|
|
self.cos_sin_cache = {}
|
|
|
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
|
|
|
|
|
|
+ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
|
|
|
+ """
|
|
|
+ Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
|
|
|
+ sequence length.
|
|
|
+ """
|
|
|
+ if layer_idx < len(self):
|
|
|
+ return (self.key_cache[layer_idx], self.value_cache[layer_idx])
|
|
|
+ else:
|
|
|
+ raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
|
|
|
+
|
|
|
+ def __iter__(self):
|
|
|
+ """
|
|
|
+ Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
|
|
|
+ keys and values
|
|
|
+ """
|
|
|
+ for layer_idx in range(len(self)):
|
|
|
+ yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
|
|
|
+
|
|
|
+ def __len__(self):
|
|
|
+ """
|
|
|
+ Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
|
|
|
+ to the number of layers in the model.
|
|
|
+ """
|
|
|
+ return len(self.key_cache)
|
|
|
+
|
|
|
@staticmethod
|
|
|
def _rotate_half(x):
|
|
|
x1 = x[..., : x.shape[-1] // 2]
|
|
@@ -396,7 +421,7 @@ class HHCache(Cache):
|
|
|
rerotation_sin.to(key_states.dtype).unsqueeze(0),
|
|
|
)
|
|
|
return self.cos_sin_cache[key_states.shape[-2]]
|
|
|
-
|
|
|
+
|
|
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
|
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
|
|
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
|