utils_llama.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710
  1. import math
  2. from typing import Any, Dict, List, Optional, Tuple, Union
  3. import pdb
  4. import types
  5. import torch
  6. from torch import nn
  7. import torch.utils.checkpoint
  8. import torch.nn.functional as F
  9. from transformers.models.llama.configuration_llama import LlamaConfig
  10. from transformers.models.llama.modeling_llama import (
  11. LlamaAttention,
  12. rotate_half,
  13. apply_rotary_pos_emb,
  14. repeat_kv,
  15. LlamaRotaryEmbedding,
  16. apply_rotary_pos_emb,
  17. LlamaForCausalLM,
  18. )
  19. from cache_utils import Cache, HHCache
  20. from transformers.utils import logging
  21. logger = logging.get_logger(__name__)
  22. __all__ = ["H2OLlamaForCausalLM"]
  23. def _make_causal_mask(
  24. bsz: int, tgt_len: int, past_key_values_length: int, dtype: torch.dtype, device: torch.device):
  25. """
  26. Make causal mask used for bi-directional self-attention.
  27. """
  28. mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
  29. mask_cond = torch.arange(mask.size(-1), device=device)
  30. mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
  31. mask = mask.to(dtype)
  32. if past_key_values_length > 0:
  33. mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
  34. return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
  35. def apply_rotary_pos_emb_single(x, cos, sin, position_ids):
  36. # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
  37. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
  38. sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
  39. cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
  40. sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
  41. x_embed = (x * cos) + (rotate_half(x) * sin)
  42. return x_embed
  43. class H2OKVCache_LayerWise:
  44. def __init__(
  45. self,
  46. hh_size=4,
  47. recent_size=512,
  48. k_seq_dim=2,
  49. v_seq_dim=2,
  50. ):
  51. self.hh_size = hh_size
  52. self.recent_size = recent_size
  53. self.cache_size = hh_size + recent_size
  54. self.k_seq_dim = k_seq_dim
  55. self.v_seq_dim = v_seq_dim
  56. self.hh_score = None
  57. def __call__(self, past_key_values, attn_score_cache):
  58. self._update_hh_score(attn_score_cache)
  59. if past_key_values is None:
  60. return None
  61. seq_len = past_key_values[0].size(self.k_seq_dim)
  62. if seq_len <= self.cache_size:
  63. return past_key_values
  64. # hh-selection
  65. bsz, num_heads, _, head_dim = past_key_values[0].shape
  66. select_hh_scores = self.hh_score[:, :seq_len - self.recent_size]
  67. _, keep_topk = torch.topk(select_hh_scores, self.hh_size, dim=-1)
  68. keep_topk = keep_topk.sort().values
  69. # keep_recent = torch.arange(seq_len - self.recent_size, seq_len).expand(keep_topk.shape[0], 1).to(keep_topk.device)
  70. keep_recent = torch.arange(seq_len - self.recent_size, seq_len, device=keep_topk.device).repeat(keep_topk.shape[0], 1)
  71. keep_idx = torch.cat([keep_topk, keep_recent], dim=-1)
  72. mask = torch.zeros(self.hh_score.shape, dtype=torch.bool).to(past_key_values[0].device)
  73. mask = mask.scatter(-1, keep_idx, 1)
  74. k_hh_recent = past_key_values[0].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
  75. v_hh_recent = past_key_values[1].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
  76. self.hh_score= self.hh_score[mask].view(num_heads, self.cache_size)
  77. return (k_hh_recent, v_hh_recent)
  78. def evict_for_space(self, past_key_values, num_coming):
  79. if past_key_values is None:
  80. return None
  81. seq_len = past_key_values[0][0].size(self.k_seq_dim)
  82. if seq_len + num_coming <= self.cache_size:
  83. return past_key_values
  84. # hh-selection
  85. bsz, num_heads, _, head_dim = past_key_values[0].shape
  86. select_hh_scores = self.hh_score[:, :seq_len - self.recent_size + num_coming]
  87. _, keep_topk = torch.topk(select_hh_scores, self.hh_size, dim=-1)
  88. keep_topk = keep_topk.sort().values
  89. # keep_recent = torch.arange(seq_len - self.recent_size, seq_len).expand(keep_topk.shape[0], 1).to(keep_topk.device)
  90. keep_recent = torch.arange(seq_len - self.recent_size + num_coming, seq_len, device=keep_topk.device).repeat(keep_topk.shape[0], 1)
  91. keep_idx = torch.cat([keep_topk, keep_recent], dim=-1)
  92. mask = torch.zeros(self.hh_score.shape, dtype=torch.bool).to(past_key_values[0].device)
  93. mask = mask.scatter(-1, keep_idx, 1)
  94. k_hh_recent = past_key_values[0].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
  95. v_hh_recent = past_key_values[1].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
  96. self.hh_score= self.hh_score[mask].view(num_heads, self.cache_size)
  97. return (k_hh_recent, v_hh_recent)
  98. def _update_hh_score(self, attn_score_cache):
  99. num_new_tokens = attn_score_cache.shape[2]
  100. if self.hh_score is None:
  101. self.hh_score = attn_score_cache.sum(0).sum(1)
  102. else:
  103. attn_score_cache = attn_score_cache.sum(0).sum(1)
  104. attn_score_cache[:, :-num_new_tokens] += self.hh_score
  105. self.hh_score = attn_score_cache
  106. def _clean_scores(self):
  107. self.hh_score = None
  108. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  109. """
  110. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  111. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  112. """
  113. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  114. if n_rep == 1:
  115. return hidden_states
  116. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  117. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  118. class H2OLlamaAttention(nn.Module):
  119. """Multi-headed attention from 'Attention Is All You Need' paper"""
  120. def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
  121. super().__init__()
  122. self.config = config
  123. self.layer_idx = layer_idx
  124. if layer_idx is None:
  125. logger.warning_once(
  126. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  127. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  128. "when creating this class."
  129. )
  130. self.attention_dropout = config.attention_dropout
  131. self.hidden_size = config.hidden_size
  132. self.num_heads = config.num_attention_heads
  133. self.head_dim = self.hidden_size // self.num_heads
  134. self.num_key_value_heads = config.num_key_value_heads
  135. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  136. self.max_position_embeddings = config.max_position_embeddings
  137. self.rope_theta = config.rope_theta
  138. self.is_causal = True
  139. if (self.head_dim * self.num_heads) != self.hidden_size:
  140. raise ValueError(
  141. f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
  142. f" and `num_heads`: {self.num_heads})."
  143. )
  144. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
  145. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  146. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  147. self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
  148. self._init_rope()
  149. def _init_rope(self):
  150. if self.config.rope_scaling is None:
  151. self.rotary_emb = LlamaRotaryEmbedding(
  152. self.head_dim,
  153. max_position_embeddings=self.max_position_embeddings,
  154. base=self.rope_theta,
  155. )
  156. else:
  157. scaling_type = self.config.rope_scaling["type"]
  158. scaling_factor = self.config.rope_scaling["factor"]
  159. if scaling_type == "linear":
  160. self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
  161. self.head_dim,
  162. max_position_embeddings=self.max_position_embeddings,
  163. scaling_factor=scaling_factor,
  164. base=self.rope_theta,
  165. )
  166. elif scaling_type == "dynamic":
  167. self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
  168. self.head_dim,
  169. max_position_embeddings=self.max_position_embeddings,
  170. scaling_factor=scaling_factor,
  171. base=self.rope_theta,
  172. )
  173. else:
  174. raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
  175. def forward(
  176. self,
  177. hidden_states: torch.Tensor,
  178. attention_mask: Optional[torch.Tensor] = None,
  179. position_ids: Optional[torch.LongTensor] = None,
  180. past_key_value: Optional[Cache] = None,
  181. output_attentions: bool = False,
  182. use_cache: bool = False,
  183. cache_position: Optional[torch.LongTensor] = None,
  184. **kwargs,
  185. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  186. bsz, q_len, _ = hidden_states.size()
  187. if self.config.pretraining_tp > 1:
  188. key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
  189. query_slices = self.q_proj.weight.split(
  190. (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
  191. )
  192. key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
  193. value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
  194. query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
  195. query_states = torch.cat(query_states, dim=-1)
  196. key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
  197. key_states = torch.cat(key_states, dim=-1)
  198. value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
  199. value_states = torch.cat(value_states, dim=-1)
  200. else:
  201. query_states = self.q_proj(hidden_states)
  202. key_states = self.k_proj(hidden_states)
  203. value_states = self.v_proj(hidden_states)
  204. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  205. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  206. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  207. past_key_value = getattr(self, "past_key_value", past_key_value)
  208. cos, sin = self.rotary_emb(value_states, position_ids)
  209. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  210. if past_key_value is not None:
  211. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  212. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  213. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  214. key_states = repeat_kv(key_states, self.num_key_value_groups)
  215. value_states = repeat_kv(value_states, self.num_key_value_groups)
  216. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
  217. if attention_mask is not None: # no matter the length, we just slice it
  218. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  219. attn_weights = attn_weights + causal_mask
  220. # upcast attention to fp32
  221. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
  222. attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
  223. attn_output = torch.matmul(attn_weights, value_states)
  224. if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
  225. raise ValueError(
  226. f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
  227. f" {attn_output.size()}"
  228. )
  229. attn_output = attn_output.transpose(1, 2).contiguous()
  230. attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
  231. if self.config.pretraining_tp > 1:
  232. attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
  233. o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
  234. attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
  235. else:
  236. attn_output = self.o_proj(attn_output)
  237. if not output_attentions:
  238. attn_weights = None
  239. return attn_output, attn_weights, past_key_value
  240. def enable_h2ocache_forward(
  241. self,
  242. input_ids: torch.LongTensor = None,
  243. attention_mask: Optional[torch.Tensor] = None,
  244. position_ids: Optional[torch.LongTensor] = None,
  245. past_key_values: Optional[List[torch.FloatTensor]] = None,
  246. inputs_embeds: Optional[torch.FloatTensor] = None,
  247. use_cache: Optional[bool] = None,
  248. output_attentions: Optional[bool] = None,
  249. output_hidden_states: Optional[bool] = None,
  250. return_dict: Optional[bool] = None,
  251. cache_position: Optional[torch.LongTensor] = None,
  252. ) -> Union[Tuple, BaseModelOutputWithPast]:
  253. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  254. output_hidden_states = (
  255. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  256. )
  257. use_cache = use_cache if use_cache is not None else self.config.use_cache
  258. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  259. if (input_ids is None) ^ (inputs_embeds is not None):
  260. raise ValueError(
  261. "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
  262. )
  263. if self.gradient_checkpointing and self.training and use_cache:
  264. logger.warning_once(
  265. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
  266. )
  267. use_cache = False
  268. if inputs_embeds is None:
  269. inputs_embeds = self.embed_tokens(input_ids)
  270. past_seen_tokens = 0
  271. if use_cache: # kept for BC (cache positions)
  272. if not isinstance(past_key_values, StaticCache):
  273. past_key_values = HHCache.from_legacy_cache(past_key_values)
  274. past_seen_tokens = past_key_values.get_seq_length()
  275. if cache_position is None:
  276. if isinstance(past_key_values, StaticCache):
  277. raise ValueError("cache_position is a required argument when using StaticCache.")
  278. cache_position = torch.arange(
  279. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  280. )
  281. if position_ids is None:
  282. position_ids = cache_position.unsqueeze(0)
  283. causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
  284. # embed positions
  285. hidden_states = inputs_embeds
  286. # decoder layers
  287. all_hidden_states = () if output_hidden_states else None
  288. all_self_attns = () if output_attentions else None
  289. next_decoder_cache = None
  290. for decoder_layer in self.layers:
  291. if output_hidden_states:
  292. all_hidden_states += (hidden_states,)
  293. if self.gradient_checkpointing and self.training:
  294. layer_outputs = self._gradient_checkpointing_func(
  295. decoder_layer.__call__,
  296. hidden_states,
  297. causal_mask,
  298. position_ids,
  299. past_key_values,
  300. output_attentions,
  301. use_cache,
  302. cache_position,
  303. )
  304. else:
  305. layer_outputs = decoder_layer(
  306. hidden_states,
  307. attention_mask=causal_mask,
  308. position_ids=position_ids,
  309. past_key_value=past_key_values,
  310. output_attentions=output_attentions,
  311. use_cache=use_cache,
  312. cache_position=cache_position,
  313. )
  314. hidden_states = layer_outputs[0]
  315. if use_cache:
  316. next_decoder_cache = layer_outputs[2 if output_attentions else 1]
  317. if output_attentions:
  318. all_self_attns += (layer_outputs[1],)
  319. hidden_states = self.norm(hidden_states)
  320. # add hidden states from the last decoder layer
  321. if output_hidden_states:
  322. all_hidden_states += (hidden_states,)
  323. next_cache = None
  324. if use_cache:
  325. next_cache = (
  326. next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
  327. )
  328. if not return_dict:
  329. return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
  330. return BaseModelOutputWithPast(
  331. last_hidden_state=hidden_states,
  332. past_key_values=next_cache,
  333. hidden_states=all_hidden_states,
  334. attentions=all_self_attns,
  335. )
  336. # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
  337. # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
  338. # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
  339. # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
  340. def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
  341. if self.config._attn_implementation == "flash_attention_2":
  342. if attention_mask is not None and 0.0 in attention_mask:
  343. return attention_mask
  344. return None
  345. dtype, device = input_tensor.dtype, input_tensor.device
  346. min_dtype = torch.finfo(dtype).min
  347. sequence_length = input_tensor.shape[1]
  348. if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache
  349. target_length = self.config.max_position_embeddings
  350. else: # dynamic cache
  351. target_length = (
  352. attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1
  353. )
  354. causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
  355. if sequence_length != 1:
  356. causal_mask = torch.triu(causal_mask, diagonal=1)
  357. causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
  358. causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
  359. if attention_mask is not None:
  360. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  361. if attention_mask.dim() == 2:
  362. mask_length = attention_mask.shape[-1]
  363. padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
  364. causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
  365. elif attention_mask.dim() == 4:
  366. # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with
  367. # cache. In that case, the 4D attention mask attends to the newest tokens only.
  368. if attention_mask.shape[-2] < cache_position[0] + sequence_length:
  369. offset = cache_position[0]
  370. else:
  371. offset = 0
  372. mask_shape = attention_mask.shape
  373. mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype
  374. causal_mask[
  375. : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]
  376. ] = mask_slice
  377. if (
  378. self.config._attn_implementation == "sdpa"
  379. and attention_mask is not None
  380. and attention_mask.device.type == "cuda"
  381. ):
  382. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  383. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  384. # Details: https://github.com/pytorch/pytorch/issues/110213
  385. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  386. return causal_mask
  387. # class H2OLlamaAttention(nn.Module):
  388. # """Multi-headed attention from 'Attention Is All You Need' paper"""
  389. # def __init__(self, config: LlamaConfig):
  390. # super().__init__()
  391. # self.config = config
  392. # self.hidden_size = config.hidden_size
  393. # self.num_heads = config.num_attention_heads
  394. # self.head_dim = self.hidden_size // self.num_heads
  395. # self.num_key_value_heads = config.num_key_value_heads
  396. # self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  397. # self.max_position_embeddings = config.max_position_embeddings
  398. # self.rope_theta = config.rope_theta
  399. # if (self.head_dim * self.num_heads) != self.hidden_size:
  400. # raise ValueError(
  401. # f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
  402. # f" and `num_heads`: {self.num_heads})."
  403. # )
  404. # self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
  405. # self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  406. # self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  407. # self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
  408. # self._init_rope()
  409. # self.kv_cache = H2OKVCache_LayerWise(
  410. # hh_size=config.hh_size,
  411. # recent_size=config.recent_size,
  412. # k_seq_dim=2,
  413. # v_seq_dim=2,
  414. # )
  415. # def _init_rope(self):
  416. # if self.config.rope_scaling is None:
  417. # self.rotary_emb = LlamaRotaryEmbedding(
  418. # self.head_dim,
  419. # max_position_embeddings=self.max_position_embeddings,
  420. # base=self.rope_theta,
  421. # )
  422. # else:
  423. # scaling_type = self.config.rope_scaling["type"]
  424. # scaling_factor = self.config.rope_scaling["factor"]
  425. # if scaling_type == "linear":
  426. # self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
  427. # self.head_dim,
  428. # max_position_embeddings=self.max_position_embeddings,
  429. # scaling_factor=scaling_factor,
  430. # base=self.rope_theta,
  431. # )
  432. # elif scaling_type == "dynamic":
  433. # self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
  434. # self.head_dim,
  435. # max_position_embeddings=self.max_position_embeddings,
  436. # scaling_factor=scaling_factor,
  437. # base=self.rope_theta,
  438. # )
  439. # else:
  440. # raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
  441. # def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  442. # return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
  443. # def _clean_cache(self):
  444. # self.kv_cache._clean_scores()
  445. # def forward(
  446. # self,
  447. # hidden_states: torch.Tensor,
  448. # attention_mask: Optional[torch.Tensor] = None,
  449. # position_ids: Optional[torch.LongTensor] = None,
  450. # past_key_value: Optional[Tuple[torch.Tensor]] = None,
  451. # output_attentions: bool = False,
  452. # use_cache: bool = False,
  453. # cache_position: Optional[torch.LongTensor] = None,
  454. # ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  455. # bsz, q_len, _ = hidden_states.size()
  456. # if self.config.pretraining_tp > 1:
  457. # key_value_slicing = (
  458. # self.num_key_value_heads * self.head_dim
  459. # ) // self.config.pretraining_tp
  460. # query_slices = self.q_proj.weight.split(
  461. # (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
  462. # )
  463. # key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
  464. # value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
  465. # query_states = [
  466. # F.linear(hidden_states, query_slices[i])
  467. # for i in range(self.config.pretraining_tp)
  468. # ]
  469. # query_states = torch.cat(query_states, dim=-1)
  470. # key_states = [
  471. # F.linear(hidden_states, key_slices[i])
  472. # for i in range(self.config.pretraining_tp)
  473. # ]
  474. # key_states = torch.cat(key_states, dim=-1)
  475. # value_states = [
  476. # F.linear(hidden_states, value_slices[i])
  477. # for i in range(self.config.pretraining_tp)
  478. # ]
  479. # value_states = torch.cat(value_states, dim=-1)
  480. # else:
  481. # query_states = self.q_proj(hidden_states)
  482. # key_states = self.k_proj(hidden_states)
  483. # value_states = self.v_proj(hidden_states)
  484. # query_states = query_states.view(
  485. # bsz, q_len, self.num_heads, self.head_dim
  486. # ).transpose(1, 2)
  487. # key_states = key_states.view(
  488. # bsz, q_len, self.num_key_value_heads, self.head_dim
  489. # ).transpose(1, 2)
  490. # value_states = value_states.view(
  491. # bsz, q_len, self.num_key_value_heads, self.head_dim
  492. # ).transpose(1, 2)
  493. # # remake causal mask
  494. # attention_mask = _make_causal_mask(
  495. # bsz=bsz,
  496. # tgt_len=q_len,
  497. # past_key_values_length=past_key_value[0].shape[-2] if past_key_value is not None else 0,
  498. # dtype=query_states.dtype,
  499. # device=query_states.device,
  500. # )
  501. # kv_seq_len = key_states.shape[-2]
  502. # if past_key_value is not None:
  503. # kv_seq_len += past_key_value[0].shape[-2]
  504. # if not position_ids.nelement() > 1:
  505. # position_ids[0][0] = kv_seq_len - 1
  506. # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
  507. # ### Shift Pos: query pos is min(cache_size, idx)
  508. # # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
  509. # query_states = apply_rotary_pos_emb_single(query_states, cos, sin, position_ids)
  510. # ###
  511. # if past_key_value is not None:
  512. # # reuse k, v, self_attention
  513. # key_states = torch.cat([past_key_value[0], key_states], dim=2)
  514. # value_states = torch.cat([past_key_value[1], value_states], dim=2)
  515. # past_key_value = (key_states, value_states) if use_cache else None
  516. # ### Shift Pos: key pos is the pos in cache (Rolling KV Cache and using relative pos emb)
  517. # key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0)
  518. # key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids)
  519. # ###
  520. # # repeat k/v heads if n_kv_heads < n_heads
  521. # key_states = repeat_kv(key_states, self.num_key_value_groups)
  522. # value_states = repeat_kv(value_states, self.num_key_value_groups)
  523. # attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
  524. # self.head_dim
  525. # )
  526. # if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
  527. # raise ValueError(
  528. # f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
  529. # f" {attn_weights.size()}"
  530. # )
  531. # if attention_mask is not None:
  532. # if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
  533. # raise ValueError(
  534. # f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
  535. # )
  536. # attn_weights = attn_weights + attention_mask
  537. # # upcast attention to fp32
  538. # attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
  539. # query_states.dtype
  540. # )
  541. # past_key_value = self.kv_cache(past_key_value, attn_weights.detach().clone())
  542. # attn_output = torch.matmul(attn_weights, value_states)
  543. # if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
  544. # raise ValueError(
  545. # f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
  546. # f" {attn_output.size()}"
  547. # )
  548. # attn_output = attn_output.transpose(1, 2).contiguous()
  549. # attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
  550. # if self.config.pretraining_tp > 1:
  551. # attn_output = attn_output.split(
  552. # self.hidden_size // self.config.pretraining_tp, dim=2
  553. # )
  554. # o_proj_slices = self.o_proj.weight.split(
  555. # self.hidden_size // self.config.pretraining_tp, dim=1
  556. # )
  557. # attn_output = sum(
  558. # [
  559. # F.linear(attn_output[i], o_proj_slices[i])
  560. # for i in range(self.config.pretraining_tp)
  561. # ]
  562. # )
  563. # else:
  564. # attn_output = self.o_proj(attn_output)
  565. # if not output_attentions:
  566. # attn_weights = None
  567. # return attn_output, attn_weights, past_key_value
  568. class H2OLlamaForCausalLM(LlamaForCausalLM):
  569. def __init__(self, config):
  570. super().__init__(config)
  571. num_layers = len(self.model.layers)
  572. for layer_idx in range(num_layers):
  573. self.model.layers[layer_idx].self_attn = H2OLlamaAttention(config, layer_idx)
  574. self.model.forward = types.MethodType(enable_h2ocache_forward, self.model)