Allen 1 년 전
부모
커밋
9320185d3c
2개의 변경된 파일724개의 추가작업 그리고 109개의 파일을 삭제
  1. 433 0
      research/long-context-llama/H2O/cache_utils.py
  2. 291 109
      research/long-context-llama/H2O/utils_llama.py

+ 433 - 0
research/long-context-llama/H2O/cache_utils.py

@@ -0,0 +1,433 @@
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+@dataclass
+class Cache:
+    """
+    Base, abstract class for all caches. The actual data structure is specific to each subclass.
+    """
+
+    def update(
+        self,
+        key_states: torch.Tensor,
+        value_states: torch.Tensor,
+        layer_idx: int,
+        cache_kwargs: Optional[Dict[str, Any]] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
+
+        Parameters:
+            key_states (`torch.Tensor`):
+                The new key states to cache.
+            value_states (`torch.Tensor`):
+                The new value states to cache.
+            layer_idx (`int`):
+                The index of the layer to cache the states for.
+            cache_kwargs (`Dict[str, Any]`, `optional`):
+                Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
+                cache to be created.
+
+        Return:
+            A tuple containing the updated key and value states.
+        """
+        raise NotImplementedError("Make sure to implement `update` in a subclass.")
+
+    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+        raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
+
+    def get_max_length(self) -> Optional[int]:
+        """Returns the maximum sequence length of the cached states, if there is any."""
+        raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.")
+
+    def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
+        """Given the sequence length of the new inputs, returns the usable length of the cache."""
+        # Cache without size limit -> all cache is usable
+        # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
+        #   length, we will need to evict part of the cache (and thus not all cache is usable)
+        max_length = self.get_max_length()
+        previous_seq_length = self.get_seq_length(layer_idx)
+        if max_length is not None and previous_seq_length + new_seq_length > max_length:
+            return max_length - new_seq_length
+        return previous_seq_length
+
+    @property
+    def seen_tokens(self):
+        logger.warning_once(
+            "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` "
+            "model input instead."
+        )
+        if hasattr(self, "_seen_tokens"):
+            return self._seen_tokens
+        else:
+            return None
+
+
+class DynamicCache(Cache):
+    """
+    A cache that grows dynamically as more tokens are generated. This is the default for generative models.
+
+    It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
+    `[batch_size, num_heads, seq_len, head_dim]`.
+    """
+
+    def __init__(self) -> None:
+        self.key_cache: List[torch.Tensor] = []
+        self.value_cache: List[torch.Tensor] = []
+        self._seen_tokens = 0  # Used in `generate` to keep tally of how many tokens the cache has seen
+
+    def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
+        """
+        Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
+        sequence length.
+        """
+        if layer_idx < len(self):
+            return (self.key_cache[layer_idx], self.value_cache[layer_idx])
+        else:
+            raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
+
+    def __iter__(self):
+        """
+        Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
+        keys and values
+        """
+        for layer_idx in range(len(self)):
+            yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
+
+    def __len__(self):
+        """
+        Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
+        to the number of layers in the model.
+        """
+        return len(self.key_cache)
+
+    def update(
+        self,
+        key_states: torch.Tensor,
+        value_states: torch.Tensor,
+        layer_idx: int,
+        cache_kwargs: Optional[Dict[str, Any]] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
+
+        Parameters:
+            key_states (`torch.Tensor`):
+                The new key states to cache.
+            value_states (`torch.Tensor`):
+                The new value states to cache.
+            layer_idx (`int`):
+                The index of the layer to cache the states for.
+            cache_kwargs (`Dict[str, Any]`, `optional`):
+                Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
+
+        Return:
+            A tuple containing the updated key and value states.
+        """
+        # Update the number of seen tokens
+        if layer_idx == 0:
+            self._seen_tokens += key_states.shape[-2]
+
+        # Update the cache
+        if len(self.key_cache) <= layer_idx:
+            self.key_cache.append(key_states)
+            self.value_cache.append(value_states)
+        else:
+            self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
+            self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
+
+        return self.key_cache[layer_idx], self.value_cache[layer_idx]
+
+    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+        if len(self.key_cache) <= layer_idx:
+            return 0
+        return self.key_cache[layer_idx].shape[-2]
+
+    def get_max_length(self) -> Optional[int]:
+        """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
+        return None
+
+    def reorder_cache(self, beam_idx: torch.LongTensor):
+        """Reorders the cache for beam search, given the selected beam indices."""
+        for layer_idx in range(len(self.key_cache)):
+            device = self.key_cache[layer_idx].device
+            self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
+            device = self.value_cache[layer_idx].device
+            self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
+
+    def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
+        """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
+        legacy_cache = ()
+        for layer_idx in range(len(self)):
+            legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
+        return legacy_cache
+
+    @classmethod
+    def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
+        """Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
+        cache = cls()
+        if past_key_values is not None:
+            for layer_idx in range(len(past_key_values)):
+                key_states, value_states = past_key_values[layer_idx]
+                cache.update(key_states, value_states, layer_idx)
+        return cache
+
+
+class SinkCache(Cache):
+    """
+    A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
+    generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
+    tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.
+
+    It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
+    `[batch_size, num_heads, seq_len, head_dim]`.
+
+    Parameters:
+        window_length (`int`):
+            The length of the context window.
+        num_sink_tokens (`int`):
+            The number of sink tokens. See the original paper for more information.
+    """
+
+    def __init__(self, window_length: int, num_sink_tokens: int) -> None:
+        self.key_cache: List[torch.Tensor] = []
+        self.value_cache: List[torch.Tensor] = []
+        self.window_length = window_length
+        self.num_sink_tokens = num_sink_tokens
+        self.cos_sin_cache = {}
+        self._seen_tokens = 0  # Used in `generate` to keep tally of how many tokens the cache has seen
+
+    @staticmethod
+    def _rotate_half(x):
+        x1 = x[..., : x.shape[-1] // 2]
+        x2 = x[..., x.shape[-1] // 2 :]
+        return torch.cat((-x2, x1), dim=-1)
+
+    def _apply_key_rotary_pos_emb(
+        self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
+    ) -> torch.Tensor:
+        rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
+        return rotated_key_states
+
+    def _get_rerotation_cos_sin(
+        self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        if key_states.shape[-2] not in self.cos_sin_cache:
+            # Upcast to float32 temporarily for better accuracy
+            cos = cos.to(torch.float32)
+            sin = sin.to(torch.float32)
+
+            # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
+            original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
+            shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
+            original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
+            shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
+            rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
+            rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
+
+            self.cos_sin_cache[key_states.shape[-2]] = (
+                rerotation_cos.to(key_states.dtype).unsqueeze(0),
+                rerotation_sin.to(key_states.dtype).unsqueeze(0),
+            )
+        return self.cos_sin_cache[key_states.shape[-2]]
+
+    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+        # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
+        if len(self.key_cache) <= layer_idx:
+            return 0
+        return self.key_cache[layer_idx].shape[-2]
+
+    def get_max_length(self) -> Optional[int]:
+        """Returns the maximum sequence length of the cached states."""
+        return self.window_length
+
+    def update(
+        self,
+        key_states: torch.Tensor,
+        value_states: torch.Tensor,
+        layer_idx: int,
+        cache_kwargs: Optional[Dict[str, Any]] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
+
+        Parameters:
+            key_states (`torch.Tensor`):
+                The new key states to cache.
+            value_states (`torch.Tensor`):
+                The new value states to cache.
+            layer_idx (`int`):
+                The index of the layer to cache the states for.
+            cache_kwargs (`Dict[str, Any]`, `optional`):
+                Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
+                `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
+                rotation as the tokens are shifted.
+
+        Return:
+            A tuple containing the updated key and value states.
+        """
+        # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
+        # with partially rotated position embeddings, like Phi or Persimmon.
+        sin = cache_kwargs.get("sin")
+        cos = cache_kwargs.get("cos")
+        partial_rotation_size = cache_kwargs.get("partial_rotation_size")
+        using_rope = cos is not None and sin is not None
+
+        # Update the number of seen tokens
+        if layer_idx == 0:
+            self._seen_tokens += key_states.shape[-2]
+
+        # [bsz, num_heads, seq_len, head_dim]
+        if len(self.key_cache) <= layer_idx:
+            # Empty cache
+            self.key_cache.append(key_states)
+            self.value_cache.append(value_states)
+
+        elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
+            # Growing cache
+            self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
+            self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
+
+        else:
+            # Shifting cache
+            keys_to_keep = self.key_cache[layer_idx][
+                :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :
+            ]
+
+            # On RoPE models, we need to recompute the Key rotation as the tokens are shifted
+            if using_rope:
+                rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
+                    key_states, cos[: self.window_length], sin[: self.window_length]
+                )
+                if partial_rotation_size is not None:
+                    keys_to_keep, keys_pass = (
+                        keys_to_keep[..., :partial_rotation_size],
+                        keys_to_keep[..., partial_rotation_size:],
+                    )
+                keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
+                if partial_rotation_size is not None:
+                    keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
+
+            # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
+            sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
+            self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)
+
+            sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
+            values_to_keep = self.value_cache[layer_idx][
+                :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
+            ]
+            self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)
+
+        return self.key_cache[layer_idx], self.value_cache[layer_idx]
+
+    def reorder_cache(self, beam_idx: torch.LongTensor):
+        """Reorders the cache for beam search, given the selected beam indices."""
+        for layer_idx in range(len(self.key_cache)):
+            device = self.key_cache[layer_idx].device
+            self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
+            device = self.value_cache[layer_idx].device
+            self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
+
+
+class StaticCache(Cache):
+    """
+    Static Cache class to be used with `torch.compile(model)`.
+
+    Parameters:
+        config (`PretrainedConfig):
+            The configuration file defining the `max_position_embeddings`, `hidden_size` and `num_attention_heads`
+            required to initialize the static cache.
+        max_batch_size (`int`):
+            The maximum batch size with which the model will be used.
+        max_cache_len (`int`):
+            The maximum sequence length with which the model will be used.
+        device (`torch.device`):
+            The device on which the cache should be initialized. Should be the same as the layer.
+        dtype (*optional*, defaults to `torch.float32`):
+            The default `dtype` to use when initializing the layer.
+    """
+
+    def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
+        super().__init__()
+        self.max_batch_size = max_batch_size
+        self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
+        # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
+        self.head_dim = (
+            config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
+        )
+
+        self.dtype = dtype if dtype is not None else torch.float32
+        self.num_key_value_heads = (
+            config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
+        )
+
+        cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
+        self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
+        self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
+
+    def update(
+        self,
+        key_states: torch.Tensor,
+        value_states: torch.Tensor,
+        layer_idx: int,
+        cache_kwargs: Optional[Dict[str, Any]] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
+        It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
+
+        Parameters:
+            key_states (`torch.Tensor`):
+                The new key states to cache.
+            value_states (`torch.Tensor`):
+                The new value states to cache.
+            layer_idx (`int`):
+                The index of the layer to cache the states for. Kept for backward compatibility
+            cache_kwargs (`Dict[str, Any]`, `optional`):
+                Additional arguments for the cache subclass. The `StaticCache` just needs the `q_len`
+                to know how much of the cache it should overwrite.
+
+        Return:
+            A tuple containing the updated key and value states.
+        """
+        new_cache_positions = cache_kwargs.get("cache_position")
+        k_out = self.key_cache
+        v_out = self.value_cache
+
+        k_out[:, :, new_cache_positions] = key_states
+        v_out[:, :, new_cache_positions] = value_states
+
+        return k_out, v_out
+
+    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+        """Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
+        # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
+        # limit the check to the first batch member and head dimension.
+        # TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after
+        # https://github.com/pytorch/pytorch/issues/120248 is fixed
+        return (self.key_cache[0, 0].any(dim=-1)).sum()
+
+    def get_max_length(self) -> Optional[int]:
+        """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
+        return self.max_cache_len
+
+    def reorder_cache(self, beam_idx: torch.LongTensor):
+        """Reorders the cache for beam search, given the selected beam indices."""
+        device = self.key_cache.device
+        self.key_cache = self.key_cache.index_select(0, beam_idx.to(device))
+        device = self.value_cache.device
+        self.value_cache = self.value_cache.index_select(0, beam_idx.to(device))
+
+    def to_legacy_cache(self):
+        """Dummy function for BC. We have to keep it because otherwise the call in the forward of models will break it"""
+        return None

+ 291 - 109
research/long-context-llama/H2O/utils_llama.py

@@ -132,12 +132,35 @@ class H2OKVCache_LayerWise:
     def _clean_scores(self):
         self.hh_score = None
 
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+    """
+    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+    """
+    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+    if n_rep == 1:
+        return hidden_states
+    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
 class H2OLlamaAttention(nn.Module):
     """Multi-headed attention from 'Attention Is All You Need' paper"""
 
-    def __init__(self, config: LlamaConfig):
+    def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
         super().__init__()
         self.config = config
+        self.layer_idx = layer_idx
+        if layer_idx is None:
+            logger.warning_once(
+                f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
+                "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
+                "when creating this class."
+            )
+
+        self.attention_dropout = config.attention_dropout
         self.hidden_size = config.hidden_size
         self.num_heads = config.num_attention_heads
         self.head_dim = self.hidden_size // self.num_heads
@@ -145,24 +168,19 @@ class H2OLlamaAttention(nn.Module):
         self.num_key_value_groups = self.num_heads // self.num_key_value_heads
         self.max_position_embeddings = config.max_position_embeddings
         self.rope_theta = config.rope_theta
+        self.is_causal = True
 
         if (self.head_dim * self.num_heads) != self.hidden_size:
             raise ValueError(
                 f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                 f" and `num_heads`: {self.num_heads})."
             )
-        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
-        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
-        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
-        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
-        self._init_rope()
 
-        self.kv_cache = H2OKVCache_LayerWise(
-            hh_size=config.hh_size,
-            recent_size=config.recent_size,
-            k_seq_dim=2,
-            v_seq_dim=2,
-        )
+        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
+        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
+        self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
+        self._init_rope()
 
     def _init_rope(self):
         if self.config.rope_scaling is None:
@@ -191,51 +209,34 @@ class H2OLlamaAttention(nn.Module):
             else:
                 raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
 
-    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
-        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
-
-    def _clean_cache(self):
-        self.kv_cache._clean_scores()
-
     def forward(
         self,
         hidden_states: torch.Tensor,
         attention_mask: Optional[torch.Tensor] = None,
         position_ids: Optional[torch.LongTensor] = None,
-        past_key_value: Optional[Tuple[torch.Tensor]] = None,
+        past_key_value: Optional[Cache] = None,
         output_attentions: bool = False,
         use_cache: bool = False,
         cache_position: Optional[torch.LongTensor] = None,
+        **kwargs,
     ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
-
         bsz, q_len, _ = hidden_states.size()
 
         if self.config.pretraining_tp > 1:
-            key_value_slicing = (
-                self.num_key_value_heads * self.head_dim
-            ) // self.config.pretraining_tp
+            key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
             query_slices = self.q_proj.weight.split(
                 (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
             )
             key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
             value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
 
-            query_states = [
-                F.linear(hidden_states, query_slices[i])
-                for i in range(self.config.pretraining_tp)
-            ]
+            query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
             query_states = torch.cat(query_states, dim=-1)
 
-            key_states = [
-                F.linear(hidden_states, key_slices[i])
-                for i in range(self.config.pretraining_tp)
-            ]
+            key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
             key_states = torch.cat(key_states, dim=-1)
 
-            value_states = [
-                F.linear(hidden_states, value_slices[i])
-                for i in range(self.config.pretraining_tp)
-            ]
+            value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
             value_states = torch.cat(value_states, dim=-1)
 
         else:
@@ -243,78 +244,31 @@ class H2OLlamaAttention(nn.Module):
             key_states = self.k_proj(hidden_states)
             value_states = self.v_proj(hidden_states)
 
-        query_states = query_states.view(
-            bsz, q_len, self.num_heads, self.head_dim
-        ).transpose(1, 2)
-        key_states = key_states.view(
-            bsz, q_len, self.num_key_value_heads, self.head_dim
-        ).transpose(1, 2)
-        value_states = value_states.view(
-            bsz, q_len, self.num_key_value_heads, self.head_dim
-        ).transpose(1, 2)
-
-        # remake causal mask
-        attention_mask = _make_causal_mask(
-            bsz=bsz,
-            tgt_len=q_len,
-            past_key_values_length=past_key_value[0].shape[-2] if past_key_value is not None else 0,
-            dtype=query_states.dtype,
-            device=query_states.device,
-        )
-
-        kv_seq_len = key_states.shape[-2]
-        if past_key_value is not None:
-            kv_seq_len += past_key_value[0].shape[-2]
-
-        if not position_ids.nelement() > 1:
-            position_ids[0][0] = kv_seq_len - 1
+        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
 
-        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
-        ### Shift Pos: query pos is min(cache_size, idx)
-        # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
-        query_states = apply_rotary_pos_emb_single(query_states, cos, sin, position_ids)
-        ###
+        past_key_value = getattr(self, "past_key_value", past_key_value)
+        cos, sin = self.rotary_emb(value_states, position_ids)
+        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
 
         if past_key_value is not None:
-            # reuse k, v, self_attention
-            key_states = torch.cat([past_key_value[0], key_states], dim=2)
-            value_states = torch.cat([past_key_value[1], value_states], dim=2)
-
-        past_key_value = (key_states, value_states) if use_cache else None
-
-        ### Shift Pos: key pos is the pos in cache (Rolling KV Cache and using relative pos emb)
-        key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0)
-        key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids)
-        ###
+            # sin and cos are specific to RoPE models; cache_position needed for the static cache
+            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
 
-        # repeat k/v heads if n_kv_heads < n_heads
         key_states = repeat_kv(key_states, self.num_key_value_groups)
         value_states = repeat_kv(value_states, self.num_key_value_groups)
 
-        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
-            self.head_dim
-        )
+        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 
-        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
-            raise ValueError(
-                f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
-                f" {attn_weights.size()}"
-            )
-
-        if attention_mask is not None:
-            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
-                raise ValueError(
-                    f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
-                )
-            attn_weights = attn_weights + attention_mask
+        if attention_mask is not None:  # no matter the length, we just slice it
+            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+            attn_weights = attn_weights + causal_mask
 
         # upcast attention to fp32
-        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
-            query_states.dtype
-        )
-
-        past_key_value = self.kv_cache(past_key_value, attn_weights.detach().clone())
-
+        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
         attn_output = torch.matmul(attn_weights, value_states)
 
         if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
@@ -324,21 +278,13 @@ class H2OLlamaAttention(nn.Module):
             )
 
         attn_output = attn_output.transpose(1, 2).contiguous()
+
         attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 
         if self.config.pretraining_tp > 1:
-            attn_output = attn_output.split(
-                self.hidden_size // self.config.pretraining_tp, dim=2
-            )
-            o_proj_slices = self.o_proj.weight.split(
-                self.hidden_size // self.config.pretraining_tp, dim=1
-            )
-            attn_output = sum(
-                [
-                    F.linear(attn_output[i], o_proj_slices[i])
-                    for i in range(self.config.pretraining_tp)
-                ]
-            )
+            attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
+            o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
+            attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
         else:
             attn_output = self.o_proj(attn_output)
 
@@ -347,6 +293,242 @@ class H2OLlamaAttention(nn.Module):
 
         return attn_output, attn_weights, past_key_value
 
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+# class H2OLlamaAttention(nn.Module):
+#     """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+#     def __init__(self, config: LlamaConfig):
+#         super().__init__()
+#         self.config = config
+#         self.hidden_size = config.hidden_size
+#         self.num_heads = config.num_attention_heads
+#         self.head_dim = self.hidden_size // self.num_heads
+#         self.num_key_value_heads = config.num_key_value_heads
+#         self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+#         self.max_position_embeddings = config.max_position_embeddings
+#         self.rope_theta = config.rope_theta
+
+#         if (self.head_dim * self.num_heads) != self.hidden_size:
+#             raise ValueError(
+#                 f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+#                 f" and `num_heads`: {self.num_heads})."
+#             )
+#         self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
+#         self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+#         self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+#         self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+#         self._init_rope()
+
+#         self.kv_cache = H2OKVCache_LayerWise(
+#             hh_size=config.hh_size,
+#             recent_size=config.recent_size,
+#             k_seq_dim=2,
+#             v_seq_dim=2,
+#         )
+
+#     def _init_rope(self):
+#         if self.config.rope_scaling is None:
+#             self.rotary_emb = LlamaRotaryEmbedding(
+#                 self.head_dim,
+#                 max_position_embeddings=self.max_position_embeddings,
+#                 base=self.rope_theta,
+#             )
+#         else:
+#             scaling_type = self.config.rope_scaling["type"]
+#             scaling_factor = self.config.rope_scaling["factor"]
+#             if scaling_type == "linear":
+#                 self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
+#                     self.head_dim,
+#                     max_position_embeddings=self.max_position_embeddings,
+#                     scaling_factor=scaling_factor,
+#                     base=self.rope_theta,
+#                 )
+#             elif scaling_type == "dynamic":
+#                 self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
+#                     self.head_dim,
+#                     max_position_embeddings=self.max_position_embeddings,
+#                     scaling_factor=scaling_factor,
+#                     base=self.rope_theta,
+#                 )
+#             else:
+#                 raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
+
+#     def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+#         return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+#     def _clean_cache(self):
+#         self.kv_cache._clean_scores()
+
+#     def forward(
+#         self,
+#         hidden_states: torch.Tensor,
+#         attention_mask: Optional[torch.Tensor] = None,
+#         position_ids: Optional[torch.LongTensor] = None,
+#         past_key_value: Optional[Tuple[torch.Tensor]] = None,
+#         output_attentions: bool = False,
+#         use_cache: bool = False,
+#         cache_position: Optional[torch.LongTensor] = None,
+#     ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+
+#         bsz, q_len, _ = hidden_states.size()
+
+#         if self.config.pretraining_tp > 1:
+#             key_value_slicing = (
+#                 self.num_key_value_heads * self.head_dim
+#             ) // self.config.pretraining_tp
+#             query_slices = self.q_proj.weight.split(
+#                 (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
+#             )
+#             key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
+#             value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
+
+#             query_states = [
+#                 F.linear(hidden_states, query_slices[i])
+#                 for i in range(self.config.pretraining_tp)
+#             ]
+#             query_states = torch.cat(query_states, dim=-1)
+
+#             key_states = [
+#                 F.linear(hidden_states, key_slices[i])
+#                 for i in range(self.config.pretraining_tp)
+#             ]
+#             key_states = torch.cat(key_states, dim=-1)
+
+#             value_states = [
+#                 F.linear(hidden_states, value_slices[i])
+#                 for i in range(self.config.pretraining_tp)
+#             ]
+#             value_states = torch.cat(value_states, dim=-1)
+
+#         else:
+#             query_states = self.q_proj(hidden_states)
+#             key_states = self.k_proj(hidden_states)
+#             value_states = self.v_proj(hidden_states)
+
+#         query_states = query_states.view(
+#             bsz, q_len, self.num_heads, self.head_dim
+#         ).transpose(1, 2)
+#         key_states = key_states.view(
+#             bsz, q_len, self.num_key_value_heads, self.head_dim
+#         ).transpose(1, 2)
+#         value_states = value_states.view(
+#             bsz, q_len, self.num_key_value_heads, self.head_dim
+#         ).transpose(1, 2)
+
+#         # remake causal mask
+#         attention_mask = _make_causal_mask(
+#             bsz=bsz,
+#             tgt_len=q_len,
+#             past_key_values_length=past_key_value[0].shape[-2] if past_key_value is not None else 0,
+#             dtype=query_states.dtype,
+#             device=query_states.device,
+#         )
+
+#         kv_seq_len = key_states.shape[-2]
+#         if past_key_value is not None:
+#             kv_seq_len += past_key_value[0].shape[-2]
+
+#         if not position_ids.nelement() > 1:
+#             position_ids[0][0] = kv_seq_len - 1
+
+#         cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+#         ### Shift Pos: query pos is min(cache_size, idx)
+#         # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+#         query_states = apply_rotary_pos_emb_single(query_states, cos, sin, position_ids)
+#         ###
+
+#         if past_key_value is not None:
+#             # reuse k, v, self_attention
+#             key_states = torch.cat([past_key_value[0], key_states], dim=2)
+#             value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+#         past_key_value = (key_states, value_states) if use_cache else None
+
+#         ### Shift Pos: key pos is the pos in cache (Rolling KV Cache and using relative pos emb)
+#         key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0)
+#         key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids)
+#         ###
+
+#         # repeat k/v heads if n_kv_heads < n_heads
+#         key_states = repeat_kv(key_states, self.num_key_value_groups)
+#         value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+#         attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
+#             self.head_dim
+#         )
+
+#         if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+#             raise ValueError(
+#                 f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
+#                 f" {attn_weights.size()}"
+#             )
+
+#         if attention_mask is not None:
+#             if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+#                 raise ValueError(
+#                     f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+#                 )
+#             attn_weights = attn_weights + attention_mask
+
+#         # upcast attention to fp32
+#         attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
+#             query_states.dtype
+#         )
+
+#         past_key_value = self.kv_cache(past_key_value, attn_weights.detach().clone())
+
+#         attn_output = torch.matmul(attn_weights, value_states)
+
+#         if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+#             raise ValueError(
+#                 f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+#                 f" {attn_output.size()}"
+#             )
+
+#         attn_output = attn_output.transpose(1, 2).contiguous()
+#         attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+#         if self.config.pretraining_tp > 1:
+#             attn_output = attn_output.split(
+#                 self.hidden_size // self.config.pretraining_tp, dim=2
+#             )
+#             o_proj_slices = self.o_proj.weight.split(
+#                 self.hidden_size // self.config.pretraining_tp, dim=1
+#             )
+#             attn_output = sum(
+#                 [
+#                     F.linear(attn_output[i], o_proj_slices[i])
+#                     for i in range(self.config.pretraining_tp)
+#                 ]
+#             )
+#         else:
+#             attn_output = self.o_proj(attn_output)
+
+#         if not output_attentions:
+#             attn_weights = None
+
+#         return attn_output, attn_weights, past_key_value
+
+
+
+
 class H2OLlamaForCausalLM(LlamaForCausalLM):
     def __init__(self, config):
         super().__init__(config)