utils_llama.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. import math
  2. from typing import Optional, Tuple
  3. import pdb
  4. import types
  5. import torch
  6. from torch import nn
  7. import torch.utils.checkpoint
  8. import torch.nn.functional as F
  9. from transformers.models.llama.configuration_llama import LlamaConfig
  10. from transformers.models.llama.modeling_llama import (
  11. LlamaAttention,
  12. rotate_half,
  13. apply_rotary_pos_emb,
  14. repeat_kv,
  15. LlamaRotaryEmbedding,
  16. apply_rotary_pos_emb,
  17. LlamaForCausalLM,
  18. )
  19. __all__ = ["H2OLlamaForCausalLM"]
  20. def _make_causal_mask(
  21. bsz: int, tgt_len: int, past_key_values_length: int, dtype: torch.dtype, device: torch.device):
  22. """
  23. Make causal mask used for bi-directional self-attention.
  24. """
  25. mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
  26. mask_cond = torch.arange(mask.size(-1), device=device)
  27. mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
  28. mask = mask.to(dtype)
  29. if past_key_values_length > 0:
  30. mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
  31. return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
  32. def apply_rotary_pos_emb_single(x, cos, sin, position_ids):
  33. # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
  34. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
  35. sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
  36. cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
  37. sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
  38. x_embed = (x * cos) + (rotate_half(x) * sin)
  39. return x_embed
  40. class H2OKVCache_LayerWise:
  41. def __init__(
  42. self,
  43. hh_size=4,
  44. recent_size=512,
  45. k_seq_dim=2,
  46. v_seq_dim=2,
  47. ):
  48. self.hh_size = hh_size
  49. self.recent_size = recent_size
  50. self.cache_size = hh_size + recent_size
  51. self.k_seq_dim = k_seq_dim
  52. self.v_seq_dim = v_seq_dim
  53. self.k_slice = DIM_TO_SLICE[k_seq_dim]
  54. self.v_slice = DIM_TO_SLICE[v_seq_dim]
  55. self.hh_score = None
  56. def __call__(self, past_key_values, attn_score_cache):
  57. self._update_hh_score(attn_score_cache)
  58. if past_key_values is None:
  59. return None
  60. seq_len = past_key_values[0].size(self.k_seq_dim)
  61. if seq_len <= self.cache_size:
  62. return past_key_values
  63. # hh-selection
  64. bsz, num_heads, _, head_dim = past_key_values[0].shape
  65. select_hh_scores = self.hh_score[:, :seq_len - self.recent_size]
  66. _, keep_topk = torch.topk(select_hh_scores, self.hh_size, dim=-1)
  67. keep_topk = keep_topk.sort().values
  68. # keep_recent = torch.arange(seq_len - self.recent_size, seq_len).expand(keep_topk.shape[0], 1).to(keep_topk.device)
  69. keep_recent = torch.arange(seq_len - self.recent_size, seq_len, device=keep_topk.device).repeat(keep_topk.shape[0], 1)
  70. keep_idx = torch.cat([keep_topk, keep_recent], dim=-1)
  71. mask = torch.zeros(self.hh_score.shape, dtype=torch.bool).to(past_key_values[0].device)
  72. mask = mask.scatter(-1, keep_idx, 1)
  73. k_hh_recent = past_key_values[0].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
  74. v_hh_recent = past_key_values[1].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
  75. self.hh_score= self.hh_score[mask].view(num_heads, self.cache_size)
  76. return (k_hh_recent, v_hh_recent)
  77. def evict_for_space(self, past_key_values, num_coming):
  78. if past_key_values is None:
  79. return None
  80. seq_len = past_key_values[0][0].size(self.k_seq_dim)
  81. if seq_len + num_coming <= self.cache_size:
  82. return past_key_values
  83. # hh-selection
  84. bsz, num_heads, _, head_dim = past_key_values[0].shape
  85. select_hh_scores = self.hh_score[:, :seq_len - self.recent_size + num_coming]
  86. _, keep_topk = torch.topk(select_hh_scores, self.hh_size, dim=-1)
  87. keep_topk = keep_topk.sort().values
  88. # keep_recent = torch.arange(seq_len - self.recent_size, seq_len).expand(keep_topk.shape[0], 1).to(keep_topk.device)
  89. keep_recent = torch.arange(seq_len - self.recent_size + num_coming, seq_len, device=keep_topk.device).repeat(keep_topk.shape[0], 1)
  90. keep_idx = torch.cat([keep_topk, keep_recent], dim=-1)
  91. mask = torch.zeros(self.hh_score.shape, dtype=torch.bool).to(past_key_values[0].device)
  92. mask = mask.scatter(-1, keep_idx, 1)
  93. k_hh_recent = past_key_values[0].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
  94. v_hh_recent = past_key_values[1].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
  95. self.hh_score= self.hh_score[mask].view(num_heads, self.cache_size)
  96. return (k_hh_recent, v_hh_recent)
  97. def _update_hh_score(self, attn_score_cache):
  98. num_new_tokens = attn_score_cache.shape[2]
  99. if self.hh_score is None:
  100. self.hh_score = attn_score_cache.sum(0).sum(1)
  101. else:
  102. attn_score_cache = attn_score_cache.sum(0).sum(1)
  103. attn_score_cache[:, :-num_new_tokens] += self.hh_score
  104. self.hh_score = attn_score_cache
  105. def _clean_scores(self):
  106. self.hh_score = None
  107. class H2OLlamaAttention(nn.Module):
  108. """Multi-headed attention from 'Attention Is All You Need' paper"""
  109. def __init__(self, config: LlamaConfig):
  110. super().__init__()
  111. self.config = config
  112. self.hidden_size = config.hidden_size
  113. self.num_heads = config.num_attention_heads
  114. self.head_dim = self.hidden_size // self.num_heads
  115. self.num_key_value_heads = config.num_key_value_heads
  116. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  117. self.max_position_embeddings = config.max_position_embeddings
  118. self.rope_theta = config.rope_theta
  119. if (self.head_dim * self.num_heads) != self.hidden_size:
  120. raise ValueError(
  121. f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
  122. f" and `num_heads`: {self.num_heads})."
  123. )
  124. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
  125. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  126. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  127. self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
  128. self._init_rope()
  129. self.kv_cache = H2OKVCache_LayerWise(
  130. hh_size=config.hh_size,
  131. recent_size=config.recent_size,
  132. k_seq_dim=2,
  133. v_seq_dim=2,
  134. )
  135. def _init_rope(self):
  136. if self.config.rope_scaling is None:
  137. self.rotary_emb = LlamaRotaryEmbedding(
  138. self.head_dim,
  139. max_position_embeddings=self.max_position_embeddings,
  140. base=self.rope_theta,
  141. )
  142. else:
  143. scaling_type = self.config.rope_scaling["type"]
  144. scaling_factor = self.config.rope_scaling["factor"]
  145. if scaling_type == "linear":
  146. self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
  147. self.head_dim,
  148. max_position_embeddings=self.max_position_embeddings,
  149. scaling_factor=scaling_factor,
  150. base=self.rope_theta,
  151. )
  152. elif scaling_type == "dynamic":
  153. self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
  154. self.head_dim,
  155. max_position_embeddings=self.max_position_embeddings,
  156. scaling_factor=scaling_factor,
  157. base=self.rope_theta,
  158. )
  159. else:
  160. raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
  161. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  162. return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
  163. def _clean_cache(self):
  164. self.kv_cache._clean_scores()
  165. def forward(
  166. self,
  167. hidden_states: torch.Tensor,
  168. attention_mask: Optional[torch.Tensor] = None,
  169. position_ids: Optional[torch.LongTensor] = None,
  170. past_key_value: Optional[Tuple[torch.Tensor]] = None,
  171. output_attentions: bool = False,
  172. use_cache: bool = False,
  173. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  174. bsz, q_len, _ = hidden_states.size()
  175. if self.config.pretraining_tp > 1:
  176. key_value_slicing = (
  177. self.num_key_value_heads * self.head_dim
  178. ) // self.config.pretraining_tp
  179. query_slices = self.q_proj.weight.split(
  180. (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
  181. )
  182. key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
  183. value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
  184. query_states = [
  185. F.linear(hidden_states, query_slices[i])
  186. for i in range(self.config.pretraining_tp)
  187. ]
  188. query_states = torch.cat(query_states, dim=-1)
  189. key_states = [
  190. F.linear(hidden_states, key_slices[i])
  191. for i in range(self.config.pretraining_tp)
  192. ]
  193. key_states = torch.cat(key_states, dim=-1)
  194. value_states = [
  195. F.linear(hidden_states, value_slices[i])
  196. for i in range(self.config.pretraining_tp)
  197. ]
  198. value_states = torch.cat(value_states, dim=-1)
  199. else:
  200. query_states = self.q_proj(hidden_states)
  201. key_states = self.k_proj(hidden_states)
  202. value_states = self.v_proj(hidden_states)
  203. query_states = query_states.view(
  204. bsz, q_len, self.num_heads, self.head_dim
  205. ).transpose(1, 2)
  206. key_states = key_states.view(
  207. bsz, q_len, self.num_key_value_heads, self.head_dim
  208. ).transpose(1, 2)
  209. value_states = value_states.view(
  210. bsz, q_len, self.num_key_value_heads, self.head_dim
  211. ).transpose(1, 2)
  212. # remake causal mask
  213. attention_mask = _make_causal_mask(
  214. bsz=bsz,
  215. tgt_len=q_len,
  216. past_key_values_length=past_key_value[0].shape[-2] if past_key_value is not None else 0,
  217. dtype=query_states.dtype,
  218. device=query_states.device,
  219. )
  220. kv_seq_len = key_states.shape[-2]
  221. if past_key_value is not None:
  222. kv_seq_len += past_key_value[0].shape[-2]
  223. if not position_ids.nelement() > 1:
  224. position_ids[0][0] = kv_seq_len - 1
  225. cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
  226. ### Shift Pos: query pos is min(cache_size, idx)
  227. # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
  228. query_states = apply_rotary_pos_emb_single(query_states, cos, sin, position_ids)
  229. ###
  230. if past_key_value is not None:
  231. # reuse k, v, self_attention
  232. key_states = torch.cat([past_key_value[0], key_states], dim=2)
  233. value_states = torch.cat([past_key_value[1], value_states], dim=2)
  234. past_key_value = (key_states, value_states) if use_cache else None
  235. ### Shift Pos: key pos is the pos in cache (Rolling KV Cache and using relative pos emb)
  236. key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0)
  237. key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids)
  238. ###
  239. # repeat k/v heads if n_kv_heads < n_heads
  240. key_states = repeat_kv(key_states, self.num_key_value_groups)
  241. value_states = repeat_kv(value_states, self.num_key_value_groups)
  242. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
  243. self.head_dim
  244. )
  245. if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
  246. raise ValueError(
  247. f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
  248. f" {attn_weights.size()}"
  249. )
  250. if attention_mask is not None:
  251. if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
  252. raise ValueError(
  253. f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
  254. )
  255. attn_weights = attn_weights + attention_mask
  256. # upcast attention to fp32
  257. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
  258. query_states.dtype
  259. )
  260. past_key_value = self.kv_cache(past_key_value, attn_weights.detach().clone())
  261. attn_output = torch.matmul(attn_weights, value_states)
  262. if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
  263. raise ValueError(
  264. f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
  265. f" {attn_output.size()}"
  266. )
  267. attn_output = attn_output.transpose(1, 2).contiguous()
  268. attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
  269. if self.config.pretraining_tp > 1:
  270. attn_output = attn_output.split(
  271. self.hidden_size // self.config.pretraining_tp, dim=2
  272. )
  273. o_proj_slices = self.o_proj.weight.split(
  274. self.hidden_size // self.config.pretraining_tp, dim=1
  275. )
  276. attn_output = sum(
  277. [
  278. F.linear(attn_output[i], o_proj_slices[i])
  279. for i in range(self.config.pretraining_tp)
  280. ]
  281. )
  282. else:
  283. attn_output = self.o_proj(attn_output)
  284. if not output_attentions:
  285. attn_weights = None
  286. return attn_output, attn_weights, past_key_value
  287. class H2OLlamaForCausalLM(LlamaForCausalLM):
  288. def __init__(self, config):
  289. super().__init__(config)
  290. num_layers = len(self.model.layers)
  291. for layer_idx in range(num_layers):
  292. self.model.layers[layer_idx].self_attn = H2OLlamaAttention(config)