|
@@ -369,9 +369,6 @@ def enable_h2ocache_forward(
|
|
|
all_self_attns = () if output_attentions else None
|
|
|
next_decoder_cache = None
|
|
|
|
|
|
- import pdb;pdb.set_trace()
|
|
|
-
|
|
|
-
|
|
|
for decoder_layer in self.layers:
|
|
|
if output_hidden_states:
|
|
|
all_hidden_states += (hidden_states,)
|
|
@@ -427,313 +424,6 @@ def enable_h2ocache_forward(
|
|
|
attentions=all_self_attns,
|
|
|
)
|
|
|
|
|
|
-def prepare_inputs_for_generation_w_h2o(
|
|
|
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
|
|
|
-):
|
|
|
- # With static cache, the `past_key_values` is None
|
|
|
- # TODO joao: standardize interface for the different Cache classes and remove of this if
|
|
|
-
|
|
|
- import pdb; pdb.set_trace()
|
|
|
-
|
|
|
- has_static_cache = False
|
|
|
- if past_key_values is None:
|
|
|
- past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None)
|
|
|
- has_static_cache = past_key_values is not None
|
|
|
-
|
|
|
- past_length = 0
|
|
|
- if past_key_values is not None:
|
|
|
- if isinstance(past_key_values, Cache):
|
|
|
- past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
|
|
- max_cache_length = (
|
|
|
- torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
|
|
|
- if past_key_values.get_max_length() is not None
|
|
|
- else None
|
|
|
- )
|
|
|
- cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
|
|
|
- # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
|
|
|
- else:
|
|
|
- cache_length = past_length = past_key_values[0][0].shape[2]
|
|
|
- max_cache_length = None
|
|
|
-
|
|
|
- # Keep only the unprocessed tokens:
|
|
|
- # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
|
|
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
|
|
- # input)
|
|
|
- if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
|
|
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
|
|
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
|
|
- # input_ids based on the past_length.
|
|
|
- elif past_length < input_ids.shape[1]:
|
|
|
- input_ids = input_ids[:, past_length:]
|
|
|
- # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
|
|
-
|
|
|
- # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
|
|
- if (
|
|
|
- max_cache_length is not None
|
|
|
- and attention_mask is not None
|
|
|
- and cache_length + input_ids.shape[1] > max_cache_length
|
|
|
- ):
|
|
|
- attention_mask = attention_mask[:, -max_cache_length:]
|
|
|
-
|
|
|
- position_ids = kwargs.get("position_ids", None)
|
|
|
- if attention_mask is not None and position_ids is None:
|
|
|
- # create position_ids on the fly for batch generation
|
|
|
- position_ids = attention_mask.long().cumsum(-1) - 1
|
|
|
- position_ids.masked_fill_(attention_mask == 0, 1)
|
|
|
- if past_key_values:
|
|
|
- position_ids = position_ids[:, -input_ids.shape[1] :]
|
|
|
-
|
|
|
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
|
|
- if inputs_embeds is not None and past_key_values is None:
|
|
|
- model_inputs = {"inputs_embeds": inputs_embeds}
|
|
|
- else:
|
|
|
- # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
|
|
- # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
|
|
|
- # TODO: use `next_tokens` directly instead.
|
|
|
- model_inputs = {"input_ids": input_ids.contiguous()}
|
|
|
-
|
|
|
- input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
|
|
|
- if cache_position is None:
|
|
|
- cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
|
|
|
- else:
|
|
|
- cache_position = cache_position[-input_length:]
|
|
|
-
|
|
|
- if has_static_cache:
|
|
|
- past_key_values = None
|
|
|
-
|
|
|
- model_inputs.update(
|
|
|
- {
|
|
|
- "position_ids": position_ids,
|
|
|
- "cache_position": cache_position,
|
|
|
- "past_key_values": past_key_values,
|
|
|
- "use_cache": kwargs.get("use_cache"),
|
|
|
- "attention_mask": attention_mask,
|
|
|
- }
|
|
|
- )
|
|
|
- return model_inputs
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-# class H2OLlamaAttention(nn.Module):
|
|
|
-# """Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
|
-
|
|
|
-# def __init__(self, config: LlamaConfig):
|
|
|
-# super().__init__()
|
|
|
-# self.config = config
|
|
|
-# self.hidden_size = config.hidden_size
|
|
|
-# self.num_heads = config.num_attention_heads
|
|
|
-# self.head_dim = self.hidden_size // self.num_heads
|
|
|
-# self.num_key_value_heads = config.num_key_value_heads
|
|
|
-# self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
|
|
-# self.max_position_embeddings = config.max_position_embeddings
|
|
|
-# self.rope_theta = config.rope_theta
|
|
|
-
|
|
|
-# if (self.head_dim * self.num_heads) != self.hidden_size:
|
|
|
-# raise ValueError(
|
|
|
-# f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
|
|
-# f" and `num_heads`: {self.num_heads})."
|
|
|
-# )
|
|
|
-# self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
|
|
-# self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
|
|
-# self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
|
|
-# self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
|
|
-# self._init_rope()
|
|
|
-
|
|
|
-# self.kv_cache = H2OKVCache_LayerWise(
|
|
|
-# hh_size=config.hh_size,
|
|
|
-# recent_size=config.recent_size,
|
|
|
-# k_seq_dim=2,
|
|
|
-# v_seq_dim=2,
|
|
|
-# )
|
|
|
-
|
|
|
-# def _init_rope(self):
|
|
|
-# if self.config.rope_scaling is None:
|
|
|
-# self.rotary_emb = LlamaRotaryEmbedding(
|
|
|
-# self.head_dim,
|
|
|
-# max_position_embeddings=self.max_position_embeddings,
|
|
|
-# base=self.rope_theta,
|
|
|
-# )
|
|
|
-# else:
|
|
|
-# scaling_type = self.config.rope_scaling["type"]
|
|
|
-# scaling_factor = self.config.rope_scaling["factor"]
|
|
|
-# if scaling_type == "linear":
|
|
|
-# self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
|
|
|
-# self.head_dim,
|
|
|
-# max_position_embeddings=self.max_position_embeddings,
|
|
|
-# scaling_factor=scaling_factor,
|
|
|
-# base=self.rope_theta,
|
|
|
-# )
|
|
|
-# elif scaling_type == "dynamic":
|
|
|
-# self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
|
|
|
-# self.head_dim,
|
|
|
-# max_position_embeddings=self.max_position_embeddings,
|
|
|
-# scaling_factor=scaling_factor,
|
|
|
-# base=self.rope_theta,
|
|
|
-# )
|
|
|
-# else:
|
|
|
-# raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
|
|
-
|
|
|
-# def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
|
|
-# return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
|
|
-
|
|
|
-# def _clean_cache(self):
|
|
|
-# self.kv_cache._clean_scores()
|
|
|
-
|
|
|
-# def forward(
|
|
|
-# self,
|
|
|
-# hidden_states: torch.Tensor,
|
|
|
-# attention_mask: Optional[torch.Tensor] = None,
|
|
|
-# position_ids: Optional[torch.LongTensor] = None,
|
|
|
-# past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
|
|
-# output_attentions: bool = False,
|
|
|
-# use_cache: bool = False,
|
|
|
-# cache_position: Optional[torch.LongTensor] = None,
|
|
|
-# ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
|
-
|
|
|
-# bsz, q_len, _ = hidden_states.size()
|
|
|
-
|
|
|
-# if self.config.pretraining_tp > 1:
|
|
|
-# key_value_slicing = (
|
|
|
-# self.num_key_value_heads * self.head_dim
|
|
|
-# ) // self.config.pretraining_tp
|
|
|
-# query_slices = self.q_proj.weight.split(
|
|
|
-# (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
|
|
-# )
|
|
|
-# key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
|
|
-# value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
|
|
-
|
|
|
-# query_states = [
|
|
|
-# F.linear(hidden_states, query_slices[i])
|
|
|
-# for i in range(self.config.pretraining_tp)
|
|
|
-# ]
|
|
|
-# query_states = torch.cat(query_states, dim=-1)
|
|
|
-
|
|
|
-# key_states = [
|
|
|
-# F.linear(hidden_states, key_slices[i])
|
|
|
-# for i in range(self.config.pretraining_tp)
|
|
|
-# ]
|
|
|
-# key_states = torch.cat(key_states, dim=-1)
|
|
|
-
|
|
|
-# value_states = [
|
|
|
-# F.linear(hidden_states, value_slices[i])
|
|
|
-# for i in range(self.config.pretraining_tp)
|
|
|
-# ]
|
|
|
-# value_states = torch.cat(value_states, dim=-1)
|
|
|
-
|
|
|
-# else:
|
|
|
-# query_states = self.q_proj(hidden_states)
|
|
|
-# key_states = self.k_proj(hidden_states)
|
|
|
-# value_states = self.v_proj(hidden_states)
|
|
|
-
|
|
|
-# query_states = query_states.view(
|
|
|
-# bsz, q_len, self.num_heads, self.head_dim
|
|
|
-# ).transpose(1, 2)
|
|
|
-# key_states = key_states.view(
|
|
|
-# bsz, q_len, self.num_key_value_heads, self.head_dim
|
|
|
-# ).transpose(1, 2)
|
|
|
-# value_states = value_states.view(
|
|
|
-# bsz, q_len, self.num_key_value_heads, self.head_dim
|
|
|
-# ).transpose(1, 2)
|
|
|
-
|
|
|
-# # remake causal mask
|
|
|
-# attention_mask = _make_causal_mask(
|
|
|
-# bsz=bsz,
|
|
|
-# tgt_len=q_len,
|
|
|
-# past_key_values_length=past_key_value[0].shape[-2] if past_key_value is not None else 0,
|
|
|
-# dtype=query_states.dtype,
|
|
|
-# device=query_states.device,
|
|
|
-# )
|
|
|
-
|
|
|
-# kv_seq_len = key_states.shape[-2]
|
|
|
-# if past_key_value is not None:
|
|
|
-# kv_seq_len += past_key_value[0].shape[-2]
|
|
|
-
|
|
|
-# if not position_ids.nelement() > 1:
|
|
|
-# position_ids[0][0] = kv_seq_len - 1
|
|
|
-
|
|
|
-# cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
|
-# ### Shift Pos: query pos is min(cache_size, idx)
|
|
|
-# # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
|
|
-# query_states = apply_rotary_pos_emb_single(query_states, cos, sin, position_ids)
|
|
|
-# ###
|
|
|
-
|
|
|
-# if past_key_value is not None:
|
|
|
-# # reuse k, v, self_attention
|
|
|
-# key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
|
-# value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
|
-
|
|
|
-# past_key_value = (key_states, value_states) if use_cache else None
|
|
|
-
|
|
|
-# ### Shift Pos: key pos is the pos in cache (Rolling KV Cache and using relative pos emb)
|
|
|
-# key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0)
|
|
|
-# key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids)
|
|
|
-# ###
|
|
|
-
|
|
|
-# # repeat k/v heads if n_kv_heads < n_heads
|
|
|
-# key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
|
-# value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
|
-
|
|
|
-# attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
|
|
|
-# self.head_dim
|
|
|
-# )
|
|
|
-
|
|
|
-# if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
|
|
-# raise ValueError(
|
|
|
-# f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
|
|
-# f" {attn_weights.size()}"
|
|
|
-# )
|
|
|
-
|
|
|
-# if attention_mask is not None:
|
|
|
-# if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
|
|
-# raise ValueError(
|
|
|
-# f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
|
|
-# )
|
|
|
-# attn_weights = attn_weights + attention_mask
|
|
|
-
|
|
|
-# # upcast attention to fp32
|
|
|
-# attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
|
|
|
-# query_states.dtype
|
|
|
-# )
|
|
|
-
|
|
|
-# past_key_value = self.kv_cache(past_key_value, attn_weights.detach().clone())
|
|
|
-
|
|
|
-# attn_output = torch.matmul(attn_weights, value_states)
|
|
|
-
|
|
|
-# if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
|
|
-# raise ValueError(
|
|
|
-# f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
|
|
-# f" {attn_output.size()}"
|
|
|
-# )
|
|
|
-
|
|
|
-# attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
-# attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
|
|
-
|
|
|
-# if self.config.pretraining_tp > 1:
|
|
|
-# attn_output = attn_output.split(
|
|
|
-# self.hidden_size // self.config.pretraining_tp, dim=2
|
|
|
-# )
|
|
|
-# o_proj_slices = self.o_proj.weight.split(
|
|
|
-# self.hidden_size // self.config.pretraining_tp, dim=1
|
|
|
-# )
|
|
|
-# attn_output = sum(
|
|
|
-# [
|
|
|
-# F.linear(attn_output[i], o_proj_slices[i])
|
|
|
-# for i in range(self.config.pretraining_tp)
|
|
|
-# ]
|
|
|
-# )
|
|
|
-# else:
|
|
|
-# attn_output = self.o_proj(attn_output)
|
|
|
-
|
|
|
-# if not output_attentions:
|
|
|
-# attn_weights = None
|
|
|
-
|
|
|
-# return attn_output, attn_weights, past_key_value
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
class H2OLlamaForCausalLM(LlamaForCausalLM):
|
|
|
def __init__(self, config):
|
|
|
super().__init__(config)
|
|
@@ -751,8 +441,6 @@ class H2OLlamaForCausalLM(LlamaForCausalLM):
|
|
|
# With static cache, the `past_key_values` is None
|
|
|
# TODO joao: standardize interface for the different Cache classes and remove of this if
|
|
|
|
|
|
- import pdb; pdb.set_trace()
|
|
|
-
|
|
|
has_static_cache = False
|
|
|
if past_key_values is None:
|
|
|
past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None)
|