| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645 | 
							- from dataclasses import dataclass
 
- from typing import Any, Dict, List, Optional, Tuple
 
- import torch
 
- from transformers.configuration_utils import PretrainedConfig
 
- from transformers.utils import logging
 
- logger = logging.get_logger(__name__)
 
- @dataclass
 
- class Cache:
 
-     """
 
-     Base, abstract class for all caches. The actual data structure is specific to each subclass.
 
-     """
 
-     def update(
 
-         self,
 
-         key_states: torch.Tensor,
 
-         value_states: torch.Tensor,
 
-         layer_idx: int,
 
-         cache_kwargs: Optional[Dict[str, Any]] = None,
 
-     ) -> Tuple[torch.Tensor, torch.Tensor]:
 
-         """
 
-         Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
 
-         Parameters:
 
-             key_states (`torch.Tensor`):
 
-                 The new key states to cache.
 
-             value_states (`torch.Tensor`):
 
-                 The new value states to cache.
 
-             layer_idx (`int`):
 
-                 The index of the layer to cache the states for.
 
-             cache_kwargs (`Dict[str, Any]`, `optional`):
 
-                 Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
 
-                 cache to be created.
 
-         Return:
 
-             A tuple containing the updated key and value states.
 
-         """
 
-         raise NotImplementedError("Make sure to implement `update` in a subclass.")
 
-     def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
 
-         """Returns the sequence length of the cached states. A layer index can be optionally passed."""
 
-         raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
 
-     def get_max_length(self) -> Optional[int]:
 
-         """Returns the maximum sequence length of the cached states, if there is any."""
 
-         raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.")
 
-     def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
 
-         """Given the sequence length of the new inputs, returns the usable length of the cache."""
 
-         # Cache without size limit -> all cache is usable
 
-         # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
 
-         #   length, we will need to evict part of the cache (and thus not all cache is usable)
 
-         max_length = self.get_max_length()
 
-         previous_seq_length = self.get_seq_length(layer_idx)
 
-         if max_length is not None and previous_seq_length + new_seq_length > max_length:
 
-             return max_length - new_seq_length
 
-         return previous_seq_length
 
-     @property
 
-     def seen_tokens(self):
 
-         logger.warning_once(
 
-             "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` "
 
-             "model input instead."
 
-         )
 
-         if hasattr(self, "_seen_tokens"):
 
-             return self._seen_tokens
 
-         else:
 
-             return None
 
- class DynamicCache(Cache):
 
-     """
 
-     A cache that grows dynamically as more tokens are generated. This is the default for generative models.
 
-     It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
 
-     `[batch_size, num_heads, seq_len, head_dim]`.
 
-     """
 
-     def __init__(self) -> None:
 
-         self.key_cache: List[torch.Tensor] = []
 
-         self.value_cache: List[torch.Tensor] = []
 
-         self._seen_tokens = 0  # Used in `generate` to keep tally of how many tokens the cache has seen
 
-     def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
 
-         """
 
-         Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
 
-         sequence length.
 
-         """
 
-         if layer_idx < len(self):
 
-             return (self.key_cache[layer_idx], self.value_cache[layer_idx])
 
-         else:
 
-             raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
 
-     def __iter__(self):
 
-         """
 
-         Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
 
-         keys and values
 
-         """
 
-         for layer_idx in range(len(self)):
 
-             yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
 
-     def __len__(self):
 
-         """
 
-         Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
 
-         to the number of layers in the model.
 
-         """
 
-         return len(self.key_cache)
 
-     def update(
 
-         self,
 
-         key_states: torch.Tensor,
 
-         value_states: torch.Tensor,
 
-         layer_idx: int,
 
-         cache_kwargs: Optional[Dict[str, Any]] = None,
 
-     ) -> Tuple[torch.Tensor, torch.Tensor]:
 
-         """
 
-         Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
 
-         Parameters:
 
-             key_states (`torch.Tensor`):
 
-                 The new key states to cache.
 
-             value_states (`torch.Tensor`):
 
-                 The new value states to cache.
 
-             layer_idx (`int`):
 
-                 The index of the layer to cache the states for.
 
-             cache_kwargs (`Dict[str, Any]`, `optional`):
 
-                 Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
 
-         Return:
 
-             A tuple containing the updated key and value states.
 
-         """
 
