utils_llama.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. import math
  2. from typing import Any, Dict, List, Optional, Tuple, Union
  3. import pdb
  4. import types
  5. import torch
  6. from torch import nn
  7. import torch.utils.checkpoint
  8. import torch.nn.functional as F
  9. from transformers.models.llama.configuration_llama import LlamaConfig
  10. from transformers.models.llama.modeling_llama import (
  11. LlamaAttention,
  12. rotate_half,
  13. apply_rotary_pos_emb,
  14. repeat_kv,
  15. LlamaRotaryEmbedding,
  16. apply_rotary_pos_emb,
  17. LlamaForCausalLM,
  18. )
  19. from cache_utils import Cache, HHCache, StaticCache
  20. from transformers.utils import logging
  21. from transformers.modeling_outputs import BaseModelOutputWithPast
  22. logger = logging.get_logger(__name__)
  23. __all__ = ["H2OLlamaForCausalLM"]
  24. def _make_causal_mask(
  25. bsz: int, tgt_len: int, past_key_values_length: int, dtype: torch.dtype, device: torch.device):
  26. """
  27. Make causal mask used for bi-directional self-attention.
  28. """
  29. mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
  30. mask_cond = torch.arange(mask.size(-1), device=device)
  31. mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
  32. mask = mask.to(dtype)
  33. if past_key_values_length > 0:
  34. mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
  35. return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
  36. def apply_rotary_pos_emb_single(x, cos, sin, position_ids):
  37. # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
  38. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
  39. sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
  40. cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
  41. sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
  42. x_embed = (x * cos) + (rotate_half(x) * sin)
  43. return x_embed
  44. class H2OKVCache_LayerWise:
  45. def __init__(
  46. self,
  47. hh_size=4,
  48. recent_size=512,
  49. k_seq_dim=2,
  50. v_seq_dim=2,
  51. ):
  52. self.hh_size = hh_size
  53. self.recent_size = recent_size
  54. self.cache_size = hh_size + recent_size
  55. self.k_seq_dim = k_seq_dim
  56. self.v_seq_dim = v_seq_dim
  57. self.hh_score = None
  58. def __call__(self, past_key_values, attn_score_cache):
  59. self._update_hh_score(attn_score_cache)
  60. if past_key_values is None:
  61. return None
  62. seq_len = past_key_values[0].size(self.k_seq_dim)
  63. if seq_len <= self.cache_size:
  64. return past_key_values
  65. # hh-selection
  66. bsz, num_heads, _, head_dim = past_key_values[0].shape
  67. select_hh_scores = self.hh_score[:, :seq_len - self.recent_size]
  68. _, keep_topk = torch.topk(select_hh_scores, self.hh_size, dim=-1)
  69. keep_topk = keep_topk.sort().values
  70. # keep_recent = torch.arange(seq_len - self.recent_size, seq_len).expand(keep_topk.shape[0], 1).to(keep_topk.device)
  71. keep_recent = torch.arange(seq_len - self.recent_size, seq_len, device=keep_topk.device).repeat(keep_topk.shape[0], 1)
  72. keep_idx = torch.cat([keep_topk, keep_recent], dim=-1)
  73. mask = torch.zeros(self.hh_score.shape, dtype=torch.bool).to(past_key_values[0].device)
  74. mask = mask.scatter(-1, keep_idx, 1)
  75. k_hh_recent = past_key_values[0].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
  76. v_hh_recent = past_key_values[1].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
  77. self.hh_score= self.hh_score[mask].view(num_heads, self.cache_size)
  78. return (k_hh_recent, v_hh_recent)
  79. def evict_for_space(self, past_key_values, num_coming):
  80. if past_key_values is None:
  81. return None
  82. seq_len = past_key_values[0][0].size(self.k_seq_dim)
  83. if seq_len + num_coming <= self.cache_size:
  84. return past_key_values
  85. # hh-selection
  86. bsz, num_heads, _, head_dim = past_key_values[0].shape
  87. select_hh_scores = self.hh_score[:, :seq_len - self.recent_size + num_coming]
  88. _, keep_topk = torch.topk(select_hh_scores, self.hh_size, dim=-1)
  89. keep_topk = keep_topk.sort().values
  90. # keep_recent = torch.arange(seq_len - self.recent_size, seq_len).expand(keep_topk.shape[0], 1).to(keep_topk.device)
  91. keep_recent = torch.arange(seq_len - self.recent_size + num_coming, seq_len, device=keep_topk.device).repeat(keep_topk.shape[0], 1)
  92. keep_idx = torch.cat([keep_topk, keep_recent], dim=-1)
  93. mask = torch.zeros(self.hh_score.shape, dtype=torch.bool).to(past_key_values[0].device)
  94. mask = mask.scatter(-1, keep_idx, 1)
  95. k_hh_recent = past_key_values[0].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
  96. v_hh_recent = past_key_values[1].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
  97. self.hh_score= self.hh_score[mask].view(num_heads, self.cache_size)
  98. return (k_hh_recent, v_hh_recent)
  99. def _update_hh_score(self, attn_score_cache):
  100. num_new_tokens = attn_score_cache.shape[2]
  101. if self.hh_score is None:
  102. self.hh_score = attn_score_cache.sum(0).sum(1)
  103. else:
  104. attn_score_cache = attn_score_cache.sum(0).sum(1)
  105. attn_score_cache[:, :-num_new_tokens] += self.hh_score
  106. self.hh_score = attn_score_cache
  107. def _clean_scores(self):
  108. self.hh_score = None
  109. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  110. """
  111. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  112. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  113. """
  114. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  115. if n_rep == 1:
  116. return hidden_states
  117. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  118. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  119. class H2OLlamaAttention(nn.Module):
  120. """Multi-headed attention from 'Attention Is All You Need' paper"""
  121. def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
  122. super().__init__()
  123. self.config = config
  124. self.layer_idx = layer_idx
  125. if layer_idx is None:
  126. logger.warning_once(
  127. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  128. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  129. "when creating this class."
  130. )
  131. self.attention_dropout = config.attention_dropout
  132. self.hidden_size = config.hidden_size
  133. self.num_heads = config.num_attention_heads
  134. self.head_dim = self.hidden_size // self.num_heads
  135. self.num_key_value_heads = config.num_key_value_heads
  136. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  137. self.max_position_embeddings = config.max_position_embeddings
  138. self.rope_theta = config.rope_theta
  139. self.is_causal = True
  140. if (self.head_dim * self.num_heads) != self.hidden_size:
  141. raise ValueError(
  142. f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
  143. f" and `num_heads`: {self.num_heads})."
  144. )
  145. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
  146. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  147. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  148. self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
  149. self._init_rope()
  150. def _init_rope(self):
  151. if self.config.rope_scaling is None:
  152. self.rotary_emb = LlamaRotaryEmbedding(
  153. self.head_dim,
  154. max_position_embeddings=self.max_position_embeddings,
  155. base=self.rope_theta,
  156. )
  157. else:
  158. scaling_type = self.config.rope_scaling["type"]
  159. scaling_factor = self.config.rope_scaling["factor"]
  160. if scaling_type == "linear":
  161. self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
  162. self.head_dim,
  163. max_position_embeddings=self.max_position_embeddings,
  164. scaling_factor=scaling_factor,
  165. base=self.rope_theta,
  166. )
  167. elif scaling_type == "dynamic":
  168. self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
  169. self.head_dim,
  170. max_position_embeddings=self.max_position_embeddings,
  171. scaling_factor=scaling_factor,
  172. base=self.rope_theta,
  173. )
  174. else:
  175. raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
  176. def forward(
  177. self,
  178. hidden_states: torch.Tensor,
  179. attention_mask: Optional[torch.Tensor] = None,
  180. position_ids: Optional[torch.LongTensor] = None,
  181. past_key_value: Optional[Cache] = None,
  182. output_attentions: bool = False,
  183. use_cache: bool = False,
  184. cache_position: Optional[torch.LongTensor] = None,
  185. **kwargs,
  186. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  187. bsz, q_len, _ = hidden_states.size()
  188. if self.config.pretraining_tp > 1:
  189. key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
  190. query_slices = self.q_proj.weight.split(
  191. (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
  192. )
  193. key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
  194. value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
  195. query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
  196. query_states = torch.cat(query_states, dim=-1)
  197. key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
  198. key_states = torch.cat(key_states, dim=-1)
  199. value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
  200. value_states = torch.cat(value_states, dim=-1)
  201. else:
  202. query_states = self.q_proj(hidden_states)
  203. key_states = self.k_proj(hidden_states)
  204. value_states = self.v_proj(hidden_states)
  205. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  206. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  207. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  208. past_key_value = getattr(self, "past_key_value", past_key_value)
  209. cos, sin = self.rotary_emb(value_states, position_ids)
  210. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  211. if past_key_value is not None:
  212. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  213. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  214. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  215. key_states = repeat_kv(key_states, self.num_key_value_groups)
  216. value_states = repeat_kv(value_states, self.num_key_value_groups)
  217. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
  218. if attention_mask is not None: # no matter the length, we just slice it
  219. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  220. attn_weights = attn_weights + causal_mask
  221. # upcast attention to fp32
  222. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
  223. # Update KV Cache based on Heavy-Hitter Oracle
  224. if past_key_value is not None:
  225. past_key_value.update_slimming(attn_weights, self.num_key_value_groups, self.layer_idx, cache_kwargs)
  226. attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
  227. attn_output = torch.matmul(attn_weights, value_states)
  228. if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
  229. raise ValueError(
  230. f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
  231. f" {attn_output.size()}"
  232. )
  233. attn_output = attn_output.transpose(1, 2).contiguous()
  234. attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
  235. if self.config.pretraining_tp > 1:
  236. attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
  237. o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
  238. attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
  239. else:
  240. attn_output = self.o_proj(attn_output)
  241. if not output_attentions:
  242. attn_weights = None
  243. if self.layer_idx == 0:
  244. print(past_key_value.key_cache[0].shape, past_key_value.value_cache[0].shape, past_key_value.accumulated_attention_scores[0][0,0,0].item())
  245. return attn_output, attn_weights, past_key_value
  246. def enable_h2ocache_forward(
  247. self,
  248. input_ids: torch.LongTensor = None,
  249. attention_mask: Optional[torch.Tensor] = None,
  250. position_ids: Optional[torch.LongTensor] = None,
  251. past_key_values: Optional[List[torch.FloatTensor]] = None,
  252. inputs_embeds: Optional[torch.FloatTensor] = None,
  253. use_cache: Optional[bool] = None,
  254. output_attentions: Optional[bool] = None,
  255. output_hidden_states: Optional[bool] = None,
  256. return_dict: Optional[bool] = None,
  257. cache_position: Optional[torch.LongTensor] = None,
  258. ) -> Union[Tuple, BaseModelOutputWithPast]:
  259. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  260. output_hidden_states = (
  261. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  262. )
  263. use_cache = use_cache if use_cache is not None else self.config.use_cache
  264. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  265. if (input_ids is None) ^ (inputs_embeds is not None):
  266. raise ValueError(
  267. "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
  268. )
  269. if self.gradient_checkpointing and self.training and use_cache:
  270. logger.warning_once(
  271. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
  272. )
  273. use_cache = False
  274. if inputs_embeds is None:
  275. inputs_embeds = self.embed_tokens(input_ids)
  276. past_seen_tokens = 0
  277. if use_cache: # kept for BC (cache positions)
  278. if not isinstance(past_key_values, StaticCache):
  279. past_key_values = HHCache.from_legacy_cache(self.num_window_length, self.num_heavy_hitter_tokens, past_key_values)
  280. past_seen_tokens = past_key_values.get_seq_length()
  281. if cache_position is None:
  282. if isinstance(past_key_values, StaticCache):
  283. raise ValueError("cache_position is a required argument when using StaticCache.")
  284. cache_position = torch.arange(
  285. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  286. )
  287. if position_ids is None:
  288. position_ids = cache_position.unsqueeze(0)
  289. causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
  290. # embed positions
  291. hidden_states = inputs_embeds
  292. # decoder layers
  293. all_hidden_states = () if output_hidden_states else None
  294. all_self_attns = () if output_attentions else None
  295. next_decoder_cache = None
  296. for decoder_layer in self.layers:
  297. if output_hidden_states:
  298. all_hidden_states += (hidden_states,)
  299. if self.gradient_checkpointing and self.training:
  300. layer_outputs = self._gradient_checkpointing_func(
  301. decoder_layer.__call__,
  302. hidden_states,
  303. causal_mask,
  304. position_ids,
  305. past_key_values,
  306. output_attentions,
  307. use_cache,
  308. cache_position,
  309. )
  310. else:
  311. layer_outputs = decoder_layer(
  312. hidden_states,
  313. attention_mask=causal_mask,
  314. position_ids=position_ids,
  315. past_key_value=past_key_values,
  316. output_attentions=output_attentions,
  317. use_cache=use_cache,
  318. cache_position=cache_position,
  319. )
  320. hidden_states = layer_outputs[0]
  321. if use_cache:
  322. next_decoder_cache = layer_outputs[2 if output_attentions else 1]
  323. if output_attentions:
  324. all_self_attns += (layer_outputs[1],)
  325. hidden_states = self.norm(hidden_states)
  326. # add hidden states from the last decoder layer
  327. if output_hidden_states:
  328. all_hidden_states += (hidden_states,)
  329. next_cache = None
  330. if use_cache:
  331. next_cache = (
  332. next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
  333. )
  334. if not return_dict:
  335. return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
  336. return BaseModelOutputWithPast(
  337. last_hidden_state=hidden_states,
  338. past_key_values=next_cache,
  339. hidden_states=all_hidden_states,
  340. attentions=all_self_attns,
  341. )
  342. class H2OLlamaForCausalLM(LlamaForCausalLM):
  343. def __init__(self, config):
  344. super().__init__(config)
  345. num_layers = len(self.model.layers)
  346. for layer_idx in range(num_layers):
  347. self.model.layers[layer_idx].self_attn = H2OLlamaAttention(config, layer_idx)
  348. self.model.forward = types.MethodType(enable_h2ocache_forward, self.model)
  349. self.model.num_heavy_hitter_tokens = config.num_heavy_hitter_tokens
  350. self.model.num_window_length = config.num_window_length
  351. def prepare_inputs_for_generation(
  352. self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
  353. ):
  354. # With static cache, the `past_key_values` is None
  355. # TODO joao: standardize interface for the different Cache classes and remove of this if
  356. has_static_cache = False
  357. if past_key_values is None:
  358. past_key_values = getattr(self.model.layers[0].self_attn, "past_key_value", None)
  359. has_static_cache = past_key_values is not None
  360. past_length = 0
  361. if past_key_values is not None:
  362. if isinstance(past_key_values, Cache):
  363. past_length = cache_position[0]
  364. max_cache_length = (
  365. torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
  366. if past_key_values.get_max_length() is not None
  367. else None
  368. )
  369. cache_length = past_key_values.get_seq_length()
  370. # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
  371. else:
  372. past_length = cache_position[0]
  373. cache_length = past_key_values[0].shape[2] # length = num_layers * 3 (3 -> key, value, score)
  374. max_cache_length = None
  375. # Keep only the unprocessed tokens:
  376. # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
  377. # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
  378. # input)
  379. if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
  380. input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
  381. # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
  382. # input_ids based on the past_length.
  383. elif past_length < input_ids.shape[1]:
  384. input_ids = input_ids[:, past_length:]
  385. # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
  386. # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
  387. if (
  388. max_cache_length is not None
  389. and attention_mask is not None
  390. and cache_length + input_ids.shape[1] > max_cache_length
  391. ):
  392. attention_mask = attention_mask[:, -max_cache_length:]
  393. position_ids = kwargs.get("position_ids", None)
  394. if attention_mask is not None and position_ids is None:
  395. # create position_ids on the fly for batch generation
  396. position_ids = attention_mask.long().cumsum(-1) - 1
  397. position_ids.masked_fill_(attention_mask == 0, 1)
  398. if past_key_values:
  399. position_ids = position_ids[:, -input_ids.shape[1] :]
  400. # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
  401. if inputs_embeds is not None and past_key_values is None:
  402. model_inputs = {"inputs_embeds": inputs_embeds}
  403. else:
  404. # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
  405. # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
  406. # TODO: use `next_tokens` directly instead.
  407. model_inputs = {"input_ids": input_ids.contiguous()}
  408. input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
  409. if cache_position is None:
  410. cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
  411. else:
  412. cache_position = cache_position[-input_length:]
  413. if has_static_cache:
  414. past_key_values = None
  415. model_inputs.update(
  416. {
  417. "position_ids": position_ids,
  418. "cache_position": cache_position,
  419. "past_key_values": past_key_values,
  420. "use_cache": kwargs.get("use_cache"),
  421. "attention_mask": attention_mask,
  422. }
  423. )
  424. return model_inputs