|
@@ -412,7 +412,6 @@ class HHCache(Cache):
|
|
self,
|
|
self,
|
|
key_states: torch.Tensor,
|
|
key_states: torch.Tensor,
|
|
value_states: torch.Tensor,
|
|
value_states: torch.Tensor,
|
|
- attention_scores: torch.Tensor,
|
|
|
|
layer_idx: int,
|
|
layer_idx: int,
|
|
cache_kwargs: Optional[Dict[str, Any]] = None,
|
|
cache_kwargs: Optional[Dict[str, Any]] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
@@ -436,52 +435,42 @@ class HHCache(Cache):
|
|
if layer_idx == 0:
|
|
if layer_idx == 0:
|
|
self._seen_tokens += key_states.shape[-2]
|
|
self._seen_tokens += key_states.shape[-2]
|
|
|
|
|
|
- import pdb; pdb.set_trace()
|
|
|
|
-
|
|
|
|
# Update the cache
|
|
# Update the cache
|
|
- # [bsz, num_heads, seq_len, head_dim]
|
|
|
|
- if key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
|
|
|
|
- if len(self.key_cache) <= layer_idx:
|
|
|
|
- # Empty cache
|
|
|
|
- self.key_cache.append(key_states)
|
|
|
|
- self.value_cache.append(value_states)
|
|
|
|
- else:
|
|
|
|
- # Growing cache
|
|
|
|
- self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
|
|
|
|
- self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
|
|
|
|
-
|
|
|
|
|
|
+ if len(self.key_cache) <= layer_idx:
|
|
|
|
+ self.key_cache.append(key_states)
|
|
|
|
+ self.value_cache.append(value_states)
|
|
else:
|
|
else:
|
|
- # Shifting cache
|
|
|
|
- # Keeping the local tokens
|
|
|
|
- keys_to_keep = self.key_cache[layer_idx][
|
|
|
|
- :, :, -self.window_length + self.num_hh_tokens + key_states.shape[-2] :
|
|
|
|
- ]
|
|
|
|
|
|
+ self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
|
|
|
|
+ self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
|
|
|
|
|
|
- # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
|
|
|
|
- sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
|
|
|
|
- self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)
|
|
|
|
|
|
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
|
|
|
|
|
- sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
|
|
|
|
- values_to_keep = self.value_cache[layer_idx][
|
|
|
|
- :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
|
|
|
|
- ]
|
|
|
|
- self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)
|
|
|
|
|
|
+ def update_slimming(
|
|
|
|
+ self,
|
|
|
|
+ attention_scores: torch.Tensor,
|
|
|
|
+ num_kv_groups: int,
|
|
|
|
+ layer_idx: int,
|
|
|
|
+ cache_kwargs: Optional[Dict[str, Any]] = None,
|
|
|
|
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
+ """
|
|
|
|
+ Slimming the cache based on accumulated attention scores, only keep heavy-hitters + local tokens.
|
|
|
|
+
|
|
|
|
+ Parameters:
|
|
|
|
+ attention_scores (`torch.Tensor`):
|
|
|
|
+ Attention_scores for current steps.
|
|
|
|
+ num_kv_groups (`int`):
|
|
|
|
+ The number of kv groups in repeat kv.
|
|
|
|
+ layer_idx (`int`):
|
|
|
|
+ The index of the layer to cache the states for.
|
|
|
|
+ cache_kwargs (`Dict[str, Any]`, `optional`):
|
|
|
|
+ Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
|
|
|
|
+ Return:
|
|
|
|
+ A tuple containing the updated key and value states.
|
|
|
|
+ """
|
|
|
|
+
|
|
|
|
+ import pdb; pdb.set_trace()
|
|
|
|
|
|
- # On RoPE models, we need to recompute the Key rotation as the tokens are shifted
|
|
|
|
- if using_rope:
|
|
|
|
- rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
|
|
|
|
- key_states, cos[: self.window_length], sin[: self.window_length]
|
|
|
|
- )
|
|
|
|
- if partial_rotation_size is not None:
|
|
|
|
- keys_to_keep, keys_pass = (
|
|
|
|
- keys_to_keep[..., :partial_rotation_size],
|
|
|
|
- keys_to_keep[..., partial_rotation_size:],
|
|
|
|
- )
|
|
|
|
- keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
|
|
|
|
- if partial_rotation_size is not None:
|
|
|
|
- keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
|
|
|
|
|
|
|
|
- return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
|
|
|
|
|
|
|
|
|
|
|
def reorder_cache(self, beam_idx: torch.LongTensor):
|
|
def reorder_cache(self, beam_idx: torch.LongTensor):
|