|
@@ -761,16 +761,18 @@ class H2OLlamaForCausalLM(LlamaForCausalLM):
|
|
|
past_length = 0
|
|
|
if past_key_values is not None:
|
|
|
if isinstance(past_key_values, Cache):
|
|
|
- past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
|
|
+ past_length = cache_position[0]
|
|
|
max_cache_length = (
|
|
|
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
|
|
|
if past_key_values.get_max_length() is not None
|
|
|
else None
|
|
|
)
|
|
|
- cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
|
|
|
+ cache_length = past_key_values.get_seq_length()
|
|
|
+
|
|
|
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
|
|
|
else:
|
|
|
- cache_length = past_length = past_key_values[0].shape[2] # length = num_layers * 3 (3 -> key, value, score)
|
|
|
+ past_length = cache_position[0]
|
|
|
+ cache_length = past_key_values[0].shape[2] # length = num_layers * 3 (3 -> key, value, score)
|
|
|
max_cache_length = None
|
|
|
|
|
|
# Keep only the unprocessed tokens:
|