utils_llama.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539
  1. import math
  2. from typing import Optional, Tuple
  3. import pdb
  4. import types
  5. import torch
  6. from torch import nn
  7. import torch.utils.checkpoint
  8. import torch.nn.functional as F
  9. from transformers.models.llama.configuration_llama import LlamaConfig
  10. from transformers.models.llama.modeling_llama import (
  11. LlamaAttention,
  12. rotate_half,
  13. apply_rotary_pos_emb,
  14. repeat_kv,
  15. LlamaRotaryEmbedding,
  16. apply_rotary_pos_emb,
  17. LlamaForCausalLM,
  18. )
  19. from cache_utils import Cache
  20. __all__ = ["H2OLlamaForCausalLM"]
  21. def _make_causal_mask(
  22. bsz: int, tgt_len: int, past_key_values_length: int, dtype: torch.dtype, device: torch.device):
  23. """
  24. Make causal mask used for bi-directional self-attention.
  25. """
  26. mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
  27. mask_cond = torch.arange(mask.size(-1), device=device)
  28. mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
  29. mask = mask.to(dtype)
  30. if past_key_values_length > 0:
  31. mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
  32. return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
  33. def apply_rotary_pos_emb_single(x, cos, sin, position_ids):
  34. # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
  35. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
  36. sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
  37. cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
  38. sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
  39. x_embed = (x * cos) + (rotate_half(x) * sin)
  40. return x_embed
  41. class H2OKVCache_LayerWise:
  42. def __init__(
  43. self,
  44. hh_size=4,
  45. recent_size=512,
  46. k_seq_dim=2,
  47. v_seq_dim=2,
  48. ):
  49. self.hh_size = hh_size
  50. self.recent_size = recent_size
  51. self.cache_size = hh_size + recent_size
  52. self.k_seq_dim = k_seq_dim
  53. self.v_seq_dim = v_seq_dim
  54. self.hh_score = None
  55. def __call__(self, past_key_values, attn_score_cache):
  56. self._update_hh_score(attn_score_cache)
  57. if past_key_values is None:
  58. return None
  59. seq_len = past_key_values[0].size(self.k_seq_dim)
  60. if seq_len <= self.cache_size:
  61. return past_key_values
  62. # hh-selection
  63. bsz, num_heads, _, head_dim = past_key_values[0].shape
  64. select_hh_scores = self.hh_score[:, :seq_len - self.recent_size]
  65. _, keep_topk = torch.topk(select_hh_scores, self.hh_size, dim=-1)
  66. keep_topk = keep_topk.sort().values
  67. # keep_recent = torch.arange(seq_len - self.recent_size, seq_len).expand(keep_topk.shape[0], 1).to(keep_topk.device)
  68. keep_recent = torch.arange(seq_len - self.recent_size, seq_len, device=keep_topk.device).repeat(keep_topk.shape[0], 1)
  69. keep_idx = torch.cat([keep_topk, keep_recent], dim=-1)
  70. mask = torch.zeros(self.hh_score.shape, dtype=torch.bool).to(past_key_values[0].device)
  71. mask = mask.scatter(-1, keep_idx, 1)
  72. k_hh_recent = past_key_values[0].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
  73. v_hh_recent = past_key_values[1].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
  74. self.hh_score= self.hh_score[mask].view(num_heads, self.cache_size)
  75. return (k_hh_recent, v_hh_recent)
  76. def evict_for_space(self, past_key_values, num_coming):
  77. if past_key_values is None:
  78. return None
  79. seq_len = past_key_values[0][0].size(self.k_seq_dim)
  80. if seq_len + num_coming <= self.cache_size:
  81. return past_key_values
  82. # hh-selection
  83. bsz, num_heads, _, head_dim = past_key_values[0].shape
  84. select_hh_scores = self.hh_score[:, :seq_len - self.recent_size + num_coming]
  85. _, keep_topk = torch.topk(select_hh_scores, self.hh_size, dim=-1)
  86. keep_topk = keep_topk.sort().values
  87. # keep_recent = torch.arange(seq_len - self.recent_size, seq_len).expand(keep_topk.shape[0], 1).to(keep_topk.device)
  88. keep_recent = torch.arange(seq_len - self.recent_size + num_coming, seq_len, device=keep_topk.device).repeat(keep_topk.shape[0], 1)
  89. keep_idx = torch.cat([keep_topk, keep_recent], dim=-1)
  90. mask = torch.zeros(self.hh_score.shape, dtype=torch.bool).to(past_key_values[0].device)
  91. mask = mask.scatter(-1, keep_idx, 1)
  92. k_hh_recent = past_key_values[0].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
  93. v_hh_recent = past_key_values[1].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
  94. self.hh_score= self.hh_score[mask].view(num_heads, self.cache_size)
  95. return (k_hh_recent, v_hh_recent)
  96. def _update_hh_score(self, attn_score_cache):
  97. num_new_tokens = attn_score_cache.shape[2]
  98. if self.hh_score is None:
  99. self.hh_score = attn_score_cache.sum(0).sum(1)
  100. else:
  101. attn_score_cache = attn_score_cache.sum(0).sum(1)
  102. attn_score_cache[:, :-num_new_tokens] += self.hh_score
  103. self.hh_score = attn_score_cache
  104. def _clean_scores(self):
  105. self.hh_score = None
  106. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  107. """
  108. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  109. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  110. """
  111. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  112. if n_rep == 1:
  113. return hidden_states
  114. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  115. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  116. class H2OLlamaAttention(nn.Module):
  117. """Multi-headed attention from 'Attention Is All You Need' paper"""
  118. def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
  119. super().__init__()
  120. self.config = config
  121. self.layer_idx = layer_idx
  122. if layer_idx is None:
  123. logger.warning_once(
  124. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  125. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  126. "when creating this class."
  127. )
  128. self.attention_dropout = config.attention_dropout
  129. self.hidden_size = config.hidden_size
  130. self.num_heads = config.num_attention_heads
  131. self.head_dim = self.hidden_size // self.num_heads
  132. self.num_key_value_heads = config.num_key_value_heads
  133. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  134. self.max_position_embeddings = config.max_position_embeddings
  135. self.rope_theta = config.rope_theta
  136. self.is_causal = True
  137. if (self.head_dim * self.num_heads) != self.hidden_size:
  138. raise ValueError(
  139. f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
  140. f" and `num_heads`: {self.num_heads})."
  141. )
  142. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
  143. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  144. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  145. self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
  146. self._init_rope()
  147. def _init_rope(self):
  148. if self.config.rope_scaling is None:
  149. self.rotary_emb = LlamaRotaryEmbedding(
  150. self.head_dim,
  151. max_position_embeddings=self.max_position_embeddings,
  152. base=self.rope_theta,
  153. )
  154. else:
  155. scaling_type = self.config.rope_scaling["type"]
  156. scaling_factor = self.config.rope_scaling["factor"]
  157. if scaling_type == "linear":
  158. self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
  159. self.head_dim,
  160. max_position_embeddings=self.max_position_embeddings,
  161. scaling_factor=scaling_factor,
  162. base=self.rope_theta,
  163. )
  164. elif scaling_type == "dynamic":
  165. self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
  166. self.head_dim,
  167. max_position_embeddings=self.max_position_embeddings,
  168. scaling_factor=scaling_factor,
  169. base=self.rope_theta,
  170. )
  171. else:
  172. raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
  173. def forward(
  174. self,
  175. hidden_states: torch.Tensor,
  176. attention_mask: Optional[torch.Tensor] = None,
  177. position_ids: Optional[torch.LongTensor] = None,
  178. past_key_value: Optional[Cache] = None,
  179. output_attentions: bool = False,
  180. use_cache: bool = False,
  181. cache_position: Optional[torch.LongTensor] = None,
  182. **kwargs,
  183. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  184. bsz, q_len, _ = hidden_states.size()
  185. if self.config.pretraining_tp > 1:
  186. key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
  187. query_slices = self.q_proj.weight.split(
  188. (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
  189. )
  190. key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
  191. value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
  192. query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
  193. query_states = torch.cat(query_states, dim=-1)
  194. key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
  195. key_states = torch.cat(key_states, dim=-1)
  196. value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
  197. value_states = torch.cat(value_states, dim=-1)
  198. else:
  199. query_states = self.q_proj(hidden_states)
  200. key_states = self.k_proj(hidden_states)
  201. value_states = self.v_proj(hidden_states)
  202. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  203. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  204. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  205. past_key_value = getattr(self, "past_key_value", past_key_value)
  206. cos, sin = self.rotary_emb(value_states, position_ids)
  207. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  208. if past_key_value is not None:
  209. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  210. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  211. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  212. key_states = repeat_kv(key_states, self.num_key_value_groups)
  213. value_states = repeat_kv(value_states, self.num_key_value_groups)
  214. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
  215. if attention_mask is not None: # no matter the length, we just slice it
  216. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  217. attn_weights = attn_weights + causal_mask
  218. # upcast attention to fp32
  219. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
  220. attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
  221. attn_output = torch.matmul(attn_weights, value_states)
  222. if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
  223. raise ValueError(
  224. f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
  225. f" {attn_output.size()}"
  226. )
  227. attn_output = attn_output.transpose(1, 2).contiguous()
  228. attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
  229. if self.config.pretraining_tp > 1:
  230. attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
  231. o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
  232. attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
  233. else:
  234. attn_output = self.o_proj(attn_output)
  235. if not output_attentions:
  236. attn_weights = None
  237. return attn_output, attn_weights, past_key_value
  238. # class H2OLlamaAttention(nn.Module):
  239. # """Multi-headed attention from 'Attention Is All You Need' paper"""
  240. # def __init__(self, config: LlamaConfig):
  241. # super().__init__()
  242. # self.config = config
  243. # self.hidden_size = config.hidden_size
  244. # self.num_heads = config.num_attention_heads
  245. # self.head_dim = self.hidden_size // self.num_heads
  246. # self.num_key_value_heads = config.num_key_value_heads
  247. # self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  248. # self.max_position_embeddings = config.max_position_embeddings
  249. # self.rope_theta = config.rope_theta
  250. # if (self.head_dim * self.num_heads) != self.hidden_size:
  251. # raise ValueError(
  252. # f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
  253. # f" and `num_heads`: {self.num_heads})."
  254. # )
  255. # self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
  256. # self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  257. # self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  258. # self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
  259. # self._init_rope()
  260. # self.kv_cache = H2OKVCache_LayerWise(
  261. # hh_size=config.hh_size,
  262. # recent_size=config.recent_size,
  263. # k_seq_dim=2,
  264. # v_seq_dim=2,
  265. # )
  266. # def _init_rope(self):
  267. # if self.config.rope_scaling is None:
  268. # self.rotary_emb = LlamaRotaryEmbedding(
  269. # self.head_dim,
  270. # max_position_embeddings=self.max_position_embeddings,
  271. # base=self.rope_theta,
  272. # )
  273. # else:
  274. # scaling_type = self.config.rope_scaling["type"]
  275. # scaling_factor = self.config.rope_scaling["factor"]
  276. # if scaling_type == "linear":
  277. # self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
  278. # self.head_dim,
  279. # max_position_embeddings=self.max_position_embeddings,
  280. # scaling_factor=scaling_factor,
  281. # base=self.rope_theta,
  282. # )
  283. # elif scaling_type == "dynamic":
  284. # self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
  285. # self.head_dim,
  286. # max_position_embeddings=self.max_position_embeddings,
  287. # scaling_factor=scaling_factor,
  288. # base=self.rope_theta,
  289. # )
  290. # else:
  291. # raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
  292. # def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  293. # return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
  294. # def _clean_cache(self):
  295. # self.kv_cache._clean_scores()
  296. # def forward(
  297. # self,
  298. # hidden_states: torch.Tensor,
  299. # attention_mask: Optional[torch.Tensor] = None,
  300. # position_ids: Optional[torch.LongTensor] = None,
  301. # past_key_value: Optional[Tuple[torch.Tensor]] = None,
  302. # output_attentions: bool = False,
  303. # use_cache: bool = False,
  304. # cache_position: Optional[torch.LongTensor] = None,
  305. # ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  306. # bsz, q_len, _ = hidden_states.size()
  307. # if self.config.pretraining_tp > 1:
  308. # key_value_slicing = (
  309. # self.num_key_value_heads * self.head_dim
  310. # ) // self.config.pretraining_tp
  311. # query_slices = self.q_proj.weight.split(
  312. # (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
  313. # )
  314. # key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
  315. # value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
  316. # query_states = [
  317. # F.linear(hidden_states, query_slices[i])
  318. # for i in range(self.config.pretraining_tp)
  319. # ]
  320. # query_states = torch.cat(query_states, dim=-1)
  321. # key_states = [
  322. # F.linear(hidden_states, key_slices[i])
  323. # for i in range(self.config.pretraining_tp)
  324. # ]
  325. # key_states = torch.cat(key_states, dim=-1)
  326. # value_states = [
  327. # F.linear(hidden_states, value_slices[i])
  328. # for i in range(self.config.pretraining_tp)
  329. # ]
  330. # value_states = torch.cat(value_states, dim=-1)
  331. # else:
  332. # query_states = self.q_proj(hidden_states)
  333. # key_states = self.k_proj(hidden_states)
  334. # value_states = self.v_proj(hidden_states)
  335. # query_states = query_states.view(
  336. # bsz, q_len, self.num_heads, self.head_dim
  337. # ).transpose(1, 2)
  338. # key_states = key_states.view(
  339. # bsz, q_len, self.num_key_value_heads, self.head_dim
  340. # ).transpose(1, 2)
  341. # value_states = value_states.view(
  342. # bsz, q_len, self.num_key_value_heads, self.head_dim
  343. # ).transpose(1, 2)
  344. # # remake causal mask
  345. # attention_mask = _make_causal_mask(
  346. # bsz=bsz,
  347. # tgt_len=q_len,
  348. # past_key_values_length=past_key_value[0].shape[-2] if past_key_value is not None else 0,
  349. # dtype=query_states.dtype,
  350. # device=query_states.device,
  351. # )
  352. # kv_seq_len = key_states.shape[-2]
  353. # if past_key_value is not None:
  354. # kv_seq_len += past_key_value[0].shape[-2]
  355. # if not position_ids.nelement() > 1:
  356. # position_ids[0][0] = kv_seq_len - 1
  357. # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
  358. # ### Shift Pos: query pos is min(cache_size, idx)
  359. # # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
  360. # query_states = apply_rotary_pos_emb_single(query_states, cos, sin, position_ids)
  361. # ###
  362. # if past_key_value is not None:
  363. # # reuse k, v, self_attention
  364. # key_states = torch.cat([past_key_value[0], key_states], dim=2)
  365. # value_states = torch.cat([past_key_value[1], value_states], dim=2)
  366. # past_key_value = (key_states, value_states) if use_cache else None
  367. # ### Shift Pos: key pos is the pos in cache (Rolling KV Cache and using relative pos emb)
  368. # key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0)
  369. # key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids)
  370. # ###
  371. # # repeat k/v heads if n_kv_heads < n_heads
  372. # key_states = repeat_kv(key_states, self.num_key_value_groups)
  373. # value_states = repeat_kv(value_states, self.num_key_value_groups)
  374. # attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
  375. # self.head_dim
  376. # )
  377. # if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
  378. # raise ValueError(
  379. # f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
  380. # f" {attn_weights.size()}"
  381. # )
  382. # if attention_mask is not None:
  383. # if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
  384. # raise ValueError(
  385. # f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
  386. # )
  387. # attn_weights = attn_weights + attention_mask
  388. # # upcast attention to fp32
  389. # attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
  390. # query_states.dtype
  391. # )
  392. # past_key_value = self.kv_cache(past_key_value, attn_weights.detach().clone())
  393. # attn_output = torch.matmul(attn_weights, value_states)
  394. # if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
  395. # raise ValueError(
  396. # f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
  397. # f" {attn_output.size()}"
  398. # )
  399. # attn_output = attn_output.transpose(1, 2).contiguous()
  400. # attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
  401. # if self.config.pretraining_tp > 1:
  402. # attn_output = attn_output.split(
  403. # self.hidden_size // self.config.pretraining_tp, dim=2
  404. # )
  405. # o_proj_slices = self.o_proj.weight.split(
  406. # self.hidden_size // self.config.pretraining_tp, dim=1
  407. # )
  408. # attn_output = sum(
  409. # [
  410. # F.linear(attn_output[i], o_proj_slices[i])
  411. # for i in range(self.config.pretraining_tp)
  412. # ]
  413. # )
  414. # else:
  415. # attn_output = self.o_proj(attn_output)
  416. # if not output_attentions:
  417. # attn_weights = None
  418. # return attn_output, attn_weights, past_key_value
  419. class H2OLlamaForCausalLM(LlamaForCausalLM):
  420. def __init__(self, config):
  421. super().__init__(config)
  422. num_layers = len(self.model.layers)
  423. for layer_idx in range(num_layers):
  424. self.model.layers[layer_idx].self_attn = H2OLlamaAttention(config)