Explorar el Código

Update cache_utils.py

Allen hace 1 año
padre
commit
905e546bfc
Se han modificado 1 ficheros con 95 adiciones y 11 borrados
  1. 95 11
      research/long-context-llama/H2O/cache_utils.py

+ 95 - 11
research/long-context-llama/H2O/cache_utils.py

@@ -354,11 +354,61 @@ class HHCache(Cache):
             The number of heavy hitter tokens. See the original paper for more information.
     """
 
-    def __init__(self) -> None:
+    def __init__(self, window_length: int, num_hh_tokens: int) -> None:
         self.key_cache: List[torch.Tensor] = []
         self.value_cache: List[torch.Tensor] = []
+        self.window_length = window_length
+        self.num_hh_tokens = num_hh_tokens
+        self.accumulated_attention_scores = {}
+        self.cos_sin_cache = {}
         self._seen_tokens = 0  # Used in `generate` to keep tally of how many tokens the cache has seen
 
+    @staticmethod
+    def _rotate_half(x):
+        x1 = x[..., : x.shape[-1] // 2]
+        x2 = x[..., x.shape[-1] // 2 :]
+        return torch.cat((-x2, x1), dim=-1)
+
+    def _apply_key_rotary_pos_emb(
+        self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
+    ) -> torch.Tensor:
+        rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
+        return rotated_key_states
+
+    def _get_rerotation_cos_sin(
+        self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        if key_states.shape[-2] not in self.cos_sin_cache:
+            # Upcast to float32 temporarily for better accuracy
+            cos = cos.to(torch.float32)
+            sin = sin.to(torch.float32)
+
+            # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
+            original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
+            shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
+            original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
+            shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
+            rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
+            rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
+
+            self.cos_sin_cache[key_states.shape[-2]] = (
+                rerotation_cos.to(key_states.dtype).unsqueeze(0),
+                rerotation_sin.to(key_states.dtype).unsqueeze(0),
+            )
+        return self.cos_sin_cache[key_states.shape[-2]]
+
+    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+        # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
+        if len(self.key_cache) <= layer_idx:
+            return 0
+        return self.key_cache[layer_idx].shape[-2]
+
+    def get_max_length(self) -> Optional[int]:
+        """Returns the maximum sequence length of the cached states."""
+        return self.window_length
+
+
     def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
         """
         Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
@@ -388,6 +438,7 @@ class HHCache(Cache):
         self,
         key_states: torch.Tensor,
         value_states: torch.Tensor,
+        attention_scores: torch.Tensor,
         layer_idx: int,
         cache_kwargs: Optional[Dict[str, Any]] = None,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -413,24 +464,50 @@ class HHCache(Cache):
             self._seen_tokens += key_states.shape[-2]
 
         # Update the cache
+        # [bsz, num_heads, seq_len, head_dim]
         if len(self.key_cache) <= layer_idx:
+            # Empty cache
             self.key_cache.append(key_states)
             self.value_cache.append(value_states)
-        else:
+
+        elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
+            # Growing cache
             self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
             self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
 
-        return self.key_cache[layer_idx], self.value_cache[layer_idx]
+        else:
+            # Shifting cache 
+            # Keeping the local tokens
+            keys_to_keep = self.key_cache[layer_idx][
+                :, :, -self.window_length + self.num_hh_tokens + key_states.shape[-2] :
+            ]
 
-    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
-        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
-        if len(self.key_cache) <= layer_idx:
-            return 0
-        return self.key_cache[layer_idx].shape[-2]
+            # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
+            sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
+            self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)
+
+            sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
+            values_to_keep = self.value_cache[layer_idx][
+                :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
+            ]
+            self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)
+
+            # On RoPE models, we need to recompute the Key rotation as the tokens are shifted
+            if using_rope:
+                rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
+                    key_states, cos[: self.window_length], sin[: self.window_length]
+                )
+                if partial_rotation_size is not None:
+                    keys_to_keep, keys_pass = (
+                        keys_to_keep[..., :partial_rotation_size],
+                        keys_to_keep[..., partial_rotation_size:],
+                    )
+                keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
+                if partial_rotation_size is not None:
+                    keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
+
+        return self.key_cache[layer_idx], self.value_cache[layer_idx]
 
-    def get_max_length(self) -> Optional[int]:
-        """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
-        return None
 
     def reorder_cache(self, beam_idx: torch.LongTensor):
         """Reorders the cache for beam search, given the selected beam indices."""
@@ -440,6 +517,12 @@ class HHCache(Cache):
             device = self.value_cache[layer_idx].device
             self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
 
+
+
+
+
+
+
     def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
         """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
         legacy_cache = ()
@@ -450,6 +533,7 @@ class HHCache(Cache):
     @classmethod
     def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
         """Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
+        import pdb; pdb.set_trace()
         cache = cls()
         if past_key_values is not None:
             for layer_idx in range(len(past_key_values)):