| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454 | import mathfrom typing import Any, Dict, List, Optional, Tuple, Unionimport warningswarnings.filterwarnings("ignore")import pdbimport typesimport torchfrom torch import nnimport torch.utils.checkpointimport torch.nn.functional as Ffrom transformers.models.llama.configuration_llama import LlamaConfigfrom transformers.models.llama.modeling_llama import (    LlamaAttention,    rotate_half,    apply_rotary_pos_emb,    repeat_kv,    LlamaRotaryEmbedding,    LlamaForCausalLM,)from utils.cache import Cache, HHCache, StaticCachefrom transformers.utils import loggingfrom transformers.modeling_outputs import BaseModelOutputWithPastlogger = logging.get_logger(__name__)__all__ = ["H2OLlamaForCausalLM"]def _make_causal_mask(    bsz: int, tgt_len: int, past_key_values_length: int, dtype: torch.dtype, device: torch.device):    """    Make causal mask used for bi-directional self-attention.    """    mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)    mask_cond = torch.arange(mask.size(-1), device=device)    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)    mask = mask.to(dtype)    if past_key_values_length > 0:        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)def apply_rotary_pos_emb_single(x, cos, sin, position_ids=None, unsqueeze_dim=1):    cos = cos.unsqueeze(unsqueeze_dim)    sin = sin.unsqueeze(unsqueeze_dim)    x_embed = (x * cos) + (rotate_half(x) * sin)    return x_embeddef repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:    """    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)    """    batch, num_key_value_heads, slen, head_dim = hidden_states.shape    if n_rep == 1:        return hidden_states    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)class H2OLlamaAttention(nn.Module):    """Multi-headed attention from 'Attention Is All You Need' paper"""    def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):        super().__init__()        self.config = config        self.layer_idx = layer_idx        if layer_idx is None:            logger.warning_once(                f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "                "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "                "when creating this class."            )        self.attention_dropout = config.attention_dropout        self.hidden_size = config.hidden_size        self.num_heads = config.num_attention_heads        self.head_dim = self.hidden_size // self.num_heads        self.num_key_value_heads = config.num_key_value_heads        self.num_key_value_groups = self.num_heads // self.num_key_value_heads        self.max_position_embeddings = config.max_position_embeddings        self.rope_theta = config.rope_theta        self.is_causal = True        self.positional_rolling = config.enable_position_rolling        if (self.head_dim * self.num_heads) != self.hidden_size:            raise ValueError(                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"                f" and `num_heads`: {self.num_heads})."            )        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)        self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)        self._init_rope()    def _init_rope(self):        if self.config.rope_scaling is None:            self.rotary_emb = LlamaRotaryEmbedding(                self.head_dim,                max_position_embeddings=self.max_position_embeddings,                base=self.rope_theta,            )        else:            scaling_type = self.config.rope_scaling["type"]            scaling_factor = self.config.rope_scaling["factor"]            if scaling_type == "linear":                self.rotary_emb = LlamaLinearScalingRotaryEmbedding(                    self.head_dim,                    max_position_embeddings=self.max_position_embeddings,                    scaling_factor=scaling_factor,                    base=self.rope_theta,                )            elif scaling_type == "dynamic":                self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(                    self.head_dim,                    max_position_embeddings=self.max_position_embeddings,                    scaling_factor=scaling_factor,                    base=self.rope_theta,                )            else:                raise ValueError(f"Unknown RoPE scaling type {scaling_type}")    def forward(        self,        hidden_states: torch.Tensor,        attention_mask: Optional[torch.Tensor] = None,        position_ids: Optional[torch.LongTensor] = None,        past_key_value: Optional[Cache] = None,        output_attentions: bool = False,        use_cache: bool = False,        cache_position: Optional[torch.LongTensor] = None,        **kwargs,    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:        bsz, q_len, _ = hidden_states.size()        if self.config.pretraining_tp > 1:            key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp            query_slices = self.q_proj.weight.split(                (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0            )            key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)            value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)            query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]            query_states = torch.cat(query_states, dim=-1)            key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]            key_states = torch.cat(key_states, dim=-1)            value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]            value_states = torch.cat(value_states, dim=-1)        else:            query_states = self.q_proj(hidden_states)            key_states = self.k_proj(hidden_states)            value_states = self.v_proj(hidden_states)        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)        past_key_value = getattr(self, "past_key_value", past_key_value)        if not self.positional_rolling:            cos, sin = self.rotary_emb(value_states, position_ids)            query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)            if past_key_value is not None:                # sin and cos are specific to RoPE models; cache_position needed for the static cache                cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}                key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)        else:            if past_key_value is not None:                # sin and cos are specific to RoPE models; cache_position needed for the static cache                key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)            kv_seq_len = past_key_value.get_seq_length(self.layer_idx) if past_key_value is not None else key_states.shape[-2]            if not position_ids.nelement() > 1:                # decoding stage                key_position_ids = torch.arange(kv_seq_len, device=hidden_states.device).unsqueeze(0)                query_position_ids = key_position_ids[:, -1].unsqueeze(0)            elif not kv_seq_len == position_ids.shape[-1]:                # prefilling stage with evicting                query_position_ids = position_ids                key_position_ids = torch.arange(kv_seq_len, device=hidden_states.device).unsqueeze(0)            else:                # prefilling stage                query_position_ids = position_ids                key_position_ids = position_ids            key_cos, key_sin = self.rotary_emb(value_states, key_position_ids)            query_cos, query_sin = self.rotary_emb(value_states, query_position_ids)            query_states = apply_rotary_pos_emb_single(query_states, query_cos, query_sin)            key_states = apply_rotary_pos_emb_single(key_states, key_cos, key_sin)        key_states = repeat_kv(key_states, self.num_key_value_groups)        value_states = repeat_kv(value_states, self.num_key_value_groups)        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)        if attention_mask is not None:  # no matter the length, we just slice it            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]            attn_weights = attn_weights + causal_mask        # upcast attention to fp32        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)        # Update KV Cache based on Heavy-Hitter Oracle        if past_key_value is not None:            past_key_value.update_slimming(attn_weights, self.num_key_value_groups, self.layer_idx)        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)        attn_output = torch.matmul(attn_weights, value_states)        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):            raise ValueError(                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"                f" {attn_output.size()}"            )        attn_output = attn_output.transpose(1, 2).contiguous()        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)        if self.config.pretraining_tp > 1:            attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)            o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)            attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])        else:            attn_output = self.o_proj(attn_output)        if not output_attentions:            attn_weights = None                return attn_output, attn_weights, past_key_valuedef enable_h2ocache_forward(    self,    input_ids: torch.LongTensor = None,    attention_mask: Optional[torch.Tensor] = None,    position_ids: Optional[torch.LongTensor] = None,    past_key_values: Optional[List[torch.FloatTensor]] = None,    inputs_embeds: Optional[torch.FloatTensor] = None,    use_cache: Optional[bool] = None,    output_attentions: Optional[bool] = None,    output_hidden_states: Optional[bool] = None,    return_dict: Optional[bool] = None,    cache_position: Optional[torch.LongTensor] = None,) -> Union[Tuple, BaseModelOutputWithPast]:    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions    output_hidden_states = (        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states    )    use_cache = use_cache if use_cache is not None else self.config.use_cache    return_dict = return_dict if return_dict is not None else self.config.use_return_dict    if (input_ids is None) ^ (inputs_embeds is not None):        raise ValueError(            "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"        )    if self.gradient_checkpointing and self.training and use_cache:        logger.warning_once(            "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."        )        use_cache = False    if inputs_embeds is None:        inputs_embeds = self.embed_tokens(input_ids)    past_seen_tokens = 0    if use_cache:  # kept for BC (cache positions)        if not isinstance(past_key_values, StaticCache):            past_key_values = HHCache.from_legacy_cache(self.num_window_length, self.num_heavy_hitter_tokens, past_key_values)            past_seen_tokens = past_key_values.get_seq_length()    if cache_position is None:        if isinstance(past_key_values, StaticCache):            raise ValueError("cache_position is a required argument when using StaticCache.")        cache_position = torch.arange(            past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device        )    if position_ids is None:        position_ids = cache_position.unsqueeze(0)    causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)    # embed positions    hidden_states = inputs_embeds    # decoder layers    all_hidden_states = () if output_hidden_states else None    all_self_attns = () if output_attentions else None    next_decoder_cache = None    for decoder_layer in self.layers:        if output_hidden_states:            all_hidden_states += (hidden_states,)        if self.gradient_checkpointing and self.training:            layer_outputs = self._gradient_checkpointing_func(                decoder_layer.__call__,                hidden_states,                causal_mask,                position_ids,                past_key_values,                output_attentions,                use_cache,                cache_position,            )        else:            layer_outputs = decoder_layer(                hidden_states,                attention_mask=causal_mask,                position_ids=position_ids,                past_key_value=past_key_values,                output_attentions=output_attentions,                use_cache=use_cache,                cache_position=cache_position,            )        hidden_states = layer_outputs[0]        if use_cache:            next_decoder_cache = layer_outputs[2 if output_attentions else 1]        if output_attentions:            all_self_attns += (layer_outputs[1],)    hidden_states = self.norm(hidden_states)    # add hidden states from the last decoder layer    if output_hidden_states:        all_hidden_states += (hidden_states,)    next_cache = None    if use_cache:        next_cache = (            next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache        )    if not return_dict:        return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)        return BaseModelOutputWithPast(        last_hidden_state=hidden_states,        past_key_values=next_cache,        hidden_states=all_hidden_states,        attentions=all_self_attns,    )class H2OLlamaForCausalLM(LlamaForCausalLM):    def __init__(self, config):        super().__init__(config)        num_layers = len(self.model.layers)        for layer_idx in range(num_layers):            self.model.layers[layer_idx].self_attn = H2OLlamaAttention(config, layer_idx)        self.model.forward = types.MethodType(enable_h2ocache_forward, self.model)        self.model.num_heavy_hitter_tokens = config.num_heavy_hitter_tokens        self.model.num_window_length = config.num_window_length        def prepare_inputs_for_generation(        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs    ):        # With static cache, the `past_key_values` is None        # TODO joao: standardize interface for the different Cache classes and remove of this if        has_static_cache = False        if past_key_values is None:            past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None)            has_static_cache = past_key_values is not None        past_length = 0        if past_key_values is not None:            if isinstance(past_key_values, Cache):                past_length = cache_position[0]                max_cache_length = (                    torch.tensor(past_key_values.get_max_length(), device=input_ids.device)                    if past_key_values.get_max_length() is not None                    else None                )                cache_length = past_key_values.get_seq_length()            # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects            else:                past_length = cache_position[0]                cache_length = past_key_values[0].shape[2] # length = num_layers * 3 (3 -> key, value, score)                max_cache_length = None            # Keep only the unprocessed tokens:            # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where            # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as            # input)            if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:                input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]            # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard            # input_ids based on the past_length.            elif past_length < input_ids.shape[1]:                input_ids = input_ids[:, past_length:]            # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.            # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.            if (                max_cache_length is not None                and attention_mask is not None                and cache_length + input_ids.shape[1] > max_cache_length            ):                attention_mask = attention_mask[:, -max_cache_length:]        position_ids = kwargs.get("position_ids", None)        if attention_mask is not None and position_ids is None:            # create position_ids on the fly for batch generation            position_ids = attention_mask.long().cumsum(-1) - 1            position_ids.masked_fill_(attention_mask == 0, 1)            if past_key_values:                position_ids = position_ids[:, -input_ids.shape[1] :]        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step        if inputs_embeds is not None and past_key_values is None:            model_inputs = {"inputs_embeds": inputs_embeds}        else:            # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise            # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114            # TODO: use `next_tokens` directly instead.            model_inputs = {"input_ids": input_ids.contiguous()}        input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]        if cache_position is None:            cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)        else:            cache_position = cache_position[-input_length:]        if has_static_cache:            past_key_values = None        model_inputs.update(            {                "position_ids": position_ids,                "cache_position": cache_position,                "past_key_values": past_key_values,                "use_cache": kwargs.get("use_cache"),                "attention_mask": attention_mask,            }        )        return model_inputs
 |