瀏覽代碼

Update cache_utils.py

Allen 1 年之前
父節點
當前提交
ab4eee2542
共有 1 個文件被更改,包括 2 次插入2 次删除
  1. 2 2
      research/long-context-llama/H2O/cache_utils.py

+ 2 - 2
research/long-context-llama/H2O/cache_utils.py

@@ -369,7 +369,7 @@ class HHCache(Cache):
         sequence length.
         """
         if layer_idx < len(self):
-            return (self.key_cache[layer_idx], self.value_cache[layer_idx])
+            return (self.key_cache[layer_idx], self.value_cache[layer_idx], self.accumulated_attention_scores[layer_idx])
         else:
             raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
 
@@ -379,7 +379,7 @@ class HHCache(Cache):
         keys and values
         """
         for layer_idx in range(len(self)):
-            yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
+            yield (self.key_cache[layer_idx], self.value_cache[layer_idx], self.accumulated_attention_scores[layer_idx])
 
     def __len__(self):
         """