|
@@ -0,0 +1,449 @@
|
|
|
+import math
|
|
|
+from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
+
|
|
|
+import pdb
|
|
|
+import types
|
|
|
+import torch
|
|
|
+from torch import nn
|
|
|
+import torch.utils.checkpoint
|
|
|
+import torch.nn.functional as F
|
|
|
+
|
|
|
+from transformers.models.llama.configuration_llama import LlamaConfig
|
|
|
+from transformers.models.llama.modeling_llama import (
|
|
|
+ LlamaAttention,
|
|
|
+ rotate_half,
|
|
|
+ apply_rotary_pos_emb,
|
|
|
+ repeat_kv,
|
|
|
+ LlamaRotaryEmbedding,
|
|
|
+ LlamaForCausalLM,
|
|
|
+)
|
|
|
+from utils.cache import Cache, HHCache, StaticCache
|
|
|
+from transformers.utils import logging
|
|
|
+from transformers.modeling_outputs import BaseModelOutputWithPast
|
|
|
+
|
|
|
+logger = logging.get_logger(__name__)
|
|
|
+
|
|
|
+__all__ = ["H2OLlamaForCausalLM"]
|
|
|
+
|
|
|
+def _make_causal_mask(
|
|
|
+ bsz: int, tgt_len: int, past_key_values_length: int, dtype: torch.dtype, device: torch.device):
|
|
|
+ """
|
|
|
+ Make causal mask used for bi-directional self-attention.
|
|
|
+ """
|
|
|
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
|
|
+ mask_cond = torch.arange(mask.size(-1), device=device)
|
|
|
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
|
|
+ mask = mask.to(dtype)
|
|
|
+
|
|
|
+ if past_key_values_length > 0:
|
|
|
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
|
|
|
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
|
|
+
|
|
|
+def apply_rotary_pos_emb_single(x, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
|
+
|
|
|
+ cos = cos.unsqueeze(unsqueeze_dim)
|
|
|
+ sin = sin.unsqueeze(unsqueeze_dim)
|
|
|
+ x_embed = (x * cos) + (rotate_half(x) * sin)
|
|
|
+
|
|
|
+ return x_embed
|
|
|
+
|
|
|
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
|
+ """
|
|
|
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
|
|
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
|
|
+ """
|
|
|
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
|
|
+ if n_rep == 1:
|
|
|
+ return hidden_states
|
|
|
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
|
|
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
|
|
+
|
|
|
+
|
|
|
+class H2OLlamaAttention(nn.Module):
|
|
|
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
|
|
|
+
|
|
|
+ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
|
|
|
+ super().__init__()
|
|
|
+ self.config = config
|
|
|
+ self.layer_idx = layer_idx
|
|
|
+ if layer_idx is None:
|
|
|
+ logger.warning_once(
|
|
|
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
|
|
|
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
|
|
|
+ "when creating this class."
|
|
|
+ )
|
|
|
+
|
|
|
+ self.attention_dropout = config.attention_dropout
|
|
|
+ self.hidden_size = config.hidden_size
|
|
|
+ self.num_heads = config.num_attention_heads
|
|
|
+ self.head_dim = self.hidden_size // self.num_heads
|
|
|
+ self.num_key_value_heads = config.num_key_value_heads
|
|
|
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
|
|
+ self.max_position_embeddings = config.max_position_embeddings
|
|
|
+ self.rope_theta = config.rope_theta
|
|
|
+ self.is_causal = True
|
|
|
+ self.positional_rolling = config.enable_position_rolling
|
|
|
+
|
|
|
+ if (self.head_dim * self.num_heads) != self.hidden_size:
|
|
|
+ raise ValueError(
|
|
|
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
|
|
+ f" and `num_heads`: {self.num_heads})."
|
|
|
+ )
|
|
|
+
|
|
|
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
|
|
|
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
|
|
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
|
|
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
|
|
|
+ self._init_rope()
|
|
|
+
|
|
|
+ def _init_rope(self):
|
|
|
+ if self.config.rope_scaling is None:
|
|
|
+ self.rotary_emb = LlamaRotaryEmbedding(
|
|
|
+ self.head_dim,
|
|
|
+ max_position_embeddings=self.max_position_embeddings,
|
|
|
+ base=self.rope_theta,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ scaling_type = self.config.rope_scaling["type"]
|
|
|
+ scaling_factor = self.config.rope_scaling["factor"]
|
|
|
+ if scaling_type == "linear":
|
|
|
+ self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
|
|
|
+ self.head_dim,
|
|
|
+ max_position_embeddings=self.max_position_embeddings,
|
|
|
+ scaling_factor=scaling_factor,
|
|
|
+ base=self.rope_theta,
|
|
|
+ )
|
|
|
+ elif scaling_type == "dynamic":
|
|
|
+ self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
|
|
|
+ self.head_dim,
|
|
|
+ max_position_embeddings=self.max_position_embeddings,
|
|
|
+ scaling_factor=scaling_factor,
|
|
|
+ base=self.rope_theta,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
|
|
+
|
|
|
+ def forward(
|
|
|
+ self,
|
|
|
+ hidden_states: torch.Tensor,
|
|
|
+ attention_mask: Optional[torch.Tensor] = None,
|
|
|
+ position_ids: Optional[torch.LongTensor] = None,
|
|
|
+ past_key_value: Optional[Cache] = None,
|
|
|
+ output_attentions: bool = False,
|
|
|
+ use_cache: bool = False,
|
|
|
+ cache_position: Optional[torch.LongTensor] = None,
|
|
|
+ **kwargs,
|
|
|
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
|
+ bsz, q_len, _ = hidden_states.size()
|
|
|
+
|
|
|
+ if self.config.pretraining_tp > 1:
|
|
|
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
|
|
+ query_slices = self.q_proj.weight.split(
|
|
|
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
|
|
+ )
|
|
|
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
|
|
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
|
|
+
|
|
|
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
|
|
+ query_states = torch.cat(query_states, dim=-1)
|
|
|
+
|
|
|
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
|
|
+ key_states = torch.cat(key_states, dim=-1)
|
|
|
+
|
|
|
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
|
|
+ value_states = torch.cat(value_states, dim=-1)
|
|
|
+
|
|
|
+ else:
|
|
|
+ query_states = self.q_proj(hidden_states)
|
|
|
+ key_states = self.k_proj(hidden_states)
|
|
|
+ value_states = self.v_proj(hidden_states)
|
|
|
+
|
|
|
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
+
|
|
|
+ past_key_value = getattr(self, "past_key_value", past_key_value)
|
|
|
+
|
|
|
+ if not self.positional_rolling:
|
|
|
+ cos, sin = self.rotary_emb(value_states, position_ids)
|
|
|
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
|
+ if past_key_value is not None:
|
|
|
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
|
|
|
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
|
|
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
|
+ else:
|
|
|
+ if past_key_value is not None:
|
|
|
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
|
|
|
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
|
|
|
+
|
|
|
+ kv_seq_len = past_key_value.get_seq_length(self.layer_idx) if past_key_value is not None else key_states.shape[-2]
|
|
|
+
|
|
|
+ if not position_ids.nelement() > 1:
|
|
|
+ # decoding stage
|
|
|
+ key_position_ids = torch.arange(kv_seq_len, device=hidden_states.device).unsqueeze(0)
|
|
|
+ query_position_ids = key_position_ids[:, -1].unsqueeze(0)
|
|
|
+ else:
|
|
|
+ query_position_ids = position_ids
|
|
|
+ key_position_ids = position_ids
|
|
|
+
|
|
|
+ key_cos, key_sin = self.rotary_emb(value_states, key_position_ids)
|
|
|
+ query_cos, query_sin = self.rotary_emb(value_states, query_position_ids)
|
|
|
+
|
|
|
+ if self.layer_idx == 0:
|
|
|
+ print(kv_seq_len, query_position_ids, key_position_ids)
|
|
|
+
|
|
|
+ query_states = apply_rotary_pos_emb_single(query_states, query_cos, query_sin)
|
|
|
+ key_states = apply_rotary_pos_emb_single(key_states, key_cos, key_sin)
|
|
|
+
|
|
|
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
|
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
|
+
|
|
|
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
|
|
+
|
|
|
+ if attention_mask is not None: # no matter the length, we just slice it
|
|
|
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
|
|
+ attn_weights = attn_weights + causal_mask
|
|
|
+
|
|
|
+ # upcast attention to fp32
|
|
|
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
|
|
+
|
|
|
+ # Update KV Cache based on Heavy-Hitter Oracle
|
|
|
+ if past_key_value is not None:
|
|
|
+ past_key_value.update_slimming(attn_weights, self.num_key_value_groups, self.layer_idx)
|
|
|
+
|
|
|
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
|
|
+ attn_output = torch.matmul(attn_weights, value_states)
|
|
|
+
|
|
|
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
|
|
+ raise ValueError(
|
|
|
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
|
|
+ f" {attn_output.size()}"
|
|
|
+ )
|
|
|
+
|
|
|
+ attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
+
|
|
|
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
|
|
+
|
|
|
+ if self.config.pretraining_tp > 1:
|
|
|
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
|
|
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
|
|
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
|
|
+ else:
|
|
|
+ attn_output = self.o_proj(attn_output)
|
|
|
+
|
|
|
+ if not output_attentions:
|
|
|
+ attn_weights = None
|
|
|
+
|
|
|
+ return attn_output, attn_weights, past_key_value
|
|
|
+
|
|
|
+
|
|
|
+def enable_h2ocache_forward(
|
|
|
+ self,
|
|
|
+ input_ids: torch.LongTensor = None,
|
|
|
+ attention_mask: Optional[torch.Tensor] = None,
|
|
|
+ position_ids: Optional[torch.LongTensor] = None,
|
|
|
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
|
+ inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
|
+ use_cache: Optional[bool] = None,
|
|
|
+ output_attentions: Optional[bool] = None,
|
|
|
+ output_hidden_states: Optional[bool] = None,
|
|
|
+ return_dict: Optional[bool] = None,
|
|
|
+ cache_position: Optional[torch.LongTensor] = None,
|
|
|
+) -> Union[Tuple, BaseModelOutputWithPast]:
|
|
|
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
|
+ output_hidden_states = (
|
|
|
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
+ )
|
|
|
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
|
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
+
|
|
|
+ if (input_ids is None) ^ (inputs_embeds is not None):
|
|
|
+ raise ValueError(
|
|
|
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
|
|
+ )
|
|
|
+
|
|
|
+ if self.gradient_checkpointing and self.training and use_cache:
|
|
|
+ logger.warning_once(
|
|
|
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
|
|
+ )
|
|
|
+ use_cache = False
|
|
|
+
|
|
|
+ if inputs_embeds is None:
|
|
|
+ inputs_embeds = self.embed_tokens(input_ids)
|
|
|
+
|
|
|
+ past_seen_tokens = 0
|
|
|
+ if use_cache: # kept for BC (cache positions)
|
|
|
+ if not isinstance(past_key_values, StaticCache):
|
|
|
+ past_key_values = HHCache.from_legacy_cache(self.num_window_length, self.num_heavy_hitter_tokens, past_key_values)
|
|
|
+ past_seen_tokens = past_key_values.get_seq_length()
|
|
|
+
|
|
|
+ if cache_position is None:
|
|
|
+ if isinstance(past_key_values, StaticCache):
|
|
|
+ raise ValueError("cache_position is a required argument when using StaticCache.")
|
|
|
+ cache_position = torch.arange(
|
|
|
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
|
|
+ )
|
|
|
+
|
|
|
+ if position_ids is None:
|
|
|
+ position_ids = cache_position.unsqueeze(0)
|
|
|
+
|
|
|
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
|
|
|
+
|
|
|
+ # embed positions
|
|
|
+ hidden_states = inputs_embeds
|
|
|
+
|
|
|
+ # decoder layers
|
|
|
+ all_hidden_states = () if output_hidden_states else None
|
|
|
+ all_self_attns = () if output_attentions else None
|
|
|
+ next_decoder_cache = None
|
|
|
+
|
|
|
+ for decoder_layer in self.layers:
|
|
|
+ if output_hidden_states:
|
|
|
+ all_hidden_states += (hidden_states,)
|
|
|
+
|
|
|
+ if self.gradient_checkpointing and self.training:
|
|
|
+ layer_outputs = self._gradient_checkpointing_func(
|
|
|
+ decoder_layer.__call__,
|
|
|
+ hidden_states,
|
|
|
+ causal_mask,
|
|
|
+ position_ids,
|
|
|
+ past_key_values,
|
|
|
+ output_attentions,
|
|
|
+ use_cache,
|
|
|
+ cache_position,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ layer_outputs = decoder_layer(
|
|
|
+ hidden_states,
|
|
|
+ attention_mask=causal_mask,
|
|
|
+ position_ids=position_ids,
|
|
|
+ past_key_value=past_key_values,
|
|
|
+ output_attentions=output_attentions,
|
|
|
+ use_cache=use_cache,
|
|
|
+ cache_position=cache_position,
|
|
|
+ )
|
|
|
+
|
|
|
+ hidden_states = layer_outputs[0]
|
|
|
+
|
|
|
+ if use_cache:
|
|
|
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
|
|
+
|
|
|
+ if output_attentions:
|
|
|
+ all_self_attns += (layer_outputs[1],)
|
|
|
+
|
|
|
+ hidden_states = self.norm(hidden_states)
|
|
|
+
|
|
|
+ # add hidden states from the last decoder layer
|
|
|
+ if output_hidden_states:
|
|
|
+ all_hidden_states += (hidden_states,)
|
|
|
+
|
|
|
+ next_cache = None
|
|
|
+ if use_cache:
|
|
|
+ next_cache = (
|
|
|
+ next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
|
|
|
+ )
|
|
|
+ if not return_dict:
|
|
|
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
|
|
+
|
|
|
+ return BaseModelOutputWithPast(
|
|
|
+ last_hidden_state=hidden_states,
|
|
|
+ past_key_values=next_cache,
|
|
|
+ hidden_states=all_hidden_states,
|
|
|
+ attentions=all_self_attns,
|
|
|
+ )
|
|
|
+
|
|
|
+class H2OLlamaForCausalLM(LlamaForCausalLM):
|
|
|
+ def __init__(self, config):
|
|
|
+ super().__init__(config)
|
|
|
+ num_layers = len(self.model.layers)
|
|
|
+ for layer_idx in range(num_layers):
|
|
|
+ self.model.layers[layer_idx].self_attn = H2OLlamaAttention(config, layer_idx)
|
|
|
+
|
|
|
+ self.model.forward = types.MethodType(enable_h2ocache_forward, self.model)
|
|
|
+ self.model.num_heavy_hitter_tokens = config.num_heavy_hitter_tokens
|
|
|
+ self.model.num_window_length = config.num_window_length
|
|
|
+
|
|
|
+ def prepare_inputs_for_generation(
|
|
|
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
|
|
|
+ ):
|
|
|
+ # With static cache, the `past_key_values` is None
|
|
|
+ # TODO joao: standardize interface for the different Cache classes and remove of this if
|
|
|
+
|
|
|
+ has_static_cache = False
|
|
|
+ if past_key_values is None:
|
|
|
+ past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None)
|
|
|
+ has_static_cache = past_key_values is not None
|
|
|
+
|
|
|
+ past_length = 0
|
|
|
+ if past_key_values is not None:
|
|
|
+ if isinstance(past_key_values, Cache):
|
|
|
+ past_length = cache_position[0]
|
|
|
+ max_cache_length = (
|
|
|
+ torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
|
|
|
+ if past_key_values.get_max_length() is not None
|
|
|
+ else None
|
|
|
+ )
|
|
|
+ cache_length = past_key_values.get_seq_length()
|
|
|
+
|
|
|
+ # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
|
|
|
+ else:
|
|
|
+ past_length = cache_position[0]
|
|
|
+ cache_length = past_key_values[0].shape[2] # length = num_layers * 3 (3 -> key, value, score)
|
|
|
+ max_cache_length = None
|
|
|
+
|
|
|
+ # Keep only the unprocessed tokens:
|
|
|
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
|
|
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
|
|
+ # input)
|
|
|
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
|
|
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
|
|
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
|
|
+ # input_ids based on the past_length.
|
|
|
+ elif past_length < input_ids.shape[1]:
|
|
|
+ input_ids = input_ids[:, past_length:]
|
|
|
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
|
|
+
|
|
|
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
|
|
+ if (
|
|
|
+ max_cache_length is not None
|
|
|
+ and attention_mask is not None
|
|
|
+ and cache_length + input_ids.shape[1] > max_cache_length
|
|
|
+ ):
|
|
|
+ attention_mask = attention_mask[:, -max_cache_length:]
|
|
|
+
|
|
|
+ position_ids = kwargs.get("position_ids", None)
|
|
|
+ if attention_mask is not None and position_ids is None:
|
|
|
+ # create position_ids on the fly for batch generation
|
|
|
+ position_ids = attention_mask.long().cumsum(-1) - 1
|
|
|
+ position_ids.masked_fill_(attention_mask == 0, 1)
|
|
|
+ if past_key_values:
|
|
|
+ position_ids = position_ids[:, -input_ids.shape[1] :]
|
|
|
+
|
|
|
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
|
|
+ if inputs_embeds is not None and past_key_values is None:
|
|
|
+ model_inputs = {"inputs_embeds": inputs_embeds}
|
|
|
+ else:
|
|
|
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
|
|
+ # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
|
|
|
+ # TODO: use `next_tokens` directly instead.
|
|
|
+ model_inputs = {"input_ids": input_ids.contiguous()}
|
|
|
+
|
|
|
+ input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
|
|
|
+ if cache_position is None:
|
|
|
+ cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
|
|
|
+ else:
|
|
|
+ cache_position = cache_position[-input_length:]
|
|
|
+
|
|
|
+ if has_static_cache:
|
|
|
+ past_key_values = None
|
|
|
+
|
|
|
+ model_inputs.update(
|
|
|
+ {
|
|
|
+ "position_ids": position_ids,
|
|
|
+ "cache_position": cache_position,
|
|
|
+ "past_key_values": past_key_values,
|
|
|
+ "use_cache": kwargs.get("use_cache"),
|
|
|
+ "attention_mask": attention_mask,
|
|
|
+ }
|
|
|
+ )
|
|
|
+ return model_inputs
|