Allen 1 year ago
parent
commit
266d5dcedc

+ 58 - 96
research/long-context-llama/H2O/cache_utils.py

@@ -341,9 +341,8 @@ class SinkCache(Cache):
 
 
 class HHCache(Cache):
 class HHCache(Cache):
     """
     """
-    A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
-    generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
-    tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.
+    A cache that apply heavy-hitter oracle (https://proceedings.neurips.cc/paper_files/paper/2023/file/6ceefa7b15572587b78ecfcebb2827f8-Paper-Conference.pdf).
+    Only the heavy-hitter and the recent tokens are stored in the cache.
 
 
     It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
     It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
     `[batch_size, num_heads, seq_len, head_dim]`.
     `[batch_size, num_heads, seq_len, head_dim]`.
@@ -351,62 +350,39 @@ class HHCache(Cache):
     Parameters:
     Parameters:
         window_length (`int`):
         window_length (`int`):
             The length of the context window.
             The length of the context window.
-        num_sink_tokens (`int`):
-            The number of sink tokens. See the original paper for more information.
+        num_hh_tokens (`int`):
+            The number of heavy hitter tokens. See the original paper for more information.
     """
     """
 
 
-    def __init__(self, window_length: int, num_sink_tokens: int) -> None:
+    def __init__(self) -> None:
         self.key_cache: List[torch.Tensor] = []
         self.key_cache: List[torch.Tensor] = []
         self.value_cache: List[torch.Tensor] = []
         self.value_cache: List[torch.Tensor] = []
-        self.window_length = window_length
-        self.num_sink_tokens = num_sink_tokens
-        self.cos_sin_cache = {}
         self._seen_tokens = 0  # Used in `generate` to keep tally of how many tokens the cache has seen
         self._seen_tokens = 0  # Used in `generate` to keep tally of how many tokens the cache has seen
 
 
-    @staticmethod
-    def _rotate_half(x):
-        x1 = x[..., : x.shape[-1] // 2]
-        x2 = x[..., x.shape[-1] // 2 :]
-        return torch.cat((-x2, x1), dim=-1)
-
-    def _apply_key_rotary_pos_emb(
-        self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
-    ) -> torch.Tensor:
-        rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
-        return rotated_key_states
-
-    def _get_rerotation_cos_sin(
-        self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        if key_states.shape[-2] not in self.cos_sin_cache:
-            # Upcast to float32 temporarily for better accuracy
-            cos = cos.to(torch.float32)
-            sin = sin.to(torch.float32)
-
-            # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
-            original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
-            shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
-            original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
-            shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
-            rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
-            rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
-
-            self.cos_sin_cache[key_states.shape[-2]] = (
-                rerotation_cos.to(key_states.dtype).unsqueeze(0),
-                rerotation_sin.to(key_states.dtype).unsqueeze(0),
-            )
-        return self.cos_sin_cache[key_states.shape[-2]]
+    def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
+        """
+        Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
+        sequence length.
+        """
+        if layer_idx < len(self):
+            return (self.key_cache[layer_idx], self.value_cache[layer_idx])
+        else:
+            raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
 
 
-    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
-        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
-        # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
-        if len(self.key_cache) <= layer_idx:
-            return 0
-        return self.key_cache[layer_idx].shape[-2]
+    def __iter__(self):
+        """
+        Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
+        keys and values
+        """
+        for layer_idx in range(len(self)):
+            yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
 
 
-    def get_max_length(self) -> Optional[int]:
-        """Returns the maximum sequence length of the cached states."""
-        return self.window_length
+    def __len__(self):
+        """
+        Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
+        to the number of layers in the model.
+        """
+        return len(self.key_cache)
 
 
     def update(
     def update(
         self,
         self,
@@ -426,66 +402,34 @@ class HHCache(Cache):
             layer_idx (`int`):
             layer_idx (`int`):
                 The index of the layer to cache the states for.
                 The index of the layer to cache the states for.
             cache_kwargs (`Dict[str, Any]`, `optional`):
             cache_kwargs (`Dict[str, Any]`, `optional`):
-                Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
-                `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
-                rotation as the tokens are shifted.
+                Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
 
 
         Return:
         Return:
             A tuple containing the updated key and value states.
             A tuple containing the updated key and value states.
         """
         """
-        # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
-        # with partially rotated position embeddings, like Phi or Persimmon.
-        sin = cache_kwargs.get("sin")
-        cos = cache_kwargs.get("cos")
-        partial_rotation_size = cache_kwargs.get("partial_rotation_size")
-        using_rope = cos is not None and sin is not None
-
         # Update the number of seen tokens
         # Update the number of seen tokens
         if layer_idx == 0:
         if layer_idx == 0:
             self._seen_tokens += key_states.shape[-2]
             self._seen_tokens += key_states.shape[-2]
 
 
-        # [bsz, num_heads, seq_len, head_dim]
+        # Update the cache
         if len(self.key_cache) <= layer_idx:
         if len(self.key_cache) <= layer_idx:
-            # Empty cache
             self.key_cache.append(key_states)
             self.key_cache.append(key_states)
             self.value_cache.append(value_states)
             self.value_cache.append(value_states)
-
-        elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
-            # Growing cache
+        else:
             self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
             self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
             self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
             self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
 
 
-        else:
-            # Shifting cache
-            keys_to_keep = self.key_cache[layer_idx][
-                :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :
-            ]
-
-            # On RoPE models, we need to recompute the Key rotation as the tokens are shifted
-            if using_rope:
-                rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
-                    key_states, cos[: self.window_length], sin[: self.window_length]
-                )
-                if partial_rotation_size is not None:
-                    keys_to_keep, keys_pass = (
-                        keys_to_keep[..., :partial_rotation_size],
-                        keys_to_keep[..., partial_rotation_size:],
-                    )
-                keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
-                if partial_rotation_size is not None:
-                    keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
-
-            # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
-            sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
-            self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)
+        return self.key_cache[layer_idx], self.value_cache[layer_idx]
 
 
-            sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
-            values_to_keep = self.value_cache[layer_idx][
-                :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
-            ]
-            self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)
+    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+        if len(self.key_cache) <= layer_idx:
+            return 0
+        return self.key_cache[layer_idx].shape[-2]
 
 
-        return self.key_cache[layer_idx], self.value_cache[layer_idx]
+    def get_max_length(self) -> Optional[int]:
+        """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
+        return None
 
 
     def reorder_cache(self, beam_idx: torch.LongTensor):
     def reorder_cache(self, beam_idx: torch.LongTensor):
         """Reorders the cache for beam search, given the selected beam indices."""
         """Reorders the cache for beam search, given the selected beam indices."""
@@ -495,6 +439,23 @@ class HHCache(Cache):
             device = self.value_cache[layer_idx].device
             device = self.value_cache[layer_idx].device
             self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
             self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
 
 
+    def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
+        """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
+        legacy_cache = ()
+        for layer_idx in range(len(self)):
+            legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
+        return legacy_cache
+
+    @classmethod
+    def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
+        """Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
+        cache = cls()
+        if past_key_values is not None:
+            for layer_idx in range(len(past_key_values)):
+                key_states, value_states = past_key_values[layer_idx]
+                cache.update(key_states, value_states, layer_idx)
+        return cache
+
 
 
 class StaticCache(Cache):
 class StaticCache(Cache):
     """
     """
@@ -587,4 +548,5 @@ class StaticCache(Cache):
 
 
     def to_legacy_cache(self):
     def to_legacy_cache(self):
         """Dummy function for BC. We have to keep it because otherwise the call in the forward of models will break it"""
         """Dummy function for BC. We have to keep it because otherwise the call in the forward of models will break it"""
-        return None
+        return None
+

+ 2 - 0
research/long-context-llama/H2O/utils_llama.py

@@ -252,6 +252,8 @@ class H2OLlamaAttention(nn.Module):
         key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
         key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
         value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
         value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
 
 
+        print(self.past_key_value)
+
         past_key_value = getattr(self, "past_key_value", past_key_value)
         past_key_value = getattr(self, "past_key_value", past_key_value)
         cos, sin = self.rotary_emb(value_states, position_ids)
         cos, sin = self.rotary_emb(value_states, position_ids)
         query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
         query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)