Browse Source

Update utils_llama.py

Allen 1 year ago
parent
commit
9daddd4b34
1 changed files with 84 additions and 1 deletions
  1. 84 1
      research/long-context-llama/H2O/utils_llama.py

+ 84 - 1
research/long-context-llama/H2O/utils_llama.py

@@ -420,7 +420,6 @@ def enable_h2ocache_forward(
     if not return_dict:
     if not return_dict:
         return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
         return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
     
     
-    import pdb;pdb.set_trace()
     return BaseModelOutputWithPast(
     return BaseModelOutputWithPast(
         last_hidden_state=hidden_states,
         last_hidden_state=hidden_states,
         past_key_values=next_cache,
         past_key_values=next_cache,
@@ -428,6 +427,88 @@ def enable_h2ocache_forward(
         attentions=all_self_attns,
         attentions=all_self_attns,
     )
     )
 
 
+def prepare_inputs_for_generation_w_h2o(
+    self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
+):
+    # With static cache, the `past_key_values` is None
+    # TODO joao: standardize interface for the different Cache classes and remove of this if
+    import pdb; pdb.set_trace()
+    has_static_cache = False
+    if past_key_values is None:
+        past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None)
+        has_static_cache = past_key_values is not None
+
+    past_length = 0
+    if past_key_values is not None:
+        if isinstance(past_key_values, Cache):
+            past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
+            max_cache_length = (
+                torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
+                if past_key_values.get_max_length() is not None
+                else None
+            )
+            cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
+        # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
+        else:
+            cache_length = past_length = past_key_values[0][0].shape[2]
+            max_cache_length = None
+
+        # Keep only the unprocessed tokens:
+        # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
+        # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
+        # input)
+        if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
+            input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
+        # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
+        # input_ids based on the past_length.
+        elif past_length < input_ids.shape[1]:
+            input_ids = input_ids[:, past_length:]
+        # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
+
+        # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
+        if (
+            max_cache_length is not None
+            and attention_mask is not None
+            and cache_length + input_ids.shape[1] > max_cache_length
+        ):
+            attention_mask = attention_mask[:, -max_cache_length:]
+
+    position_ids = kwargs.get("position_ids", None)
+    if attention_mask is not None and position_ids is None:
+        # create position_ids on the fly for batch generation
+        position_ids = attention_mask.long().cumsum(-1) - 1
+        position_ids.masked_fill_(attention_mask == 0, 1)
+        if past_key_values:
+            position_ids = position_ids[:, -input_ids.shape[1] :]
+
+    # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+    if inputs_embeds is not None and past_key_values is None:
+        model_inputs = {"inputs_embeds": inputs_embeds}
+    else:
+        # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
+        # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
+        # TODO: use `next_tokens` directly instead.
+        model_inputs = {"input_ids": input_ids.contiguous()}
+
+    input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
+    if cache_position is None:
+        cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
+    else:
+        cache_position = cache_position[-input_length:]
+
+    if has_static_cache:
+        past_key_values = None
+
+    model_inputs.update(
+        {
+            "position_ids": position_ids,
+            "cache_position": cache_position,
+            "past_key_values": past_key_values,
+            "use_cache": kwargs.get("use_cache"),
+            "attention_mask": attention_mask,
+        }
+    )
+    return model_inputs
 
 
 
 
 
 
@@ -661,3 +742,5 @@ class H2OLlamaForCausalLM(LlamaForCausalLM):
         self.model.forward = types.MethodType(enable_h2ocache_forward, self.model)
         self.model.forward = types.MethodType(enable_h2ocache_forward, self.model)
         self.model.num_heavy_hitter_tokens = config.num_heavy_hitter_tokens
         self.model.num_heavy_hitter_tokens = config.num_heavy_hitter_tokens
         self.model.num_window_length = config.num_window_length
         self.model.num_window_length = config.num_window_length
+
+        self.forward = types.MethodType(prepare_inputs_for_generation_w_h2o, self)