|
@@ -360,7 +360,6 @@ class HHCache(Cache):
|
|
|
self.window_length = window_length
|
|
|
self.num_hh_tokens = num_hh_tokens
|
|
|
self.accumulated_attention_scores: List[torch.Tensor] = []
|
|
|
- self.cos_sin_cache = {}
|
|
|
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
|
|
|
|
|
|
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
|
|
@@ -388,40 +387,6 @@ class HHCache(Cache):
|
|
|
"""
|
|
|
return len(self.key_cache)
|
|
|
|
|
|
- @staticmethod
|
|
|
- def _rotate_half(x):
|
|
|
- x1 = x[..., : x.shape[-1] // 2]
|
|
|
- x2 = x[..., x.shape[-1] // 2 :]
|
|
|
- return torch.cat((-x2, x1), dim=-1)
|
|
|
-
|
|
|
- def _apply_key_rotary_pos_emb(
|
|
|
- self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
|
|
- ) -> torch.Tensor:
|
|
|
- rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
|
|
|
- return rotated_key_states
|
|
|
-
|
|
|
- def _get_rerotation_cos_sin(
|
|
|
- self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
|
|
- ) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
- if key_states.shape[-2] not in self.cos_sin_cache:
|
|
|
- # Upcast to float32 temporarily for better accuracy
|
|
|
- cos = cos.to(torch.float32)
|
|
|
- sin = sin.to(torch.float32)
|
|
|
-
|
|
|
- # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
|
|
|
- original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
|
|
|
- shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
|
|
|
- original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
|
|
|
- shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
|
|
|
- rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
|
|
|
- rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
|
|
|
-
|
|
|
- self.cos_sin_cache[key_states.shape[-2]] = (
|
|
|
- rerotation_cos.to(key_states.dtype).unsqueeze(0),
|
|
|
- rerotation_sin.to(key_states.dtype).unsqueeze(0),
|
|
|
- )
|
|
|
- return self.cos_sin_cache[key_states.shape[-2]]
|
|
|
-
|
|
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
|
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
|
|
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
|