-         # Update the number of seen tokens
 
-         if layer_idx == 0:
 
-             self._seen_tokens += key_states.shape[-2]
 
-         # Update the cache
 
-         if len(self.key_cache) <= layer_idx:
 
-             self.key_cache.append(key_states)
 
-             self.value_cache.append(value_states)
 
-         else:
 
-             self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
 
-             self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
 
-         return self.key_cache[layer_idx], self.value_cache[layer_idx]
 
-     def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
 
-         """Returns the sequence length of the cached states. A layer index can be optionally passed."""
 
-         if len(self.key_cache) <= layer_idx:
 
-             return 0
 
-         return self.key_cache[layer_idx].shape[-2]
 
-     def get_max_length(self) -> Optional[int]:
 
-         """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
 
-         return None
 
-     def reorder_cache(self, beam_idx: torch.LongTensor):
 
-         """Reorders the cache for beam search, given the selected beam indices."""
 
-         for layer_idx in range(len(self.key_cache)):
 
-             device = self.key_cache[layer_idx].device
 
-             self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
 
-             device = self.value_cache[layer_idx].device
 
-             self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
 
-     def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
 
-         """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
 
-         legacy_cache = ()
 
-         for layer_idx in range(len(self)):
 
-             legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
 
-         return legacy_cache
 
-     @classmethod
 
-     def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
 
-         """Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
 
-         cache = cls()
 
-         if past_key_values is not None:
 
-             for layer_idx in range(len(past_key_values)):
 
-                 key_states, value_states = past_key_values[layer_idx]
 
-                 cache.update(key_states, value_states, layer_idx)
 
-         return cache
 
- class SinkCache(Cache):
 
-     """
 
-     A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
 
-     generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
 
-     tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.
 
-     It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
 
-     `[batch_size, num_heads, seq_len, head_dim]`.
 
-     Parameters:
 
-         window_length (`int`):
 
-             The length of the context window.
 
-         num_sink_tokens (`int`):
 
-             The number of sink tokens. See the original paper for more information.
 
-     """
 
-     def __init__(self, window_length: int, num_sink_tokens: int) -> None:
 
-         self.key_cache: List[torch.Tensor] = []
 
-         self.value_cache: List[torch.Tensor] = []
 
-         self.window_length = window_length
 
-         self.num_sink_tokens = num_sink_tokens
 
-         self.cos_sin_cache = {}
 
-         self._seen_tokens = 0  # Used in `generate` to keep tally of how many tokens the cache has seen
 
-     @staticmethod
 
-     def _rotate_half(x):
 
-         x1 = x[..., : x.shape[-1] // 2]
 
-         x2 = x[..., x.shape[-1] // 2 :]
 
-         return torch.cat((-x2, x1), dim=-1)
 
-     def _apply_key_rotary_pos_emb(
 
-         self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
 
-     ) -> torch.Tensor:
 
-         rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
 
-         return rotated_key_states
 
-     def _get_rerotation_cos_sin(
 
-         self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
 
-     ) -> Tuple[torch.Tensor, torch.Tensor]:
 
-         if key_states.shape[-2] not in self.cos_sin_cache:
 
-             # Upcast to float32 temporarily for better accuracy
 
-             cos = cos.to(torch.float32)
 
-             sin = sin.to(torch.float32)
 
-             # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
 
-             original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
 
-             shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
 
-             original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
 
-             shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
 
-             rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
 
-             rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
 
-             self.cos_sin_cache[key_states.shape[-2]] = (
 
-                 rerotation_cos.to(key_states.dtype).unsqueeze(0),
 
-                 rerotation_sin.to(key_states.dtype).unsqueeze(0),
 
-             )
 
-         return self.cos_sin_cache[key_states.shape[-2]]
 
-     def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
 
-         """Returns the sequence length of the cached states. A layer index can be optionally passed."""
 
-         # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
 
-         if len(self.key_cache) <= layer_idx:
 
-             return 0
 
-         return self.key_cache[layer_idx].shape[-2]
 
-     def get_max_length(self) -> Optional[int]:
 
-         """Returns the maximum sequence length of the cached states."""
 
-         return self.window_length
 
-     def update(
 
-         self,
 
-         key_states: torch.Tensor,
 
-         value_states: torch.Tensor,
 
-         layer_idx: int,
 
-         cache_kwargs: Optional[Dict[str, Any]] = None,
 
-     ) -> Tuple[torch.Tensor, torch.Tensor]:
 
-         """
 
