123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357 |
- import math
- from typing import Optional, Tuple
- import pdb
- import types
- import torch
- from torch import nn
- import torch.utils.checkpoint
- import torch.nn.functional as F
- from transformers.models.llama.configuration_llama import LlamaConfig
- from transformers.models.llama.modeling_llama import (
- LlamaAttention,
- rotate_half,
- apply_rotary_pos_emb,
- repeat_kv,
- LlamaRotaryEmbedding,
- apply_rotary_pos_emb,
- LlamaForCausalLM,
- )
- __all__ = ["H2OLlamaForCausalLM"]
- def _make_causal_mask(
- bsz: int, tgt_len: int, past_key_values_length: int, dtype: torch.dtype, device: torch.device):
- """
- Make causal mask used for bi-directional self-attention.
- """
- mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
- mask_cond = torch.arange(mask.size(-1), device=device)
- mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
- mask = mask.to(dtype)
- if past_key_values_length > 0:
- mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
- return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
- def apply_rotary_pos_emb_single(x, cos, sin, position_ids):
- # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
- cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
- sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
- cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
- sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
- x_embed = (x * cos) + (rotate_half(x) * sin)
- return x_embed
- class H2OKVCache_LayerWise:
- def __init__(
- self,
- hh_size=4,
- recent_size=512,
- k_seq_dim=2,
- v_seq_dim=2,
- ):
- self.hh_size = hh_size
- self.recent_size = recent_size
- self.cache_size = hh_size + recent_size
- self.k_seq_dim = k_seq_dim
- self.v_seq_dim = v_seq_dim
- self.k_slice = DIM_TO_SLICE[k_seq_dim]
- self.v_slice = DIM_TO_SLICE[v_seq_dim]
- self.hh_score = None
- def __call__(self, past_key_values, attn_score_cache):
- self._update_hh_score(attn_score_cache)
- if past_key_values is None:
- return None
- seq_len = past_key_values[0].size(self.k_seq_dim)
- if seq_len <= self.cache_size:
- return past_key_values
- # hh-selection
- bsz, num_heads, _, head_dim = past_key_values[0].shape
- select_hh_scores = self.hh_score[:, :seq_len - self.recent_size]
- _, keep_topk = torch.topk(select_hh_scores, self.hh_size, dim=-1)
- keep_topk = keep_topk.sort().values
- # keep_recent = torch.arange(seq_len - self.recent_size, seq_len).expand(keep_topk.shape[0], 1).to(keep_topk.device)
- keep_recent = torch.arange(seq_len - self.recent_size, seq_len, device=keep_topk.device).repeat(keep_topk.shape[0], 1)
- keep_idx = torch.cat([keep_topk, keep_recent], dim=-1)
- mask = torch.zeros(self.hh_score.shape, dtype=torch.bool).to(past_key_values[0].device)
- mask = mask.scatter(-1, keep_idx, 1)
- k_hh_recent = past_key_values[0].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
- v_hh_recent = past_key_values[1].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
- self.hh_score= self.hh_score[mask].view(num_heads, self.cache_size)
- return (k_hh_recent, v_hh_recent)
- def evict_for_space(self, past_key_values, num_coming):
- if past_key_values is None:
- return None
- seq_len = past_key_values[0][0].size(self.k_seq_dim)
- if seq_len + num_coming <= self.cache_size:
- return past_key_values
- # hh-selection
- bsz, num_heads, _, head_dim = past_key_values[0].shape
- select_hh_scores = self.hh_score[:, :seq_len - self.recent_size + num_coming]
- _, keep_topk = torch.topk(select_hh_scores, self.hh_size, dim=-1)
- keep_topk = keep_topk.sort().values
- # keep_recent = torch.arange(seq_len - self.recent_size, seq_len).expand(keep_topk.shape[0], 1).to(keep_topk.device)
- keep_recent = torch.arange(seq_len - self.recent_size + num_coming, seq_len, device=keep_topk.device).repeat(keep_topk.shape[0], 1)
- keep_idx = torch.cat([keep_topk, keep_recent], dim=-1)
- mask = torch.zeros(self.hh_score.shape, dtype=torch.bool).to(past_key_values[0].device)
- mask = mask.scatter(-1, keep_idx, 1)
- k_hh_recent = past_key_values[0].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
- v_hh_recent = past_key_values[1].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
- self.hh_score= self.hh_score[mask].view(num_heads, self.cache_size)
- return (k_hh_recent, v_hh_recent)
- def _update_hh_score(self, attn_score_cache):
- num_new_tokens = attn_score_cache.shape[2]
- if self.hh_score is None:
- self.hh_score = attn_score_cache.sum(0).sum(1)
- else:
- attn_score_cache = attn_score_cache.sum(0).sum(1)
- attn_score_cache[:, :-num_new_tokens] += self.hh_score
- self.hh_score = attn_score_cache
- def _clean_scores(self):
- self.hh_score = None
- class H2OLlamaAttention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config: LlamaConfig):
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- self.num_heads = config.num_attention_heads
- self.head_dim = self.hidden_size // self.num_heads
- self.num_key_value_heads = config.num_key_value_heads
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
- self.max_position_embeddings = config.max_position_embeddings
- self.rope_theta = config.rope_theta
- if (self.head_dim * self.num_heads) != self.hidden_size:
- raise ValueError(
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
- f" and `num_heads`: {self.num_heads})."
- )
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
- self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
- self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
- self._init_rope()
- self.kv_cache = H2OKVCache_LayerWise(
- hh_size=config.hh_size,
- recent_size=config.recent_size,
- k_seq_dim=2,
- v_seq_dim=2,
- )
- def _init_rope(self):
- if self.config.rope_scaling is None:
- self.rotary_emb = LlamaRotaryEmbedding(
- self.head_dim,
- max_position_embeddings=self.max_position_embeddings,
- base=self.rope_theta,
- )
- else:
- scaling_type = self.config.rope_scaling["type"]
- scaling_factor = self.config.rope_scaling["factor"]
- if scaling_type == "linear":
- self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
- self.head_dim,
- max_position_embeddings=self.max_position_embeddings,
- scaling_factor=scaling_factor,
- base=self.rope_theta,
- )
- elif scaling_type == "dynamic":
- self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
- self.head_dim,
- max_position_embeddings=self.max_position_embeddings,
- scaling_factor=scaling_factor,
- base=self.rope_theta,
- )
- else:
- raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
- def _clean_cache(self):
- self.kv_cache._clean_scores()
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- bsz, q_len, _ = hidden_states.size()
- if self.config.pretraining_tp > 1:
- key_value_slicing = (
- self.num_key_value_heads * self.head_dim
- ) // self.config.pretraining_tp
- query_slices = self.q_proj.weight.split(
- (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
- )
- key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
- value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
- query_states = [
- F.linear(hidden_states, query_slices[i])
- for i in range(self.config.pretraining_tp)
- ]
- query_states = torch.cat(query_states, dim=-1)
- key_states = [
- F.linear(hidden_states, key_slices[i])
- for i in range(self.config.pretraining_tp)
- ]
- key_states = torch.cat(key_states, dim=-1)
- value_states = [
- F.linear(hidden_states, value_slices[i])
- for i in range(self.config.pretraining_tp)
- ]
- value_states = torch.cat(value_states, dim=-1)
- else:
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
- query_states = query_states.view(
- bsz, q_len, self.num_heads, self.head_dim
- ).transpose(1, 2)
- key_states = key_states.view(
- bsz, q_len, self.num_key_value_heads, self.head_dim
- ).transpose(1, 2)
- value_states = value_states.view(
- bsz, q_len, self.num_key_value_heads, self.head_dim
- ).transpose(1, 2)
- # remake causal mask
- attention_mask = _make_causal_mask(
- bsz=bsz,
- tgt_len=q_len,
- past_key_values_length=past_key_value[0].shape[-2] if past_key_value is not None else 0,
- dtype=query_states.dtype,
- device=query_states.device,
- )
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- kv_seq_len += past_key_value[0].shape[-2]
- if not position_ids.nelement() > 1:
- position_ids[0][0] = kv_seq_len - 1
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
- ### Shift Pos: query pos is min(cache_size, idx)
- # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
- query_states = apply_rotary_pos_emb_single(query_states, cos, sin, position_ids)
- ###
- if past_key_value is not None:
- # reuse k, v, self_attention
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
- past_key_value = (key_states, value_states) if use_cache else None
- ### Shift Pos: key pos is the pos in cache (Rolling KV Cache and using relative pos emb)
- key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0)
- key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids)
- ###
- # repeat k/v heads if n_kv_heads < n_heads
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
- self.head_dim
- )
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
- raise ValueError(
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
- f" {attn_weights.size()}"
- )
- if attention_mask is not None:
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
- raise ValueError(
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
- )
- attn_weights = attn_weights + attention_mask
- # upcast attention to fp32
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
- query_states.dtype
- )
- past_key_value = self.kv_cache(past_key_value, attn_weights.detach().clone())
- attn_output = torch.matmul(attn_weights, value_states)
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
- attn_output = attn_output.transpose(1, 2).contiguous()
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
- if self.config.pretraining_tp > 1:
- attn_output = attn_output.split(
- self.hidden_size // self.config.pretraining_tp, dim=2
- )
- o_proj_slices = self.o_proj.weight.split(
- self.hidden_size // self.config.pretraining_tp, dim=1
- )
- attn_output = sum(
- [
- F.linear(attn_output[i], o_proj_slices[i])
- for i in range(self.config.pretraining_tp)
- ]
- )
- else:
- attn_output = self.o_proj(attn_output)
- if not output_attentions:
- attn_weights = None
- return attn_output, attn_weights, past_key_value
- class H2OLlamaForCausalLM(LlamaForCausalLM):
- def __init__(self, config):
- super().__init__(config)
- num_layers = len(self.model.layers)
- for layer_idx in range(num_layers):
- self.model.layers[layer_idx].self_attn = H2OLlamaAttention(config)
|