Allen 1 år sedan
förälder
incheckning
db9e84246a

+ 2 - 0
research/long-context-llama/H2O/cache_utils.py

@@ -457,8 +457,10 @@ class HHCache(Cache):
             A tuple containing the updated key and value states.
         """
         # Update the number of seen tokens
+        
         if layer_idx == 0:
             self._seen_tokens += key_states.shape[-2]
+            import pdb; pdb.set_trace()
 
         # Update the cache
         if len(self.key_cache) <= layer_idx:

+ 0 - 66
research/long-context-llama/H2O/utils_llama.py

@@ -401,8 +401,6 @@ def enable_h2ocache_forward(
             all_self_attns += (layer_outputs[1],)
 
     hidden_states = self.norm(hidden_states)
-    
-    import pdb; pdb.set_trace()
 
     # add hidden states from the last decoder layer
     if output_hidden_states:
@@ -422,70 +420,6 @@ def enable_h2ocache_forward(
         attentions=all_self_attns,
     )
 
-# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
-# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
-# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
-# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
-def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
-    if self.config._attn_implementation == "flash_attention_2":
-        if attention_mask is not None and 0.0 in attention_mask:
-            return attention_mask
-        return None
-
-    dtype, device = input_tensor.dtype, input_tensor.device
-    min_dtype = torch.finfo(dtype).min
-    sequence_length = input_tensor.shape[1]
-    if hasattr(self.layers[0].self_attn, "past_key_value"):  # static cache
-        target_length = self.config.max_position_embeddings
-    else:  # dynamic cache
-        target_length = (
-            attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1
-        )
-
-    causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
-    if sequence_length != 1:
-        causal_mask = torch.triu(causal_mask, diagonal=1)
-    causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
-    causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
-    if attention_mask is not None:
-        causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit
-        if attention_mask.dim() == 2:
-            mask_length = attention_mask.shape[-1]
-            padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
-            causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
-        elif attention_mask.dim() == 4:
-            # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
-            # cache. In that case, the 4D attention mask attends to the newest tokens only.
-            if attention_mask.shape[-2] < cache_position[0] + sequence_length:
-                offset = cache_position[0]
-            else:
-                offset = 0
-            mask_shape = attention_mask.shape
-            mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
-            causal_mask[
-                : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
-            ] = mask_slice
-
-    if (
-        self.config._attn_implementation == "sdpa"
-        and attention_mask is not None
-        and attention_mask.device.type == "cuda"
-    ):
-        # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
-        # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
-        # Details: https://github.com/pytorch/pytorch/issues/110213
-        causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
-
-    return causal_mask
-
-
-
-
-
-
-
-
-