123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454 |
- import math
- from typing import Any, Dict, List, Optional, Tuple, Union
- import warnings
- warnings.filterwarnings("ignore")
- import pdb
- import types
- import torch
- from torch import nn
- import torch.utils.checkpoint
- import torch.nn.functional as F
- from transformers.models.llama.configuration_llama import LlamaConfig
- from transformers.models.llama.modeling_llama import (
- LlamaAttention,
- rotate_half,
- apply_rotary_pos_emb,
- repeat_kv,
- LlamaRotaryEmbedding,
- LlamaForCausalLM,
- )
- from utils.cache import Cache, HHCache, StaticCache
- from transformers.utils import logging
- from transformers.modeling_outputs import BaseModelOutputWithPast
- logger = logging.get_logger(__name__)
- __all__ = ["H2OLlamaForCausalLM"]
- def _make_causal_mask(
- bsz: int, tgt_len: int, past_key_values_length: int, dtype: torch.dtype, device: torch.device):
- """
- Make causal mask used for bi-directional self-attention.
- """
- mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
- mask_cond = torch.arange(mask.size(-1), device=device)
- mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
- mask = mask.to(dtype)
- if past_key_values_length > 0:
- mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
- return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
- def apply_rotary_pos_emb_single(x, cos, sin, position_ids=None, unsqueeze_dim=1):
- cos = cos.unsqueeze(unsqueeze_dim)
- sin = sin.unsqueeze(unsqueeze_dim)
- x_embed = (x * cos) + (rotate_half(x) * sin)
- return x_embed
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
- """
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
- """
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
- if n_rep == 1:
- return hidden_states
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
- class H2OLlamaAttention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
- super().__init__()
- self.config = config
- self.layer_idx = layer_idx
- if layer_idx is None:
- logger.warning_once(
- f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
- "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
- "when creating this class."
- )
- self.attention_dropout = config.attention_dropout
- self.hidden_size = config.hidden_size
- self.num_heads = config.num_attention_heads
- self.head_dim = self.hidden_size // self.num_heads
- self.num_key_value_heads = config.num_key_value_heads
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
- self.max_position_embeddings = config.max_position_embeddings
- self.rope_theta = config.rope_theta
- self.is_causal = True
- self.positional_rolling = config.enable_position_rolling
- if (self.head_dim * self.num_heads) != self.hidden_size:
- raise ValueError(
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
- f" and `num_heads`: {self.num_heads})."
- )
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
- self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
- self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
- self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
- self._init_rope()
- def _init_rope(self):
- if self.config.rope_scaling is None:
- self.rotary_emb = LlamaRotaryEmbedding(
- self.head_dim,
- max_position_embeddings=self.max_position_embeddings,
- base=self.rope_theta,
- )
- else:
- scaling_type = self.config.rope_scaling["type"]
- scaling_factor = self.config.rope_scaling["factor"]
- if scaling_type == "linear":
- self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
- self.head_dim,
- max_position_embeddings=self.max_position_embeddings,
- scaling_factor=scaling_factor,
- base=self.rope_theta,
- )
- elif scaling_type == "dynamic":
- self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
- self.head_dim,
- max_position_embeddings=self.max_position_embeddings,
- scaling_factor=scaling_factor,
- base=self.rope_theta,
- )
- else:
- raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- cache_position: Optional[torch.LongTensor] = None,
- **kwargs,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- bsz, q_len, _ = hidden_states.size()
- if self.config.pretraining_tp > 1:
- key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
- query_slices = self.q_proj.weight.split(
- (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
- )
- key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
- value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
- query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
- query_states = torch.cat(query_states, dim=-1)
- key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
- key_states = torch.cat(key_states, dim=-1)
- value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
- value_states = torch.cat(value_states, dim=-1)
- else:
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
- past_key_value = getattr(self, "past_key_value", past_key_value)
- if not self.positional_rolling:
- cos, sin = self.rotary_emb(value_states, position_ids)
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
- if past_key_value is not None:
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
- else:
- if past_key_value is not None:
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
- kv_seq_len = past_key_value.get_seq_length(self.layer_idx) if past_key_value is not None else key_states.shape[-2]
- if not position_ids.nelement() > 1:
- # decoding stage
- key_position_ids = torch.arange(kv_seq_len, device=hidden_states.device).unsqueeze(0)
- query_position_ids = key_position_ids[:, -1].unsqueeze(0)
- elif not kv_seq_len == position_ids.shape[-1]:
- # prefilling stage with evicting
- query_position_ids = position_ids
- key_position_ids = torch.arange(kv_seq_len, device=hidden_states.device).unsqueeze(0)
- else:
- # prefilling stage
- query_position_ids = position_ids
- key_position_ids = position_ids
- key_cos, key_sin = self.rotary_emb(value_states, key_position_ids)
- query_cos, query_sin = self.rotary_emb(value_states, query_position_ids)
- query_states = apply_rotary_pos_emb_single(query_states, query_cos, query_sin)
- key_states = apply_rotary_pos_emb_single(key_states, key_cos, key_sin)
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
- if attention_mask is not None: # no matter the length, we just slice it
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
- attn_weights = attn_weights + causal_mask
- # upcast attention to fp32
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
- # Update KV Cache based on Heavy-Hitter Oracle
- if past_key_value is not None:
- past_key_value.update_slimming(attn_weights, self.num_key_value_groups, self.layer_idx)
- attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
- attn_output = torch.matmul(attn_weights, value_states)
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
- attn_output = attn_output.transpose(1, 2).contiguous()
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
- if self.config.pretraining_tp > 1:
- attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
- o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
- attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
- else:
- attn_output = self.o_proj(attn_output)
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value
- def enable_h2ocache_forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- ) -> Union[Tuple, BaseModelOutputWithPast]:
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError(
- "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
- )
- if self.gradient_checkpointing and self.training and use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
- )
- use_cache = False
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
- past_seen_tokens = 0
- if use_cache: # kept for BC (cache positions)
- if not isinstance(past_key_values, StaticCache):
- past_key_values = HHCache.from_legacy_cache(self.num_window_length, self.num_heavy_hitter_tokens, past_key_values)
- past_seen_tokens = past_key_values.get_seq_length()
- if cache_position is None:
- if isinstance(past_key_values, StaticCache):
- raise ValueError("cache_position is a required argument when using StaticCache.")
- cache_position = torch.arange(
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
- )
- if position_ids is None:
- position_ids = cache_position.unsqueeze(0)
- causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
- # embed positions
- hidden_states = inputs_embeds
- # decoder layers
- all_hidden_states = () if output_hidden_states else None
- all_self_attns = () if output_attentions else None
- next_decoder_cache = None
- for decoder_layer in self.layers:
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
- if self.gradient_checkpointing and self.training:
- layer_outputs = self._gradient_checkpointing_func(
- decoder_layer.__call__,
- hidden_states,
- causal_mask,
- position_ids,
- past_key_values,
- output_attentions,
- use_cache,
- cache_position,
- )
- else:
- layer_outputs = decoder_layer(
- hidden_states,
- attention_mask=causal_mask,
- position_ids=position_ids,
- past_key_value=past_key_values,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- )
- hidden_states = layer_outputs[0]
- if use_cache:
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
- if output_attentions:
- all_self_attns += (layer_outputs[1],)
- hidden_states = self.norm(hidden_states)
- # add hidden states from the last decoder layer
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
- next_cache = None
- if use_cache:
- next_cache = (
- next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
- )
- if not return_dict:
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
-
- return BaseModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=next_cache,
- hidden_states=all_hidden_states,
- attentions=all_self_attns,
- )
- class H2OLlamaForCausalLM(LlamaForCausalLM):
- def __init__(self, config):
- super().__init__(config)
- num_layers = len(self.model.layers)
- for layer_idx in range(num_layers):
- self.model.layers[layer_idx].self_attn = H2OLlamaAttention(config, layer_idx)
- self.model.forward = types.MethodType(enable_h2ocache_forward, self.model)
- self.model.num_heavy_hitter_tokens = config.num_heavy_hitter_tokens
- self.model.num_window_length = config.num_window_length
-
- def prepare_inputs_for_generation(
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
- ):
- # With static cache, the `past_key_values` is None
- # TODO joao: standardize interface for the different Cache classes and remove of this if
- has_static_cache = False
- if past_key_values is None:
- past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None)
- has_static_cache = past_key_values is not None
- past_length = 0
- if past_key_values is not None:
- if isinstance(past_key_values, Cache):
- past_length = cache_position[0]
- max_cache_length = (
- torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
- if past_key_values.get_max_length() is not None
- else None
- )
- cache_length = past_key_values.get_seq_length()
- # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
- else:
- past_length = cache_position[0]
- cache_length = past_key_values[0].shape[2] # length = num_layers * 3 (3 -> key, value, score)
- max_cache_length = None
- # Keep only the unprocessed tokens:
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
- # input)
- if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
- # input_ids based on the past_length.
- elif past_length < input_ids.shape[1]:
- input_ids = input_ids[:, past_length:]
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
- if (
- max_cache_length is not None
- and attention_mask is not None
- and cache_length + input_ids.shape[1] > max_cache_length
- ):
- attention_mask = attention_mask[:, -max_cache_length:]
- position_ids = kwargs.get("position_ids", None)
- if attention_mask is not None and position_ids is None:
- # create position_ids on the fly for batch generation
- position_ids = attention_mask.long().cumsum(-1) - 1
- position_ids.masked_fill_(attention_mask == 0, 1)
- if past_key_values:
- position_ids = position_ids[:, -input_ids.shape[1] :]
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
- if inputs_embeds is not None and past_key_values is None:
- model_inputs = {"inputs_embeds": inputs_embeds}
- else:
- # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
- # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
- # TODO: use `next_tokens` directly instead.
- model_inputs = {"input_ids": input_ids.contiguous()}
- input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
- if cache_position is None:
- cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
- else:
- cache_position = cache_position[-input_length:]
- if has_static_cache:
- past_key_values = None
- model_inputs.update(
- {
- "position_ids": position_ids,
- "cache_position": cache_position,
- "past_key_values": past_key_values,
- "use_cache": kwargs.get("use_cache"),
- "attention_mask": attention_mask,
- }
- )
- return model_inputs
|