소스 검색

Update cache_utils.py

Allen 1 년 전
부모
커밋
a0240ea73d
1개의 변경된 파일26개의 추가작업 그리고 1개의 파일을 삭제
  1. 26 1
      research/long-context-llama/H2O/cache_utils.py

+ 26 - 1
research/long-context-llama/H2O/cache_utils.py

@@ -363,6 +363,31 @@ class HHCache(Cache):
         self.cos_sin_cache = {}
         self._seen_tokens = 0  # Used in `generate` to keep tally of how many tokens the cache has seen
 
+    def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
+        """
+        Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
+        sequence length.
+        """
+        if layer_idx < len(self):
+            return (self.key_cache[layer_idx], self.value_cache[layer_idx])
+        else:
+            raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
+
+    def __iter__(self):
+        """
+        Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
+        keys and values
+        """
+        for layer_idx in range(len(self)):
+            yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
+
+    def __len__(self):
+        """
+        Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
+        to the number of layers in the model.
+        """
+        return len(self.key_cache)
+
     @staticmethod
     def _rotate_half(x):
         x1 = x[..., : x.shape[-1] // 2]
@@ -396,7 +421,7 @@ class HHCache(Cache):
                 rerotation_sin.to(key_states.dtype).unsqueeze(0),
             )
         return self.cos_sin_cache[key_states.shape[-2]]
-
+        
     def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
         """Returns the sequence length of the cached states. A layer index can be optionally passed."""
         # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length