Allen 1 vuosi sitten
vanhempi
commit
d860109bc3

+ 1 - 1
research/long-context-llama/H2O/generation.py

@@ -34,7 +34,7 @@ if __name__ == '__main__':
 
     parser.add_argument("--enable_h2o_generation", action='store_true')
     parser.add_argument("--num_heavy_hitter_tokens", type=int, default=256)
-    parser.add_argument("--num_window_length", type=int, default=4096)
+    parser.add_argument("--num_window_length", type=int, default=1024)
 
     parser.add_argument("--enable_position_rolling", action='store_true')
 

+ 1 - 1
research/long-context-llama/H2O/utils_llama.py

@@ -770,7 +770,7 @@ class H2OLlamaForCausalLM(LlamaForCausalLM):
                 cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
             # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
             else:
-                cache_length = past_length = past_key_values[0][0].shape[2]
+                cache_length = past_length = past_key_values[0].shape[2] # length = num_layers * 3 (3 -> key, value, score)
                 max_cache_length = None
 
             # Keep only the unprocessed tokens: