|  | @@ -412,7 +412,6 @@ class HHCache(Cache):
 | 
	
		
			
				|  |  |          self,
 | 
	
		
			
				|  |  |          key_states: torch.Tensor,
 | 
	
		
			
				|  |  |          value_states: torch.Tensor,
 | 
	
		
			
				|  |  | -        attention_scores: torch.Tensor,
 | 
	
		
			
				|  |  |          layer_idx: int,
 | 
	
		
			
				|  |  |          cache_kwargs: Optional[Dict[str, Any]] = None,
 | 
	
		
			
				|  |  |      ) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
	
	
		
			
				|  | @@ -436,52 +435,42 @@ class HHCache(Cache):
 | 
	
		
			
				|  |  |          if layer_idx == 0:
 | 
	
		
			
				|  |  |              self._seen_tokens += key_states.shape[-2]
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        import pdb; pdb.set_trace()
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  |          # Update the cache
 | 
	
		
			
				|  |  | -        # [bsz, num_heads, seq_len, head_dim]
 | 
	
		
			
				|  |  | -        if key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
 | 
	
		
			
				|  |  | -            if len(self.key_cache) <= layer_idx:
 | 
	
		
			
				|  |  | -                # Empty cache
 | 
	
		
			
				|  |  | -                self.key_cache.append(key_states)
 | 
	
		
			
				|  |  | -                self.value_cache.append(value_states)
 | 
	
		
			
				|  |  | -            else:
 | 
	
		
			
				|  |  | -                # Growing cache
 | 
	
		
			
				|  |  | -                self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
 | 
	
		
			
				|  |  | -                self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | +        if len(self.key_cache) <= layer_idx:
 | 
	
		
			
				|  |  | +            self.key_cache.append(key_states)
 | 
	
		
			
				|  |  | +            self.value_cache.append(value_states)
 | 
	
		
			
				|  |  |          else:
 | 
	
		
			
				|  |  | -            # Shifting cache 
 | 
	
		
			
				|  |  | -            # Keeping the local tokens
 | 
	
		
			
				|  |  | -            keys_to_keep = self.key_cache[layer_idx][
 | 
	
		
			
				|  |  | -                :, :, -self.window_length + self.num_hh_tokens + key_states.shape[-2] :
 | 
	
		
			
				|  |  | -            ]
 | 
	
		
			
				|  |  | +            self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
 | 
	
		
			
				|  |  | +            self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -            # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
 | 
	
		
			
				|  |  | -            sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
 | 
	
		
			
				|  |  | -            self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)
 | 
	
		
			
				|  |  | +        return self.key_cache[layer_idx], self.value_cache[layer_idx]
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -            sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
 | 
	
		
			
				|  |  | -            values_to_keep = self.value_cache[layer_idx][
 | 
	
		
			
				|  |  | -                :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
 | 
	
		
			
				|  |  | -            ]
 | 
	
		
			
				|  |  | -            self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)
 | 
	
		
			
				|  |  | +    def update_slimming(
 | 
	
		
			
				|  |  | +        self,
 | 
	
		
			
				|  |  | +        attention_scores: torch.Tensor,
 | 
	
		
			
				|  |  | +        num_kv_groups: int,
 | 
	
		
			
				|  |  | +        layer_idx: int,
 | 
	
		
			
				|  |  | +        cache_kwargs: Optional[Dict[str, Any]] = None,
 | 
	
		
			
				|  |  | +    ) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        Slimming the cache based on accumulated attention scores, only keep heavy-hitters + local tokens.
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        Parameters:
 | 
	
		
			
				|  |  | +            attention_scores (`torch.Tensor`):
 | 
	
		
			
				|  |  | +                Attention_scores for current steps.
 | 
	
		
			
				|  |  | +            num_kv_groups (`int`):
 | 
	
		
			
				|  |  | +                The number of kv groups in repeat kv.
 | 
	
		
			
				|  |  | +            layer_idx (`int`):
 | 
	
		
			
				|  |  | +                The index of the layer to cache the states for.
 | 
	
		
			
				|  |  | +            cache_kwargs (`Dict[str, Any]`, `optional`):
 | 
	
		
			
				|  |  | +                Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
 | 
	
		
			
				|  |  | +        Return:
 | 
	
		
			
				|  |  | +            A tuple containing the updated key and value states.
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        import pdb; pdb.set_trace()
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -            # On RoPE models, we need to recompute the Key rotation as the tokens are shifted
 | 
	
		
			
				|  |  | -            if using_rope:
 | 
	
		
			
				|  |  | -                rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
 | 
	
		
			
				|  |  | -                    key_states, cos[: self.window_length], sin[: self.window_length]
 | 
	
		
			
				|  |  | -                )
 | 
	
		
			
				|  |  | -                if partial_rotation_size is not None:
 | 
	
		
			
				|  |  | -                    keys_to_keep, keys_pass = (
 | 
	
		
			
				|  |  | -                        keys_to_keep[..., :partial_rotation_size],
 | 
	
		
			
				|  |  | -                        keys_to_keep[..., partial_rotation_size:],
 | 
	
		
			
				|  |  | -                    )
 | 
	
		
			
				|  |  | -                keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
 | 
	
		
			
				|  |  | -                if partial_rotation_size is not None:
 | 
	
		
			
				|  |  | -                    keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        return self.key_cache[layer_idx], self.value_cache[layer_idx]
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      def reorder_cache(self, beam_idx: torch.LongTensor):
 |