cache.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645
  1. from dataclasses import dataclass
  2. from typing import Any, Dict, List, Optional, Tuple
  3. import torch
  4. from transformers.configuration_utils import PretrainedConfig
  5. from transformers.utils import logging
  6. logger = logging.get_logger(__name__)
  7. @dataclass
  8. class Cache:
  9. """
  10. Base, abstract class for all caches. The actual data structure is specific to each subclass.
  11. """
  12. def update(
  13. self,
  14. key_states: torch.Tensor,
  15. value_states: torch.Tensor,
  16. layer_idx: int,
  17. cache_kwargs: Optional[Dict[str, Any]] = None,
  18. ) -> Tuple[torch.Tensor, torch.Tensor]:
  19. """
  20. Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
  21. Parameters:
  22. key_states (`torch.Tensor`):
  23. The new key states to cache.
  24. value_states (`torch.Tensor`):
  25. The new value states to cache.
  26. layer_idx (`int`):
  27. The index of the layer to cache the states for.
  28. cache_kwargs (`Dict[str, Any]`, `optional`):
  29. Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
  30. cache to be created.
  31. Return:
  32. A tuple containing the updated key and value states.
  33. """
  34. raise NotImplementedError("Make sure to implement `update` in a subclass.")
  35. def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
  36. """Returns the sequence length of the cached states. A layer index can be optionally passed."""
  37. raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
  38. def get_max_length(self) -> Optional[int]:
  39. """Returns the maximum sequence length of the cached states, if there is any."""
  40. raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.")
  41. def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
  42. """Given the sequence length of the new inputs, returns the usable length of the cache."""
  43. # Cache without size limit -> all cache is usable
  44. # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
  45. # length, we will need to evict part of the cache (and thus not all cache is usable)
  46. max_length = self.get_max_length()
  47. previous_seq_length = self.get_seq_length(layer_idx)
  48. if max_length is not None and previous_seq_length + new_seq_length > max_length:
  49. return max_length - new_seq_length
  50. return previous_seq_length
  51. @property
  52. def seen_tokens(self):
  53. logger.warning_once(
  54. "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` "
  55. "model input instead."
  56. )
  57. if hasattr(self, "_seen_tokens"):
  58. return self._seen_tokens
  59. else:
  60. return None
  61. class DynamicCache(Cache):
  62. """
  63. A cache that grows dynamically as more tokens are generated. This is the default for generative models.
  64. It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
  65. `[batch_size, num_heads, seq_len, head_dim]`.
  66. """
  67. def __init__(self) -> None:
  68. self.key_cache: List[torch.Tensor] = []
  69. self.value_cache: List[torch.Tensor] = []
  70. self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
  71. def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
  72. """
  73. Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
  74. sequence length.
  75. """
  76. if layer_idx < len(self):
  77. return (self.key_cache[layer_idx], self.value_cache[layer_idx])
  78. else:
  79. raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
  80. def __iter__(self):
  81. """
  82. Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
  83. keys and values
  84. """
  85. for layer_idx in range(len(self)):
  86. yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
  87. def __len__(self):
  88. """
  89. Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
  90. to the number of layers in the model.
  91. """
  92. return len(self.key_cache)
  93. def update(
  94. self,
  95. key_states: torch.Tensor,
  96. value_states: torch.Tensor,
  97. layer_idx: int,
  98. cache_kwargs: Optional[Dict[str, Any]] = None,
  99. ) -> Tuple[torch.Tensor, torch.Tensor]:
  100. """
  101. Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
  102. Parameters:
  103. key_states (`torch.Tensor`):
  104. The new key states to cache.
  105. value_states (`torch.Tensor`):
  106. The new value states to cache.
  107. layer_idx (`int`):
  108. The index of the layer to cache the states for.
  109. cache_kwargs (`Dict[str, Any]`, `optional`):
  110. Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
  111. Return:
  112. A tuple containing the updated key and value states.
  113. """
  114. # Update the number of seen tokens
  115. if layer_idx == 0:
  116. self._seen_tokens += key_states.shape[-2]
  117. # Update the cache
  118. if len(self.key_cache) <= layer_idx:
  119. self.key_cache.append(key_states)
  120. self.value_cache.append(value_states)
  121. else:
  122. self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
  123. self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
  124. return self.key_cache[layer_idx], self.value_cache[layer_idx]
  125. def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
  126. """Returns the sequence length of the cached states. A layer index can be optionally passed."""
  127. if len(self.key_cache) <= layer_idx:
  128. return 0
  129. return self.key_cache[layer_idx].shape[-2]
  130. def get_max_length(self) -> Optional[int]:
  131. """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
  132. return None
  133. def reorder_cache(self, beam_idx: torch.LongTensor):
  134. """Reorders the cache for beam search, given the selected beam indices."""
  135. for layer_idx in range(len(self.key_cache)):
  136. device = self.key_cache[layer_idx].device
  137. self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
  138. device = self.value_cache[layer_idx].device
  139. self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
  140. def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
  141. """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
  142. legacy_cache = ()
  143. for layer_idx in range(len(self)):
  144. legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
  145. return legacy_cache
  146. @classmethod
  147. def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
  148. """Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
  149. cache = cls()
  150. if past_key_values is not None:
  151. for layer_idx in range(len(past_key_values)):
  152. key_states, value_states = past_key_values[layer_idx]
  153. cache.update(key_states, value_states, layer_idx)
  154. return cache
  155. class SinkCache(Cache):
  156. """
  157. A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
  158. generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
  159. tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.
  160. It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
  161. `[batch_size, num_heads, seq_len, head_dim]`.
  162. Parameters:
  163. window_length (`int`):
  164. The length of the context window.
  165. num_sink_tokens (`int`):
  166. The number of sink tokens. See the original paper for more information.
  167. """
  168. def __init__(self, window_length: int, num_sink_tokens: int) -> None:
  169. self.key_cache: List[torch.Tensor] = []
  170. self.value_cache: List[torch.Tensor] = []
  171. self.window_length = window_length
  172. self.num_sink_tokens = num_sink_tokens
  173. self.cos_sin_cache = {}
  174. self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
  175. @staticmethod
  176. def _rotate_half(x):
  177. x1 = x[..., : x.shape[-1] // 2]
  178. x2 = x[..., x.shape[-1] // 2 :]
  179. return torch.cat((-x2, x1), dim=-1)
  180. def _apply_key_rotary_pos_emb(
  181. self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
  182. ) -> torch.Tensor:
  183. rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
  184. return rotated_key_states
  185. def _get_rerotation_cos_sin(
  186. self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
  187. ) -> Tuple[torch.Tensor, torch.Tensor]:
  188. if key_states.shape[-2] not in self.cos_sin_cache:
  189. # Upcast to float32 temporarily for better accuracy
  190. cos = cos.to(torch.float32)
  191. sin = sin.to(torch.float32)
  192. # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
  193. original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
  194. shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
  195. original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
  196. shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
  197. rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
  198. rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
  199. self.cos_sin_cache[key_states.shape[-2]] = (
  200. rerotation_cos.to(key_states.dtype).unsqueeze(0),
  201. rerotation_sin.to(key_states.dtype).unsqueeze(0),
  202. )
  203. return self.cos_sin_cache[key_states.shape[-2]]
  204. def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
  205. """Returns the sequence length of the cached states. A layer index can be optionally passed."""
  206. # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
  207. if len(self.key_cache) <= layer_idx:
  208. return 0
  209. return self.key_cache[layer_idx].shape[-2]
  210. def get_max_length(self) -> Optional[int]:
  211. """Returns the maximum sequence length of the cached states."""
  212. return self.window_length
  213. def update(
  214. self,
  215. key_states: torch.Tensor,
  216. value_states: torch.Tensor,
  217. layer_idx: int,
  218. cache_kwargs: Optional[Dict[str, Any]] = None,
  219. ) -> Tuple[torch.Tensor, torch.Tensor]:
  220. """
  221. Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
  222. Parameters:
  223. key_states (`torch.Tensor`):
  224. The new key states to cache.
  225. value_states (`torch.Tensor`):
  226. The new value states to cache.
  227. layer_idx (`int`):
  228. The index of the layer to cache the states for.
  229. cache_kwargs (`Dict[str, Any]`, `optional`):
  230. Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
  231. `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
  232. rotation as the tokens are shifted.
  233. Return:
  234. A tuple containing the updated key and value states.
  235. """
  236. # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
  237. # with partially rotated position embeddings, like Phi or Persimmon.
  238. sin = cache_kwargs.get("sin")
  239. cos = cache_kwargs.get("cos")
  240. partial_rotation_size = cache_kwargs.get("partial_rotation_size")
  241. using_rope = cos is not None and sin is not None
  242. # Update the number of seen tokens
  243. if layer_idx == 0:
  244. self._seen_tokens += key_states.shape[-2]
  245. # [bsz, num_heads, seq_len, head_dim]
  246. if len(self.key_cache) <= layer_idx:
  247. # Empty cache
  248. self.key_cache.append(key_states)
  249. self.value_cache.append(value_states)
  250. elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
  251. # Growing cache
  252. self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
  253. self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
  254. else:
  255. # Shifting cache
  256. keys_to_keep = self.key_cache[layer_idx][
  257. :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :
  258. ]
  259. # On RoPE models, we need to recompute the Key rotation as the tokens are shifted
  260. if using_rope:
  261. rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
  262. key_states, cos[: self.window_length], sin[: self.window_length]
  263. )
  264. if partial_rotation_size is not None:
  265. keys_to_keep, keys_pass = (
  266. keys_to_keep[..., :partial_rotation_size],
  267. keys_to_keep[..., partial_rotation_size:],
  268. )
  269. keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
  270. if partial_rotation_size is not None:
  271. keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
  272. # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
  273. sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
  274. self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)
  275. sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
  276. values_to_keep = self.value_cache[layer_idx][
  277. :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
  278. ]
  279. self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)
  280. return self.key_cache[layer_idx], self.value_cache[layer_idx]
  281. def reorder_cache(self, beam_idx: torch.LongTensor):
  282. """Reorders the cache for beam search, given the selected beam indices."""
  283. for layer_idx in range(len(self.key_cache)):
  284. device = self.key_cache[layer_idx].device
  285. self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
  286. device = self.value_cache[layer_idx].device
  287. self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
  288. class HHCache(Cache):
  289. """
  290. A cache that apply heavy-hitter oracle (https://proceedings.neurips.cc/paper_files/paper/2023/file/6ceefa7b15572587b78ecfcebb2827f8-Paper-Conference.pdf).
  291. Only the heavy-hitter and the recent tokens are stored in the cache.
  292. It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
  293. `[batch_size, num_heads, seq_len, head_dim]`.
  294. Parameters:
  295. window_length (`int`):
  296. The length of the context window.
  297. num_hh_tokens (`int`):
  298. The number of heavy hitter tokens. See the original paper for more information.
  299. """
  300. def __init__(self, window_length: int, num_hh_tokens: int) -> None:
  301. self.key_cache: List[torch.Tensor] = []
  302. self.value_cache: List[torch.Tensor] = []
  303. self.window_length = window_length
  304. self.num_hh_tokens = num_hh_tokens
  305. self.accumulated_attention_scores: List[torch.Tensor] = []
  306. self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
  307. def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
  308. """
  309. Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
  310. sequence length.
  311. """
  312. if layer_idx < len(self):
  313. return (self.key_cache[layer_idx], self.value_cache[layer_idx], self.accumulated_attention_scores[layer_idx])
  314. else:
  315. raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
  316. def __iter__(self):
  317. """
  318. Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
  319. keys and values
  320. """
  321. for layer_idx in range(len(self)):
  322. yield (self.key_cache[layer_idx], self.value_cache[layer_idx], self.accumulated_attention_scores[layer_idx])
  323. def __len__(self):
  324. """
  325. Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
  326. to the number of layers in the model.
  327. """
  328. return len(self.key_cache)
  329. def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
  330. """Returns the sequence length of the cached states. A layer index can be optionally passed."""
  331. # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
  332. if len(self.key_cache) <= layer_idx:
  333. return 0
  334. return self.key_cache[layer_idx].shape[-2]
  335. def get_max_length(self) -> Optional[int]:
  336. """Returns the maximum sequence length of the cached states."""
  337. return self.window_length
  338. def update(
  339. self,
  340. key_states: torch.Tensor,
  341. value_states: torch.Tensor,
  342. layer_idx: int,
  343. cache_kwargs: Optional[Dict[str, Any]] = None,
  344. accumulated_attention_scores: Optional[torch.Tensor] = None,
  345. ) -> Tuple[torch.Tensor, torch.Tensor]:
  346. """
  347. Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
  348. Parameters:
  349. key_states (`torch.Tensor`):
  350. The new key states to cache.
  351. value_states (`torch.Tensor`):
  352. The new value states to cache.
  353. layer_idx (`int`):
  354. The index of the layer to cache the states for.
  355. cache_kwargs (`Dict[str, Any]`, `optional`):
  356. Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
  357. Return:
  358. A tuple containing the updated key and value states.
  359. """
  360. # Update the number of seen tokens
  361. if accumulated_attention_scores is not None:
  362. self.accumulated_attention_scores.append(accumulated_attention_scores)
  363. if layer_idx == 0:
  364. self._seen_tokens += key_states.shape[-2]
  365. # Update the cache
  366. if len(self.key_cache) <= layer_idx:
  367. self.key_cache.append(key_states)
  368. self.value_cache.append(value_states)
  369. else:
  370. self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
  371. self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
  372. return self.key_cache[layer_idx], self.value_cache[layer_idx]
  373. def update_slimming(
  374. self,
  375. attention_scores: torch.Tensor,
  376. num_kv_groups: int,
  377. layer_idx: int,
  378. cache_kwargs: Optional[Dict[str, Any]] = None,
  379. ) -> Tuple[torch.Tensor, torch.Tensor]:
  380. """
  381. Slimming the cache based on accumulated attention scores, only keep heavy-hitters + local tokens.
  382. Parameters:
  383. attention_scores (`torch.Tensor`):
  384. Attention_scores for current steps.
  385. num_kv_groups (`int`):
  386. The number of kv groups in repeat kv.
  387. layer_idx (`int`):
  388. The index of the layer to cache the states for.
  389. cache_kwargs (`Dict[str, Any]`, `optional`):
  390. Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
  391. Return:
  392. A tuple containing the updated key and value states.
  393. """
  394. # Update score metrics (Accumulated attention scores)
  395. if len(self.accumulated_attention_scores) <= layer_idx:
  396. self.accumulated_attention_scores.append(attention_scores.sum(2)[:,::num_kv_groups, :]) # [bs, num_heads, key_len]
  397. else:
  398. num_new_tokens = attention_scores.shape[2]
  399. updated_attention_scores = attention_scores.sum(2)[:,::num_kv_groups, :] # [bs, num_heads, key_len]
  400. updated_attention_scores[:, :, :-num_new_tokens] += self.accumulated_attention_scores[layer_idx]
  401. self.accumulated_attention_scores[layer_idx] = updated_attention_scores
  402. # Update KV Cache
  403. if self.get_seq_length(layer_idx) > self.window_length:
  404. seq_scores = self.accumulated_attention_scores[layer_idx][:, :, :-self.window_length + self.num_hh_tokens]
  405. _, keep_hh_index = torch.topk(seq_scores, self.num_hh_tokens, dim=-1)
  406. keep_hh_index = keep_hh_index.sort().values
  407. keep_local_index = torch.arange(self.get_seq_length(layer_idx) - self.window_length + self.num_hh_tokens, self.get_seq_length(layer_idx), device=keep_hh_index.device).repeat(keep_hh_index.shape[0], keep_hh_index.shape[1], 1)
  408. keep_index = torch.cat([keep_hh_index, keep_local_index], dim=-1)
  409. mask = torch.zeros(self.accumulated_attention_scores[layer_idx].shape, dtype=torch.bool).to(keep_hh_index.device)
  410. mask = mask.scatter(-1, keep_index, 1)
  411. bsz, num_heads, _, head_dim = self.key_cache[layer_idx].shape
  412. self.key_cache[layer_idx] = self.key_cache[layer_idx][mask].view(bsz, num_heads, -1, head_dim)
  413. self.value_cache[layer_idx] = self.value_cache[layer_idx][mask].view(bsz, num_heads, -1, head_dim)
  414. self.accumulated_attention_scores[layer_idx] = self.accumulated_attention_scores[layer_idx][mask].view(bsz, num_heads, -1)
  415. def reorder_cache(self, beam_idx: torch.LongTensor):
  416. """Reorders the cache for beam search, given the selected beam indices."""
  417. for layer_idx in range(len(self.key_cache)):
  418. device = self.key_cache[layer_idx].device
  419. self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
  420. device = self.value_cache[layer_idx].device
  421. self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
  422. def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
  423. """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
  424. legacy_cache = ()
  425. for layer_idx in range(len(self)):
  426. legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx], self.accumulated_attention_scores[layer_idx],))
  427. return legacy_cache
  428. @classmethod
  429. def from_legacy_cache(cls, window_length: int, num_hh_tokens: int, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
  430. """Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
  431. cache = cls(window_length, num_hh_tokens)
  432. if past_key_values is not None:
  433. for layer_idx in range(len(past_key_values) // 3):
  434. key_states = past_key_values[layer_idx * 3]
  435. value_states = past_key_values[layer_idx * 3 + 1]
  436. accumulated_attention_scores = past_key_values[layer_idx * 3 + 2]
  437. cache.update(key_states, value_states, layer_idx, accumulated_attention_scores=accumulated_attention_scores)
  438. return cache
  439. def evict_for_space(self, space_needed: int):
  440. num_layers = len(self.key_cache)
  441. # Update score metrics (Accumulated attention scores)
  442. if len(self.accumulated_attention_scores) < num_layers:
  443. raise ValueError("The accumulated_attention_scores should be updated before evicting the cache.")
  444. for layer_idx in range(num_layers):
  445. # Update KV Cache, Evict for new coming prompts
  446. if self.get_seq_length(layer_idx) + space_needed > self.window_length:
  447. if self.window_length - self.num_hh_tokens <= space_needed:
  448. raise ValueError("The space_needed should be less than the window_length - num_hh_tokens.")
  449. seq_scores = self.accumulated_attention_scores[layer_idx][:, :, :-self.window_length + self.num_hh_tokens + space_needed]
  450. _, keep_hh_index = torch.topk(seq_scores, self.num_hh_tokens, dim=-1)
  451. keep_hh_index = keep_hh_index.sort().values
  452. keep_local_index = torch.arange(self.get_seq_length(layer_idx) - self.window_length + self.num_hh_tokens + space_needed, self.get_seq_length(layer_idx), device=keep_hh_index.device).repeat(keep_hh_index.shape[0], keep_hh_index.shape[1], 1)
  453. keep_index = torch.cat([keep_hh_index, keep_local_index], dim=-1)
  454. mask = torch.zeros(self.accumulated_attention_scores[layer_idx].shape, dtype=torch.bool).to(keep_hh_index.device)
  455. mask = mask.scatter(-1, keep_index, 1)
  456. bsz, num_heads, _, head_dim = self.key_cache[layer_idx].shape
  457. self.key_cache[layer_idx] = self.key_cache[layer_idx][mask].view(bsz, num_heads, -1, head_dim)
  458. self.value_cache[layer_idx] = self.value_cache[layer_idx][mask].view(bsz, num_heads, -1, head_dim)
  459. self.accumulated_attention_scores[layer_idx] = self.accumulated_attention_scores[layer_idx][mask].view(bsz, num_heads, -1)
  460. class StaticCache(Cache):
  461. """
  462. Static Cache class to be used with `torch.compile(model)`.
  463. Parameters:
  464. config (`PretrainedConfig):
  465. The configuration file defining the `max_position_embeddings`, `hidden_size` and `num_attention_heads`
  466. required to initialize the static cache.
  467. max_batch_size (`int`):
  468. The maximum batch size with which the model will be used.
  469. max_cache_len (`int`):
  470. The maximum sequence length with which the model will be used.
  471. device (`torch.device`):
  472. The device on which the cache should be initialized. Should be the same as the layer.
  473. dtype (*optional*, defaults to `torch.float32`):
  474. The default `dtype` to use when initializing the layer.
  475. """
  476. def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
  477. super().__init__()
  478. self.max_batch_size = max_batch_size
  479. self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
  480. # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
  481. self.head_dim = (
  482. config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
  483. )
  484. self.dtype = dtype if dtype is not None else torch.float32
  485. self.num_key_value_heads = (
  486. config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
  487. )
  488. cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
  489. self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
  490. self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
  491. def update(
  492. self,
  493. key_states: torch.Tensor,
  494. value_states: torch.Tensor,
  495. layer_idx: int,
  496. cache_kwargs: Optional[Dict[str, Any]] = None,
  497. ) -> Tuple[torch.Tensor, torch.Tensor]:
  498. """
  499. Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
  500. It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
  501. Parameters:
  502. key_states (`torch.Tensor`):
  503. The new key states to cache.
  504. value_states (`torch.Tensor`):
  505. The new value states to cache.
  506. layer_idx (`int`):
  507. The index of the layer to cache the states for. Kept for backward compatibility
  508. cache_kwargs (`Dict[str, Any]`, `optional`):
  509. Additional arguments for the cache subclass. The `StaticCache` just needs the `q_len`
  510. to know how much of the cache it should overwrite.
  511. Return:
  512. A tuple containing the updated key and value states.
  513. """
  514. new_cache_positions = cache_kwargs.get("cache_position")
  515. k_out = self.key_cache
  516. v_out = self.value_cache
  517. k_out[:, :, new_cache_positions] = key_states
  518. v_out[:, :, new_cache_positions] = value_states
  519. return k_out, v_out
  520. def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
  521. """Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
  522. # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
  523. # limit the check to the first batch member and head dimension.
  524. # TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after
  525. # https://github.com/pytorch/pytorch/issues/120248 is fixed
  526. return (self.key_cache[0, 0].any(dim=-1)).sum()
  527. def get_max_length(self) -> Optional[int]:
  528. """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
  529. return self.max_cache_len
  530. def reorder_cache(self, beam_idx: torch.LongTensor):
  531. """Reorders the cache for beam search, given the selected beam indices."""
  532. device = self.key_cache.device
  533. self.key_cache = self.key_cache.index_select(0, beam_idx.to(device))
  534. device = self.value_cache.device
  535. self.value_cache = self.value_cache.index_select(0, beam_idx.to(device))
  536. def to_legacy_cache(self):
  537. """Dummy function for BC. We have to keep it because otherwise the call in the forward of models will break it"""
  538. return None