Allen 1 yıl önce
ebeveyn
işleme
266d5dcedc

+ 58 - 96
research/long-context-llama/H2O/cache_utils.py

@@ -341,9 +341,8 @@ class SinkCache(Cache):
 
 class HHCache(Cache):
     """
-    A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
-    generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
-    tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.
+    A cache that apply heavy-hitter oracle (https://proceedings.neurips.cc/paper_files/paper/2023/file/6ceefa7b15572587b78ecfcebb2827f8-Paper-Conference.pdf).
+    Only the heavy-hitter and the recent tokens are stored in the cache.
 
     It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
     `[batch_size, num_heads, seq_len, head_dim]`.
@@ -351,62 +350,39 @@ class HHCache(Cache):
     Parameters:
         window_length (`int`):
             The length of the context window.
-        num_sink_tokens (`int`):
-            The number of sink tokens. See the original paper for more information.
+        num_hh_tokens (`int`):
+            The number of heavy hitter tokens. See the original paper for more information.
     """
 
-    def __init__(self, window_length: int, num_sink_tokens: int) -> None:
+    def __init__(self) -> None:
         self.key_cache: List[torch.Tensor] = []
         self.value_cache: List[torch.Tensor] = []
-        self.window_length = window_length
-        self.num_sink_tokens = num_sink_tokens
-        self.cos_sin_cache = {}
         self._seen_tokens = 0  # Used in `generate` to keep tally of how many tokens the cache has seen
 
-    @staticmethod
-    def _rotate_half(x):
-        x1 = x[..., : x.shape[-1] // 2]
-        x2 = x[..., x.shape[-1] // 2 :]
-        return torch.cat((-x2, x1), dim=-1)
-
-    def _apply_key_rotary_pos_emb(
-        self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
-    ) -> torch.Tensor:
-        rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
-        return rotated_key_states
-
-    def _get_rerotation_cos_sin(
-        self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        if key_states.shape[-2] not in self.cos_sin_cache:
-            # Upcast to float32 temporarily for better accuracy
-            cos = cos.to(torch.float32)
-            sin = sin.to(torch.float32)
-
-            # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
-            original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
-            shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
-            original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
-            shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
-            rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
-            rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
-
-            self.cos_sin_cache[key_states.shape[-2]] = (
-                rerotation_cos.to(key_states.dtype).unsqueeze(0),
-                rerotation_sin.to(key_states.dtype).unsqueeze(0),
-            )
-        return self.cos_sin_cache[key_states.shape[-2]]
+    def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
+        """
+        Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
+        sequence length.
+        """
+        if layer_idx < len(self):
+            return (self.key_cache[layer_idx], self.value_cache[layer_idx])
+        else:
+            raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
 
-    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
-        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
-        # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
-        if len(self.key_cache) <= layer_idx:
-            return 0
-        return self.key_cache[layer_idx].shape[-2]
+    def __iter__(self):
+        """
+        Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
+        keys and values
+        """
+        for layer_idx in range(len(self)):
+            yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
 
-    def get_max_length(self) -> Optional[int]:
-        """Returns the maximum sequence length of the cached states."""
-        return self.window_length
+    def __len__(self):
+        """
+        Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
+        to the number of layers in the model.
+        """
+        return len(self.key_cache)
 
     def update(
         self,
@@ -426,66 +402,34 @@ class HHCache(Cache):
             layer_idx (`int`):
                 The index of the layer to cache the states for.
             cache_kwargs (`Dict[str, Any]`, `optional`):
-                Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
-                `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
-                rotation as the tokens are shifted.
+                Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
 
         Return:
             A tuple containing the updated key and value states.
         """
-        # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
-        # with partially rotated position embeddings, like Phi or Persimmon.
-        sin = cache_kwargs.get("sin")
-        cos = cache_kwargs.get("cos")
-        partial_rotation_size = cache_kwargs.get("partial_rotation_size")
-        using_rope = cos is not None and sin is not None
-
         # Update the number of seen tokens
         if layer_idx == 0:
             self._seen_tokens += key_states.shape[-2]
 
-        # [bsz, num_heads, seq_len, head_dim]
+        # Update the cache
         if len(self.key_cache) <= layer_idx:
-            # Empty cache
             self.key_cache.append(key_states)
             self.value_cache.append(value_states)
-
-        elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
-            # Growing cache
+        else:
             self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
             self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
 
-        else:
-            # Shifting cache
-            keys_to_keep = self.key_cache[layer_idx][
-                :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :
-            ]
-
-            # On RoPE models, we need to recompute the Key rotation as the tokens are shifted
-            if using_rope:
-                rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
-                    key_states, cos[: self.window_length], sin[: self.window_length]
-                )
-                if partial_rotation_size is not None:
-                    keys_to_keep, keys_pass = (
-                        keys_to_keep[..., :partial_rotation_size],
-                        keys_to_keep[..., partial_rotation_size:],
-                    )
-                keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
-                if partial_rotation_size is not None:
-                    keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
-
-            # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
-            sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
-            self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)
+        return self.key_cache[layer_idx], self.value_cache[layer_idx]
 
-            sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
-            values_to_keep = self.value_cache[layer_idx][
-                :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
-            ]
-            self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)
+    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+        if len(self.key_cache) <= layer_idx:
+            return 0
+        return self.key_cache[layer_idx].shape[-2]
 
-        return self.key_cache[layer_idx], self.value_cache[layer_idx]
+    def get_max_length(self) -> Optional[int]:
+        """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
+        return None
 
     def reorder_cache(self, beam_idx: torch.LongTensor):
         """Reorders the cache for beam search, given the selected beam indices."""
@@ -495,6 +439,23 @@ class HHCache(Cache):
             device = self.value_cache[layer_idx].device
             self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
 
+    def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
+        """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
+        legacy_cache = ()
+        for layer_idx in range(len(self)):
+            legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
+        return legacy_cache
+
+    @classmethod
+    def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
+        """Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
+        cache = cls()
+        if past_key_values is not None:
+            for layer_idx in range(len(past_key_values)):
+                key_states, value_states = past_key_values[layer_idx]
+                cache.update(key_states, value_states, layer_idx)
+        return cache
+
 
 class StaticCache(Cache):
     """
@@ -587,4 +548,5 @@ class StaticCache(Cache):
 
     def to_legacy_cache(self):
         """Dummy function for BC. We have to keep it because otherwise the call in the forward of models will break it"""
-        return None
+        return None
+

+ 2 - 0
research/long-context-llama/H2O/utils_llama.py

@@ -252,6 +252,8 @@ class H2OLlamaAttention(nn.Module):
         key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
         value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
 
+        print(self.past_key_value)
+
         past_key_value = getattr(self, "past_key_value", past_key_value)
         cos, sin = self.rotary_emb(value_states, position_ids)
         query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)