Allen пре 1 година
родитељ
комит
4cbd593e2b
1 измењених фајлова са 86 додато и 1 уклоњено
  1. 86 1
      research/long-context-llama/H2O/utils_llama.py

+ 86 - 1
research/long-context-llama/H2O/utils_llama.py

@@ -432,7 +432,9 @@ def prepare_inputs_for_generation_w_h2o(
 ):
     # With static cache, the `past_key_values` is None
     # TODO joao: standardize interface for the different Cache classes and remove of this if
+
     import pdb; pdb.set_trace()
+
     has_static_cache = False
     if past_key_values is None:
         past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None)
@@ -742,5 +744,88 @@ class H2OLlamaForCausalLM(LlamaForCausalLM):
         self.model.forward = types.MethodType(enable_h2ocache_forward, self.model)
         self.model.num_heavy_hitter_tokens = config.num_heavy_hitter_tokens
         self.model.num_window_length = config.num_window_length
+    
+    def prepare_inputs_for_generation(
+        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
+    ):
+        # With static cache, the `past_key_values` is None
+        # TODO joao: standardize interface for the different Cache classes and remove of this if
+
+        import pdb; pdb.set_trace()
 
-        self.forward = types.MethodType(prepare_inputs_for_generation_w_h2o, self)
+        has_static_cache = False
+        if past_key_values is None:
+            past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None)
+            has_static_cache = past_key_values is not None
+
+        past_length = 0
+        if past_key_values is not None:
+            if isinstance(past_key_values, Cache):
+                past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
+                max_cache_length = (
+                    torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
+                    if past_key_values.get_max_length() is not None
+                    else None
+                )
+                cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
+            # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
+            else:
+                cache_length = past_length = past_key_values[0][0].shape[2]
+                max_cache_length = None
+
+            # Keep only the unprocessed tokens:
+            # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
+            # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
+            # input)
+            if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
+                input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
+            # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
+            # input_ids based on the past_length.
+            elif past_length < input_ids.shape[1]:
+                input_ids = input_ids[:, past_length:]
+            # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
+
+            # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
+            if (
+                max_cache_length is not None
+                and attention_mask is not None
+                and cache_length + input_ids.shape[1] > max_cache_length
+            ):
+                attention_mask = attention_mask[:, -max_cache_length:]
+
+        position_ids = kwargs.get("position_ids", None)
+        if attention_mask is not None and position_ids is None:
+            # create position_ids on the fly for batch generation
+            position_ids = attention_mask.long().cumsum(-1) - 1
+            position_ids.masked_fill_(attention_mask == 0, 1)
+            if past_key_values:
+                position_ids = position_ids[:, -input_ids.shape[1] :]
+
+        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+        if inputs_embeds is not None and past_key_values is None:
+            model_inputs = {"inputs_embeds": inputs_embeds}
+        else:
+            # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
+            # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
+            # TODO: use `next_tokens` directly instead.
+            model_inputs = {"input_ids": input_ids.contiguous()}
+
+        input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
+        if cache_position is None:
+            cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
+        else:
+            cache_position = cache_position[-input_length:]
+
+        if has_static_cache:
+            past_key_values = None
+
+        model_inputs.update(
+            {
+                "position_ids": position_ids,
+                "cache_position": cache_position,
+                "past_key_values": past_key_values,
+                "use_cache": kwargs.get("use_cache"),
+                "attention_mask": attention_mask,
+            }
+        )
+        return model_inputs