| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454 | 
							- import math
 
- from typing import Any, Dict, List, Optional, Tuple, Union
 
- import warnings
 
- warnings.filterwarnings("ignore")
 
- import pdb
 
- import types
 
- import torch
 
- from torch import nn
 
- import torch.utils.checkpoint
 
- import torch.nn.functional as F
 
- from transformers.models.llama.configuration_llama import LlamaConfig
 
- from transformers.models.llama.modeling_llama import (
 
-     LlamaAttention,
 
-     rotate_half,
 
-     apply_rotary_pos_emb,
 
-     repeat_kv,
 
-     LlamaRotaryEmbedding,
 
-     LlamaForCausalLM,
 
- )
 
- from utils.cache import Cache, HHCache, StaticCache
 
- from transformers.utils import logging
 
- from transformers.modeling_outputs import BaseModelOutputWithPast
 
- logger = logging.get_logger(__name__)
 
- __all__ = ["H2OLlamaForCausalLM"]
 
- def _make_causal_mask(
 
-     bsz: int, tgt_len: int, past_key_values_length: int, dtype: torch.dtype, device: torch.device):
 
-     """
 
-     Make causal mask used for bi-directional self-attention.
 
-     """
 
-     mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
 
-     mask_cond = torch.arange(mask.size(-1), device=device)
 
-     mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
 
-     mask = mask.to(dtype)
 
-     if past_key_values_length > 0:
 
-         mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
 
-     return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
 
- def apply_rotary_pos_emb_single(x, cos, sin, position_ids=None, unsqueeze_dim=1):
 
-     cos = cos.unsqueeze(unsqueeze_dim)
 
-     sin = sin.unsqueeze(unsqueeze_dim)
 
-     x_embed = (x * cos) + (rotate_half(x) * sin)
 
-     return x_embed
 
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
 
-     """
 
-     This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
 
-     num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
 
-     """
 
-     batch, num_key_value_heads, slen, head_dim = hidden_states.shape
 
-     if n_rep == 1:
 
-         return hidden_states
 
-     hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
 
-     return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
 
- class H2OLlamaAttention(nn.Module):
 
-     """Multi-headed attention from 'Attention Is All You Need' paper"""
 
-     def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
 
-         super().__init__()
 
-         self.config = config
 
-         self.layer_idx = layer_idx
 
-         if layer_idx is None:
 
-             logger.warning_once(
 
-                 f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
 
-                 "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
 
-                 "when creating this class."
 
-             )
 
-         self.attention_dropout = config.attention_dropout
 
-         self.hidden_size = config.hidden_size
 
-         self.num_heads = config.num_attention_heads
 
-         self.head_dim = self.hidden_size // self.num_heads
 
-         self.num_key_value_heads = config.num_key_value_heads
 
-         self.num_key_value_groups = self.num_heads // self.num_key_value_heads
 
-         self.max_position_embeddings = config.max_position_embeddings
 
-         self.rope_theta = config.rope_theta
 
-         self.is_causal = True
 
-         self.positional_rolling = config.enable_position_rolling
 
-         if (self.head_dim * self.num_heads) != self.hidden_size:
 
-             raise ValueError(
 
-                 f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
 
-                 f" and `num_heads`: {self.num_heads})."
 
-             )
 
-         self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
 
-         self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
 
-         self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
 
-         self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
 
-         self._init_rope()
 
-     def _init_rope(self):
 
-         if self.config.rope_scaling is None:
 
-             self.rotary_emb = LlamaRotaryEmbedding(
 
-                 self.head_dim,
 
-                 max_position_embeddings=self.max_position_embeddings,
 
-                 base=self.rope_theta,
 
-             )
 
-         else:
 
-             scaling_type = self.config.rope_scaling["type"]
 
-             scaling_factor = self.config.rope_scaling["factor"]
 
-             if scaling_type == "linear":
 
-                 self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
 
-                     self.head_dim,
 
-                     max_position_embeddings=self.max_position_embeddings,
 
-                     scaling_factor=scaling_factor,
 
