Allen 1 jaar geleden
bovenliggende
commit
ad2ed665cf
2 gewijzigde bestanden met toevoegingen van 2 en 2 verwijderingen
  1. 0 2
      research/long-context-llama/H2O/cache_utils.py
  2. 2 0
      research/long-context-llama/H2O/utils_llama.py

+ 0 - 2
research/long-context-llama/H2O/cache_utils.py

@@ -369,7 +369,6 @@ class HHCache(Cache):
         sequence length.
         """
         if layer_idx < len(self):
-            print(layer_idx, self.key_cache[layer_idx].shape)
             return (self.key_cache[layer_idx], self.value_cache[layer_idx])
         else:
             raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
@@ -380,7 +379,6 @@ class HHCache(Cache):
         keys and values
         """
         for layer_idx in range(len(self)):
-            print(layer_idx, self.key_cache[layer_idx].shape)
             yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
 
     def __len__(self):

+ 2 - 0
research/long-context-llama/H2O/utils_llama.py

@@ -301,6 +301,8 @@ class H2OLlamaAttention(nn.Module):
         if not output_attentions:
             attn_weights = None
 
+        print(past_key_value.key_cache[self.layer_idx].shape, past_key_value.accumulated_attention_scores[self.layer_idx].shape)
+
         return attn_output, attn_weights, past_key_value