-         Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
 
-         Parameters:
 
-             key_states (`torch.Tensor`):
 
-                 The new key states to cache.
 
-             value_states (`torch.Tensor`):
 
-                 The new value states to cache.
 
-             layer_idx (`int`):
 
-                 The index of the layer to cache the states for.
 
-             cache_kwargs (`Dict[str, Any]`, `optional`):
 
-                 Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
 
-                 `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
 
-                 rotation as the tokens are shifted.
 
-         Return:
 
-             A tuple containing the updated key and value states.
 
-         """
 
-         # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
 
-         # with partially rotated position embeddings, like Phi or Persimmon.
 
-         sin = cache_kwargs.get("sin")
 
-         cos = cache_kwargs.get("cos")
 
-         partial_rotation_size = cache_kwargs.get("partial_rotation_size")
 
-         using_rope = cos is not None and sin is not None
 
-         # Update the number of seen tokens
 
-         if layer_idx == 0:
 
-             self._seen_tokens += key_states.shape[-2]
 
-         # [bsz, num_heads, seq_len, head_dim]
 
-         if len(self.key_cache) <= layer_idx:
 
-             # Empty cache
 
-             self.key_cache.append(key_states)
 
-             self.value_cache.append(value_states)
 
-         elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
 
-             # Growing cache
 
-             self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
 
-             self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
 
-         else:
 
-             # Shifting cache
 
-             keys_to_keep = self.key_cache[layer_idx][
 
-                 :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :
 
-             ]
 
-             # On RoPE models, we need to recompute the Key rotation as the tokens are shifted
 
-             if using_rope:
 
-                 rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
 
-                     key_states, cos[: self.window_length], sin[: self.window_length]
 
-                 )
 
-                 if partial_rotation_size is not None:
 
-                     keys_to_keep, keys_pass = (
 
-                         keys_to_keep[..., :partial_rotation_size],
 
-                         keys_to_keep[..., partial_rotation_size:],
 
-                     )
 
-                 keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
 
-                 if partial_rotation_size is not None:
 
-                     keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
 
-             # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
 
-             sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
 
-             self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)
 
-             sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
 
-             values_to_keep = self.value_cache[layer_idx][
 
-                 :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
 
-             ]
 
-             self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)
 
-         return self.key_cache[layer_idx], self.value_cache[layer_idx]
 
-     def reorder_cache(self, beam_idx: torch.LongTensor):
 
-         """Reorders the cache for beam search, given the selected beam indices."""
 
-         for layer_idx in range(len(self.key_cache)):
 
-             device = self.key_cache[layer_idx].device
 
-             self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
 
-             device = self.value_cache[layer_idx].device
 
-             self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
 
- class HHCache(Cache):
 
-     """
 
-     A cache that apply heavy-hitter oracle (https://proceedings.neurips.cc/paper_files/paper/2023/file/6ceefa7b15572587b78ecfcebb2827f8-Paper-Conference.pdf).
 
-     Only the heavy-hitter and the recent tokens are stored in the cache.
 
-     It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
 
-     `[batch_size, num_heads, seq_len, head_dim]`.
 
-     Parameters:
 
-         window_length (`int`):
 
-             The length of the context window.
 
-         num_hh_tokens (`int`):
 
-             The number of heavy hitter tokens. See the original paper for more information.
 
-     """
 
-     def __init__(self, window_length: int, num_hh_tokens: int) -> None:
 
-         self.key_cache: List[torch.Tensor] = []
 
-         self.value_cache: List[torch.Tensor] = []
 
-         self.window_length = window_length
 
-         self.num_hh_tokens = num_hh_tokens
 
-         self.accumulated_attention_scores: List[torch.Tensor] = []
 
-         self._seen_tokens = 0  # Used in `generate` to keep tally of how many tokens the cache has seen
 
-     def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
 
-         """
 
-         Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
 
-         sequence length.
 
-         """
 
-         if layer_idx < len(self):
 
-             return (self.key_cache[layer_idx], self.value_cache[layer_idx], self.accumulated_attention_scores[layer_idx])
 
-         else:
 
-             raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
 
-     def __iter__(self):
 
-         """
 
-         Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
 
-         keys and values
 
-         """
 
-         for layer_idx in range(len(self)):
 
-             yield (self.key_cache[layer_idx], self.value_cache[layer_idx], self.accumulated_attention_scores[layer_idx])
 
-     def __len__(self):
 
-         """
 
-         Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
 
-         to the number of layers in the model.
 
-         """
 
-         return len(self.key_cache)
 
-     def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
 
-         """Returns the sequence length of the cached states. A layer index can be optionally passed."""
 
-         # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
 
-         if len(self.key_cache) <= layer_idx:
 
-             return 0
 
-         return self.key_cache[layer_idx].shape[-2]
 
-     def get_max_length(self) -> Optional[int]:
 
-         """Returns the maximum sequence length of the cached states."""
 
-         return self.window_length
 
-     def update(
 
-         self,
 
-         key_states: torch.Tensor,
 
-         value_states: torch.Tensor,
 
-         layer_idx: int,
 
-         cache_kwargs: Optional[Dict[str, Any]] = None,
 
-         accumulated_attention_scores: Optional[torch.Tensor] = None,
 
-     ) -> Tuple[torch.Tensor, torch.Tensor]:
 
-         """
 
-         Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
 
-         Parameters:
 
-             key_states (`torch.Tensor`):
 
-                 The new key states to cache.
 
-             value_states (`torch.Tensor`):
 
-                 The new value states to cache.
 
-             layer_idx (`int`):
 
-                 The index of the layer to cache the states for.
 
-             cache_kwargs (`Dict[str, Any]`, `optional`):
 
-                 Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
 
-         Return:
 
-             A tuple containing the updated key and value states.
 
-         """
 
-         # Update the number of seen tokens
 
-         if accumulated_attention_scores is not None:
 
-             self.accumulated_attention_scores.append(accumulated_attention_scores)
 
-         if layer_idx == 0:
 
-             self._seen_tokens += key_states.shape[-2]
 
-         # Update the cache
 
-         if len(self.key_cache) <= layer_idx:
 
-             self.key_cache.append(key_states)
 
-             self.value_cache.append(value_states)
 
-         else:
 
-             self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
 
-             self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
 
-         return self.key_cache[layer_idx], self.value_cache[layer_idx]
 
-     def update_slimming(
 
-         self,
 
-         attention_scores: torch.Tensor,
 
-         num_kv_groups: int,
 
-         layer_idx: int,
 
-         cache_kwargs: Optional[Dict[str, Any]] = None,
 
-     ) -> Tuple[torch.Tensor, torch.Tensor]:
 
-         """
 
-         Slimming the cache based on accumulated attention scores, only keep heavy-hitters + local tokens.
 
-         Parameters:
 
-             attention_scores (`torch.Tensor`):
 
-                 Attention_scores for current steps.
 
-             num_kv_groups (`int`):
 
-                 The number of kv groups in repeat kv.
 
-             layer_idx (`int`):
 
-                 The index of the layer to cache the states for.
 
-             cache_kwargs (`Dict[str, Any]`, `optional`):
 
-                 Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
 
-         Return:
 
-             A tuple containing the updated key and value states.
 
-         """
 
-         # Update score metrics (Accumulated attention scores)
 
-         if len(self.accumulated_attention_scores) <= layer_idx:
 
-             self.accumulated_attention_scores.append(attention_scores.sum(2)[:,::num_kv_groups, :]) # [bs, num_heads, key_len]
 
-         else:
 
-             num_new_tokens = attention_scores.shape[2]
 
-             updated_attention_scores = attention_scores.sum(2)[:,::num_kv_groups, :] # [bs, num_heads, key_len]
 
-             updated_attention_scores[:, :, :-num_new_tokens] += self.accumulated_attention_scores[layer_idx]
 
-             self.accumulated_attention_scores[layer_idx] = updated_attention_scores
 
-         # Update KV Cache
 
-         if self.get_seq_length(layer_idx) > self.window_length:
 
-             seq_scores = self.accumulated_attention_scores[layer_idx][:, :, :-self.window_length + self.num_hh_tokens]
 
-             _, keep_hh_index = torch.topk(seq_scores, self.num_hh_tokens, dim=-1)
 
-             keep_hh_index = keep_hh_index.sort().values
 
-             keep_local_index = torch.arange(self.get_seq_length(layer_idx) - self.window_length + self.num_hh_tokens, self.get_seq_length(layer_idx), device=keep_hh_index.device).repeat(keep_hh_index.shape[0], keep_hh_index.shape[1], 1)
 
-             keep_index = torch.cat([keep_hh_index, keep_local_index], dim=-1)
 
-             mask = torch.zeros(self.accumulated_attention_scores[layer_idx].shape, dtype=torch.bool).to(keep_hh_index.device)
 
-             mask = mask.scatter(-1, keep_index, 1)
 
-             bsz, num_heads, _, head_dim = self.key_cache[layer_idx].shape
 
-             self.key_cache[layer_idx] = self.key_cache[layer_idx][mask].view(bsz, num_heads, -1, head_dim)
 
-             self.value_cache[layer_idx] = self.value_cache[layer_idx][mask].view(bsz, num_heads, -1, head_dim)
 
-             self.accumulated_attention_scores[layer_idx] = self.accumulated_attention_scores[layer_idx][mask].view(bsz, num_heads, -1)
 
-     def reorder_cache(self, beam_idx: torch.LongTensor):
 
-         """Reorders the cache for beam search, given the selected beam indices."""
 
-         for layer_idx in range(len(self.key_cache)):
 
-             device = self.key_cache[layer_idx].device
 
-             self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
 
-             device = self.value_cache[layer_idx].device
 
-             self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
 
-     def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
 
-         """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
 
-         legacy_cache = ()
 
-         for layer_idx in range(len(self)):
 
-             legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx], self.accumulated_attention_scores[layer_idx],))
 
-         return legacy_cache
 
-     @classmethod
 
-     def from_legacy_cache(cls, window_length: int, num_hh_tokens: int, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
 
-         """Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
 
-         cache = cls(window_length, num_hh_tokens)
 
-         if past_key_values is not None:
 
-             for layer_idx in range(len(past_key_values) // 3):
 
-                 key_states = past_key_values[layer_idx * 3]
 
-                 value_states = past_key_values[layer_idx * 3 + 1]
 
-                 accumulated_attention_scores = past_key_values[layer_idx * 3 + 2]
 
-                 cache.update(key_states, value_states, layer_idx, accumulated_attention_scores=accumulated_attention_scores)
 
-         return cache
 
-     def evict_for_space(self, space_needed: int):
 
-         num_layers = len(self.key_cache)
 
-         # Update score metrics (Accumulated attention scores)
 
-         if len(self.accumulated_attention_scores) < num_layers:
 
-             raise ValueError("The accumulated_attention_scores should be updated before evicting the cache.")
 
-         for layer_idx in range(num_layers):
 
-             # Update KV Cache, Evict for new coming prompts
 
-             if self.get_seq_length(layer_idx) + space_needed > self.window_length:
 
-                 if self.window_length - self.num_hh_tokens <= space_needed:
 
-                     raise ValueError("The space_needed should be less than the window_length - num_hh_tokens.")
 
-                 seq_scores = self.accumulated_attention_scores[layer_idx][:, :, :-self.window_length + self.num_hh_tokens + space_needed]
 
-                 _, keep_hh_index = torch.topk(seq_scores, self.num_hh_tokens, dim=-1)
 
-                 keep_hh_index = keep_hh_index.sort().values
 
-                 keep_local_index = torch.arange(self.get_seq_length(layer_idx) - self.window_length + self.num_hh_tokens + space_needed, self.get_seq_length(layer_idx), device=keep_hh_index.device).repeat(keep_hh_index.shape[0], keep_hh_index.shape[1], 1)
 
-                 keep_index = torch.cat([keep_hh_index, keep_local_index], dim=-1)
 
-                 mask = torch.zeros(self.accumulated_attention_scores[layer_idx].shape, dtype=torch.bool).to(keep_hh_index.device)
 
-                 mask = mask.scatter(-1, keep_index, 1)
 
-                 bsz, num_heads, _, head_dim = self.key_cache[layer_idx].shape
 
-                 self.key_cache[layer_idx] = self.key_cache[layer_idx][mask].view(bsz, num_heads, -1, head_dim)
 
-                 self.value_cache[layer_idx] = self.value_cache[layer_idx][mask].view(bsz, num_heads, -1, head_dim)
 
-                 self.accumulated_attention_scores[layer_idx] = self.accumulated_attention_scores[layer_idx][mask].view(bsz, num_heads, -1)
 
- class StaticCache(Cache):
 
-     """
 
-     Static Cache class to be used with `torch.compile(model)`.
 
-     Parameters:
 
-         config (`PretrainedConfig):
 
