utils_llama.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544
  1. import math
  2. from typing import Optional, Tuple
  3. import pdb
  4. import types
  5. import torch
  6. from torch import nn
  7. import torch.utils.checkpoint
  8. import torch.nn.functional as F
  9. from transformers.models.llama.configuration_llama import LlamaConfig
  10. from transformers.models.llama.modeling_llama import (
  11. LlamaAttention,
  12. rotate_half,
  13. apply_rotary_pos_emb,
  14. repeat_kv,
  15. LlamaRotaryEmbedding,
  16. apply_rotary_pos_emb,
  17. LlamaForCausalLM,
  18. )
  19. from cache_utils import Cache
  20. from transformers.utils import logging
  21. logger = logging.get_logger(__name__)
  22. __all__ = ["H2OLlamaForCausalLM"]
  23. def _make_causal_mask(
  24. bsz: int, tgt_len: int, past_key_values_length: int, dtype: torch.dtype, device: torch.device):
  25. """
  26. Make causal mask used for bi-directional self-attention.
  27. """
  28. mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
  29. mask_cond = torch.arange(mask.size(-1), device=device)
  30. mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
  31. mask = mask.to(dtype)
  32. if past_key_values_length > 0:
  33. mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
  34. return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
  35. def apply_rotary_pos_emb_single(x, cos, sin, position_ids):
  36. # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
  37. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
  38. sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
  39. cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
  40. sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
  41. x_embed = (x * cos) + (rotate_half(x) * sin)
  42. return x_embed
  43. class H2OKVCache_LayerWise:
  44. def __init__(
  45. self,
  46. hh_size=4,
  47. recent_size=512,
  48. k_seq_dim=2,
  49. v_seq_dim=2,
  50. ):
  51. self.hh_size = hh_size
  52. self.recent_size = recent_size
  53. self.cache_size = hh_size + recent_size
  54. self.k_seq_dim = k_seq_dim
  55. self.v_seq_dim = v_seq_dim
  56. self.hh_score = None
  57. def __call__(self, past_key_values, attn_score_cache):
  58. self._update_hh_score(attn_score_cache)
  59. if past_key_values is None:
  60. return None
  61. seq_len = past_key_values[0].size(self.k_seq_dim)
  62. if seq_len <= self.cache_size:
  63. return past_key_values
  64. # hh-selection
  65. bsz, num_heads, _, head_dim = past_key_values[0].shape
  66. select_hh_scores = self.hh_score[:, :seq_len - self.recent_size]
  67. _, keep_topk = torch.topk(select_hh_scores, self.hh_size, dim=-1)
  68. keep_topk = keep_topk.sort().values
  69. # keep_recent = torch.arange(seq_len - self.recent_size, seq_len).expand(keep_topk.shape[0], 1).to(keep_topk.device)
  70. keep_recent = torch.arange(seq_len - self.recent_size, seq_len, device=keep_topk.device).repeat(keep_topk.shape[0], 1)
  71. keep_idx = torch.cat([keep_topk, keep_recent], dim=-1)
  72. mask = torch.zeros(self.hh_score.shape, dtype=torch.bool).to(past_key_values[0].device)
  73. mask = mask.scatter(-1, keep_idx, 1)
  74. k_hh_recent = past_key_values[0].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
  75. v_hh_recent = past_key_values[1].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
  76. self.hh_score= self.hh_score[mask].view(num_heads, self.cache_size)
  77. return (k_hh_recent, v_hh_recent)
  78. def evict_for_space(self, past_key_values, num_coming):
  79. if past_key_values is None:
  80. return None
  81. seq_len = past_key_values[0][0].size(self.k_seq_dim)
  82. if seq_len + num_coming <= self.cache_size:
  83. return past_key_values
  84. # hh-selection
  85. bsz, num_heads, _, head_dim = past_key_values[0].shape
  86. select_hh_scores = self.hh_score[:, :seq_len - self.recent_size + num_coming]
  87. _, keep_topk = torch.topk(select_hh_scores, self.hh_size, dim=-1)
  88. keep_topk = keep_topk.sort().values
  89. # keep_recent = torch.arange(seq_len - self.recent_size, seq_len).expand(keep_topk.shape[0], 1).to(keep_topk.device)
  90. keep_recent = torch.arange(seq_len - self.recent_size + num_coming, seq_len, device=keep_topk.device).repeat(keep_topk.shape[0], 1)
  91. keep_idx = torch.cat([keep_topk, keep_recent], dim=-1)
  92. mask = torch.zeros(self.hh_score.shape, dtype=torch.bool).to(past_key_values[0].device)
  93. mask = mask.scatter(-1, keep_idx, 1)
  94. k_hh_recent = past_key_values[0].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
  95. v_hh_recent = past_key_values[1].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
  96. self.hh_score= self.hh_score[mask].view(num_heads, self.cache_size)
  97. return (k_hh_recent, v_hh_recent)
  98. def _update_hh_score(self, attn_score_cache):
  99. num_new_tokens = attn_score_cache.shape[2]
  100. if self.hh_score is None:
  101. self.hh_score = attn_score_cache.sum(0).sum(1)
  102. else:
  103. attn_score_cache = attn_score_cache.sum(0).sum(1)
  104. attn_score_cache[:, :-num_new_tokens] += self.hh_score
  105. self.hh_score = attn_score_cache
  106. def _clean_scores(self):
  107. self.hh_score = None
  108. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  109. """
  110. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  111. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  112. """
  113. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  114. if n_rep == 1:
  115. return hidden_states
  116. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  117. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  118. class H2OLlamaAttention(nn.Module):
  119. """Multi-headed attention from 'Attention Is All You Need' paper"""
  120. def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
  121. super().__init__()
  122. self.config = config
  123. self.layer_idx = layer_idx
  124. if layer_idx is None:
  125. logger.warning_once(
  126. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  127. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  128. "when creating this class."
  129. )
  130. self.attention_dropout = config.attention_dropout
  131. self.hidden_size = config.hidden_size
  132. self.num_heads = config.num_attention_heads
  133. self.head_dim = self.hidden_size // self.num_heads
  134. self.num_key_value_heads = config.num_key_value_heads
  135. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  136. self.max_position_embeddings = config.max_position_embeddings
  137. self.rope_theta = config.rope_theta
  138. self.is_causal = True
  139. if (self.head_dim * self.num_heads) != self.hidden_size:
  140. raise ValueError(
  141. f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
  142. f" and `num_heads`: {self.num_heads})."
  143. )
  144. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
  145. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  146. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  147. self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
  148. self._init_rope()
  149. def _init_rope(self):
  150. if self.config.rope_scaling is None:
  151. self.rotary_emb = LlamaRotaryEmbedding(
  152. self.head_dim,
  153. max_position_embeddings=self.max_position_embeddings,
  154. base=self.rope_theta,
  155. )
  156. else:
  157. scaling_type = self.config.rope_scaling["type"]
  158. scaling_factor = self.config.rope_scaling["factor"]
  159. if scaling_type == "linear":
  160. self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
  161. self.head_dim,
  162. max_position_embeddings=self.max_position_embeddings,
  163. scaling_factor=scaling_factor,
  164. base=self.rope_theta,
  165. )
  166. elif scaling_type == "dynamic":
  167. self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
  168. self.head_dim,
  169. max_position_embeddings=self.max_position_embeddings,
  170. scaling_factor=scaling_factor,
  171. base=self.rope_theta,
  172. )
  173. else:
  174. raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
  175. def forward(
  176. self,
  177. hidden_states: torch.Tensor,
  178. attention_mask: Optional[torch.Tensor] = None,
  179. position_ids: Optional[torch.LongTensor] = None,
  180. past_key_value: Optional[Cache] = None,
  181. output_attentions: bool = False,
  182. use_cache: bool = False,
  183. cache_position: Optional[torch.LongTensor] = None,
  184. **kwargs,
  185. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  186. bsz, q_len, _ = hidden_states.size()
  187. if self.config.pretraining_tp > 1:
  188. key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
  189. query_slices = self.q_proj.weight.split(
  190. (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
  191. )
  192. key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
  193. value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
  194. query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
  195. query_states = torch.cat(query_states, dim=-1)
  196. key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
  197. key_states = torch.cat(key_states, dim=-1)
  198. value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
  199. value_states = torch.cat(value_states, dim=-1)
  200. else:
  201. query_states = self.q_proj(hidden_states)
  202. key_states = self.k_proj(hidden_states)
  203. value_states = self.v_proj(hidden_states)
  204. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  205. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  206. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  207. print(self.past_key_value)
  208. past_key_value = getattr(self, "past_key_value", past_key_value)
  209. cos, sin = self.rotary_emb(value_states, position_ids)
  210. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  211. if past_key_value is not None:
  212. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  213. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  214. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  215. key_states = repeat_kv(key_states, self.num_key_value_groups)
  216. value_states = repeat_kv(value_states, self.num_key_value_groups)
  217. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
  218. if attention_mask is not None: # no matter the length, we just slice it
  219. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  220. attn_weights = attn_weights + causal_mask
  221. # upcast attention to fp32
  222. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
  223. attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
  224. attn_output = torch.matmul(attn_weights, value_states)
  225. if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
  226. raise ValueError(
  227. f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
  228. f" {attn_output.size()}"
  229. )
  230. attn_output = attn_output.transpose(1, 2).contiguous()
  231. attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
  232. if self.config.pretraining_tp > 1:
  233. attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
  234. o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
  235. attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
  236. else:
  237. attn_output = self.o_proj(attn_output)
  238. if not output_attentions:
  239. attn_weights = None
  240. return attn_output, attn_weights, past_key_value
  241. # class H2OLlamaAttention(nn.Module):
  242. # """Multi-headed attention from 'Attention Is All You Need' paper"""
  243. # def __init__(self, config: LlamaConfig):
  244. # super().__init__()
  245. # self.config = config
  246. # self.hidden_size = config.hidden_size
  247. # self.num_heads = config.num_attention_heads
  248. # self.head_dim = self.hidden_size // self.num_heads
  249. # self.num_key_value_heads = config.num_key_value_heads
  250. # self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  251. # self.max_position_embeddings = config.max_position_embeddings
  252. # self.rope_theta = config.rope_theta
  253. # if (self.head_dim * self.num_heads) != self.hidden_size:
  254. # raise ValueError(
  255. # f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
  256. # f" and `num_heads`: {self.num_heads})."
  257. # )
  258. # self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
  259. # self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  260. # self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  261. # self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
  262. # self._init_rope()
  263. # self.kv_cache = H2OKVCache_LayerWise(
  264. # hh_size=config.hh_size,
  265. # recent_size=config.recent_size,
  266. # k_seq_dim=2,
  267. # v_seq_dim=2,
  268. # )
  269. # def _init_rope(self):
  270. # if self.config.rope_scaling is None:
  271. # self.rotary_emb = LlamaRotaryEmbedding(
  272. # self.head_dim,
  273. # max_position_embeddings=self.max_position_embeddings,
  274. # base=self.rope_theta,
  275. # )
  276. # else:
  277. # scaling_type = self.config.rope_scaling["type"]
  278. # scaling_factor = self.config.rope_scaling["factor"]
  279. # if scaling_type == "linear":
  280. # self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
  281. # self.head_dim,
  282. # max_position_embeddings=self.max_position_embeddings,
  283. # scaling_factor=scaling_factor,
  284. # base=self.rope_theta,
  285. # )
  286. # elif scaling_type == "dynamic":
  287. # self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
  288. # self.head_dim,
  289. # max_position_embeddings=self.max_position_embeddings,
  290. # scaling_factor=scaling_factor,
  291. # base=self.rope_theta,
  292. # )
  293. # else:
  294. # raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
  295. # def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  296. # return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
  297. # def _clean_cache(self):
  298. # self.kv_cache._clean_scores()
  299. # def forward(
  300. # self,
  301. # hidden_states: torch.Tensor,
  302. # attention_mask: Optional[torch.Tensor] = None,
  303. # position_ids: Optional[torch.LongTensor] = None,
  304. # past_key_value: Optional[Tuple[torch.Tensor]] = None,
  305. # output_attentions: bool = False,
  306. # use_cache: bool = False,
  307. # cache_position: Optional[torch.LongTensor] = None,
  308. # ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  309. # bsz, q_len, _ = hidden_states.size()
  310. # if self.config.pretraining_tp > 1:
  311. # key_value_slicing = (
  312. # self.num_key_value_heads * self.head_dim
  313. # ) // self.config.pretraining_tp
  314. # query_slices = self.q_proj.weight.split(
  315. # (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
  316. # )
  317. # key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
  318. # value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
  319. # query_states = [
  320. # F.linear(hidden_states, query_slices[i])
  321. # for i in range(self.config.pretraining_tp)
  322. # ]
  323. # query_states = torch.cat(query_states, dim=-1)
  324. # key_states = [
  325. # F.linear(hidden_states, key_slices[i])
  326. # for i in range(self.config.pretraining_tp)
  327. # ]
  328. # key_states = torch.cat(key_states, dim=-1)
  329. # value_states = [
  330. # F.linear(hidden_states, value_slices[i])
  331. # for i in range(self.config.pretraining_tp)
  332. # ]
  333. # value_states = torch.cat(value_states, dim=-1)
  334. # else:
  335. # query_states = self.q_proj(hidden_states)
  336. # key_states = self.k_proj(hidden_states)
  337. # value_states = self.v_proj(hidden_states)
  338. # query_states = query_states.view(
  339. # bsz, q_len, self.num_heads, self.head_dim
  340. # ).transpose(1, 2)
  341. # key_states = key_states.view(
  342. # bsz, q_len, self.num_key_value_heads, self.head_dim
  343. # ).transpose(1, 2)
  344. # value_states = value_states.view(
  345. # bsz, q_len, self.num_key_value_heads, self.head_dim
  346. # ).transpose(1, 2)
  347. # # remake causal mask
  348. # attention_mask = _make_causal_mask(
  349. # bsz=bsz,
  350. # tgt_len=q_len,
  351. # past_key_values_length=past_key_value[0].shape[-2] if past_key_value is not None else 0,
  352. # dtype=query_states.dtype,
  353. # device=query_states.device,
  354. # )
  355. # kv_seq_len = key_states.shape[-2]
  356. # if past_key_value is not None:
  357. # kv_seq_len += past_key_value[0].shape[-2]
  358. # if not position_ids.nelement() > 1:
  359. # position_ids[0][0] = kv_seq_len - 1
  360. # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
  361. # ### Shift Pos: query pos is min(cache_size, idx)
  362. # # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
  363. # query_states = apply_rotary_pos_emb_single(query_states, cos, sin, position_ids)
  364. # ###
  365. # if past_key_value is not None:
  366. # # reuse k, v, self_attention
  367. # key_states = torch.cat([past_key_value[0], key_states], dim=2)
  368. # value_states = torch.cat([past_key_value[1], value_states], dim=2)
  369. # past_key_value = (key_states, value_states) if use_cache else None
  370. # ### Shift Pos: key pos is the pos in cache (Rolling KV Cache and using relative pos emb)
  371. # key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0)
  372. # key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids)
  373. # ###
  374. # # repeat k/v heads if n_kv_heads < n_heads
  375. # key_states = repeat_kv(key_states, self.num_key_value_groups)
  376. # value_states = repeat_kv(value_states, self.num_key_value_groups)
  377. # attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
  378. # self.head_dim
  379. # )
  380. # if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
  381. # raise ValueError(
  382. # f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
  383. # f" {attn_weights.size()}"
  384. # )
  385. # if attention_mask is not None:
  386. # if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
  387. # raise ValueError(
  388. # f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
  389. # )
  390. # attn_weights = attn_weights + attention_mask
  391. # # upcast attention to fp32
  392. # attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
  393. # query_states.dtype
  394. # )
  395. # past_key_value = self.kv_cache(past_key_value, attn_weights.detach().clone())
  396. # attn_output = torch.matmul(attn_weights, value_states)
  397. # if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
  398. # raise ValueError(
  399. # f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
  400. # f" {attn_output.size()}"
  401. # )
  402. # attn_output = attn_output.transpose(1, 2).contiguous()
  403. # attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
  404. # if self.config.pretraining_tp > 1:
  405. # attn_output = attn_output.split(
  406. # self.hidden_size // self.config.pretraining_tp, dim=2
  407. # )
  408. # o_proj_slices = self.o_proj.weight.split(
  409. # self.hidden_size // self.config.pretraining_tp, dim=1
  410. # )
  411. # attn_output = sum(
  412. # [
  413. # F.linear(attn_output[i], o_proj_slices[i])
  414. # for i in range(self.config.pretraining_tp)
  415. # ]
  416. # )
  417. # else:
  418. # attn_output = self.o_proj(attn_output)
  419. # if not output_attentions:
  420. # attn_weights = None
  421. # return attn_output, attn_weights, past_key_value
  422. class H2OLlamaForCausalLM(LlamaForCausalLM):
  423. def __init__(self, config):
  424. super().__init__(config)
  425. num_layers = len(self.model.layers)
  426. for layer_idx in range(num_layers):
  427. self.model.layers[layer_idx].self_attn = H2OLlamaAttention(config, layer_idx)