Allen 1 năm trước cách đây
mục cha
commit
67f3666930

+ 2 - 2
research/long-context-llama/H2O/cache_utils.py

@@ -531,10 +531,10 @@ class HHCache(Cache):
         return legacy_cache
 
     @classmethod
-    def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
+    def from_legacy_cache(cls, window_length: int, num_hh_tokens: int, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
         """Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
         import pdb; pdb.set_trace()
-        cache = cls()
+        cache = cls(window_length, num_hh_tokens)
         if past_key_values is not None:
             for layer_idx in range(len(past_key_values)):
                 key_states, value_states = past_key_values[layer_idx]

+ 3 - 3
research/long-context-llama/H2O/generation.py

@@ -34,7 +34,7 @@ if __name__ == '__main__':
 
     parser.add_argument("--enable_h2o_generation", action='store_true')
     parser.add_argument("--num_heavy_hitter_tokens", type=int, default=256)
-    parser.add_argument("--num_local_windows", type=int, default=256)
+    parser.add_argument("--num_window_length", type=int, default=1024)
 
     parser.add_argument("--enable_position_rolling", action='store_true')
 
@@ -56,8 +56,8 @@ if __name__ == '__main__':
     tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
 
     if args.enable_h2o_generation:
-        config.hh_size = args.num_heavy_hitter_tokens
-        config.recent_size = args.num_local_windows
+        config.num_heavy_hitter_tokens = args.num_heavy_hitter_tokens
+        config.num_window_length = args.num_window_length
         config.enable_position_rolling = args.enable_position_rolling
         model = H2OLlamaForCausalLM.from_pretrained(model_name, config=config)
     else:

+ 3 - 2
research/long-context-llama/H2O/utils_llama.py

@@ -336,8 +336,7 @@ def enable_h2ocache_forward(
     past_seen_tokens = 0
     if use_cache:  # kept for BC (cache positions)
         if not isinstance(past_key_values, StaticCache):
-            pdb.set_trace()
-            past_key_values = HHCache.from_legacy_cache(past_key_values)
+            past_key_values = HHCache.from_legacy_cache(self.num_window_length, self.num_heavy_hitter_tokens, past_key_values)
             past_seen_tokens = past_key_values.get_seq_length()
 
     if cache_position is None:
@@ -709,3 +708,5 @@ class H2OLlamaForCausalLM(LlamaForCausalLM):
             self.model.layers[layer_idx].self_attn = H2OLlamaAttention(config, layer_idx)
 
         self.model.forward = types.MethodType(enable_h2ocache_forward, self.model)
+        self.model.num_heavy_hitter_tokens = config.num_heavy_hitter_tokens
+        self.model.num_window_length = config.num_window_length