utils_llama.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664
  1. import math
  2. from typing import Any, Dict, List, Optional, Tuple, Union
  3. import pdb
  4. import types
  5. import torch
  6. from torch import nn
  7. import torch.utils.checkpoint
  8. import torch.nn.functional as F
  9. from transformers.models.llama.configuration_llama import LlamaConfig
  10. from transformers.models.llama.modeling_llama import (
  11. LlamaAttention,
  12. rotate_half,
  13. apply_rotary_pos_emb,
  14. repeat_kv,
  15. LlamaRotaryEmbedding,
  16. apply_rotary_pos_emb,
  17. LlamaForCausalLM,
  18. )
  19. from cache_utils import Cache, HHCache, StaticCache
  20. from transformers.utils import logging
  21. from transformers.modeling_outputs import BaseModelOutputWithPast
  22. logger = logging.get_logger(__name__)
  23. __all__ = ["H2OLlamaForCausalLM"]
  24. def _make_causal_mask(
  25. bsz: int, tgt_len: int, past_key_values_length: int, dtype: torch.dtype, device: torch.device):
  26. """
  27. Make causal mask used for bi-directional self-attention.
  28. """
  29. mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
  30. mask_cond = torch.arange(mask.size(-1), device=device)
  31. mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
  32. mask = mask.to(dtype)
  33. if past_key_values_length > 0:
  34. mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
  35. return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
  36. def apply_rotary_pos_emb_single(x, cos, sin, position_ids):
  37. # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
  38. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
  39. sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
  40. cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
  41. sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
  42. x_embed = (x * cos) + (rotate_half(x) * sin)
  43. return x_embed
  44. class H2OKVCache_LayerWise:
  45. def __init__(
  46. self,
  47. hh_size=4,
  48. recent_size=512,
  49. k_seq_dim=2,
  50. v_seq_dim=2,
  51. ):
  52. self.hh_size = hh_size
  53. self.recent_size = recent_size
  54. self.cache_size = hh_size + recent_size
  55. self.k_seq_dim = k_seq_dim
  56. self.v_seq_dim = v_seq_dim
  57. self.hh_score = None
  58. def __call__(self, past_key_values, attn_score_cache):
  59. self._update_hh_score(attn_score_cache)
  60. if past_key_values is None:
  61. return None
  62. seq_len = past_key_values[0].size(self.k_seq_dim)
  63. if seq_len <= self.cache_size:
  64. return past_key_values
  65. # hh-selection
  66. bsz, num_heads, _, head_dim = past_key_values[0].shape
  67. select_hh_scores = self.hh_score[:, :seq_len - self.recent_size]
  68. _, keep_topk = torch.topk(select_hh_scores, self.hh_size, dim=-1)
  69. keep_topk = keep_topk.sort().values
  70. # keep_recent = torch.arange(seq_len - self.recent_size, seq_len).expand(keep_topk.shape[0], 1).to(keep_topk.device)
  71. keep_recent = torch.arange(seq_len - self.recent_size, seq_len, device=keep_topk.device).repeat(keep_topk.shape[0], 1)
  72. keep_idx = torch.cat([keep_topk, keep_recent], dim=-1)
  73. mask = torch.zeros(self.hh_score.shape, dtype=torch.bool).to(past_key_values[0].device)
  74. mask = mask.scatter(-1, keep_idx, 1)
  75. k_hh_recent = past_key_values[0].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
  76. v_hh_recent = past_key_values[1].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
  77. self.hh_score= self.hh_score[mask].view(num_heads, self.cache_size)
  78. return (k_hh_recent, v_hh_recent)
  79. def evict_for_space(self, past_key_values, num_coming):
  80. if past_key_values is None:
  81. return None
  82. seq_len = past_key_values[0][0].size(self.k_seq_dim)
  83. if seq_len + num_coming <= self.cache_size:
  84. return past_key_values
  85. # hh-selection
  86. bsz, num_heads, _, head_dim = past_key_values[0].shape
  87. select_hh_scores = self.hh_score[:, :seq_len - self.recent_size + num_coming]
  88. _, keep_topk = torch.topk(select_hh_scores, self.hh_size, dim=-1)
  89. keep_topk = keep_topk.sort().values
  90. # keep_recent = torch.arange(seq_len - self.recent_size, seq_len).expand(keep_topk.shape[0], 1).to(keep_topk.device)
  91. keep_recent = torch.arange(seq_len - self.recent_size + num_coming, seq_len, device=keep_topk.device).repeat(keep_topk.shape[0], 1)
  92. keep_idx = torch.cat([keep_topk, keep_recent], dim=-1)
  93. mask = torch.zeros(self.hh_score.shape, dtype=torch.bool).to(past_key_values[0].device)
  94. mask = mask.scatter(-1, keep_idx, 1)
  95. k_hh_recent = past_key_values[0].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
  96. v_hh_recent = past_key_values[1].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
  97. self.hh_score= self.hh_score[mask].view(num_heads, self.cache_size)
  98. return (k_hh_recent, v_hh_recent)
  99. def _update_hh_score(self, attn_score_cache):
  100. num_new_tokens = attn_score_cache.shape[2]
  101. if self.hh_score is None:
  102. self.hh_score = attn_score_cache.sum(0).sum(1)
  103. else:
  104. attn_score_cache = attn_score_cache.sum(0).sum(1)
  105. attn_score_cache[:, :-num_new_tokens] += self.hh_score
  106. self.hh_score = attn_score_cache
  107. def _clean_scores(self):
  108. self.hh_score = None
  109. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  110. """
  111. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  112. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  113. """
  114. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  115. if n_rep == 1:
  116. return hidden_states
  117. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  118. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  119. class H2OLlamaAttention(nn.Module):
  120. """Multi-headed attention from 'Attention Is All You Need' paper"""
  121. def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
  122. super().__init__()
  123. self.config = config
  124. self.layer_idx = layer_idx
  125. if layer_idx is None:
  126. logger.warning_once(
  127. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  128. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  129. "when creating this class."
  130. )
  131. self.attention_dropout = config.attention_dropout
  132. self.hidden_size = config.hidden_size
  133. self.num_heads = config.num_attention_heads
  134. self.head_dim = self.hidden_size // self.num_heads
  135. self.num_key_value_heads = config.num_key_value_heads
  136. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  137. self.max_position_embeddings = config.max_position_embeddings
  138. self.rope_theta = config.rope_theta
  139. self.is_causal = True
  140. if (self.head_dim * self.num_heads) != self.hidden_size:
  141. raise ValueError(
  142. f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
  143. f" and `num_heads`: {self.num_heads})."
  144. )
  145. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
  146. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  147. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  148. self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
  149. self._init_rope()
  150. def _init_rope(self):
  151. if self.config.rope_scaling is None:
  152. self.rotary_emb = LlamaRotaryEmbedding(
  153. self.head_dim,
  154. max_position_embeddings=self.max_position_embeddings,
  155. base=self.rope_theta,
  156. )
  157. else:
  158. scaling_type = self.config.rope_scaling["type"]
  159. scaling_factor = self.config.rope_scaling["factor"]
  160. if scaling_type == "linear":
  161. self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
  162. self.head_dim,
  163. max_position_embeddings=self.max_position_embeddings,
  164. scaling_factor=scaling_factor,
  165. base=self.rope_theta,
  166. )
  167. elif scaling_type == "dynamic":
  168. self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
  169. self.head_dim,
  170. max_position_embeddings=self.max_position_embeddings,
  171. scaling_factor=scaling_factor,
  172. base=self.rope_theta,
  173. )
  174. else:
  175. raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
  176. def forward(
  177. self,
  178. hidden_states: torch.Tensor,
  179. attention_mask: Optional[torch.Tensor] = None,
  180. position_ids: Optional[torch.LongTensor] = None,
  181. past_key_value: Optional[Cache] = None,
  182. output_attentions: bool = False,
  183. use_cache: bool = False,
  184. cache_position: Optional[torch.LongTensor] = None,
  185. **kwargs,
  186. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  187. bsz, q_len, _ = hidden_states.size()
  188. if self.layer_idx == 0:
  189. import pdb;pdb.set_trace()
  190. if self.config.pretraining_tp > 1:
  191. key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
  192. query_slices = self.q_proj.weight.split(
  193. (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
  194. )
  195. key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
  196. value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
  197. query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
  198. query_states = torch.cat(query_states, dim=-1)
  199. key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
  200. key_states = torch.cat(key_states, dim=-1)
  201. value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
  202. value_states = torch.cat(value_states, dim=-1)
  203. else:
  204. query_states = self.q_proj(hidden_states)
  205. key_states = self.k_proj(hidden_states)
  206. value_states = self.v_proj(hidden_states)
  207. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  208. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  209. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  210. past_key_value = getattr(self, "past_key_value", past_key_value)
  211. cos, sin = self.rotary_emb(value_states, position_ids)
  212. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  213. if past_key_value is not None:
  214. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  215. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  216. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  217. key_states = repeat_kv(key_states, self.num_key_value_groups)
  218. value_states = repeat_kv(value_states, self.num_key_value_groups)
  219. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
  220. if attention_mask is not None: # no matter the length, we just slice it
  221. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  222. attn_weights = attn_weights + causal_mask
  223. # upcast attention to fp32
  224. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
  225. # Update KV Cache based on Heavy-Hitter Oracle
  226. if past_key_value is not None:
  227. past_key_value.update_slimming(attn_weights, self.num_key_value_groups, self.layer_idx, cache_kwargs)
  228. attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
  229. attn_output = torch.matmul(attn_weights, value_states)
  230. if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
  231. raise ValueError(
  232. f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
  233. f" {attn_output.size()}"
  234. )
  235. attn_output = attn_output.transpose(1, 2).contiguous()
  236. attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
  237. if self.config.pretraining_tp > 1:
  238. attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
  239. o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
  240. attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
  241. else:
  242. attn_output = self.o_proj(attn_output)
  243. if not output_attentions:
  244. attn_weights = None
  245. print(past_key_value.key_cache[self.layer_idx].shape, past_key_value.accumulated_attention_scores[self.layer_idx].shape)
  246. return attn_output, attn_weights, past_key_value
  247. def enable_h2ocache_forward(
  248. self,
  249. input_ids: torch.LongTensor = None,
  250. attention_mask: Optional[torch.Tensor] = None,
  251. position_ids: Optional[torch.LongTensor] = None,
  252. past_key_values: Optional[List[torch.FloatTensor]] = None,
  253. inputs_embeds: Optional[torch.FloatTensor] = None,
  254. use_cache: Optional[bool] = None,
  255. output_attentions: Optional[bool] = None,
  256. output_hidden_states: Optional[bool] = None,
  257. return_dict: Optional[bool] = None,
  258. cache_position: Optional[torch.LongTensor] = None,
  259. ) -> Union[Tuple, BaseModelOutputWithPast]:
  260. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  261. output_hidden_states = (
  262. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  263. )
  264. use_cache = use_cache if use_cache is not None else self.config.use_cache
  265. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  266. if (input_ids is None) ^ (inputs_embeds is not None):
  267. raise ValueError(
  268. "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
  269. )
  270. if self.gradient_checkpointing and self.training and use_cache:
  271. logger.warning_once(
  272. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
  273. )
  274. use_cache = False
  275. if inputs_embeds is None:
  276. inputs_embeds = self.embed_tokens(input_ids)
  277. past_seen_tokens = 0
  278. if use_cache: # kept for BC (cache positions)
  279. if not isinstance(past_key_values, StaticCache):
  280. past_key_values = HHCache.from_legacy_cache(self.num_window_length, self.num_heavy_hitter_tokens, past_key_values)
  281. past_seen_tokens = past_key_values.get_seq_length()
  282. if cache_position is None:
  283. if isinstance(past_key_values, StaticCache):
  284. raise ValueError("cache_position is a required argument when using StaticCache.")
  285. cache_position = torch.arange(
  286. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  287. )
  288. if position_ids is None:
  289. position_ids = cache_position.unsqueeze(0)
  290. causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
  291. # embed positions
  292. hidden_states = inputs_embeds
  293. # decoder layers
  294. all_hidden_states = () if output_hidden_states else None
  295. all_self_attns = () if output_attentions else None
  296. next_decoder_cache = None
  297. import pdb;pdb.set_trace()
  298. for decoder_layer in self.layers:
  299. if output_hidden_states:
  300. all_hidden_states += (hidden_states,)
  301. if self.gradient_checkpointing and self.training:
  302. layer_outputs = self._gradient_checkpointing_func(
  303. decoder_layer.__call__,
  304. hidden_states,
  305. causal_mask,
  306. position_ids,
  307. past_key_values,
  308. output_attentions,
  309. use_cache,
  310. cache_position,
  311. )
  312. else:
  313. layer_outputs = decoder_layer(
  314. hidden_states,
  315. attention_mask=causal_mask,
  316. position_ids=position_ids,
  317. past_key_value=past_key_values,
  318. output_attentions=output_attentions,
  319. use_cache=use_cache,
  320. cache_position=cache_position,
  321. )
  322. hidden_states = layer_outputs[0]
  323. if use_cache:
  324. next_decoder_cache = layer_outputs[2 if output_attentions else 1]
  325. if output_attentions:
  326. all_self_attns += (layer_outputs[1],)
  327. hidden_states = self.norm(hidden_states)
  328. # add hidden states from the last decoder layer
  329. if output_hidden_states:
  330. all_hidden_states += (hidden_states,)
  331. next_cache = None
  332. if use_cache:
  333. next_cache = (
  334. next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
  335. )
  336. if not return_dict:
  337. return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
  338. import pdb;pdb.set_trace()
  339. return BaseModelOutputWithPast(
  340. last_hidden_state=hidden_states,
  341. past_key_values=next_cache,
  342. hidden_states=all_hidden_states,
  343. attentions=all_self_attns,
  344. )
  345. # class H2OLlamaAttention(nn.Module):
  346. # """Multi-headed attention from 'Attention Is All You Need' paper"""
  347. # def __init__(self, config: LlamaConfig):
  348. # super().__init__()
  349. # self.config = config
  350. # self.hidden_size = config.hidden_size
  351. # self.num_heads = config.num_attention_heads
  352. # self.head_dim = self.hidden_size // self.num_heads
  353. # self.num_key_value_heads = config.num_key_value_heads
  354. # self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  355. # self.max_position_embeddings = config.max_position_embeddings
  356. # self.rope_theta = config.rope_theta
  357. # if (self.head_dim * self.num_heads) != self.hidden_size:
  358. # raise ValueError(
  359. # f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
  360. # f" and `num_heads`: {self.num_heads})."
  361. # )
  362. # self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
  363. # self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  364. # self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  365. # self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
  366. # self._init_rope()
  367. # self.kv_cache = H2OKVCache_LayerWise(
  368. # hh_size=config.hh_size,
  369. # recent_size=config.recent_size,
  370. # k_seq_dim=2,
  371. # v_seq_dim=2,
  372. # )
  373. # def _init_rope(self):
  374. # if self.config.rope_scaling is None:
  375. # self.rotary_emb = LlamaRotaryEmbedding(
  376. # self.head_dim,
  377. # max_position_embeddings=self.max_position_embeddings,
  378. # base=self.rope_theta,
  379. # )
  380. # else:
  381. # scaling_type = self.config.rope_scaling["type"]
  382. # scaling_factor = self.config.rope_scaling["factor"]
  383. # if scaling_type == "linear":
  384. # self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
  385. # self.head_dim,
  386. # max_position_embeddings=self.max_position_embeddings,
  387. # scaling_factor=scaling_factor,
  388. # base=self.rope_theta,
  389. # )
  390. # elif scaling_type == "dynamic":
  391. # self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
  392. # self.head_dim,
  393. # max_position_embeddings=self.max_position_embeddings,
  394. # scaling_factor=scaling_factor,
  395. # base=self.rope_theta,
  396. # )
  397. # else:
  398. # raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
  399. # def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  400. # return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
  401. # def _clean_cache(self):
  402. # self.kv_cache._clean_scores()
  403. # def forward(
  404. # self,
  405. # hidden_states: torch.Tensor,
  406. # attention_mask: Optional[torch.Tensor] = None,
  407. # position_ids: Optional[torch.LongTensor] = None,
  408. # past_key_value: Optional[Tuple[torch.Tensor]] = None,
  409. # output_attentions: bool = False,
  410. # use_cache: bool = False,
  411. # cache_position: Optional[torch.LongTensor] = None,
  412. # ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  413. # bsz, q_len, _ = hidden_states.size()
  414. # if self.config.pretraining_tp > 1:
  415. # key_value_slicing = (
  416. # self.num_key_value_heads * self.head_dim
  417. # ) // self.config.pretraining_tp
  418. # query_slices = self.q_proj.weight.split(
  419. # (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
  420. # )
  421. # key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
  422. # value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
  423. # query_states = [
  424. # F.linear(hidden_states, query_slices[i])
  425. # for i in range(self.config.pretraining_tp)
  426. # ]
  427. # query_states = torch.cat(query_states, dim=-1)
  428. # key_states = [
  429. # F.linear(hidden_states, key_slices[i])
  430. # for i in range(self.config.pretraining_tp)
  431. # ]
  432. # key_states = torch.cat(key_states, dim=-1)
  433. # value_states = [
  434. # F.linear(hidden_states, value_slices[i])
  435. # for i in range(self.config.pretraining_tp)
  436. # ]
  437. # value_states = torch.cat(value_states, dim=-1)
  438. # else:
  439. # query_states = self.q_proj(hidden_states)
  440. # key_states = self.k_proj(hidden_states)
  441. # value_states = self.v_proj(hidden_states)
  442. # query_states = query_states.view(
  443. # bsz, q_len, self.num_heads, self.head_dim
  444. # ).transpose(1, 2)
  445. # key_states = key_states.view(
  446. # bsz, q_len, self.num_key_value_heads, self.head_dim
  447. # ).transpose(1, 2)
  448. # value_states = value_states.view(
  449. # bsz, q_len, self.num_key_value_heads, self.head_dim
  450. # ).transpose(1, 2)
  451. # # remake causal mask
  452. # attention_mask = _make_causal_mask(
  453. # bsz=bsz,
  454. # tgt_len=q_len,
  455. # past_key_values_length=past_key_value[0].shape[-2] if past_key_value is not None else 0,
  456. # dtype=query_states.dtype,
  457. # device=query_states.device,
  458. # )
  459. # kv_seq_len = key_states.shape[-2]
  460. # if past_key_value is not None:
  461. # kv_seq_len += past_key_value[0].shape[-2]
  462. # if not position_ids.nelement() > 1:
  463. # position_ids[0][0] = kv_seq_len - 1
  464. # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
  465. # ### Shift Pos: query pos is min(cache_size, idx)
  466. # # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
  467. # query_states = apply_rotary_pos_emb_single(query_states, cos, sin, position_ids)
  468. # ###
  469. # if past_key_value is not None:
  470. # # reuse k, v, self_attention
  471. # key_states = torch.cat([past_key_value[0], key_states], dim=2)
  472. # value_states = torch.cat([past_key_value[1], value_states], dim=2)
  473. # past_key_value = (key_states, value_states) if use_cache else None
  474. # ### Shift Pos: key pos is the pos in cache (Rolling KV Cache and using relative pos emb)
  475. # key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0)
  476. # key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids)
  477. # ###
  478. # # repeat k/v heads if n_kv_heads < n_heads
  479. # key_states = repeat_kv(key_states, self.num_key_value_groups)
  480. # value_states = repeat_kv(value_states, self.num_key_value_groups)
  481. # attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
  482. # self.head_dim
  483. # )
  484. # if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
  485. # raise ValueError(
  486. # f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
  487. # f" {attn_weights.size()}"
  488. # )
  489. # if attention_mask is not None:
  490. # if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
  491. # raise ValueError(
  492. # f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
  493. # )
  494. # attn_weights = attn_weights + attention_mask
  495. # # upcast attention to fp32
  496. # attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
  497. # query_states.dtype
  498. # )
  499. # past_key_value = self.kv_cache(past_key_value, attn_weights.detach().clone())
  500. # attn_output = torch.matmul(attn_weights, value_states)
  501. # if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
  502. # raise ValueError(
  503. # f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
  504. # f" {attn_output.size()}"
  505. # )
  506. # attn_output = attn_output.transpose(1, 2).contiguous()
  507. # attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
  508. # if self.config.pretraining_tp > 1:
  509. # attn_output = attn_output.split(
  510. # self.hidden_size // self.config.pretraining_tp, dim=2
  511. # )
  512. # o_proj_slices = self.o_proj.weight.split(
  513. # self.hidden_size // self.config.pretraining_tp, dim=1
  514. # )
  515. # attn_output = sum(
  516. # [
  517. # F.linear(attn_output[i], o_proj_slices[i])
  518. # for i in range(self.config.pretraining_tp)
  519. # ]
  520. # )
  521. # else:
  522. # attn_output = self.o_proj(attn_output)
  523. # if not output_attentions:
  524. # attn_weights = None
  525. # return attn_output, attn_weights, past_key_value
  526. class H2OLlamaForCausalLM(LlamaForCausalLM):
  527. def __init__(self, config):
  528. super().__init__(config)
  529. num_layers = len(self.model.layers)
  530. for layer_idx in range(num_layers):
  531. self.model.layers[layer_idx].self_attn = H2OLlamaAttention(config, layer_idx)
  532. self.model.forward = types.MethodType(enable_h2ocache_forward, self.model)
  533. self.model.num_heavy_hitter_tokens = config.num_heavy_hitter_tokens
  534. self.model.num_window_length = config.num_window_length