Allen 1 vuosi sitten
vanhempi
commit
9e2072fd2c

+ 0 - 35
research/long-context-llama/H2O/cache_utils.py

@@ -360,7 +360,6 @@ class HHCache(Cache):
         self.window_length = window_length
         self.num_hh_tokens = num_hh_tokens
         self.accumulated_attention_scores: List[torch.Tensor] = []
-        self.cos_sin_cache = {}
         self._seen_tokens = 0  # Used in `generate` to keep tally of how many tokens the cache has seen
 
     def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
@@ -388,40 +387,6 @@ class HHCache(Cache):
         """
         return len(self.key_cache)
 
-    @staticmethod
-    def _rotate_half(x):
-        x1 = x[..., : x.shape[-1] // 2]
-        x2 = x[..., x.shape[-1] // 2 :]
-        return torch.cat((-x2, x1), dim=-1)
-
-    def _apply_key_rotary_pos_emb(
-        self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
-    ) -> torch.Tensor:
-        rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
-        return rotated_key_states
-
-    def _get_rerotation_cos_sin(
-        self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        if key_states.shape[-2] not in self.cos_sin_cache:
-            # Upcast to float32 temporarily for better accuracy
-            cos = cos.to(torch.float32)
-            sin = sin.to(torch.float32)
-
-            # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
-            original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
-            shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
-            original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
-            shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
-            rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
-            rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
-
-            self.cos_sin_cache[key_states.shape[-2]] = (
-                rerotation_cos.to(key_states.dtype).unsqueeze(0),
-                rerotation_sin.to(key_states.dtype).unsqueeze(0),
-            )
-        return self.cos_sin_cache[key_states.shape[-2]]
-        
     def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
         """Returns the sequence length of the cached states. A layer index can be optionally passed."""
         # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length

+ 1 - 2
research/long-context-llama/H2O/utils_llama.py

@@ -175,8 +175,7 @@ class H2OLlamaAttention(nn.Module):
         else:
             if past_key_value is not None:
                 # sin and cos are specific to RoPE models; cache_position needed for the static cache
-                cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
-                key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+                key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
 
             kv_seq_len = past_key_value.get_seq_length() if past_key_value is not None else key_states.shape[-2]
             cos, sin = self.rotary_emb(value_states, kv_seq_len)