Kaynağa Gözat

version 1.0, inference with h2o on summarization tasks

Allen 1 yıl önce
ebeveyn
işleme
525de548aa

Dosya farkı çok büyük olduğundan ihmal edildi
+ 1 - 3
research/long-context-llama/H2O/README.md


+ 0 - 16
research/long-context-llama/H2O/exp.sh

@@ -1,16 +0,0 @@
-# CUDA_VISIBLE_DEVICES=$1 python -u generation.py \
-# --input-path data/summarization/xsum.jsonl \
-# --output-path summarization_output/xsum_baseline.jsonl \
-# --model-name meta-llama/Llama-2-7b-hf 
-# 20.46/4.9/15.11
-
-CUDA_VISIBLE_DEVICES=$1 python -u generation.py \
---input-path data/summarization/xsum.jsonl \
---output-path summarization_output/xsum_h2o.jsonl \
---model-name meta-llama/Llama-2-7b-hf \
---enable_h2o_generation 
-
-# CUDA_VISIBLE_DEVICES=$1 python -u generation.py \
-# --input-path data/summarization/xsum.jsonl \
-# --output-path summarization_output/xsum_h2o.jsonl \
-# --model-name meta-llama/Llama-2-7b-hf

+ 4 - 0
research/long-context-llama/H2O/requirements.txt

@@ -0,0 +1,4 @@
+transformers
+rouge
+xopen
+needlehaystack

+ 1 - 1
research/long-context-llama/H2O/generation.py

@@ -38,7 +38,7 @@ if __name__ == '__main__':
 
     parser.add_argument("--enable_position_rolling", action='store_true')
 
-    parser.add_argument("--sample_num", type=int, default=10)
+    parser.add_argument("--sample_num", type=int, default=1000)
     parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
 
     args = parser.parse_args()

+ 0 - 2
research/long-context-llama/H2O/setup.sh

@@ -1,2 +0,0 @@
-pip install rouge
-pip install xopen

+ 0 - 614
research/long-context-llama/H2O/utils_cache.py

