Allen 1 jaar geleden
bovenliggende
commit
d9b9857293
2 gewijzigde bestanden met toevoegingen van 2 en 2 verwijderingen
  1. 2 0
      research/long-context-llama/H2O/cache_utils.py
  2. 0 2
      research/long-context-llama/H2O/utils_llama.py

+ 2 - 0
research/long-context-llama/H2O/cache_utils.py

@@ -369,6 +369,7 @@ class HHCache(Cache):
         sequence length.
         """
         if layer_idx < len(self):
+            print(layer_idx, self.key_cache[layer_idx].shape)
             return (self.key_cache[layer_idx], self.value_cache[layer_idx])
         else:
             raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
@@ -379,6 +380,7 @@ class HHCache(Cache):
         keys and values
         """
         for layer_idx in range(len(self)):
+            print(layer_idx, self.key_cache[layer_idx].shape)
             yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
 
     def __len__(self):

+ 0 - 2
research/long-context-llama/H2O/utils_llama.py

@@ -301,8 +301,6 @@ class H2OLlamaAttention(nn.Module):
         if not output_attentions:
             attn_weights = None
 
-        pdb.set_trace()
-
         return attn_output, attn_weights, past_key_value