utils_llama.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. import math
  2. from typing import Any, Dict, List, Optional, Tuple, Union
  3. import pdb
  4. import types
  5. import torch
  6. from torch import nn
  7. import torch.utils.checkpoint
  8. import torch.nn.functional as F
  9. from transformers.models.llama.configuration_llama import LlamaConfig
  10. from transformers.models.llama.modeling_llama import (
  11. LlamaAttention,
  12. rotate_half,
  13. apply_rotary_pos_emb,
  14. repeat_kv,
  15. LlamaRotaryEmbedding,
  16. LlamaForCausalLM,
  17. )
  18. from cache_utils import Cache, HHCache, StaticCache
  19. from transformers.utils import logging
  20. from transformers.modeling_outputs import BaseModelOutputWithPast
  21. logger = logging.get_logger(__name__)
  22. __all__ = ["H2OLlamaForCausalLM"]
  23. def _make_causal_mask(
  24. bsz: int, tgt_len: int, past_key_values_length: int, dtype: torch.dtype, device: torch.device):
  25. """
  26. Make causal mask used for bi-directional self-attention.
  27. """
  28. mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
  29. mask_cond = torch.arange(mask.size(-1), device=device)
  30. mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
  31. mask = mask.to(dtype)
  32. if past_key_values_length > 0:
  33. mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
  34. return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
  35. def apply_rotary_pos_emb_single(x, cos, sin, position_ids):
  36. # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
  37. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
  38. sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
  39. cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
  40. sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
  41. x_embed = (x * cos) + (rotate_half(x) * sin)
  42. return x_embed
  43. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  44. """
  45. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  46. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  47. """
  48. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  49. if n_rep == 1:
  50. return hidden_states
  51. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  52. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  53. class H2OLlamaAttention(nn.Module):
  54. """Multi-headed attention from 'Attention Is All You Need' paper"""
  55. def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
  56. super().__init__()
  57. self.config = config
  58. self.layer_idx = layer_idx
  59. if layer_idx is None:
  60. logger.warning_once(
  61. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  62. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  63. "when creating this class."
  64. )
  65. self.attention_dropout = config.attention_dropout
  66. self.hidden_size = config.hidden_size
  67. self.num_heads = config.num_attention_heads
  68. self.head_dim = self.hidden_size // self.num_heads
  69. self.num_key_value_heads = config.num_key_value_heads
  70. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  71. self.max_position_embeddings = config.max_position_embeddings
  72. self.rope_theta = config.rope_theta
  73. self.is_causal = True
  74. self.positional_rolling = True
  75. if (self.head_dim * self.num_heads) != self.hidden_size:
  76. raise ValueError(
  77. f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
  78. f" and `num_heads`: {self.num_heads})."
  79. )
  80. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
  81. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  82. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  83. self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
  84. self._init_rope()
  85. def _init_rope(self):
  86. if self.config.rope_scaling is None:
  87. self.rotary_emb = LlamaRotaryEmbedding(
  88. self.head_dim,
  89. max_position_embeddings=self.max_position_embeddings,
  90. base=self.rope_theta,
  91. )
  92. else:
  93. scaling_type = self.config.rope_scaling["type"]
  94. scaling_factor = self.config.rope_scaling["factor"]
  95. if scaling_type == "linear":
  96. self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
  97. self.head_dim,
  98. max_position_embeddings=self.max_position_embeddings,
  99. scaling_factor=scaling_factor,
  100. base=self.rope_theta,
  101. )
  102. elif scaling_type == "dynamic":
  103. self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
  104. self.head_dim,
  105. max_position_embeddings=self.max_position_embeddings,
  106. scaling_factor=scaling_factor,
  107. base=self.rope_theta,
  108. )
  109. else:
  110. raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
  111. def forward(
  112. self,
  113. hidden_states: torch.Tensor,
  114. attention_mask: Optional[torch.Tensor] = None,
  115. position_ids: Optional[torch.LongTensor] = None,
  116. past_key_value: Optional[Cache] = None,
  117. output_attentions: bool = False,
  118. use_cache: bool = False,
  119. cache_position: Optional[torch.LongTensor] = None,
  120. **kwargs,
  121. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  122. bsz, q_len, _ = hidden_states.size()
  123. if self.config.pretraining_tp > 1:
  124. key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
  125. query_slices = self.q_proj.weight.split(
  126. (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
  127. )
  128. key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
  129. value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
  130. query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
  131. query_states = torch.cat(query_states, dim=-1)
  132. key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
  133. key_states = torch.cat(key_states, dim=-1)
  134. value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
  135. value_states = torch.cat(value_states, dim=-1)
  136. else:
  137. query_states = self.q_proj(hidden_states)
  138. key_states = self.k_proj(hidden_states)
  139. value_states = self.v_proj(hidden_states)
  140. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  141. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  142. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  143. past_key_value = getattr(self, "past_key_value", past_key_value)
  144. if not self.positional_rolling:
  145. cos, sin = self.rotary_emb(value_states, position_ids)
  146. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  147. if past_key_value is not None:
  148. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  149. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  150. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  151. else:
  152. if past_key_value is not None:
  153. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  154. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
  155. kv_seq_len = past_key_value.get_seq_length() if past_key_value is not None else key_states.shape[-2]
  156. if self.layer_idx == 0:
  157. import pdb; pdb.set_trace()
  158. if not position_ids.nelement() > 1:
  159. # decoding stage
  160. key_position_ids = torch.arange(kv_seq_len, device=hidden_states.device).unsqueeze(0)
  161. query_position_ids = key_position_ids[:, -1].unsqueeze(0)
  162. else:
  163. query_position_ids = position_ids
  164. key_position_ids = position_ids
  165. key_cos, key_sin = self.rotary_emb(value_states, key_position_ids)
  166. query_cos, query_sin = self.rotary_emb(value_states, query_position_ids)
  167. if self.layer_idx == 0:
  168. print(kv_seq_len, query_position_ids, key_position_ids)
  169. query_states = apply_rotary_pos_emb_single(query_states, query_cos, query_sin, query_position_ids)
  170. key_states = apply_rotary_pos_emb_single(key_states, key_cos, key_sin, key_position_ids)
  171. key_states = repeat_kv(key_states, self.num_key_value_groups)
  172. value_states = repeat_kv(value_states, self.num_key_value_groups)
  173. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
  174. if attention_mask is not None: # no matter the length, we just slice it
  175. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  176. attn_weights = attn_weights + causal_mask
  177. # upcast attention to fp32
  178. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
  179. # Update KV Cache based on Heavy-Hitter Oracle
  180. if past_key_value is not None:
  181. past_key_value.update_slimming(attn_weights, self.num_key_value_groups, self.layer_idx)
  182. attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
  183. attn_output = torch.matmul(attn_weights, value_states)
  184. if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
  185. raise ValueError(
  186. f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
  187. f" {attn_output.size()}"
  188. )
  189. attn_output = attn_output.transpose(1, 2).contiguous()
  190. attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
  191. if self.config.pretraining_tp > 1:
  192. attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
  193. o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
  194. attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
  195. else:
  196. attn_output = self.o_proj(attn_output)
  197. if not output_attentions:
  198. attn_weights = None
  199. return attn_output, attn_weights, past_key_value
  200. def enable_h2ocache_forward(
  201. self,
  202. input_ids: torch.LongTensor = None,
  203. attention_mask: Optional[torch.Tensor] = None,
  204. position_ids: Optional[torch.LongTensor] = None,
  205. past_key_values: Optional[List[torch.FloatTensor]] = None,
  206. inputs_embeds: Optional[torch.FloatTensor] = None,
  207. use_cache: Optional[bool] = None,
  208. output_attentions: Optional[bool] = None,
  209. output_hidden_states: Optional[bool] = None,
  210. return_dict: Optional[bool] = None,
  211. cache_position: Optional[torch.LongTensor] = None,
  212. ) -> Union[Tuple, BaseModelOutputWithPast]:
  213. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  214. output_hidden_states = (
  215. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  216. )
  217. use_cache = use_cache if use_cache is not None else self.config.use_cache
  218. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  219. if (input_ids is None) ^ (inputs_embeds is not None):
  220. raise ValueError(
  221. "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
  222. )
  223. if self.gradient_checkpointing and self.training and use_cache:
  224. logger.warning_once(
  225. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
  226. )
  227. use_cache = False
  228. if inputs_embeds is None:
  229. inputs_embeds = self.embed_tokens(input_ids)
  230. past_seen_tokens = 0
  231. if use_cache: # kept for BC (cache positions)
  232. if not isinstance(past_key_values, StaticCache):
  233. past_key_values = HHCache.from_legacy_cache(self.num_window_length, self.num_heavy_hitter_tokens, past_key_values)
  234. past_seen_tokens = past_key_values.get_seq_length()
  235. if cache_position is None:
  236. if isinstance(past_key_values, StaticCache):
  237. raise ValueError("cache_position is a required argument when using StaticCache.")
  238. cache_position = torch.arange(
  239. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  240. )
  241. if position_ids is None:
  242. position_ids = cache_position.unsqueeze(0)
  243. causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
  244. # embed positions
  245. hidden_states = inputs_embeds
  246. # decoder layers
  247. all_hidden_states = () if output_hidden_states else None
  248. all_self_attns = () if output_attentions else None
  249. next_decoder_cache = None
  250. for decoder_layer in self.layers:
  251. if output_hidden_states:
  252. all_hidden_states += (hidden_states,)
  253. if self.gradient_checkpointing and self.training:
  254. layer_outputs = self._gradient_checkpointing_func(
  255. decoder_layer.__call__,
  256. hidden_states,
  257. causal_mask,
  258. position_ids,
  259. past_key_values,
  260. output_attentions,
  261. use_cache,
  262. cache_position,
  263. )
  264. else:
  265. layer_outputs = decoder_layer(
  266. hidden_states,
  267. attention_mask=causal_mask,
  268. position_ids=position_ids,
  269. past_key_value=past_key_values,
  270. output_attentions=output_attentions,
  271. use_cache=use_cache,
  272. cache_position=cache_position,
  273. )
  274. hidden_states = layer_outputs[0]
  275. if use_cache:
  276. next_decoder_cache = layer_outputs[2 if output_attentions else 1]
  277. if output_attentions:
  278. all_self_attns += (layer_outputs[1],)
  279. hidden_states = self.norm(hidden_states)
  280. # add hidden states from the last decoder layer
  281. if output_hidden_states:
  282. all_hidden_states += (hidden_states,)
  283. next_cache = None
  284. if use_cache:
  285. next_cache = (
  286. next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
  287. )
  288. if not return_dict:
  289. return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
  290. return BaseModelOutputWithPast(
  291. last_hidden_state=hidden_states,
  292. past_key_values=next_cache,
  293. hidden_states=all_hidden_states,
  294. attentions=all_self_attns,
  295. )
  296. class H2OLlamaForCausalLM(LlamaForCausalLM):
  297. def __init__(self, config):
  298. super().__init__(config)
  299. num_layers = len(self.model.layers)
  300. for layer_idx in range(num_layers):
  301. self.model.layers[layer_idx].self_attn = H2OLlamaAttention(config, layer_idx)
  302. self.model.forward = types.MethodType(enable_h2ocache_forward, self.model)
  303. self.model.num_heavy_hitter_tokens = config.num_heavy_hitter_tokens
  304. self.model.num_window_length = config.num_window_length
  305. def prepare_inputs_for_generation(
  306. self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
  307. ):
  308. # With static cache, the `past_key_values` is None
  309. # TODO joao: standardize interface for the different Cache classes and remove of this if
  310. has_static_cache = False
  311. if past_key_values is None:
  312. past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None)
  313. has_static_cache = past_key_values is not None
  314. past_length = 0
  315. if past_key_values is not None:
  316. if isinstance(past_key_values, Cache):
  317. past_length = cache_position[0]
  318. max_cache_length = (
  319. torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
  320. if past_key_values.get_max_length() is not None
  321. else None
  322. )
  323. cache_length = past_key_values.get_seq_length()
  324. # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
  325. else:
  326. past_length = cache_position[0]
  327. cache_length = past_key_values[0].shape[2] # length = num_layers * 3 (3 -> key, value, score)
  328. max_cache_length = None
  329. # Keep only the unprocessed tokens:
  330. # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
  331. # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
  332. # input)
  333. if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
  334. input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
  335. # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
  336. # input_ids based on the past_length.
  337. elif past_length < input_ids.shape[1]:
  338. input_ids = input_ids[:, past_length:]
  339. # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
  340. # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
  341. if (
  342. max_cache_length is not None
  343. and attention_mask is not None
  344. and cache_length + input_ids.shape[1] > max_cache_length
  345. ):
  346. attention_mask = attention_mask[:, -max_cache_length:]
  347. position_ids = kwargs.get("position_ids", None)
  348. if attention_mask is not None and position_ids is None:
  349. # create position_ids on the fly for batch generation
  350. position_ids = attention_mask.long().cumsum(-1) - 1
  351. position_ids.masked_fill_(attention_mask == 0, 1)
  352. if past_key_values:
  353. position_ids = position_ids[:, -input_ids.shape[1] :]
  354. # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
  355. if inputs_embeds is not None and past_key_values is None:
  356. model_inputs = {"inputs_embeds": inputs_embeds}
  357. else:
  358. # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
  359. # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
  360. # TODO: use `next_tokens` directly instead.
  361. model_inputs = {"input_ids": input_ids.contiguous()}
  362. input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
  363. if cache_position is None:
  364. cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
  365. else:
  366. cache_position = cache_position[-input_length:]
  367. if has_static_cache:
  368. past_key_values = None
  369. model_inputs.update(
  370. {
  371. "position_ids": position_ids,
  372. "cache_position": cache_position,
  373. "past_key_values": past_key_values,
  374. "use_cache": kwargs.get("use_cache"),
  375. "attention_mask": attention_mask,
  376. }
  377. )
  378. return model_inputs