Ver código fonte

Update cache_utils.py

Allen 1 ano atrás
pai
commit
9adb645476
1 arquivos alterados com 8 adições e 4 exclusões
  1. 8 4
      research/long-context-llama/H2O/cache_utils.py

+ 8 - 4
research/long-context-llama/H2O/cache_utils.py

@@ -439,6 +439,7 @@ class HHCache(Cache):
         value_states: torch.Tensor,
         layer_idx: int,
         cache_kwargs: Optional[Dict[str, Any]] = None,
+        accumulated_attention_scores: Optional[torch.Tensor] = None,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """
         Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
@@ -457,7 +458,10 @@ class HHCache(Cache):
             A tuple containing the updated key and value states.
         """
         # Update the number of seen tokens
-        
+
+        if accumulated_attention_scores is not None:
+            self.accumulated_attention_scores.append(accumulated_attention_scores)
+
         if layer_idx == 0:
             self._seen_tokens += key_states.shape[-2]
             import pdb; pdb.set_trace()
@@ -542,7 +546,7 @@ class HHCache(Cache):
         """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
         legacy_cache = ()
         for layer_idx in range(len(self)):
-            legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
+            legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]), self.accumulated_attention_scores[layer_idx], )
         return legacy_cache
 
     @classmethod
@@ -551,8 +555,8 @@ class HHCache(Cache):
         cache = cls(window_length, num_hh_tokens)
         if past_key_values is not None:
             for layer_idx in range(len(past_key_values)):
-                key_states, value_states = past_key_values[layer_idx]
-                cache.update(key_states, value_states, layer_idx)
+                key_states, value_states, accumulated_attention_scores = past_key_values[layer_idx]
+                cache.update(key_states, value_states, layer_idx, accumulated_attention_scores=accumulated_attention_scores)
         return cache