-             The configuration file defining the `max_position_embeddings`, `hidden_size` and `num_attention_heads`
 
-             required to initialize the static cache.
 
-         max_batch_size (`int`):
 
-             The maximum batch size with which the model will be used.
 
-         max_cache_len (`int`):
 
-             The maximum sequence length with which the model will be used.
 
-         device (`torch.device`):
 
-             The device on which the cache should be initialized. Should be the same as the layer.
 
-         dtype (*optional*, defaults to `torch.float32`):
 
-             The default `dtype` to use when initializing the layer.
 
-     """
 
-     def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
 
-         super().__init__()
 
-         self.max_batch_size = max_batch_size
 
-         self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
 
-         # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
 
-         self.head_dim = (
 
-             config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
 
-         )
 
-         self.dtype = dtype if dtype is not None else torch.float32
 
-         self.num_key_value_heads = (
 
-             config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
 
-         )
 
-         cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
 
-         self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
 
-         self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
 
-     def update(
 
-         self,
 
-         key_states: torch.Tensor,
 
-         value_states: torch.Tensor,
 
-         layer_idx: int,
 
-         cache_kwargs: Optional[Dict[str, Any]] = None,
 
-     ) -> Tuple[torch.Tensor, torch.Tensor]:
 
-         """
 
-         Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
 
