|  | @@ -439,6 +439,7 @@ class HHCache(Cache):
 | 
	
		
			
				|  |  |          value_states: torch.Tensor,
 | 
	
		
			
				|  |  |          layer_idx: int,
 | 
	
		
			
				|  |  |          cache_kwargs: Optional[Dict[str, Any]] = None,
 | 
	
		
			
				|  |  | +        accumulated_attention_scores: Optional[torch.Tensor] = None,
 | 
	
		
			
				|  |  |      ) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  |          Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
 | 
	
	
		
			
				|  | @@ -457,7 +458,10 @@ class HHCache(Cache):
 | 
	
		
			
				|  |  |              A tuple containing the updated key and value states.
 | 
	
		
			
				|  |  |          """
 | 
	
		
			
				|  |  |          # Update the number of seen tokens
 | 
	
		
			
				|  |  | -        
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        if accumulated_attention_scores is not None:
 | 
	
		
			
				|  |  | +            self.accumulated_attention_scores.append(accumulated_attention_scores)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          if layer_idx == 0:
 | 
	
		
			
				|  |  |              self._seen_tokens += key_states.shape[-2]
 | 
	
		
			
				|  |  |              import pdb; pdb.set_trace()
 | 
	
	
		
			
				|  | @@ -542,7 +546,7 @@ class HHCache(Cache):
 | 
	
		
			
				|  |  |          """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
 | 
	
		
			
				|  |  |          legacy_cache = ()
 | 
	
		
			
				|  |  |          for layer_idx in range(len(self)):
 | 
	
		
			
				|  |  | -            legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
 | 
	
		
			
				|  |  | +            legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]), self.accumulated_attention_scores[layer_idx], )
 | 
	
		
			
				|  |  |          return legacy_cache
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @classmethod
 | 
	
	
		
			
				|  | @@ -551,8 +555,8 @@ class HHCache(Cache):
 | 
	
		
			
				|  |  |          cache = cls(window_length, num_hh_tokens)
 | 
	
		
			
				|  |  |          if past_key_values is not None:
 | 
	
		
			
				|  |  |              for layer_idx in range(len(past_key_values)):
 | 
	
		
			
				|  |  | -                key_states, value_states = past_key_values[layer_idx]
 | 
	
		
			
				|  |  | -                cache.update(key_states, value_states, layer_idx)
 | 
	
		
			
				|  |  | +                key_states, value_states, accumulated_attention_scores = past_key_values[layer_idx]
 | 
	
		
			
				|  |  | +                cache.update(key_states, value_states, layer_idx, accumulated_attention_scores=accumulated_attention_scores)
 | 
	
		
			
				|  |  |          return cache
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 |