Allen 1 anno fa
parent
commit
6b447d661c

+ 30 - 41
research/long-context-llama/H2O/cache_utils.py

@@ -412,7 +412,6 @@ class HHCache(Cache):
         self,
         key_states: torch.Tensor,
         value_states: torch.Tensor,
-        attention_scores: torch.Tensor,
         layer_idx: int,
         cache_kwargs: Optional[Dict[str, Any]] = None,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -436,52 +435,42 @@ class HHCache(Cache):
         if layer_idx == 0:
             self._seen_tokens += key_states.shape[-2]
 
-        import pdb; pdb.set_trace()
-
         # Update the cache
-        # [bsz, num_heads, seq_len, head_dim]
-        if key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
-            if len(self.key_cache) <= layer_idx:
-                # Empty cache
-                self.key_cache.append(key_states)
-                self.value_cache.append(value_states)
-            else:
-                # Growing cache
-                self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
-                self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
-
+        if len(self.key_cache) <= layer_idx:
+            self.key_cache.append(key_states)
+            self.value_cache.append(value_states)
         else:
-            # Shifting cache 
-            # Keeping the local tokens
-            keys_to_keep = self.key_cache[layer_idx][
-                :, :, -self.window_length + self.num_hh_tokens + key_states.shape[-2] :
-            ]
+            self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
+            self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
 
-            # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
-            sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
-            self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)
+        return self.key_cache[layer_idx], self.value_cache[layer_idx]
 
-            sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
-            values_to_keep = self.value_cache[layer_idx][
-                :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
-            ]
-            self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)
+    def update_slimming(
+        self,
+        attention_scores: torch.Tensor,
+        num_kv_groups: int,
+        layer_idx: int,
+        cache_kwargs: Optional[Dict[str, Any]] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Slimming the cache based on accumulated attention scores, only keep heavy-hitters + local tokens.
+
+        Parameters:
+            attention_scores (`torch.Tensor`):
+                Attention_scores for current steps.
+            num_kv_groups (`int`):
+                The number of kv groups in repeat kv.
+            layer_idx (`int`):
+                The index of the layer to cache the states for.
+            cache_kwargs (`Dict[str, Any]`, `optional`):
+                Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
+        Return:
+            A tuple containing the updated key and value states.
+        """
+
+        import pdb; pdb.set_trace()
 
-            # On RoPE models, we need to recompute the Key rotation as the tokens are shifted
-            if using_rope:
-                rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
-                    key_states, cos[: self.window_length], sin[: self.window_length]
-                )
-                if partial_rotation_size is not None:
-                    keys_to_keep, keys_pass = (
-                        keys_to_keep[..., :partial_rotation_size],
-                        keys_to_keep[..., partial_rotation_size:],
-                    )
-                keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
-                if partial_rotation_size is not None:
-                    keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
 
-        return self.key_cache[layer_idx], self.value_cache[layer_idx]
 
 
     def reorder_cache(self, beam_idx: torch.LongTensor):

+ 9 - 3
research/long-context-llama/H2O/utils_llama.py

@@ -257,6 +257,11 @@ class H2OLlamaAttention(nn.Module):
         cos, sin = self.rotary_emb(value_states, position_ids)
         query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
 
+        if past_key_value is not None:
+            # sin and cos are specific to RoPE models; cache_position needed for the static cache
+            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
         key_states = repeat_kv(key_states, self.num_key_value_groups)
         value_states = repeat_kv(value_states, self.num_key_value_groups)
 
@@ -269,10 +274,11 @@ class H2OLlamaAttention(nn.Module):
         # upcast attention to fp32
         attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
 
+        # Update KV Cache based on Heavy-Hitter Oracle
         if past_key_value is not None:
-            # sin and cos are specific to RoPE models; cache_position needed for the static cache
-            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
-            key_states, value_states = past_key_value.update(key_states, value_states, attn_weights, self.layer_idx, cache_kwargs)
+            past_key_value.update_slimming(attn_weights, self.num_key_value_groups, self.layer_idx, cache_kwargs)
+
+
 
         attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
         attn_output = torch.matmul(attn_weights, value_states)