Quellcode durchsuchen

Update utils_llama.py

Allen vor 1 Jahr
Ursprung
Commit
0ba99ca0b8
1 geänderte Dateien mit 5 neuen und 3 gelöschten Zeilen
  1. 5 3
      research/long-context-llama/H2O/utils_llama.py

+ 5 - 3
research/long-context-llama/H2O/utils_llama.py

@@ -761,16 +761,18 @@ class H2OLlamaForCausalLM(LlamaForCausalLM):
         past_length = 0
         if past_key_values is not None:
             if isinstance(past_key_values, Cache):
-                past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
+                past_length = cache_position[0]
                 max_cache_length = (
                     torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
                     if past_key_values.get_max_length() is not None
                     else None
                 )
-                cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
+                cache_length = past_key_values.get_seq_length()
+
             # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
             else:
-                cache_length = past_length = past_key_values[0].shape[2] # length = num_layers * 3 (3 -> key, value, score)
+                past_length = cache_position[0]
+                cache_length = past_key_values[0].shape[2] # length = num_layers * 3 (3 -> key, value, score)
                 max_cache_length = None
 
             # Keep only the unprocessed tokens: