utils_llama.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. import math
  2. from typing import Optional, Tuple
  3. import pdb
  4. import types
  5. import torch
  6. from torch import nn
  7. import torch.utils.checkpoint
  8. import torch.nn.functional as F
  9. from transformers.models.llama.configuration_llama import LlamaConfig
  10. from transformers.models.llama.modeling_llama import (
  11. LlamaAttention,
  12. rotate_half,
  13. apply_rotary_pos_emb,
  14. repeat_kv,
  15. LlamaRotaryEmbedding,
  16. apply_rotary_pos_emb,
  17. LlamaForCausalLM,
  18. )
  19. __all__ = ["H2OLlamaForCausalLM"]
  20. def _make_causal_mask(
  21. bsz: int, tgt_len: int, past_key_values_length: int, dtype: torch.dtype, device: torch.device):
  22. """
  23. Make causal mask used for bi-directional self-attention.
  24. """
  25. mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
  26. mask_cond = torch.arange(mask.size(-1), device=device)
  27. mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
  28. mask = mask.to(dtype)
  29. if past_key_values_length > 0:
  30. mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
  31. return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
  32. def apply_rotary_pos_emb_single(x, cos, sin, position_ids):
  33. # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
  34. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
  35. sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
  36. cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
  37. sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
  38. x_embed = (x * cos) + (rotate_half(x) * sin)
  39. return x_embed
  40. class H2OKVCache_LayerWise:
  41. def __init__(
  42. self,
  43. hh_size=4,
  44. recent_size=512,
  45. k_seq_dim=2,
  46. v_seq_dim=2,
  47. ):
  48. self.hh_size = hh_size
  49. self.recent_size = recent_size
  50. self.cache_size = hh_size + recent_size
  51. self.k_seq_dim = k_seq_dim
  52. self.v_seq_dim = v_seq_dim
  53. self.hh_score = None
  54. def __call__(self, past_key_values, attn_score_cache):
  55. self._update_hh_score(attn_score_cache)
  56. if past_key_values is None:
  57. return None
  58. seq_len = past_key_values[0].size(self.k_seq_dim)
  59. if seq_len <= self.cache_size:
  60. return past_key_values
  61. # hh-selection
  62. bsz, num_heads, _, head_dim = past_key_values[0].shape
  63. select_hh_scores = self.hh_score[:, :seq_len - self.recent_size]
  64. _, keep_topk = torch.topk(select_hh_scores, self.hh_size, dim=-1)
  65. keep_topk = keep_topk.sort().values
  66. # keep_recent = torch.arange(seq_len - self.recent_size, seq_len).expand(keep_topk.shape[0], 1).to(keep_topk.device)
  67. keep_recent = torch.arange(seq_len - self.recent_size, seq_len, device=keep_topk.device).repeat(keep_topk.shape[0], 1)
  68. keep_idx = torch.cat([keep_topk, keep_recent], dim=-1)
  69. mask = torch.zeros(self.hh_score.shape, dtype=torch.bool).to(past_key_values[0].device)
  70. mask = mask.scatter(-1, keep_idx, 1)
  71. k_hh_recent = past_key_values[0].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
  72. v_hh_recent = past_key_values[1].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
  73. self.hh_score= self.hh_score[mask].view(num_heads, self.cache_size)
  74. return (k_hh_recent, v_hh_recent)
  75. def evict_for_space(self, past_key_values, num_coming):
  76. if past_key_values is None:
  77. return None
  78. seq_len = past_key_values[0][0].size(self.k_seq_dim)
  79. if seq_len + num_coming <= self.cache_size:
  80. return past_key_values
  81. # hh-selection
  82. bsz, num_heads, _, head_dim = past_key_values[0].shape
  83. select_hh_scores = self.hh_score[:, :seq_len - self.recent_size + num_coming]
  84. _, keep_topk = torch.topk(select_hh_scores, self.hh_size, dim=-1)
  85. keep_topk = keep_topk.sort().values
  86. # keep_recent = torch.arange(seq_len - self.recent_size, seq_len).expand(keep_topk.shape[0], 1).to(keep_topk.device)
  87. keep_recent = torch.arange(seq_len - self.recent_size + num_coming, seq_len, device=keep_topk.device).repeat(keep_topk.shape[0], 1)
  88. keep_idx = torch.cat([keep_topk, keep_recent], dim=-1)
  89. mask = torch.zeros(self.hh_score.shape, dtype=torch.bool).to(past_key_values[0].device)
  90. mask = mask.scatter(-1, keep_idx, 1)
  91. k_hh_recent = past_key_values[0].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
  92. v_hh_recent = past_key_values[1].squeeze()[mask].view(bsz, num_heads, -1, head_dim)
  93. self.hh_score= self.hh_score[mask].view(num_heads, self.cache_size)
  94. return (k_hh_recent, v_hh_recent)
  95. def _update_hh_score(self, attn_score_cache):
  96. num_new_tokens = attn_score_cache.shape[2]
  97. if self.hh_score is None:
  98. self.hh_score = attn_score_cache.sum(0).sum(1)
  99. else:
  100. attn_score_cache = attn_score_cache.sum(0).sum(1)
  101. attn_score_cache[:, :-num_new_tokens] += self.hh_score
  102. self.hh_score = attn_score_cache
  103. def _clean_scores(self):
  104. self.hh_score = None
  105. class H2OLlamaAttention(nn.Module):
  106. """Multi-headed attention from 'Attention Is All You Need' paper"""
  107. def __init__(self, config: LlamaConfig):
  108. super().__init__()
  109. self.config = config
  110. self.hidden_size = config.hidden_size
  111. self.num_heads = config.num_attention_heads
  112. self.head_dim = self.hidden_size // self.num_heads
  113. self.num_key_value_heads = config.num_key_value_heads
  114. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  115. self.max_position_embeddings = config.max_position_embeddings
  116. self.rope_theta = config.rope_theta
  117. if (self.head_dim * self.num_heads) != self.hidden_size:
  118. raise ValueError(
  119. f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
  120. f" and `num_heads`: {self.num_heads})."
  121. )
  122. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
  123. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  124. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  125. self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
  126. self._init_rope()
  127. self.kv_cache = H2OKVCache_LayerWise(
  128. hh_size=config.hh_size,
  129. recent_size=config.recent_size,
  130. k_seq_dim=2,
  131. v_seq_dim=2,
  132. )
  133. def _init_rope(self):
  134. if self.config.rope_scaling is None:
  135. self.rotary_emb = LlamaRotaryEmbedding(
  136. self.head_dim,
  137. max_position_embeddings=self.max_position_embeddings,
  138. base=self.rope_theta,
  139. )
  140. else:
  141. scaling_type = self.config.rope_scaling["type"]
  142. scaling_factor = self.config.rope_scaling["factor"]
  143. if scaling_type == "linear":
  144. self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
  145. self.head_dim,
  146. max_position_embeddings=self.max_position_embeddings,
  147. scaling_factor=scaling_factor,
  148. base=self.rope_theta,
  149. )
  150. elif scaling_type == "dynamic":
  151. self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
  152. self.head_dim,
  153. max_position_embeddings=self.max_position_embeddings,
  154. scaling_factor=scaling_factor,
  155. base=self.rope_theta,
  156. )
  157. else:
  158. raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
  159. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  160. return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
  161. def _clean_cache(self):
  162. self.kv_cache._clean_scores()
  163. def forward(
  164. self,
  165. hidden_states: torch.Tensor,
  166. attention_mask: Optional[torch.Tensor] = None,
  167. position_ids: Optional[torch.LongTensor] = None,
  168. past_key_value: Optional[Tuple[torch.Tensor]] = None,
  169. output_attentions: bool = False,
  170. use_cache: bool = False,
  171. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  172. bsz, q_len, _ = hidden_states.size()
  173. if self.config.pretraining_tp > 1:
  174. key_value_slicing = (
  175. self.num_key_value_heads * self.head_dim
  176. ) // self.config.pretraining_tp
  177. query_slices = self.q_proj.weight.split(
  178. (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
  179. )
  180. key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
  181. value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
  182. query_states = [
  183. F.linear(hidden_states, query_slices[i])
  184. for i in range(self.config.pretraining_tp)
  185. ]
  186. query_states = torch.cat(query_states, dim=-1)
  187. key_states = [
  188. F.linear(hidden_states, key_slices[i])
  189. for i in range(self.config.pretraining_tp)
  190. ]
  191. key_states = torch.cat(key_states, dim=-1)
  192. value_states = [
  193. F.linear(hidden_states, value_slices[i])
  194. for i in range(self.config.pretraining_tp)
  195. ]
  196. value_states = torch.cat(value_states, dim=-1)
  197. else:
  198. query_states = self.q_proj(hidden_states)
  199. key_states = self.k_proj(hidden_states)
  200. value_states = self.v_proj(hidden_states)
  201. query_states = query_states.view(
  202. bsz, q_len, self.num_heads, self.head_dim
  203. ).transpose(1, 2)
  204. key_states = key_states.view(
  205. bsz, q_len, self.num_key_value_heads, self.head_dim
  206. ).transpose(1, 2)
  207. value_states = value_states.view(
  208. bsz, q_len, self.num_key_value_heads, self.head_dim
  209. ).transpose(1, 2)
  210. # remake causal mask
  211. attention_mask = _make_causal_mask(
  212. bsz=bsz,
  213. tgt_len=q_len,
  214. past_key_values_length=past_key_value[0].shape[-2] if past_key_value is not None else 0,
  215. dtype=query_states.dtype,
  216. device=query_states.device,
  217. )
  218. kv_seq_len = key_states.shape[-2]
  219. if past_key_value is not None:
  220. kv_seq_len += past_key_value[0].shape[-2]
  221. if not position_ids.nelement() > 1:
  222. position_ids[0][0] = kv_seq_len - 1
  223. cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
  224. ### Shift Pos: query pos is min(cache_size, idx)
  225. # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
  226. query_states = apply_rotary_pos_emb_single(query_states, cos, sin, position_ids)
  227. ###
  228. if past_key_value is not None:
  229. # reuse k, v, self_attention
  230. key_states = torch.cat([past_key_value[0], key_states], dim=2)
  231. value_states = torch.cat([past_key_value[1], value_states], dim=2)
  232. past_key_value = (key_states, value_states) if use_cache else None
  233. ### Shift Pos: key pos is the pos in cache (Rolling KV Cache and using relative pos emb)
  234. key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0)
  235. key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids)
  236. ###
  237. # repeat k/v heads if n_kv_heads < n_heads
  238. key_states = repeat_kv(key_states, self.num_key_value_groups)
  239. value_states = repeat_kv(value_states, self.num_key_value_groups)
  240. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
  241. self.head_dim
  242. )
  243. if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
  244. raise ValueError(
  245. f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
  246. f" {attn_weights.size()}"
  247. )
  248. if attention_mask is not None:
  249. if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
  250. raise ValueError(
  251. f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
  252. )
  253. attn_weights = attn_weights + attention_mask
  254. # upcast attention to fp32
  255. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
  256. query_states.dtype
  257. )
  258. past_key_value = self.kv_cache(past_key_value, attn_weights.detach().clone())
  259. attn_output = torch.matmul(attn_weights, value_states)
  260. if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
  261. raise ValueError(
  262. f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
  263. f" {attn_output.size()}"
  264. )
  265. attn_output = attn_output.transpose(1, 2).contiguous()
  266. attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
  267. if self.config.pretraining_tp > 1:
  268. attn_output = attn_output.split(
  269. self.hidden_size // self.config.pretraining_tp, dim=2
  270. )
  271. o_proj_slices = self.o_proj.weight.split(
  272. self.hidden_size // self.config.pretraining_tp, dim=1
  273. )
  274. attn_output = sum(
  275. [
  276. F.linear(attn_output[i], o_proj_slices[i])
  277. for i in range(self.config.pretraining_tp)
  278. ]
  279. )
  280. else:
  281. attn_output = self.o_proj(attn_output)
  282. if not output_attentions:
  283. attn_weights = None
  284. return attn_output, attn_weights, past_key_value
  285. class H2OLlamaForCausalLM(LlamaForCausalLM):
  286. def __init__(self, config):
  287. super().__init__(config)
  288. num_layers = len(self.model.layers)
  289. for layer_idx in range(num_layers):
  290. self.model.layers[layer_idx].self_attn = H2OLlamaAttention(config)