瀏覽代碼

Update cache.py

Allen 1 年之前
父節點
當前提交
115e9306d5
共有 1 個文件被更改,包括 614 次插入0 次删除
  1. 614 0
      research/long-context-llama/H2O/utils/cache.py

+ 614 - 0
research/long-context-llama/H2O/utils/cache.py

@@ -0,0 +1,614 @@
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+@dataclass
+class Cache:
+    """
+    Base, abstract class for all caches. The actual data structure is specific to each subclass.
+    """
+
+    def update(
+        self,
+        key_states: torch.Tensor,
+        value_states: torch.Tensor,
+        layer_idx: int,
+        cache_kwargs: Optional[Dict[str, Any]] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
+
+        Parameters:
+            key_states (`torch.Tensor`):
+                The new key states to cache.
+            value_states (`torch.Tensor`):
+                The new value states to cache.
+            layer_idx (`int`):
+                The index of the layer to cache the states for.
+            cache_kwargs (`Dict[str, Any]`, `optional`):
+                Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
+                cache to be created.
+
+        Return:
+            A tuple containing the updated key and value states.
+        """
+        raise NotImplementedError("Make sure to implement `update` in a subclass.")
+
+    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+        raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
+
+    def get_max_length(self) -> Optional[int]:
+        """Returns the maximum sequence length of the cached states, if there is any."""
+        raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.")
+
+    def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
+        """Given the sequence length of the new inputs, returns the usable length of the cache."""
+        # Cache without size limit -> all cache is usable
+        # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
+        #   length, we will need to evict part of the cache (and thus not all cache is usable)
+        max_length = self.get_max_length()
+        previous_seq_length = self.get_seq_length(layer_idx)
+        if max_length is not None and previous_seq_length + new_seq_length > max_length:
+            return max_length - new_seq_length
+        return previous_seq_length
+
+    @property
+    def seen_tokens(self):
+        logger.warning_once(
+            "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` "
+            "model input instead."
+        )
+        if hasattr(self, "_seen_tokens"):
+            return self._seen_tokens
+        else:
+            return None
+
+
+class DynamicCache(Cache):
+    """
+    A cache that grows dynamically as more tokens are generated. This is the default for generative models.
+
+    It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
+    `[batch_size, num_heads, seq_len, head_dim]`.
+    """
+
+    def __init__(self) -> None:
+        self.key_cache: List[torch.Tensor] = []
+        self.value_cache: List[torch.Tensor] = []
+        self._seen_tokens = 0  # Used in `generate` to keep tally of how many tokens the cache has seen
+
+    def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
+        """
+        Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
+        sequence length.
+        """
+        if layer_idx < len(self):
+            return (self.key_cache[layer_idx], self.value_cache[layer_idx])
+        else:
+            raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
+
+    def __iter__(self):
+        """
+        Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
+        keys and values
+        """
+        for layer_idx in range(len(self)):
+            yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
+
+    def __len__(self):
+        """
+        Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
+        to the number of layers in the model.
+        """
+        return len(self.key_cache)
+
+    def update(
+        self,
+        key_states: torch.Tensor,
+        value_states: torch.Tensor,
+        layer_idx: int,
+        cache_kwargs: Optional[Dict[str, Any]] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
+
+        Parameters:
+            key_states (`torch.Tensor`):
+                The new key states to cache.
+            value_states (`torch.Tensor`):
+                The new value states to cache.
+            layer_idx (`int`):
+                The index of the layer to cache the states for.
+            cache_kwargs (`Dict[str, Any]`, `optional`):
+                Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
+
+        Return:
+            A tuple containing the updated key and value states.
+        """
+        # Update the number of seen tokens
+        if layer_idx == 0:
+            self._seen_tokens += key_states.shape[-2]
+
+        # Update the cache
+        if len(self.key_cache) <= layer_idx:
+            self.key_cache.append(key_states)
+            self.value_cache.append(value_states)
+        else:
+            self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
+            self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
+
+        return self.key_cache[layer_idx], self.value_cache[layer_idx]
+
+    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+        if len(self.key_cache) <= layer_idx:
+            return 0
+        return self.key_cache[layer_idx].shape[-2]
+
+    def get_max_length(self) -> Optional[int]:
+        """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
+        return None
+
+    def reorder_cache(self, beam_idx: torch.LongTensor):
+        """Reorders the cache for beam search, given the selected beam indices."""
+        for layer_idx in range(len(self.key_cache)):
+            device = self.key_cache[layer_idx].device
+            self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
+            device = self.value_cache[layer_idx].device
+            self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
+
+    def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
+        """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
+        legacy_cache = ()
+        for layer_idx in range(len(self)):
+            legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
+        return legacy_cache
+
+    @classmethod
+    def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
+        """Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
+        cache = cls()
+        if past_key_values is not None:
+            for layer_idx in range(len(past_key_values)):
+                key_states, value_states = past_key_values[layer_idx]
+                cache.update(key_states, value_states, layer_idx)
+        return cache
+
+
+class SinkCache(Cache):
+    """
+    A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
+    generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
+    tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.
+
+    It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
+    `[batch_size, num_heads, seq_len, head_dim]`.
+
+    Parameters:
+        window_length (`int`):
+            The length of the context window.
+        num_sink_tokens (`int`):
+            The number of sink tokens. See the original paper for more information.
+    """
+
+    def __init__(self, window_length: int, num_sink_tokens: int) -> None:
+        self.key_cache: List[torch.Tensor] = []
+        self.value_cache: List[torch.Tensor] = []
+        self.window_length = window_length
+        self.num_sink_tokens = num_sink_tokens
+        self.cos_sin_cache = {}
+        self._seen_tokens = 0  # Used in `generate` to keep tally of how many tokens the cache has seen
+
+    @staticmethod
+    def _rotate_half(x):
+        x1 = x[..., : x.shape[-1] // 2]
+        x2 = x[..., x.shape[-1] // 2 :]
+        return torch.cat((-x2, x1), dim=-1)
+
+    def _apply_key_rotary_pos_emb(
+        self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
+    ) -> torch.Tensor:
+        rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
+        return rotated_key_states
+
+    def _get_rerotation_cos_sin(
+        self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        if key_states.shape[-2] not in self.cos_sin_cache:
+            # Upcast to float32 temporarily for better accuracy
+            cos = cos.to(torch.float32)
+            sin = sin.to(torch.float32)
+
+            # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
+            original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
+            shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
+            original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
+            shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
+            rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
+            rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
+
+            self.cos_sin_cache[key_states.shape[-2]] = (
+                rerotation_cos.to(key_states.dtype).unsqueeze(0),
+                rerotation_sin.to(key_states.dtype).unsqueeze(0),
+            )
+        return self.cos_sin_cache[key_states.shape[-2]]
+
+    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+        # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
+        if len(self.key_cache) <= layer_idx:
+            return 0
+        return self.key_cache[layer_idx].shape[-2]
+
+    def get_max_length(self) -> Optional[int]:
+        """Returns the maximum sequence length of the cached states."""
+        return self.window_length
+
+    def update(
+        self,
+        key_states: torch.Tensor,
+        value_states: torch.Tensor,
+        layer_idx: int,
+        cache_kwargs: Optional[Dict[str, Any]] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
+
+        Parameters:
+            key_states (`torch.Tensor`):
+                The new key states to cache.
+            value_states (`torch.Tensor`):
+                The new value states to cache.
+            layer_idx (`int`):
+                The index of the layer to cache the states for.
+            cache_kwargs (`Dict[str, Any]`, `optional`):
+                Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
+                `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
+                rotation as the tokens are shifted.
+
+        Return:
+            A tuple containing the updated key and value states.
+        """
+        # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
+        # with partially rotated position embeddings, like Phi or Persimmon.
+        sin = cache_kwargs.get("sin")
+        cos = cache_kwargs.get("cos")
+        partial_rotation_size = cache_kwargs.get("partial_rotation_size")
+        using_rope = cos is not None and sin is not None
+
+        # Update the number of seen tokens
+        if layer_idx == 0:
+            self._seen_tokens += key_states.shape[-2]
+
+        # [bsz, num_heads, seq_len, head_dim]
+        if len(self.key_cache) <= layer_idx:
+            # Empty cache
+            self.key_cache.append(key_states)
+            self.value_cache.append(value_states)
+
+        elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
+            # Growing cache
+            self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
+            self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
+
+        else:
+            # Shifting cache
+            keys_to_keep = self.key_cache[layer_idx][
+                :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :
+            ]
+
+            # On RoPE models, we need to recompute the Key rotation as the tokens are shifted
+            if using_rope:
+                rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
+                    key_states, cos[: self.window_length], sin[: self.window_length]
+                )
+                if partial_rotation_size is not None:
+                    keys_to_keep, keys_pass = (
+                        keys_to_keep[..., :partial_rotation_size],
+                        keys_to_keep[..., partial_rotation_size:],
+                    )
+                keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
+                if partial_rotation_size is not None:
+                    keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
+
+            # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
+            sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
+            self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)
+
+            sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
+            values_to_keep = self.value_cache[layer_idx][
+                :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
+            ]
+            self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)
+
+        return self.key_cache[layer_idx], self.value_cache[layer_idx]
+
+    def reorder_cache(self, beam_idx: torch.LongTensor):
+        """Reorders the cache for beam search, given the selected beam indices."""
+        for layer_idx in range(len(self.key_cache)):
+            device = self.key_cache[layer_idx].device
+            self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
+            device = self.value_cache[layer_idx].device
+            self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
+
+
+class HHCache(Cache):
+    """
+    A cache that apply heavy-hitter oracle (https://proceedings.neurips.cc/paper_files/paper/2023/file/6ceefa7b15572587b78ecfcebb2827f8-Paper-Conference.pdf).
+    Only the heavy-hitter and the recent tokens are stored in the cache.
+
+    It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
+    `[batch_size, num_heads, seq_len, head_dim]`.
+
+    Parameters:
+        window_length (`int`):
+            The length of the context window.
+        num_hh_tokens (`int`):
+            The number of heavy hitter tokens. See the original paper for more information.
+    """
+
+    def __init__(self, window_length: int, num_hh_tokens: int) -> None:
+        self.key_cache: List[torch.Tensor] = []
+        self.value_cache: List[torch.Tensor] = []
+        self.window_length = window_length
+        self.num_hh_tokens = num_hh_tokens
+        self.accumulated_attention_scores: List[torch.Tensor] = []
+        self._seen_tokens = 0  # Used in `generate` to keep tally of how many tokens the cache has seen
+
+    def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
+        """
+        Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
+        sequence length.
+        """
+        if layer_idx < len(self):
+            return (self.key_cache[layer_idx], self.value_cache[layer_idx], self.accumulated_attention_scores[layer_idx])
+        else:
+            raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
+
+    def __iter__(self):
+        """
+        Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
+        keys and values
+        """
+        for layer_idx in range(len(self)):
+            yield (self.key_cache[layer_idx], self.value_cache[layer_idx], self.accumulated_attention_scores[layer_idx])
+
+    def __len__(self):
+        """
+        Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
+        to the number of layers in the model.
+        """
+        return len(self.key_cache)
+
+    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+        # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
+        if len(self.key_cache) <= layer_idx:
+            return 0
+        return self.key_cache[layer_idx].shape[-2]
+
+    def get_max_length(self) -> Optional[int]:
+        """Returns the maximum sequence length of the cached states."""
+        return self.window_length
+
+    def update(
+        self,
+        key_states: torch.Tensor,
+        value_states: torch.Tensor,
+        layer_idx: int,
+        cache_kwargs: Optional[Dict[str, Any]] = None,
+        accumulated_attention_scores: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
+
+        Parameters:
+            key_states (`torch.Tensor`):
+                The new key states to cache.
+            value_states (`torch.Tensor`):
+                The new value states to cache.
+            layer_idx (`int`):
+                The index of the layer to cache the states for.
+            cache_kwargs (`Dict[str, Any]`, `optional`):
+                Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
+
+        Return:
+            A tuple containing the updated key and value states.
+        """
+        # Update the number of seen tokens
+
+        if accumulated_attention_scores is not None:
+            self.accumulated_attention_scores.append(accumulated_attention_scores)
+
+        if layer_idx == 0:
+            self._seen_tokens += key_states.shape[-2]
+
+        # Update the cache
+        if len(self.key_cache) <= layer_idx:
+            self.key_cache.append(key_states)
+            self.value_cache.append(value_states)
+        else:
+            self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
+            self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
+
+        return self.key_cache[layer_idx], self.value_cache[layer_idx]
+
+    def update_slimming(
+        self,
+        attention_scores: torch.Tensor,
+        num_kv_groups: int,
+        layer_idx: int,
+        cache_kwargs: Optional[Dict[str, Any]] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Slimming the cache based on accumulated attention scores, only keep heavy-hitters + local tokens.
+
+        Parameters:
+            attention_scores (`torch.Tensor`):
+                Attention_scores for current steps.
+            num_kv_groups (`int`):
+                The number of kv groups in repeat kv.
+            layer_idx (`int`):
+                The index of the layer to cache the states for.
+            cache_kwargs (`Dict[str, Any]`, `optional`):
+                Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
+        Return:
+            A tuple containing the updated key and value states.
+        """
+
+        # Update score metrics (Accumulated attention scores)
+        if len(self.accumulated_attention_scores) <= layer_idx:
+            self.accumulated_attention_scores.append(attention_scores.sum(2)[:,::num_kv_groups, :]) # [bs, num_heads, key_len]
+        else:
+            num_new_tokens = attention_scores.shape[2]
+            updated_attention_scores = attention_scores.sum(2)[:,::num_kv_groups, :] # [bs, num_heads, key_len]
+            updated_attention_scores[:, :, :-num_new_tokens] += self.accumulated_attention_scores[layer_idx]
+            self.accumulated_attention_scores[layer_idx] = updated_attention_scores
+
+        # Update KV Cache
+        if self.get_seq_length(layer_idx) > self.window_length:
+
+            seq_scores = self.accumulated_attention_scores[layer_idx][:, :, :-self.window_length + self.num_hh_tokens]
+            _, keep_hh_index = torch.topk(seq_scores, self.num_hh_tokens, dim=-1)
+            keep_hh_index = keep_hh_index.sort().values
+
+            keep_local_index = torch.arange(self.get_seq_length(layer_idx) - self.window_length + self.num_hh_tokens, self.get_seq_length(layer_idx), device=keep_hh_index.device).repeat(keep_hh_index.shape[0], keep_hh_index.shape[1], 1)
+            keep_index = torch.cat([keep_hh_index, keep_local_index], dim=-1)
+
+            mask = torch.zeros(self.accumulated_attention_scores[layer_idx].shape, dtype=torch.bool).to(keep_hh_index.device)
+            mask = mask.scatter(-1, keep_index, 1)
+
+            bsz, num_heads, _, head_dim = self.key_cache[layer_idx].shape
+            self.key_cache[layer_idx] = self.key_cache[layer_idx][mask].view(bsz, num_heads, -1, head_dim)
+            self.value_cache[layer_idx] = self.value_cache[layer_idx][mask].view(bsz, num_heads, -1, head_dim)
+            self.accumulated_attention_scores[layer_idx] = self.accumulated_attention_scores[layer_idx][mask].view(bsz, num_heads, -1)
+
+
+    def reorder_cache(self, beam_idx: torch.LongTensor):
+        """Reorders the cache for beam search, given the selected beam indices."""
+        for layer_idx in range(len(self.key_cache)):
+            device = self.key_cache[layer_idx].device
+            self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
+            device = self.value_cache[layer_idx].device
+            self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
+
+    def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
+        """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
+        legacy_cache = ()
+        for layer_idx in range(len(self)):
+            legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx], self.accumulated_attention_scores[layer_idx],))
+        return legacy_cache
+
+    @classmethod
+    def from_legacy_cache(cls, window_length: int, num_hh_tokens: int, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
+        """Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
+        cache = cls(window_length, num_hh_tokens)
+        if past_key_values is not None:
+            for layer_idx in range(len(past_key_values) // 3):
+                key_states = past_key_values[layer_idx * 3]
+                value_states = past_key_values[layer_idx * 3 + 1]
+                accumulated_attention_scores = past_key_values[layer_idx * 3 + 2]
+                cache.update(key_states, value_states, layer_idx, accumulated_attention_scores=accumulated_attention_scores)
+        return cache
+
+
+class StaticCache(Cache):
+    """
+    Static Cache class to be used with `torch.compile(model)`.
+
+    Parameters:
+        config (`PretrainedConfig):
+            The configuration file defining the `max_position_embeddings`, `hidden_size` and `num_attention_heads`
+            required to initialize the static cache.
+        max_batch_size (`int`):
+            The maximum batch size with which the model will be used.
+        max_cache_len (`int`):
+            The maximum sequence length with which the model will be used.
+        device (`torch.device`):
+            The device on which the cache should be initialized. Should be the same as the layer.
+        dtype (*optional*, defaults to `torch.float32`):
+            The default `dtype` to use when initializing the layer.
+    """
+
+    def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
+        super().__init__()
+        self.max_batch_size = max_batch_size
+        self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
+        # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
+        self.head_dim = (
+            config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
+        )
+
+        self.dtype = dtype if dtype is not None else torch.float32
+        self.num_key_value_heads = (
+            config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
+        )
+
+        cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
+        self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
+        self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
+
+    def update(
+        self,
+        key_states: torch.Tensor,
+        value_states: torch.Tensor,
+        layer_idx: int,
+        cache_kwargs: Optional[Dict[str, Any]] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
+        It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
+
+        Parameters:
+            key_states (`torch.Tensor`):
+                The new key states to cache.
+            value_states (`torch.Tensor`):
+                The new value states to cache.
+            layer_idx (`int`):
+                The index of the layer to cache the states for. Kept for backward compatibility
+            cache_kwargs (`Dict[str, Any]`, `optional`):
+                Additional arguments for the cache subclass. The `StaticCache` just needs the `q_len`
+                to know how much of the cache it should overwrite.
+
+        Return:
+            A tuple containing the updated key and value states.
+        """
+        new_cache_positions = cache_kwargs.get("cache_position")
+        k_out = self.key_cache
+        v_out = self.value_cache
+
+        k_out[:, :, new_cache_positions] = key_states
+        v_out[:, :, new_cache_positions] = value_states
+
+        return k_out, v_out
+
+    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+        """Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
+        # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
+        # limit the check to the first batch member and head dimension.
+        # TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after
+        # https://github.com/pytorch/pytorch/issues/120248 is fixed
+        return (self.key_cache[0, 0].any(dim=-1)).sum()
+
+    def get_max_length(self) -> Optional[int]:
+        """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
+        return self.max_cache_len
+
+    def reorder_cache(self, beam_idx: torch.LongTensor):
+        """Reorders the cache for beam search, given the selected beam indices."""
+        device = self.key_cache.device
+        self.key_cache = self.key_cache.index_select(0, beam_idx.to(device))
+        device = self.value_cache.device
+        self.value_cache = self.value_cache.index_select(0, beam_idx.to(device))
+
+    def to_legacy_cache(self):
+        """Dummy function for BC. We have to keep it because otherwise the call in the forward of models will break it"""
+        return None
+