|  | @@ -363,6 +363,31 @@ class HHCache(Cache):
 | 
	
		
			
				|  |  |          self.cos_sin_cache = {}
 | 
	
		
			
				|  |  |          self._seen_tokens = 0  # Used in `generate` to keep tally of how many tokens the cache has seen
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
 | 
	
		
			
				|  |  | +        sequence length.
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        if layer_idx < len(self):
 | 
	
		
			
				|  |  | +            return (self.key_cache[layer_idx], self.value_cache[layer_idx])
 | 
	
		
			
				|  |  | +        else:
 | 
	
		
			
				|  |  | +            raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def __iter__(self):
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
 | 
	
		
			
				|  |  | +        keys and values
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        for layer_idx in range(len(self)):
 | 
	
		
			
				|  |  | +            yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    def __len__(self):
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
 | 
	
		
			
				|  |  | +        to the number of layers in the model.
 | 
	
		
			
				|  |  | +        """
 | 
	
		
			
				|  |  | +        return len(self.key_cache)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      @staticmethod
 | 
	
		
			
				|  |  |      def _rotate_half(x):
 | 
	
		
			
				|  |  |          x1 = x[..., : x.shape[-1] // 2]
 | 
	
	
		
			
				|  | @@ -396,7 +421,7 @@ class HHCache(Cache):
 | 
	
		
			
				|  |  |                  rerotation_sin.to(key_states.dtype).unsqueeze(0),
 | 
	
		
			
				|  |  |              )
 | 
	
		
			
				|  |  |          return self.cos_sin_cache[key_states.shape[-2]]
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | +        
 | 
	
		
			
				|  |  |      def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
 | 
	
		
			
				|  |  |          """Returns the sequence length of the cached states. A layer index can be optionally passed."""
 | 
	
		
			
				|  |  |          # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
 |