|
@@ -132,12 +132,35 @@ class H2OKVCache_LayerWise:
|
|
def _clean_scores(self):
|
|
def _clean_scores(self):
|
|
self.hh_score = None
|
|
self.hh_score = None
|
|
|
|
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
|
|
+ """
|
|
|
|
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
|
|
|
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
|
|
|
+ """
|
|
|
|
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
|
|
|
+ if n_rep == 1:
|
|
|
|
+ return hidden_states
|
|
|
|
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
|
|
|
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
|
|
|
+
|
|
|
|
+
|
|
class H2OLlamaAttention(nn.Module):
|
|
class H2OLlamaAttention(nn.Module):
|
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
|
|
|
|
- def __init__(self, config: LlamaConfig):
|
|
|
|
|
|
+ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
|
|
super().__init__()
|
|
super().__init__()
|
|
self.config = config
|
|
self.config = config
|
|
|
|
+ self.layer_idx = layer_idx
|
|
|
|
+ if layer_idx is None:
|
|
|
|
+ logger.warning_once(
|
|
|
|
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
|
|
|
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
|
|
|
+ "when creating this class."
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ self.attention_dropout = config.attention_dropout
|
|
self.hidden_size = config.hidden_size
|
|
self.hidden_size = config.hidden_size
|
|
self.num_heads = config.num_attention_heads
|
|
self.num_heads = config.num_attention_heads
|
|
self.head_dim = self.hidden_size // self.num_heads
|
|
self.head_dim = self.hidden_size // self.num_heads
|
|
@@ -145,24 +168,19 @@ class H2OLlamaAttention(nn.Module):
|
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
|
self.max_position_embeddings = config.max_position_embeddings
|
|
self.max_position_embeddings = config.max_position_embeddings
|
|
self.rope_theta = config.rope_theta
|
|
self.rope_theta = config.rope_theta
|
|
|
|
+ self.is_causal = True
|
|
|
|
|
|
if (self.head_dim * self.num_heads) != self.hidden_size:
|
|
if (self.head_dim * self.num_heads) != self.hidden_size:
|
|
raise ValueError(
|
|
raise ValueError(
|
|
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
|
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
|
f" and `num_heads`: {self.num_heads})."
|
|
f" and `num_heads`: {self.num_heads})."
|
|
)
|
|
)
|
|
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
|
|
|
- self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
|
|
|
- self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
|
|
|
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
|
|
|
- self._init_rope()
|
|
|
|
|
|
|
|
- self.kv_cache = H2OKVCache_LayerWise(
|
|
|
|
- hh_size=config.hh_size,
|
|
|
|
- recent_size=config.recent_size,
|
|
|
|
- k_seq_dim=2,
|
|
|
|
- v_seq_dim=2,
|
|
|
|
- )
|
|
|
|
|
|
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
|
|
|
|
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
|
|
|
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
|
|
|
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
|
|
|
|
+ self._init_rope()
|
|
|
|
|
|
def _init_rope(self):
|
|
def _init_rope(self):
|
|
if self.config.rope_scaling is None:
|
|
if self.config.rope_scaling is None:
|
|
@@ -191,51 +209,34 @@ class H2OLlamaAttention(nn.Module):
|
|
else:
|
|
else:
|
|
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
|
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
|
|
|
|
|
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
|
|
|
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
|
|
|
-
|
|
|
|
- def _clean_cache(self):
|
|
|
|
- self.kv_cache._clean_scores()
|
|
|
|
-
|
|
|
|
def forward(
|
|
def forward(
|
|
self,
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
|
|
|
|
+ past_key_value: Optional[Cache] = None,
|
|
output_attentions: bool = False,
|
|
output_attentions: bool = False,
|
|
use_cache: bool = False,
|
|
use_cache: bool = False,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
|
|
+ **kwargs,
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
-
|
|
|
|
bsz, q_len, _ = hidden_states.size()
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
|
|
|
if self.config.pretraining_tp > 1:
|
|
if self.config.pretraining_tp > 1:
|
|
- key_value_slicing = (
|
|
|
|
- self.num_key_value_heads * self.head_dim
|
|
|
|
- ) // self.config.pretraining_tp
|
|
|
|
|
|
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
|
query_slices = self.q_proj.weight.split(
|
|
query_slices = self.q_proj.weight.split(
|
|
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
|
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
|
)
|
|
)
|
|
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
|
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
|
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
|
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
|
|
|
|
|
- query_states = [
|
|
|
|
- F.linear(hidden_states, query_slices[i])
|
|
|
|
- for i in range(self.config.pretraining_tp)
|
|
|
|
- ]
|
|
|
|
|
|
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
|
query_states = torch.cat(query_states, dim=-1)
|
|
query_states = torch.cat(query_states, dim=-1)
|
|
|
|
|
|
- key_states = [
|
|
|
|
- F.linear(hidden_states, key_slices[i])
|
|
|
|
- for i in range(self.config.pretraining_tp)
|
|
|
|
- ]
|
|
|
|
|
|
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
|
key_states = torch.cat(key_states, dim=-1)
|
|
key_states = torch.cat(key_states, dim=-1)
|
|
|
|
|
|
- value_states = [
|
|
|
|
- F.linear(hidden_states, value_slices[i])
|
|
|
|
- for i in range(self.config.pretraining_tp)
|
|
|
|
- ]
|
|
|
|
|
|
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
|
value_states = torch.cat(value_states, dim=-1)
|
|
value_states = torch.cat(value_states, dim=-1)
|
|
|
|
|
|
else:
|
|
else:
|
|
@@ -243,78 +244,31 @@ class H2OLlamaAttention(nn.Module):
|
|
key_states = self.k_proj(hidden_states)
|
|
key_states = self.k_proj(hidden_states)
|
|
value_states = self.v_proj(hidden_states)
|
|
value_states = self.v_proj(hidden_states)
|
|
|
|
|
|
- query_states = query_states.view(
|
|
|
|
- bsz, q_len, self.num_heads, self.head_dim
|
|
|
|
- ).transpose(1, 2)
|
|
|
|
- key_states = key_states.view(
|
|
|
|
- bsz, q_len, self.num_key_value_heads, self.head_dim
|
|
|
|
- ).transpose(1, 2)
|
|
|
|
- value_states = value_states.view(
|
|
|
|
- bsz, q_len, self.num_key_value_heads, self.head_dim
|
|
|
|
- ).transpose(1, 2)
|
|
|
|
-
|
|
|
|
- # remake causal mask
|
|
|
|
- attention_mask = _make_causal_mask(
|
|
|
|
- bsz=bsz,
|
|
|
|
- tgt_len=q_len,
|
|
|
|
- past_key_values_length=past_key_value[0].shape[-2] if past_key_value is not None else 0,
|
|
|
|
- dtype=query_states.dtype,
|
|
|
|
- device=query_states.device,
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- kv_seq_len = key_states.shape[-2]
|
|
|
|
- if past_key_value is not None:
|
|
|
|
- kv_seq_len += past_key_value[0].shape[-2]
|
|
|
|
-
|
|
|
|
- if not position_ids.nelement() > 1:
|
|
|
|
- position_ids[0][0] = kv_seq_len - 1
|
|
|
|
|
|
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
|
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
|
|
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
|
|
- ### Shift Pos: query pos is min(cache_size, idx)
|
|
|
|
- # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
|
|
|
- query_states = apply_rotary_pos_emb_single(query_states, cos, sin, position_ids)
|
|
|
|
- ###
|
|
|
|
|
|
+ past_key_value = getattr(self, "past_key_value", past_key_value)
|
|
|
|
+ cos, sin = self.rotary_emb(value_states, position_ids)
|
|
|
|
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
|
|
|
|
if past_key_value is not None:
|
|
if past_key_value is not None:
|
|
- # reuse k, v, self_attention
|
|
|
|
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
|
|
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
|
|
-
|
|
|
|
- past_key_value = (key_states, value_states) if use_cache else None
|
|
|
|
-
|
|
|
|
- ### Shift Pos: key pos is the pos in cache (Rolling KV Cache and using relative pos emb)
|
|
|
|
- key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0)
|
|
|
|
- key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids)
|
|
|
|
- ###
|
|
|
|
|
|
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
|
|
|
|
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
|
|
|
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
|
|
|
|
- # repeat k/v heads if n_kv_heads < n_heads
|
|
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
|
|
|
|
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
|
|
|
|
- self.head_dim
|
|
|
|
- )
|
|
|
|
|
|
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
|
|
|
|
|
- if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
|
|
|
- raise ValueError(
|
|
|
|
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
|
|
|
- f" {attn_weights.size()}"
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- if attention_mask is not None:
|
|
|
|
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
|
|
|
- raise ValueError(
|
|
|
|
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
|
|
|
- )
|
|
|
|
- attn_weights = attn_weights + attention_mask
|
|
|
|
|
|
+ if attention_mask is not None: # no matter the length, we just slice it
|
|
|
|
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
|
|
|
+ attn_weights = attn_weights + causal_mask
|
|
|
|
|
|
# upcast attention to fp32
|
|
# upcast attention to fp32
|
|
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
|
|
|
|
- query_states.dtype
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- past_key_value = self.kv_cache(past_key_value, attn_weights.detach().clone())
|
|
|
|
-
|
|
|
|
|
|
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
|
|
|
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
|
attn_output = torch.matmul(attn_weights, value_states)
|
|
attn_output = torch.matmul(attn_weights, value_states)
|
|
|
|
|
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
|
@@ -324,21 +278,13 @@ class H2OLlamaAttention(nn.Module):
|
|
)
|
|
)
|
|
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
|
+
|
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
|
|
|
|
|
if self.config.pretraining_tp > 1:
|
|
if self.config.pretraining_tp > 1:
|
|
- attn_output = attn_output.split(
|
|
|
|
- self.hidden_size // self.config.pretraining_tp, dim=2
|
|
|
|
- )
|
|
|
|
- o_proj_slices = self.o_proj.weight.split(
|
|
|
|
- self.hidden_size // self.config.pretraining_tp, dim=1
|
|
|
|
- )
|
|
|
|
- attn_output = sum(
|
|
|
|
- [
|
|
|
|
- F.linear(attn_output[i], o_proj_slices[i])
|
|
|
|
- for i in range(self.config.pretraining_tp)
|
|
|
|
- ]
|
|
|
|
- )
|
|
|
|
|
|
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
|
|
|
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
|
|
|
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
|
else:
|
|
else:
|
|
attn_output = self.o_proj(attn_output)
|
|
attn_output = self.o_proj(attn_output)
|
|
|
|
|
|
@@ -347,6 +293,242 @@ class H2OLlamaAttention(nn.Module):
|
|
|
|
|
|
return attn_output, attn_weights, past_key_value
|
|
return attn_output, attn_weights, past_key_value
|
|
|
|
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+# class H2OLlamaAttention(nn.Module):
|
|
|
|
+# """Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
|
|
+
|
|
|
|
+# def __init__(self, config: LlamaConfig):
|
|
|
|
+# super().__init__()
|
|
|
|
+# self.config = config
|
|
|
|
+# self.hidden_size = config.hidden_size
|
|
|
|
+# self.num_heads = config.num_attention_heads
|
|
|
|
+# self.head_dim = self.hidden_size // self.num_heads
|
|
|
|
+# self.num_key_value_heads = config.num_key_value_heads
|
|
|
|
+# self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
|
|
|
+# self.max_position_embeddings = config.max_position_embeddings
|
|
|
|
+# self.rope_theta = config.rope_theta
|
|
|
|
+
|
|
|
|
+# if (self.head_dim * self.num_heads) != self.hidden_size:
|
|
|
|
+# raise ValueError(
|
|
|
|
+# f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
|
|
|
+# f" and `num_heads`: {self.num_heads})."
|
|
|
|
+# )
|
|
|
|
+# self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
|
|
|
+# self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
|
|
|
+# self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
|
|
|
+# self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
|
|
|
+# self._init_rope()
|
|
|
|
+
|
|
|
|
+# self.kv_cache = H2OKVCache_LayerWise(
|
|
|
|
+# hh_size=config.hh_size,
|
|
|
|
+# recent_size=config.recent_size,
|
|
|
|
+# k_seq_dim=2,
|
|
|
|
+# v_seq_dim=2,
|
|
|
|
+# )
|
|
|
|
+
|
|
|
|
+# def _init_rope(self):
|
|
|
|
+# if self.config.rope_scaling is None:
|
|
|
|
+# self.rotary_emb = LlamaRotaryEmbedding(
|
|
|
|
+# self.head_dim,
|
|
|
|
+# max_position_embeddings=self.max_position_embeddings,
|
|
|
|
+# base=self.rope_theta,
|
|
|
|
+# )
|
|
|
|
+# else:
|
|
|
|
+# scaling_type = self.config.rope_scaling["type"]
|
|
|
|
+# scaling_factor = self.config.rope_scaling["factor"]
|
|
|
|
+# if scaling_type == "linear":
|
|
|
|
+# self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
|
|
|
|
+# self.head_dim,
|
|
|
|
+# max_position_embeddings=self.max_position_embeddings,
|
|
|
|
+# scaling_factor=scaling_factor,
|
|
|
|
+# base=self.rope_theta,
|
|
|
|
+# )
|
|
|
|
+# elif scaling_type == "dynamic":
|
|
|
|
+# self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
|
|
|
|
+# self.head_dim,
|
|
|
|
+# max_position_embeddings=self.max_position_embeddings,
|
|
|
|
+# scaling_factor=scaling_factor,
|
|
|
|
+# base=self.rope_theta,
|
|
|
|
+# )
|
|
|
|
+# else:
|
|
|
|
+# raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
|
|
|
+
|
|
|
|
+# def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
|
|
|
+# return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
|
|
|
+
|
|
|
|
+# def _clean_cache(self):
|
|
|
|
+# self.kv_cache._clean_scores()
|
|
|
|
+
|
|
|
|
+# def forward(
|
|
|
|
+# self,
|
|
|
|
+# hidden_states: torch.Tensor,
|
|
|
|
+# attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
+# position_ids: Optional[torch.LongTensor] = None,
|
|
|
|
+# past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
|
|
+# output_attentions: bool = False,
|
|
|
|
+# use_cache: bool = False,
|
|
|
|
+# cache_position: Optional[torch.LongTensor] = None,
|
|
|
|
+# ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
|
|
+
|
|
|
|
+# bsz, q_len, _ = hidden_states.size()
|
|
|
|
+
|
|
|
|
+# if self.config.pretraining_tp > 1:
|
|
|
|
+# key_value_slicing = (
|
|
|
|
+# self.num_key_value_heads * self.head_dim
|
|
|
|
+# ) // self.config.pretraining_tp
|
|
|
|
+# query_slices = self.q_proj.weight.split(
|
|
|
|
+# (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
|
|
|
+# )
|
|
|
|
+# key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
|
|
|
+# value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
|
|
|
+
|
|
|
|
+# query_states = [
|
|
|
|
+# F.linear(hidden_states, query_slices[i])
|
|
|
|
+# for i in range(self.config.pretraining_tp)
|
|
|
|
+# ]
|
|
|
|
+# query_states = torch.cat(query_states, dim=-1)
|
|
|
|
+
|
|
|
|
+# key_states = [
|
|
|
|
+# F.linear(hidden_states, key_slices[i])
|
|
|
|
+# for i in range(self.config.pretraining_tp)
|
|
|
|
+# ]
|
|
|
|
+# key_states = torch.cat(key_states, dim=-1)
|
|
|
|
+
|
|
|
|
+# value_states = [
|
|
|
|
+# F.linear(hidden_states, value_slices[i])
|
|
|
|
+# for i in range(self.config.pretraining_tp)
|
|
|
|
+# ]
|
|
|
|
+# value_states = torch.cat(value_states, dim=-1)
|
|
|
|
+
|
|
|
|
+# else:
|
|
|
|
+# query_states = self.q_proj(hidden_states)
|
|
|
|
+# key_states = self.k_proj(hidden_states)
|
|
|
|
+# value_states = self.v_proj(hidden_states)
|
|
|
|
+
|
|
|
|
+# query_states = query_states.view(
|
|
|
|
+# bsz, q_len, self.num_heads, self.head_dim
|
|
|
|
+# ).transpose(1, 2)
|
|
|
|
+# key_states = key_states.view(
|
|
|
|
+# bsz, q_len, self.num_key_value_heads, self.head_dim
|
|
|
|
+# ).transpose(1, 2)
|
|
|
|
+# value_states = value_states.view(
|
|
|
|
+# bsz, q_len, self.num_key_value_heads, self.head_dim
|
|
|
|
+# ).transpose(1, 2)
|
|
|
|
+
|
|
|
|
+# # remake causal mask
|
|
|
|
+# attention_mask = _make_causal_mask(
|
|
|
|
+# bsz=bsz,
|
|
|
|
+# tgt_len=q_len,
|
|
|
|
+# past_key_values_length=past_key_value[0].shape[-2] if past_key_value is not None else 0,
|
|
|
|
+# dtype=query_states.dtype,
|
|
|
|
+# device=query_states.device,
|
|
|
|
+# )
|
|
|
|
+
|
|
|
|
+# kv_seq_len = key_states.shape[-2]
|
|
|
|
+# if past_key_value is not None:
|
|
|
|
+# kv_seq_len += past_key_value[0].shape[-2]
|
|
|
|
+
|
|
|
|
+# if not position_ids.nelement() > 1:
|
|
|
|
+# position_ids[0][0] = kv_seq_len - 1
|
|
|
|
+
|
|
|
|
+# cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
|
|
+# ### Shift Pos: query pos is min(cache_size, idx)
|
|
|
|
+# # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
|
|
|
+# query_states = apply_rotary_pos_emb_single(query_states, cos, sin, position_ids)
|
|
|
|
+# ###
|
|
|
|
+
|
|
|
|
+# if past_key_value is not None:
|
|
|
|
+# # reuse k, v, self_attention
|
|
|
|
+# key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
|
|
+# value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
|
|
+
|
|
|
|
+# past_key_value = (key_states, value_states) if use_cache else None
|
|
|
|
+
|
|
|
|
+# ### Shift Pos: key pos is the pos in cache (Rolling KV Cache and using relative pos emb)
|
|
|
|
+# key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0)
|
|
|
|
+# key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids)
|
|
|
|
+# ###
|
|
|
|
+
|
|
|
|
+# # repeat k/v heads if n_kv_heads < n_heads
|
|
|
|
+# key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
|
|
+# value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
|
|
+
|
|
|
|
+# attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
|
|
|
|
+# self.head_dim
|
|
|
|
+# )
|
|
|
|
+
|
|
|
|
+# if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
|
|
|
+# raise ValueError(
|
|
|
|
+# f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
|
|
|
+# f" {attn_weights.size()}"
|
|
|
|
+# )
|
|
|
|
+
|
|
|
|
+# if attention_mask is not None:
|
|
|
|
+# if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
|
|
|
+# raise ValueError(
|
|
|
|
+# f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
|
|
|
+# )
|
|
|
|
+# attn_weights = attn_weights + attention_mask
|
|
|
|
+
|
|
|
|
+# # upcast attention to fp32
|
|
|
|
+# attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
|
|
|
|
+# query_states.dtype
|
|
|
|
+# )
|
|
|
|
+
|
|
|
|
+# past_key_value = self.kv_cache(past_key_value, attn_weights.detach().clone())
|
|
|
|
+
|
|
|
|
+# attn_output = torch.matmul(attn_weights, value_states)
|
|
|
|
+
|
|
|
|
+# if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
|
|
|
+# raise ValueError(
|
|
|
|
+# f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
|
|
|
+# f" {attn_output.size()}"
|
|
|
|
+# )
|
|
|
|
+
|
|
|
|
+# attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
|
+# attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
|
|
|
+
|
|
|
|
+# if self.config.pretraining_tp > 1:
|
|
|
|
+# attn_output = attn_output.split(
|
|
|
|
+# self.hidden_size // self.config.pretraining_tp, dim=2
|
|
|
|
+# )
|
|
|
|
+# o_proj_slices = self.o_proj.weight.split(
|
|
|
|
+# self.hidden_size // self.config.pretraining_tp, dim=1
|
|
|
|
+# )
|
|
|
|
+# attn_output = sum(
|
|
|
|
+# [
|
|
|
|
+# F.linear(attn_output[i], o_proj_slices[i])
|
|
|
|
+# for i in range(self.config.pretraining_tp)
|
|
|
|
+# ]
|
|
|
|
+# )
|
|
|
|
+# else:
|
|
|
|
+# attn_output = self.o_proj(attn_output)
|
|
|
|
+
|
|
|
|
+# if not output_attentions:
|
|
|
|
+# attn_weights = None
|
|
|
|
+
|
|
|
|
+# return attn_output, attn_weights, past_key_value
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+
|
|
class H2OLlamaForCausalLM(LlamaForCausalLM):
|
|
class H2OLlamaForCausalLM(LlamaForCausalLM):
|
|
def __init__(self, config):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
super().__init__(config)
|