-                     base=self.rope_theta,
 
-                 )
 
-             elif scaling_type == "dynamic":
 
-                 self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
 
-                     self.head_dim,
 
-                     max_position_embeddings=self.max_position_embeddings,
 
-                     scaling_factor=scaling_factor,
 
-                     base=self.rope_theta,
 
-                 )
 
-             else:
 
-                 raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
 
-     def forward(
 
-         self,
 
-         hidden_states: torch.Tensor,
 
-         attention_mask: Optional[torch.Tensor] = None,
 
-         position_ids: Optional[torch.LongTensor] = None,
 
-         past_key_value: Optional[Cache] = None,
 
-         output_attentions: bool = False,
 
-         use_cache: bool = False,
 
-         cache_position: Optional[torch.LongTensor] = None,
 
-         **kwargs,
 
-     ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
 
-         bsz, q_len, _ = hidden_states.size()
 
-         if self.config.pretraining_tp > 1:
 
-             key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
 
-             query_slices = self.q_proj.weight.split(
 
-                 (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
 
-             )
 
-             key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
 
-             value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
 
-             query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
 
-             query_states = torch.cat(query_states, dim=-1)
 
-             key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
 
-             key_states = torch.cat(key_states, dim=-1)
 
-             value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
 
-             value_states = torch.cat(value_states, dim=-1)
 
-         else:
 
-             query_states = self.q_proj(hidden_states)
 
-             key_states = self.k_proj(hidden_states)
 
-             value_states = self.v_proj(hidden_states)
 
-         query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 
-         key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
 
-         value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
 
-         past_key_value = getattr(self, "past_key_value", past_key_value)
 
-         if not self.positional_rolling:
 
-             cos, sin = self.rotary_emb(value_states, position_ids)
 
-             query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
 
-             if past_key_value is not None:
 
-                 # sin and cos are specific to RoPE models; cache_position needed for the static cache
 
-                 cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
 
-                 key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
 
-         else:
 
-             if past_key_value is not None:
 
-                 # sin and cos are specific to RoPE models; cache_position needed for the static cache
 
-                 key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
 
-             kv_seq_len = past_key_value.get_seq_length(self.layer_idx) if past_key_value is not None else key_states.shape[-2]
 
-             if not position_ids.nelement() > 1:
 
-                 # decoding stage
 
-                 key_position_ids = torch.arange(kv_seq_len, device=hidden_states.device).unsqueeze(0)
 
-                 query_position_ids = key_position_ids[:, -1].unsqueeze(0)
 
-             elif not kv_seq_len == position_ids.shape[-1]:
 
-                 # prefilling stage with evicting
 
-                 query_position_ids = position_ids
 
-                 key_position_ids = torch.arange(kv_seq_len, device=hidden_states.device).unsqueeze(0)
 
-             else:
 
-                 # prefilling stage
 
-                 query_position_ids = position_ids
 
-                 key_position_ids = position_ids
 
-             key_cos, key_sin = self.rotary_emb(value_states, key_position_ids)
 
-             query_cos, query_sin = self.rotary_emb(value_states, query_position_ids)
 
-             query_states = apply_rotary_pos_emb_single(query_states, query_cos, query_sin)
 
-             key_states = apply_rotary_pos_emb_single(key_states, key_cos, key_sin)
 
-         key_states = repeat_kv(key_states, self.num_key_value_groups)
 
-         value_states = repeat_kv(value_states, self.num_key_value_groups)
 
-         attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 
-         if attention_mask is not None:  # no matter the length, we just slice it
 
-             causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
 
-             attn_weights = attn_weights + causal_mask
 
-         # upcast attention to fp32
 
-         attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
 
-         # Update KV Cache based on Heavy-Hitter Oracle
 
-         if past_key_value is not None:
 
-             past_key_value.update_slimming(attn_weights, self.num_key_value_groups, self.layer_idx)
 
-         attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
 
-         attn_output = torch.matmul(attn_weights, value_states)
 
-         if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
 
-             raise ValueError(
 
-                 f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
 
-                 f" {attn_output.size()}"
 
-             )
 
-         attn_output = attn_output.transpose(1, 2).contiguous()
 
-         attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 
-         if self.config.pretraining_tp > 1:
 
-             attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
 
-             o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
 
-             attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
 
-         else:
 
-             attn_output = self.o_proj(attn_output)
 
-         if not output_attentions:
 
-             attn_weights = None
 
-         
 
-         return attn_output, attn_weights, past_key_value
 
- def enable_h2ocache_forward(
 
-     self,
 
-     input_ids: torch.LongTensor = None,
 
-     attention_mask: Optional[torch.Tensor] = None,
 
-     position_ids: Optional[torch.LongTensor] = None,
 
-     past_key_values: Optional[List[torch.FloatTensor]] = None,
 
-     inputs_embeds: Optional[torch.FloatTensor] = None,
 
-     use_cache: Optional[bool] = None,
 
-     output_attentions: Optional[bool] = None,
 
-     output_hidden_states: Optional[bool] = None,
 
-     return_dict: Optional[bool] = None,
 
-     cache_position: Optional[torch.LongTensor] = None,
 
- ) -> Union[Tuple, BaseModelOutputWithPast]:
 
-     output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 
-     output_hidden_states = (
 
-         output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
-     )
 
-     use_cache = use_cache if use_cache is not None else self.config.use_cache
 
-     return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
-     if (input_ids is None) ^ (inputs_embeds is not None):
 
-         raise ValueError(
 
-             "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
 
-         )
 
-     if self.gradient_checkpointing and self.training and use_cache:
 
-         logger.warning_once(
 
-             "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
 
-         )
 
-         use_cache = False
 
-     if inputs_embeds is None:
 
-         inputs_embeds = self.embed_tokens(input_ids)
 
-     past_seen_tokens = 0
 
-     if use_cache:  # kept for BC (cache positions)
 
-         if not isinstance(past_key_values, StaticCache):
 
-             past_key_values = HHCache.from_legacy_cache(self.num_window_length, self.num_heavy_hitter_tokens, past_key_values)
 
-             past_seen_tokens = past_key_values.get_seq_length()
 
-     if cache_position is None:
 
-         if isinstance(past_key_values, StaticCache):
 
-             raise ValueError("cache_position is a required argument when using StaticCache.")
 
-         cache_position = torch.arange(
 
-             past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
 
-         )
 
-     if position_ids is None:
 
-         position_ids = cache_position.unsqueeze(0)
 
-     causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
 
-     # embed positions
 
-     hidden_states = inputs_embeds
 
-     # decoder layers
 
-     all_hidden_states = () if output_hidden_states else None
 
-     all_self_attns = () if output_attentions else None
 
-     next_decoder_cache = None
 
-     for decoder_layer in self.layers:
 
-         if output_hidden_states:
 
-             all_hidden_states += (hidden_states,)
 
-         if self.gradient_checkpointing and self.training:
 
-             layer_outputs = self._gradient_checkpointing_func(
 
-                 decoder_layer.__call__,
 
-                 hidden_states,
 
-                 causal_mask,
 
-                 position_ids,
 
-                 past_key_values,
 
-                 output_attentions,
 
-                 use_cache,
 
-                 cache_position,
 
-             )
 
-         else:
 
-             layer_outputs = decoder_layer(
 
-                 hidden_states,
 
-                 attention_mask=causal_mask,
 
-                 position_ids=position_ids,
 
-                 past_key_value=past_key_values,
 
-                 output_attentions=output_attentions,
 
-                 use_cache=use_cache,
 
-                 cache_position=cache_position,
 
-             )
 
-         hidden_states = layer_outputs[0]
 
-         if use_cache:
 
-             next_decoder_cache = layer_outputs[2 if output_attentions else 1]
 
-         if output_attentions:
 
-             all_self_attns += (layer_outputs[1],)
 
-     hidden_states = self.norm(hidden_states)
 
-     # add hidden states from the last decoder layer
 
-     if output_hidden_states:
 
-         all_hidden_states += (hidden_states,)
 
-     next_cache = None
 
-     if use_cache:
 
-         next_cache = (
 
-             next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
 
-         )
 
-     if not return_dict:
 
-         return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
 
-     
 
-     return BaseModelOutputWithPast(
 
-         last_hidden_state=hidden_states,
 
-         past_key_values=next_cache,
 
-         hidden_states=all_hidden_states,
 
-         attentions=all_self_attns,
 
-     )
 
- class H2OLlamaForCausalLM(LlamaForCausalLM):
 
-     def __init__(self, config):
 
-         super().__init__(config)
 
-         num_layers = len(self.model.layers)
 
-         for layer_idx in range(num_layers):
 
-             self.model.layers[layer_idx].self_attn = H2OLlamaAttention(config, layer_idx)
 
-         self.model.forward = types.MethodType(enable_h2ocache_forward, self.model)
 
-         self.model.num_heavy_hitter_tokens = config.num_heavy_hitter_tokens
 
-         self.model.num_window_length = config.num_window_length
 
-     
 
-     def prepare_inputs_for_generation(
 
-         self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
 
-     ):
 
-         # With static cache, the `past_key_values` is None
 
-         # TODO joao: standardize interface for the different Cache classes and remove of this if
 
-         has_static_cache = False
 
-         if past_key_values is None:
 
-             past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None)
 
-             has_static_cache = past_key_values is not None
 
-         past_length = 0
 
-         if past_key_values is not None:
 
-             if isinstance(past_key_values, Cache):
 
-                 past_length = cache_position[0]
 
-                 max_cache_length = (
 
-                     torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
 
-                     if past_key_values.get_max_length() is not None
 
-                     else None
 
-                 )
 
-                 cache_length = past_key_values.get_seq_length()
 
-             # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
 
-             else:
 
-                 past_length = cache_position[0]
 
-                 cache_length = past_key_values[0].shape[2] # length = num_layers * 3 (3 -> key, value, score)
 
-                 max_cache_length = None
 
-             # Keep only the unprocessed tokens:
 
-             # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
 
-             # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
 
-             # input)
 
-             if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
 
-                 input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
 
-             # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
 
-             # input_ids based on the past_length.
 
-             elif past_length < input_ids.shape[1]:
 
-                 input_ids = input_ids[:, past_length:]
 
-             # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
 
-             # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
 
-             if (
 
-                 max_cache_length is not None
 
-                 and attention_mask is not None
 
-                 and cache_length + input_ids.shape[1] > max_cache_length
 
-             ):
 
-                 attention_mask = attention_mask[:, -max_cache_length:]
 
-         position_ids = kwargs.get("position_ids", None)
 
-         if attention_mask is not None and position_ids is None:
 
-             # create position_ids on the fly for batch generation
 
-             position_ids = attention_mask.long().cumsum(-1) - 1
 
-             position_ids.masked_fill_(attention_mask == 0, 1)
 
-             if past_key_values:
 
-                 position_ids = position_ids[:, -input_ids.shape[1] :]
 
-         # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
 
-         if inputs_embeds is not None and past_key_values is None:
 
-             model_inputs = {"inputs_embeds": inputs_embeds}
 
-         else:
 
-             # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
 
-             # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
 
-             # TODO: use `next_tokens` directly instead.
 
-             model_inputs = {"input_ids": input_ids.contiguous()}
 
-         input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
 
-         if cache_position is None:
 
-             cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
 
-         else:
 
-             cache_position = cache_position[-input_length:]
 
-         if has_static_cache:
 
-             past_key_values = None
 
-         model_inputs.update(
 
-             {
 
-                 "position_ids": position_ids,
 
-                 "cache_position": cache_position,
 
-                 "past_key_values": past_key_values,
 
-                 "use_cache": kwargs.get("use_cache"),
 
-                 "attention_mask": attention_mask,
 
-             }
 
-         )
 
-         return model_inputs
 
 
  |