utils_llama.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. import math
  2. from typing import Any, Dict, List, Optional, Tuple, Union
  3. import pdb
  4. import types
  5. import torch
  6. from torch import nn
  7. import torch.utils.checkpoint
  8. import torch.nn.functional as F
  9. from transformers.models.llama.configuration_llama import LlamaConfig
  10. from transformers.models.llama.modeling_llama import (
  11. LlamaAttention,
  12. rotate_half,
  13. apply_rotary_pos_emb,
  14. repeat_kv,
  15. LlamaRotaryEmbedding,
  16. LlamaForCausalLM,
  17. )
  18. from cache_utils import Cache, HHCache, StaticCache
  19. from transformers.utils import logging
  20. from transformers.modeling_outputs import BaseModelOutputWithPast
  21. logger = logging.get_logger(__name__)
  22. __all__ = ["H2OLlamaForCausalLM"]
  23. def _make_causal_mask(
  24. bsz: int, tgt_len: int, past_key_values_length: int, dtype: torch.dtype, device: torch.device):
  25. """
  26. Make causal mask used for bi-directional self-attention.
  27. """
  28. mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
  29. mask_cond = torch.arange(mask.size(-1), device=device)
  30. mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
  31. mask = mask.to(dtype)
  32. if past_key_values_length > 0:
  33. mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
  34. return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
  35. def apply_rotary_pos_emb_single(x, cos, sin, position_ids=None, unsqueeze_dim=1):
  36. cos = cos.unsqueeze(unsqueeze_dim)
  37. sin = sin.unsqueeze(unsqueeze_dim)
  38. print(cos.shape, sin.shape, x.shape)
  39. x_embed = (x * cos) + (rotate_half(x) * sin)
  40. return x_embed
  41. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  42. """
  43. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  44. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  45. """
  46. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  47. if n_rep == 1:
  48. return hidden_states
  49. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  50. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  51. class H2OLlamaAttention(nn.Module):
  52. """Multi-headed attention from 'Attention Is All You Need' paper"""
  53. def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
  54. super().__init__()
  55. self.config = config
  56. self.layer_idx = layer_idx
  57. if layer_idx is None:
  58. logger.warning_once(
  59. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  60. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  61. "when creating this class."
  62. )
  63. self.attention_dropout = config.attention_dropout
  64. self.hidden_size = config.hidden_size
  65. self.num_heads = config.num_attention_heads
  66. self.head_dim = self.hidden_size // self.num_heads
  67. self.num_key_value_heads = config.num_key_value_heads
  68. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  69. self.max_position_embeddings = config.max_position_embeddings
  70. self.rope_theta = config.rope_theta
  71. self.is_causal = True
  72. self.positional_rolling = True
  73. if (self.head_dim * self.num_heads) != self.hidden_size:
  74. raise ValueError(
  75. f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
  76. f" and `num_heads`: {self.num_heads})."
  77. )
  78. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
  79. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  80. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  81. self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
  82. self._init_rope()
  83. def _init_rope(self):
  84. if self.config.rope_scaling is None:
  85. self.rotary_emb = LlamaRotaryEmbedding(
  86. self.head_dim,
  87. max_position_embeddings=self.max_position_embeddings,
  88. base=self.rope_theta,
  89. )
  90. else:
  91. scaling_type = self.config.rope_scaling["type"]
  92. scaling_factor = self.config.rope_scaling["factor"]
  93. if scaling_type == "linear":
  94. self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
  95. self.head_dim,
  96. max_position_embeddings=self.max_position_embeddings,
  97. scaling_factor=scaling_factor,
  98. base=self.rope_theta,
  99. )
  100. elif scaling_type == "dynamic":
  101. self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
  102. self.head_dim,
  103. max_position_embeddings=self.max_position_embeddings,
  104. scaling_factor=scaling_factor,
  105. base=self.rope_theta,
  106. )
  107. else:
  108. raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
  109. def forward(
  110. self,
  111. hidden_states: torch.Tensor,
  112. attention_mask: Optional[torch.Tensor] = None,
  113. position_ids: Optional[torch.LongTensor] = None,
  114. past_key_value: Optional[Cache] = None,
  115. output_attentions: bool = False,
  116. use_cache: bool = False,
  117. cache_position: Optional[torch.LongTensor] = None,
  118. **kwargs,
  119. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  120. bsz, q_len, _ = hidden_states.size()
  121. if self.config.pretraining_tp > 1:
  122. key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
  123. query_slices = self.q_proj.weight.split(
  124. (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
  125. )
  126. key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
  127. value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
  128. query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
  129. query_states = torch.cat(query_states, dim=-1)
  130. key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
  131. key_states = torch.cat(key_states, dim=-1)
  132. value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
  133. value_states = torch.cat(value_states, dim=-1)
  134. else:
  135. query_states = self.q_proj(hidden_states)
  136. key_states = self.k_proj(hidden_states)
  137. value_states = self.v_proj(hidden_states)
  138. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  139. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  140. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  141. past_key_value = getattr(self, "past_key_value", past_key_value)
  142. if not self.positional_rolling:
  143. cos, sin = self.rotary_emb(value_states, position_ids)
  144. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  145. if past_key_value is not None:
  146. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  147. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  148. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  149. else:
  150. if past_key_value is not None:
  151. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  152. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
  153. kv_seq_len = past_key_value.get_seq_length() if past_key_value is not None else key_states.shape[-2]
  154. if not position_ids.nelement() > 1:
  155. # decoding stage
  156. key_position_ids = torch.arange(kv_seq_len, device=hidden_states.device).unsqueeze(0)
  157. query_position_ids = key_position_ids[:, -1].unsqueeze(0)
  158. else:
  159. query_position_ids = position_ids
  160. key_position_ids = position_ids
  161. key_cos, key_sin = self.rotary_emb(value_states, key_position_ids)
  162. query_cos, query_sin = self.rotary_emb(value_states, query_position_ids)
  163. if self.layer_idx == 0:
  164. print(kv_seq_len, query_position_ids, key_position_ids)
  165. query_states = apply_rotary_pos_emb_single(query_states, query_cos, query_sin)
  166. key_states = apply_rotary_pos_emb_single(key_states, key_cos, key_sin)
  167. key_states = repeat_kv(key_states, self.num_key_value_groups)
  168. value_states = repeat_kv(value_states, self.num_key_value_groups)
  169. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
  170. if attention_mask is not None: # no matter the length, we just slice it
  171. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  172. attn_weights = attn_weights + causal_mask
  173. # upcast attention to fp32
  174. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
  175. # Update KV Cache based on Heavy-Hitter Oracle
  176. if past_key_value is not None:
  177. past_key_value.update_slimming(attn_weights, self.num_key_value_groups, self.layer_idx)
  178. attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
  179. attn_output = torch.matmul(attn_weights, value_states)
  180. if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
  181. raise ValueError(
  182. f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
  183. f" {attn_output.size()}"
  184. )
  185. attn_output = attn_output.transpose(1, 2).contiguous()
  186. attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
  187. if self.config.pretraining_tp > 1:
  188. attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
  189. o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
  190. attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
  191. else:
  192. attn_output = self.o_proj(attn_output)
  193. if not output_attentions:
  194. attn_weights = None
  195. return attn_output, attn_weights, past_key_value
  196. def enable_h2ocache_forward(
  197. self,
  198. input_ids: torch.LongTensor = None,
  199. attention_mask: Optional[torch.Tensor] = None,
  200. position_ids: Optional[torch.LongTensor] = None,
  201. past_key_values: Optional[List[torch.FloatTensor]] = None,
  202. inputs_embeds: Optional[torch.FloatTensor] = None,
  203. use_cache: Optional[bool] = None,
  204. output_attentions: Optional[bool] = None,
  205. output_hidden_states: Optional[bool] = None,
  206. return_dict: Optional[bool] = None,
  207. cache_position: Optional[torch.LongTensor] = None,
  208. ) -> Union[Tuple, BaseModelOutputWithPast]:
  209. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  210. output_hidden_states = (
  211. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  212. )
  213. use_cache = use_cache if use_cache is not None else self.config.use_cache
  214. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  215. if (input_ids is None) ^ (inputs_embeds is not None):
  216. raise ValueError(
  217. "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
  218. )
  219. if self.gradient_checkpointing and self.training and use_cache:
  220. logger.warning_once(
  221. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
  222. )
  223. use_cache = False
  224. if inputs_embeds is None:
  225. inputs_embeds = self.embed_tokens(input_ids)
  226. past_seen_tokens = 0
  227. if use_cache: # kept for BC (cache positions)
  228. if not isinstance(past_key_values, StaticCache):
  229. past_key_values = HHCache.from_legacy_cache(self.num_window_length, self.num_heavy_hitter_tokens, past_key_values)
  230. past_seen_tokens = past_key_values.get_seq_length()
  231. if cache_position is None:
  232. if isinstance(past_key_values, StaticCache):
  233. raise ValueError("cache_position is a required argument when using StaticCache.")
  234. cache_position = torch.arange(
  235. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  236. )
  237. if position_ids is None:
  238. position_ids = cache_position.unsqueeze(0)
  239. causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
  240. # embed positions
  241. hidden_states = inputs_embeds
  242. # decoder layers
  243. all_hidden_states = () if output_hidden_states else None
  244. all_self_attns = () if output_attentions else None
  245. next_decoder_cache = None
  246. for decoder_layer in self.layers:
  247. if output_hidden_states:
  248. all_hidden_states += (hidden_states,)
  249. if self.gradient_checkpointing and self.training:
  250. layer_outputs = self._gradient_checkpointing_func(
  251. decoder_layer.__call__,
  252. hidden_states,
  253. causal_mask,
  254. position_ids,
  255. past_key_values,
  256. output_attentions,
  257. use_cache,
  258. cache_position,
  259. )
  260. else:
  261. layer_outputs = decoder_layer(
  262. hidden_states,
  263. attention_mask=causal_mask,
  264. position_ids=position_ids,
  265. past_key_value=past_key_values,
  266. output_attentions=output_attentions,
  267. use_cache=use_cache,
  268. cache_position=cache_position,
  269. )
  270. hidden_states = layer_outputs[0]
  271. if use_cache:
  272. next_decoder_cache = layer_outputs[2 if output_attentions else 1]
  273. if output_attentions:
  274. all_self_attns += (layer_outputs[1],)
  275. hidden_states = self.norm(hidden_states)
  276. # add hidden states from the last decoder layer
  277. if output_hidden_states:
  278. all_hidden_states += (hidden_states,)
  279. next_cache = None
  280. if use_cache:
  281. next_cache = (
  282. next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
  283. )
  284. if not return_dict:
  285. return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
  286. return BaseModelOutputWithPast(
  287. last_hidden_state=hidden_states,
  288. past_key_values=next_cache,
  289. hidden_states=all_hidden_states,
  290. attentions=all_self_attns,
  291. )
  292. class H2OLlamaForCausalLM(LlamaForCausalLM):
  293. def __init__(self, config):
  294. super().__init__(config)
  295. num_layers = len(self.model.layers)
  296. for layer_idx in range(num_layers):
  297. self.model.layers[layer_idx].self_attn = H2OLlamaAttention(config, layer_idx)
  298. self.model.forward = types.MethodType(enable_h2ocache_forward, self.model)
  299. self.model.num_heavy_hitter_tokens = config.num_heavy_hitter_tokens
  300. self.model.num_window_length = config.num_window_length
  301. def prepare_inputs_for_generation(
  302. self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
  303. ):
  304. # With static cache, the `past_key_values` is None
  305. # TODO joao: standardize interface for the different Cache classes and remove of this if
  306. has_static_cache = False
  307. if past_key_values is None:
  308. past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None)
  309. has_static_cache = past_key_values is not None
  310. past_length = 0
  311. if past_key_values is not None:
  312. if isinstance(past_key_values, Cache):
  313. past_length = cache_position[0]
  314. max_cache_length = (
  315. torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
  316. if past_key_values.get_max_length() is not None
  317. else None
  318. )
  319. cache_length = past_key_values.get_seq_length()
  320. # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
  321. else:
  322. past_length = cache_position[0]
  323. cache_length = past_key_values[0].shape[2] # length = num_layers * 3 (3 -> key, value, score)
  324. max_cache_length = None
  325. # Keep only the unprocessed tokens:
  326. # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
  327. # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
  328. # input)
  329. if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
  330. input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
  331. # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
  332. # input_ids based on the past_length.
  333. elif past_length < input_ids.shape[1]:
  334. input_ids = input_ids[:, past_length:]
  335. # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
  336. # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
  337. if (
  338. max_cache_length is not None
  339. and attention_mask is not None
  340. and cache_length + input_ids.shape[1] > max_cache_length
  341. ):
  342. attention_mask = attention_mask[:, -max_cache_length:]
  343. position_ids = kwargs.get("position_ids", None)
  344. if attention_mask is not None and position_ids is None:
  345. # create position_ids on the fly for batch generation
  346. position_ids = attention_mask.long().cumsum(-1) - 1
  347. position_ids.masked_fill_(attention_mask == 0, 1)
  348. if past_key_values:
  349. position_ids = position_ids[:, -input_ids.shape[1] :]
  350. # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
  351. if inputs_embeds is not None and past_key_values is None:
  352. model_inputs = {"inputs_embeds": inputs_embeds}
  353. else:
  354. # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
  355. # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
  356. # TODO: use `next_tokens` directly instead.
  357. model_inputs = {"input_ids": input_ids.contiguous()}
  358. input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
  359. if cache_position is None:
  360. cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
  361. else:
  362. cache_position = cache_position[-input_length:]
  363. if has_static_cache:
  364. past_key_values = None
  365. model_inputs.update(
  366. {
  367. "position_ids": position_ids,
  368. "cache_position": cache_position,
  369. "past_key_values": past_key_values,
  370. "use_cache": kwargs.get("use_cache"),
  371. "attention_mask": attention_mask,
  372. }
  373. )
  374. return model_inputs