@@ -1,614 +0,0 @@
-from dataclasses import dataclass
-from typing import Any, Dict, List, Optional, Tuple
-
-import torch
-
-from transformers.configuration_utils import PretrainedConfig
-from transformers.utils import logging
-
-logger = logging.get_logger(__name__)
-
-@dataclass
-class Cache:
-    """
-    Base, abstract class for all caches. The actual data structure is specific to each subclass.
-    """
-
-    def update(
-        self,
-        key_states: torch.Tensor,
-        value_states: torch.Tensor,
-        layer_idx: int,
-        cache_kwargs: Optional[Dict[str, Any]] = None,
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        """
-        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
-
-        Parameters:
-            key_states (`torch.Tensor`):
-                The new key states to cache.
-            value_states (`torch.Tensor`):
-                The new value states to cache.
-            layer_idx (`int`):
-                The index of the layer to cache the states for.
-            cache_kwargs (`Dict[str, Any]`, `optional`):
-                Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
-                cache to be created.
-
-        Return:
-            A tuple containing the updated key and value states.
-        """
-        raise NotImplementedError("Make sure to implement `update` in a subclass.")
-
-    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
-        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
-        raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
-
-    def get_max_length(self) -> Optional[int]:
-        """Returns the maximum sequence length of the cached states, if there is any."""
-        raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.")
-
-    def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
-        """Given the sequence length of the new inputs, returns the usable length of the cache."""
-        # Cache without size limit -> all cache is usable
-        # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
-        #   length, we will need to evict part of the cache (and thus not all cache is usable)
-        max_length = self.get_max_length()
-        previous_seq_length = self.get_seq_length(layer_idx)
-        if max_length is not None and previous_seq_length + new_seq_length > max_length:
-            return max_length - new_seq_length
-        return previous_seq_length
-
-    @property
-    def seen_tokens(self):
-        logger.warning_once(
-            "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` "
-            "model input instead."
-        )
-        if hasattr(self, "_seen_tokens"):
-            return self._seen_tokens
-        else:
-            return None
-
-
-class DynamicCache(Cache):
-    """
-    A cache that grows dynamically as more tokens are generated. This is the default for generative models.
-
-    It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
-    `[batch_size, num_heads, seq_len, head_dim]`.
-    """
-
-    def __init__(self) -> None:
-        self.key_cache: List[torch.Tensor] = []
-        self.value_cache: List[torch.Tensor] = []
-        self._seen_tokens = 0  # Used in `generate` to keep tally of how many tokens the cache has seen
-
-    def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
-        """
-        Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
-        sequence length.
-        """
-        if layer_idx < len(self):
-            return (self.key_cache[layer_idx], self.value_cache[layer_idx])
-        else:
-            raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
-
-    def __iter__(self):
-        """
-        Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
-        keys and values
-        """
-        for layer_idx in range(len(self)):
-            yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
-
-    def __len__(self):
-        """
-        Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
-        to the number of layers in the model.
-        """
-        return len(self.key_cache)
-
-    def update(
-        self,
-        key_states: torch.Tensor,
-        value_states: torch.Tensor,
-        layer_idx: int,
-        cache_kwargs: Optional[Dict[str, Any]] = None,
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        """
-        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
-
-        Parameters:
-            key_states (`torch.Tensor`):
-                The new key states to cache.
-            value_states (`torch.Tensor`):
-                The new value states to cache.
-            layer_idx (`int`):
-                The index of the layer to cache the states for.
-            cache_kwargs (`Dict[str, Any]`, `optional`):
-                Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
-
-        Return:
-            A tuple containing the updated key and value states.
-        """
-        # Update the number of seen tokens
-        if layer_idx == 0:
-            self._seen_tokens += key_states.shape[-2]
-
-        # Update the cache
-        if len(self.key_cache) <= layer_idx:
-            self.key_cache.append(key_states)
-            self.value_cache.append(value_states)
-        else:
-            self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
-            self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
-
-        return self.key_cache[layer_idx], self.value_cache[layer_idx]
-
-    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
-        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
-        if len(self.key_cache) <= layer_idx:
-            return 0
-        return self.key_cache[layer_idx].shape[-2]
-
-    def get_max_length(self) -> Optional[int]:
-        """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
-        return None
-
-    def reorder_cache(self, beam_idx: torch.LongTensor):
-        """Reorders the cache for beam search, given the selected beam indices."""
-        for layer_idx in range(len(self.key_cache)):
-            device = self.key_cache[layer_idx].device
-            self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
-            device = self.value_cache[layer_idx].device
-            self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
-
-    def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
-        """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
-        legacy_cache = ()
-        for layer_idx in range(len(self)):
-            legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
-        return legacy_cache
-
-    @classmethod
-    def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
-        """Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
-        cache = cls()
-        if past_key_values is not None:
-            for layer_idx in range(len(past_key_values)):
-                key_states, value_states = past_key_values[layer_idx]
-                cache.update(key_states, value_states, layer_idx)
-        return cache
-
-
-class SinkCache(Cache):
-    """
-    A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
-    generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
-    tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.
-
-    It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
-    `[batch_size, num_heads, seq_len, head_dim]`.
-
-    Parameters:
-        window_length (`int`):
-            The length of the context window.
-        num_sink_tokens (`int`):
-            The number of sink tokens. See the original paper for more information.
-    """
-
-    def __init__(self, window_length: int, num_sink_tokens: int) -> None:
-        self.key_cache: List[torch.Tensor] = []
-        self.value_cache: List[torch.Tensor] = []
-        self.window_length = window_length
-        self.num_sink_tokens = num_sink_tokens
-        self.cos_sin_cache = {}
-        self._seen_tokens = 0  # Used in `generate` to keep tally of how many tokens the cache has seen
-
-    @staticmethod
-    def _rotate_half(x):
-        x1 = x[..., : x.shape[-1] // 2]
-        x2 = x[..., x.shape[-1] // 2 :]
-        return torch.cat((-x2, x1), dim=-1)
-
-    def _apply_key_rotary_pos_emb(
-        self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
-    ) -> torch.Tensor:
-        rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
-        return rotated_key_states
-
-    def _get_rerotation_cos_sin(
-        self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        if key_states.shape[-2] not in self.cos_sin_cache:
-            # Upcast to float32 temporarily for better accuracy
-            cos = cos.to(torch.float32)
-            sin = sin.to(torch.float32)
-
-            # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
-            original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
-            shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
-            original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
-            shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
-            rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
-            rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
-
-            self.cos_sin_cache[key_states.shape[-2]] = (
-                rerotation_cos.to(key_states.dtype).unsqueeze(0),
-                rerotation_sin.to(key_states.dtype).unsqueeze(0),
-            )
-        return self.cos_sin_cache[key_states.shape[-2]]
-
-    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
-        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
-        # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
-        if len(self.key_cache) <= layer_idx:
-            return 0
-        return self.key_cache[layer_idx].shape[-2]
-
-    def get_max_length(self) -> Optional[int]:
-        """Returns the maximum sequence length of the cached states."""
-        return self.window_length
-
-    def update(
-        self,
-        key_states: torch.Tensor,
-        value_states: torch.Tensor,
-        layer_idx: int,
-        cache_kwargs: Optional[Dict[str, Any]] = None,
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        """
-        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
-
-        Parameters:
-            key_states (`torch.Tensor`):
-                The new key states to cache.
-            value_states (`torch.Tensor`):
-                The new value states to cache.
-            layer_idx (`int`):
-                The index of the layer to cache the states for.
-            cache_kwargs (`Dict[str, Any]`, `optional`):
-                Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
-                `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
-                rotation as the tokens are shifted.
-
-        Return:
-            A tuple containing the updated key and value states.
-        """
-        # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
-        # with partially rotated position embeddings, like Phi or Persimmon.
-        sin = cache_kwargs.get("sin")
-        cos = cache_kwargs.get("cos")
-        partial_rotation_size = cache_kwargs.get("partial_rotation_size")
-        using_rope = cos is not None and sin is not None
-
-        # Update the number of seen tokens
-        if layer_idx == 0:
-            self._seen_tokens += key_states.shape[-2]
-
-        # [bsz, num_heads, seq_len, head_dim]
-        if len(self.key_cache) <= layer_idx:
-            # Empty cache
-            self.key_cache.append(key_states)
-            self.value_cache.append(value_states)
-
-        elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
-            # Growing cache
-            self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
-            self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
-
-        else:
-            # Shifting cache
-            keys_to_keep = self.key_cache[layer_idx][
-                :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :
-            ]
-
-            # On RoPE models, we need to recompute the Key rotation as the tokens are shifted
-            if using_rope:
-                rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
-                    key_states, cos[: self.window_length], sin[: self.window_length]
-                )
-                if partial_rotation_size is not None:
-                    keys_to_keep, keys_pass = (
-                        keys_to_keep[..., :partial_rotation_size],
-                        keys_to_keep[..., partial_rotation_size:],
-                    )
-                keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
-                if partial_rotation_size is not None:
-                    keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
-
-            # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
-            sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
-            self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)
-
-            sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
-            values_to_keep = self.value_cache[layer_idx][
-                :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
-            ]
-            self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)
-
-        return self.key_cache[layer_idx], self.value_cache[layer_idx]
-
-    def reorder_cache(self, beam_idx: torch.LongTensor):
-        """Reorders the cache for beam search, given the selected beam indices."""
-        for layer_idx in range(len(self.key_cache)):
-            device = self.key_cache[layer_idx].device
-            self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
-            device = self.value_cache[layer_idx].device
-            self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
-
-
-class HHCache(Cache):
-    """
-    A cache that apply heavy-hitter oracle (https://proceedings.neurips.cc/paper_files/paper/2023/file/6ceefa7b15572587b78ecfcebb2827f8-Paper-Conference.pdf).
-    Only the heavy-hitter and the recent tokens are stored in the cache.
-
-    It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
-    `[batch_size, num_heads, seq_len, head_dim]`.
-
-    Parameters:
-        window_length (`int`):
-            The length of the context window.
-        num_hh_tokens (`int`):
-            The number of heavy hitter tokens. See the original paper for more information.
-    """
-
-    def __init__(self, window_length: int, num_hh_tokens: int) -> None:
-        self.key_cache: List[torch.Tensor] = []
-        self.value_cache: List[torch.Tensor] = []
-        self.window_length = window_length
-        self.num_hh_tokens = num_hh_tokens
-        self.accumulated_attention_scores: List[torch.Tensor] = []
-        self._seen_tokens = 0  # Used in `generate` to keep tally of how many tokens the cache has seen
-
-    def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
-        """
-        Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
-        sequence length.
-        """
-        if layer_idx < len(self):
-            return (self.key_cache[layer_idx], self.value_cache[layer_idx], self.accumulated_attention_scores[layer_idx])
-        else:
-            raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
-
-    def __iter__(self):
-        """
-        Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
-        keys and values
-        """
-        for layer_idx in range(len(self)):
-            yield (self.key_cache[layer_idx], self.value_cache[layer_idx], self.accumulated_attention_scores[layer_idx])
-
-    def __len__(self):
-        """
-        Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
-        to the number of layers in the model.
-        """
-        return len(self.key_cache)
-
-    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
-        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
-        # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
-        if len(self.key_cache) <= layer_idx:
-            return 0
-        return self.key_cache[layer_idx].shape[-2]
-
-    def get_max_length(self) -> Optional[int]:
-        """Returns the maximum sequence length of the cached states."""
-        return self.window_length
-
-    def update(
-        self,
-        key_states: torch.Tensor,
-        value_states: torch.Tensor,
-        layer_idx: int,
-        cache_kwargs: Optional[Dict[str, Any]] = None,
-        accumulated_attention_scores: Optional[torch.Tensor] = None,
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        """
-        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
-
-        Parameters:
-            key_states (`torch.Tensor`):
-                The new key states to cache.
-            value_states (`torch.Tensor`):
-                The new value states to cache.
-            layer_idx (`int`):
-                The index of the layer to cache the states for.
-            cache_kwargs (`Dict[str, Any]`, `optional`):
-                Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
-
-        Return:
-            A tuple containing the updated key and value states.
-        """
-        # Update the number of seen tokens
-
-        if accumulated_attention_scores is not None:
-            self.accumulated_attention_scores.append(accumulated_attention_scores)
-
-        if layer_idx == 0:
-            self._seen_tokens += key_states.shape[-2]
-
-        # Update the cache
-        if len(self.key_cache) <= layer_idx:
-            self.key_cache.append(key_states)
-            self.value_cache.append(value_states)
-        else:
-            self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
-            self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
-
-        return self.key_cache[layer_idx], self.value_cache[layer_idx]
-
-    def update_slimming(
-        self,
-        attention_scores: torch.Tensor,
-        num_kv_groups: int,
-        layer_idx: int,
-        cache_kwargs: Optional[Dict[str, Any]] = None,
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        """
-        Slimming the cache based on accumulated attention scores, only keep heavy-hitters + local tokens.
-
-        Parameters:
-            attention_scores (`torch.Tensor`):
-                Attention_scores for current steps.
-            num_kv_groups (`int`):
-                The number of kv groups in repeat kv.
-            layer_idx (`int`):
-                The index of the layer to cache the states for.
-            cache_kwargs (`Dict[str, Any]`, `optional`):
-                Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
-        Return:
-            A tuple containing the updated key and value states.
-        """
-
-        # Update score metrics (Accumulated attention scores)
-        if len(self.accumulated_attention_scores) <= layer_idx:
-            self.accumulated_attention_scores.append(attention_scores.sum(2)[:,::num_kv_groups, :]) # [bs, num_heads, key_len]
-        else:
-            num_new_tokens = attention_scores.shape[2]
-            updated_attention_scores = attention_scores.sum(2)[:,::num_kv_groups, :] # [bs, num_heads, key_len]
-            updated_attention_scores[:, :, :-num_new_tokens] += self.accumulated_attention_scores[layer_idx]
-            self.accumulated_attention_scores[layer_idx] = updated_attention_scores
-
-        # Update KV Cache
-        if self.get_seq_length(layer_idx) > self.window_length:
-
-            seq_scores = self.accumulated_attention_scores[layer_idx][:, :, :-self.window_length + self.num_hh_tokens]
-            _, keep_hh_index = torch.topk(seq_scores, self.num_hh_tokens, dim=-1)
-            keep_hh_index = keep_hh_index.sort().values
-
-            keep_local_index = torch.arange(self.get_seq_length(layer_idx) - self.window_length + self.num_hh_tokens, self.get_seq_length(layer_idx), device=keep_hh_index.device).repeat(keep_hh_index.shape[0], keep_hh_index.shape[1], 1)
-            keep_index = torch.cat([keep_hh_index, keep_local_index], dim=-1)
-
-            mask = torch.zeros(self.accumulated_attention_scores[layer_idx].shape, dtype=torch.bool).to(keep_hh_index.device)
-            mask = mask.scatter(-1, keep_index, 1)
-
-            bsz, num_heads, _, head_dim = self.key_cache[layer_idx].shape
-            self.key_cache[layer_idx] = self.key_cache[layer_idx][mask].view(bsz, num_heads, -1, head_dim)
-            self.value_cache[layer_idx] = self.value_cache[layer_idx][mask].view(bsz, num_heads, -1, head_dim)
-            self.accumulated_attention_scores[layer_idx] = self.accumulated_attention_scores[layer_idx][mask].view(bsz, num_heads, -1)
-
-
-    def reorder_cache(self, beam_idx: torch.LongTensor):
-        """Reorders the cache for beam search, given the selected beam indices."""
-        for layer_idx in range(len(self.key_cache)):
-            device = self.key_cache[layer_idx].device
-            self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
-            device = self.value_cache[layer_idx].device
-            self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
-
-    def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
-        """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
-        legacy_cache = ()
-        for layer_idx in range(len(self)):
-            legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx], self.accumulated_attention_scores[layer_idx],))
-        return legacy_cache
-
-    @classmethod
-    def from_legacy_cache(cls, window_length: int, num_hh_tokens: int, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
-        """Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
-        cache = cls(window_length, num_hh_tokens)
-        if past_key_values is not None:
-            for layer_idx in range(len(past_key_values) // 3):
-                key_states = past_key_values[layer_idx * 3]
-                value_states = past_key_values[layer_idx * 3 + 1]
-                accumulated_attention_scores = past_key_values[layer_idx * 3 + 2]
-                cache.update(key_states, value_states, layer_idx, accumulated_attention_scores=accumulated_attention_scores)
-        return cache
-
-
-class StaticCache(Cache):
-    """
-    Static Cache class to be used with `torch.compile(model)`.
-
-    Parameters:
-        config (`PretrainedConfig):
-            The configuration file defining the `max_position_embeddings`, `hidden_size` and `num_attention_heads`
-            required to initialize the static cache.
-        max_batch_size (`int`):
-            The maximum batch size with which the model will be used.
-        max_cache_len (`int`):
-            The maximum sequence length with which the model will be used.
-        device (`torch.device`):
-            The device on which the cache should be initialized. Should be the same as the layer.
-        dtype (*optional*, defaults to `torch.float32`):
-            The default `dtype` to use when initializing the layer.
-    """
-
-    def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
-        super().__init__()
-        self.max_batch_size = max_batch_size
-        self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
-        # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
-        self.head_dim = (
-            config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
-        )
-
-        self.dtype = dtype if dtype is not None else torch.float32
-        self.num_key_value_heads = (
-            config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
-        )
-
-        cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
-        self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
-        self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
-
-    def update(
-        self,
-        key_states: torch.Tensor,
-        value_states: torch.Tensor,
-        layer_idx: int,
-        cache_kwargs: Optional[Dict[str, Any]] = None,
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        """
-        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
-        It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
-
-        Parameters:
-            key_states (`torch.Tensor`):
-                The new key states to cache.
-            value_states (`torch.Tensor`):
-                The new value states to cache.
-            layer_idx (`int`):
-                The index of the layer to cache the states for. Kept for backward compatibility
-            cache_kwargs (`Dict[str, Any]`, `optional`):
-                Additional arguments for the cache subclass. The `StaticCache` just needs the `q_len`
-                to know how much of the cache it should overwrite.
-
-        Return:
-            A tuple containing the updated key and value states.
-        """
-        new_cache_positions = cache_kwargs.get("cache_position")
-        k_out = self.key_cache
-        v_out = self.value_cache
-
-        k_out[:, :, new_cache_positions] = key_states
-        v_out[:, :, new_cache_positions] = value_states
-
-        return k_out, v_out
-
-    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
-        """Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
-        # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
-        # limit the check to the first batch member and head dimension.
-        # TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after
-        # https://github.com/pytorch/pytorch/issues/120248 is fixed
-        return (self.key_cache[0, 0].any(dim=-1)).sum()
-
-    def get_max_length(self) -> Optional[int]:
-        """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
-        return self.max_cache_len
-
-    def reorder_cache(self, beam_idx: torch.LongTensor):
-        """Reorders the cache for beam search, given the selected beam indices."""
-        device = self.key_cache.device
-        self.key_cache = self.key_cache.index_select(0, beam_idx.to(device))
-        device = self.value_cache.device
-        self.value_cache = self.value_cache.index_select(0, beam_idx.to(device))
-
-    def to_legacy_cache(self):
-        """Dummy function for BC. We have to keep it because otherwise the call in the forward of models will break it"""
-        return None
-

+ 0 - 449
research/long-context-llama/H2O/utils_llama.py

@@ -1,449 +0,0 @@
-import math
-from typing import Any, Dict, List, Optional, Tuple, Union
-
-import pdb
-import types
-import torch
-from torch import nn
-import torch.utils.checkpoint
-import torch.nn.functional as F
-
-from transformers.models.llama.configuration_llama import LlamaConfig
-from transformers.models.llama.modeling_llama import (
-    LlamaAttention,
-    rotate_half,
-    apply_rotary_pos_emb,
-    repeat_kv,
-    LlamaRotaryEmbedding,
-    LlamaForCausalLM,
-)
-from kv_cache import Cache, HHCache, StaticCache
-from transformers.utils import logging
-from transformers.modeling_outputs import BaseModelOutputWithPast
-
-logger = logging.get_logger(__name__)
-
-__all__ = ["H2OLlamaForCausalLM"]
-
-def _make_causal_mask(
-    bsz: int, tgt_len: int, past_key_values_length: int, dtype: torch.dtype, device: torch.device):
-    """
-    Make causal mask used for bi-directional self-attention.
-    """
-    mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
-    mask_cond = torch.arange(mask.size(-1), device=device)
-    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
-    mask = mask.to(dtype)
-
-    if past_key_values_length > 0:
-        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
-    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
-
-def apply_rotary_pos_emb_single(x, cos, sin, position_ids=None, unsqueeze_dim=1):
-
-    cos = cos.unsqueeze(unsqueeze_dim)
-    sin = sin.unsqueeze(unsqueeze_dim)
-    x_embed = (x * cos) + (rotate_half(x) * sin)
-
-    return x_embed
-
-def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
-    """
-    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
-    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
-    """
-    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
-    if n_rep == 1:
-        return hidden_states
-    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
-    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
-
-
-class H2OLlamaAttention(nn.Module):
-    """Multi-headed attention from 'Attention Is All You Need' paper"""
-
-    def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
-        super().__init__()
-        self.config = config
-        self.layer_idx = layer_idx
-        if layer_idx is None:
-            logger.warning_once(
-                f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
-                "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
-                "when creating this class."
-            )
-
-        self.attention_dropout = config.attention_dropout
-        self.hidden_size = config.hidden_size
-        self.num_heads = config.num_attention_heads
-        self.head_dim = self.hidden_size // self.num_heads
-        self.num_key_value_heads = config.num_key_value_heads
-        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
-        self.max_position_embeddings = config.max_position_embeddings
-        self.rope_theta = config.rope_theta
-        self.is_causal = True
-        self.positional_rolling = config.enable_position_rolling
-
-        if (self.head_dim * self.num_heads) != self.hidden_size:
-            raise ValueError(
-                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
-                f" and `num_heads`: {self.num_heads})."
-            )
-
-        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
-        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
-        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
-        self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
-        self._init_rope()
-
-    def _init_rope(self):
-        if self.config.rope_scaling is None:
-            self.rotary_emb = LlamaRotaryEmbedding(
-                self.head_dim,
-                max_position_embeddings=self.max_position_embeddings,
-                base=self.rope_theta,
-            )
-        else:
-            scaling_type = self.config.rope_scaling["type"]
-            scaling_factor = self.config.rope_scaling["factor"]
-            if scaling_type == "linear":
-                self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
-                    self.head_dim,
-                    max_position_embeddings=self.max_position_embeddings,
-                    scaling_factor=scaling_factor,
-                    base=self.rope_theta,
-                )
-            elif scaling_type == "dynamic":
-                self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
-                    self.head_dim,
-                    max_position_embeddings=self.max_position_embeddings,
-                    scaling_factor=scaling_factor,
-                    base=self.rope_theta,
-                )
-            else:
-                raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
-
-    def forward(
-        self,
-        hidden_states: torch.Tensor,
-        attention_mask: Optional[torch.Tensor] = None,
-        position_ids: Optional[torch.LongTensor] = None,
-        past_key_value: Optional[Cache] = None,
-        output_attentions: bool = False,
-        use_cache: bool = False,
-        cache_position: Optional[torch.LongTensor] = None,
-        **kwargs,
-    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
-        bsz, q_len, _ = hidden_states.size()
-
-        if self.config.pretraining_tp > 1:
-            key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
-            query_slices = self.q_proj.weight.split(
-                (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
-            )
-            key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
-            value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
-
-            query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
-            query_states = torch.cat(query_states, dim=-1)
-
-            key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
-            key_states = torch.cat(key_states, dim=-1)
-
-            value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
-            value_states = torch.cat(value_states, dim=-1)
-
-        else:
-            query_states = self.q_proj(hidden_states)
-            key_states = self.k_proj(hidden_states)
-            value_states = self.v_proj(hidden_states)
-
-        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
-        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
-        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
-
-        past_key_value = getattr(self, "past_key_value", past_key_value)
-
-        if not self.positional_rolling:
-            cos, sin = self.rotary_emb(value_states, position_ids)
-            query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
-            if past_key_value is not None:
-                # sin and cos are specific to RoPE models; cache_position needed for the static cache
-                cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
-                key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
-        else:
-            if past_key_value is not None:
-                # sin and cos are specific to RoPE models; cache_position needed for the static cache
-                key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
-
-            kv_seq_len = past_key_value.get_seq_length(self.layer_idx) if past_key_value is not None else key_states.shape[-2]
-
-            if not position_ids.nelement() > 1:
-                # decoding stage
-                key_position_ids = torch.arange(kv_seq_len, device=hidden_states.device).unsqueeze(0)
-                query_position_ids = key_position_ids[:, -1].unsqueeze(0)
-            else:
-                query_position_ids = position_ids
-                key_position_ids = position_ids
-
-            key_cos, key_sin = self.rotary_emb(value_states, key_position_ids)
-            query_cos, query_sin = self.rotary_emb(value_states, query_position_ids)
-
-            if self.layer_idx == 0:
-                print(kv_seq_len, query_position_ids, key_position_ids)
-
-            query_states = apply_rotary_pos_emb_single(query_states, query_cos, query_sin)
-            key_states = apply_rotary_pos_emb_single(key_states, key_cos, key_sin)
-
-        key_states = repeat_kv(key_states, self.num_key_value_groups)
-        value_states = repeat_kv(value_states, self.num_key_value_groups)
-
-        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
-
-        if attention_mask is not None:  # no matter the length, we just slice it
-            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
-            attn_weights = attn_weights + causal_mask
-
-        # upcast attention to fp32
-        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
-
-        # Update KV Cache based on Heavy-Hitter Oracle
-        if past_key_value is not None:
-            past_key_value.update_slimming(attn_weights, self.num_key_value_groups, self.layer_idx)
-
-        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
-        attn_output = torch.matmul(attn_weights, value_states)
-
-        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
-            raise ValueError(
-                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
-                f" {attn_output.size()}"
-            )
-
-        attn_output = attn_output.transpose(1, 2).contiguous()
-
-        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
-
-        if self.config.pretraining_tp > 1:
-            attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
-            o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
-            attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
-        else:
-            attn_output = self.o_proj(attn_output)
-
-        if not output_attentions:
-            attn_weights = None
-        
-        return attn_output, attn_weights, past_key_value
-
-
-def enable_h2ocache_forward(
-    self,
-    input_ids: torch.LongTensor = None,
-    attention_mask: Optional[torch.Tensor] = None,
-    position_ids: Optional[torch.LongTensor] = None,
-    past_key_values: Optional[List[torch.FloatTensor]] = None,
-    inputs_embeds: Optional[torch.FloatTensor] = None,
-    use_cache: Optional[bool] = None,
-    output_attentions: Optional[bool] = None,
-    output_hidden_states: Optional[bool] = None,
-    return_dict: Optional[bool] = None,
-    cache_position: Optional[torch.LongTensor] = None,
-) -> Union[Tuple, BaseModelOutputWithPast]:
-    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
-    output_hidden_states = (
-        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
-    )
-    use_cache = use_cache if use_cache is not None else self.config.use_cache
-    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
-    if (input_ids is None) ^ (inputs_embeds is not None):
-        raise ValueError(
-            "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
-        )
-
-    if self.gradient_checkpointing and self.training and use_cache:
-        logger.warning_once(
-            "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
-        )
-        use_cache = False
-
-    if inputs_embeds is None:
-        inputs_embeds = self.embed_tokens(input_ids)
-
-    past_seen_tokens = 0
-    if use_cache:  # kept for BC (cache positions)
-        if not isinstance(past_key_values, StaticCache):
-            past_key_values = HHCache.from_legacy_cache(self.num_window_length, self.num_heavy_hitter_tokens, past_key_values)
-            past_seen_tokens = past_key_values.get_seq_length()
-
-    if cache_position is None:
-        if isinstance(past_key_values, StaticCache):
-            raise ValueError("cache_position is a required argument when using StaticCache.")
-        cache_position = torch.arange(
-            past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
-        )
-
-    if position_ids is None:
-        position_ids = cache_position.unsqueeze(0)
-
-    causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
-
-    # embed positions
-    hidden_states = inputs_embeds
-
-    # decoder layers
-    all_hidden_states = () if output_hidden_states else None
-    all_self_attns = () if output_attentions else None
-    next_decoder_cache = None
-
-    for decoder_layer in self.layers:
-        if output_hidden_states:
-            all_hidden_states += (hidden_states,)
-
-        if self.gradient_checkpointing and self.training:
-            layer_outputs = self._gradient_checkpointing_func(
-                decoder_layer.__call__,
-                hidden_states,
-                causal_mask,
-                position_ids,
-                past_key_values,
-                output_attentions,
-                use_cache,
-                cache_position,
-            )
-        else:
-            layer_outputs = decoder_layer(
-                hidden_states,
-                attention_mask=causal_mask,
-                position_ids=position_ids,
-                past_key_value=past_key_values,
-                output_attentions=output_attentions,
-                use_cache=use_cache,
-                cache_position=cache_position,
-            )
-
-        hidden_states = layer_outputs[0]
-
-        if use_cache:
-            next_decoder_cache = layer_outputs[2 if output_attentions else 1]
-
-        if output_attentions:
-            all_self_attns += (layer_outputs[1],)
-
-    hidden_states = self.norm(hidden_states)
-
-    # add hidden states from the last decoder layer
-    if output_hidden_states:
-        all_hidden_states += (hidden_states,)
-
-    next_cache = None
-    if use_cache:
-        next_cache = (
-            next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
-        )
-    if not return_dict:
-        return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
-    
-    return BaseModelOutputWithPast(
-        last_hidden_state=hidden_states,
-        past_key_values=next_cache,
-        hidden_states=all_hidden_states,
-        attentions=all_self_attns,
-    )
-
-class H2OLlamaForCausalLM(LlamaForCausalLM):
-    def __init__(self, config):
-        super().__init__(config)
-        num_layers = len(self.model.layers)
-        for layer_idx in range(num_layers):
-            self.model.layers[layer_idx].self_attn = H2OLlamaAttention(config, layer_idx)
-
-        self.model.forward = types.MethodType(enable_h2ocache_forward, self.model)
-        self.model.num_heavy_hitter_tokens = config.num_heavy_hitter_tokens
-        self.model.num_window_length = config.num_window_length
-    
-    def prepare_inputs_for_generation(
-        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
-    ):
-        # With static cache, the `past_key_values` is None
-        # TODO joao: standardize interface for the different Cache classes and remove of this if
-
-        has_static_cache = False
-        if past_key_values is None:
-            past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None)
-            has_static_cache = past_key_values is not None
-
-        past_length = 0
-        if past_key_values is not None:
-            if isinstance(past_key_values, Cache):
-                past_length = cache_position[0]
-                max_cache_length = (
-                    torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
-                    if past_key_values.get_max_length() is not None
-                    else None
-                )
-                cache_length = past_key_values.get_seq_length()
-
-            # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
-            else:
-                past_length = cache_position[0]
-                cache_length = past_key_values[0].shape[2] # length = num_layers * 3 (3 -> key, value, score)
-                max_cache_length = None
-
-            # Keep only the unprocessed tokens:
-            # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
-            # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
-            # input)
-            if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
-                input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
-            # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
-            # input_ids based on the past_length.
-            elif past_length < input_ids.shape[1]:
-                input_ids = input_ids[:, past_length:]
-            # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
-
-            # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
-            if (
-                max_cache_length is not None
-                and attention_mask is not None
-                and cache_length + input_ids.shape[1] > max_cache_length
-            ):
-                attention_mask = attention_mask[:, -max_cache_length:]
-
-        position_ids = kwargs.get("position_ids", None)
-        if attention_mask is not None and position_ids is None:
-            # create position_ids on the fly for batch generation
-            position_ids = attention_mask.long().cumsum(-1) - 1
-            position_ids.masked_fill_(attention_mask == 0, 1)
-            if past_key_values:
-                position_ids = position_ids[:, -input_ids.shape[1] :]
-
-        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
-        if inputs_embeds is not None and past_key_values is None:
-            model_inputs = {"inputs_embeds": inputs_embeds}
-        else:
-            # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
-            # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
-            # TODO: use `next_tokens` directly instead.
-            model_inputs = {"input_ids": input_ids.contiguous()}
-
-        input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
-        if cache_position is None:
-            cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
-        else:
-            cache_position = cache_position[-input_length:]
-
-        if has_static_cache:
-            past_key_values = None
-
-        model_inputs.update(
-            {
-                "position_ids": position_ids,
-                "cache_position": cache_position,
-                "past_key_values": past_key_values,
-                "use_cache": kwargs.get("use_cache"),
-                "attention_mask": attention_mask,
-            }
-        )
-        return model_inputs