|
@@ -432,22 +432,23 @@ class HHCache(Cache):
|
|
|
Return:
|
|
|
A tuple containing the updated key and value states.
|
|
|
"""
|
|
|
- import pdb; pdb.set_trace()
|
|
|
# Update the number of seen tokens
|
|
|
if layer_idx == 0:
|
|
|
self._seen_tokens += key_states.shape[-2]
|
|
|
|
|
|
+ import pdb; pdb.set_trace()
|
|
|
+
|
|
|
# Update the cache
|
|
|
# [bsz, num_heads, seq_len, head_dim]
|
|
|
- if len(self.key_cache) <= layer_idx:
|
|
|
- # Empty cache
|
|
|
- self.key_cache.append(key_states)
|
|
|
- self.value_cache.append(value_states)
|
|
|
-
|
|
|
- elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
|
|
|
- # Growing cache
|
|
|
- self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
|
|
|
- self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
|
|
|
+ if key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
|
|
|
+ if len(self.key_cache) <= layer_idx:
|
|
|
+ # Empty cache
|
|
|
+ self.key_cache.append(key_states)
|
|
|
+ self.value_cache.append(value_states)
|
|
|
+ else:
|
|
|
+ # Growing cache
|
|
|
+ self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
|
|
|
+ self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
|
|
|
|
|
|
else:
|
|
|
# Shifting cache
|