-         It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
 
-         Parameters:
 
-             key_states (`torch.Tensor`):
 
-                 The new key states to cache.
 
-             value_states (`torch.Tensor`):
 
-                 The new value states to cache.
 
-             layer_idx (`int`):
 
-                 The index of the layer to cache the states for. Kept for backward compatibility
 
-             cache_kwargs (`Dict[str, Any]`, `optional`):
 
-                 Additional arguments for the cache subclass. The `StaticCache` just needs the `q_len`
 
-                 to know how much of the cache it should overwrite.
 
-         Return:
 
-             A tuple containing the updated key and value states.
 
-         """
 
-         new_cache_positions = cache_kwargs.get("cache_position")
 
-         k_out = self.key_cache
 
-         v_out = self.value_cache
 
-         k_out[:, :, new_cache_positions] = key_states
 
-         v_out[:, :, new_cache_positions] = value_states
 
-         return k_out, v_out
 
-     def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
 
-         """Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
 
-         # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
 
-         # limit the check to the first batch member and head dimension.
 
-         # TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after
 
-         # https://github.com/pytorch/pytorch/issues/120248 is fixed
 
-         return (self.key_cache[0, 0].any(dim=-1)).sum()
 
-     def get_max_length(self) -> Optional[int]:
 
-         """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
 
-         return self.max_cache_len
 
-     def reorder_cache(self, beam_idx: torch.LongTensor):
 
-         """Reorders the cache for beam search, given the selected beam indices."""
 
-         device = self.key_cache.device
 
-         self.key_cache = self.key_cache.index_select(0, beam_idx.to(device))
 
-         device = self.value_cache.device
 
-         self.value_cache = self.value_cache.index_select(0, beam_idx.to(device))
 
-     def to_legacy_cache(self):
 
-         """Dummy function for BC. We have to keep it because otherwise the call in the forward of models will break it"""
 
-         return None
 
 
  |