|
@@ -408,32 +408,6 @@ class HHCache(Cache):
|
|
|
"""Returns the maximum sequence length of the cached states."""
|
|
|
return self.window_length
|
|
|
|
|
|
-
|
|
|
- def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
|
|
|
- """
|
|
|
- Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
|
|
|
- sequence length.
|
|
|
- """
|
|
|
- if layer_idx < len(self):
|
|
|
- return (self.key_cache[layer_idx], self.value_cache[layer_idx])
|
|
|
- else:
|
|
|
- raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
|
|
|
-
|
|
|
- def __iter__(self):
|
|
|
- """
|
|
|
- Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
|
|
|
- keys and values
|
|
|
- """
|
|
|
- for layer_idx in range(len(self)):
|
|
|
- yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
|
|
|
-
|
|
|
- def __len__(self):
|
|
|
- """
|
|
|
- Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
|
|
|
- to the number of layers in the model.
|
|
|
- """
|
|
|
- return len(self.key_cache)
|
|
|
-
|
|
|
def update(
|
|
|
self,
|
|
|
key_states: torch.Tensor,
|
|
@@ -517,12 +491,6 @@ class HHCache(Cache):
|
|
|
device = self.value_cache[layer_idx].device
|
|
|
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
|
|
|
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
|
|
|
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
|
|
|
legacy_cache = ()
|
|
@@ -533,7 +501,6 @@ class HHCache(Cache):
|
|
|
@classmethod
|
|
|
def from_legacy_cache(cls, window_length: int, num_hh_tokens: int, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
|
|
|
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
|
|
|
- import pdb; pdb.set_trace()
|
|
|
cache = cls(window_length, num_hh_tokens)
|
|
|
if past_key_values is not None:
|
|
|
for layer_idx in range(len(past_key_values)):
|