Allen 1 سال پیش
والد
کامیت
92389d0a85

تفاوت فایلی نمایش داده نمی شود زیرا این فایل بسیار بزرگ است
+ 2 - 2
recipes/experimental/long-context/H2O/README.md


+ 5 - 1
recipes/experimental/long-context/H2O/run_needle_haystack_test.py

@@ -30,7 +30,7 @@ if __name__ == '__main__':
     parser.add_argument("--model-name", type=str, default="")
     parser.add_argument("--model-name", type=str, default="")
 
 
     parser.add_argument("--enable_h2o_generation", action='store_true')
     parser.add_argument("--enable_h2o_generation", action='store_true')
-    parser.add_argument("--num_heavy_hitter_tokens", type=int, default=128)
+    parser.add_argument("--num_heavy_hitter_tokens", type=int, default=-1)
     parser.add_argument("--num_window_length", type=int, default=256)
     parser.add_argument("--num_window_length", type=int, default=256)
     parser.add_argument("--num_chunk_size", type=int, default=2048)
     parser.add_argument("--num_chunk_size", type=int, default=2048)
 
 
@@ -53,6 +53,10 @@ if __name__ == '__main__':
     config = AutoConfig.from_pretrained(model_name)
     config = AutoConfig.from_pretrained(model_name)
     tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
     tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
 
 
+    if args.num_heavy_hitter_tokens == -1:
+        print('not assign number of heavy hitter tokens, use half of the cache size: {}'.format(args.num_window_length // 2))
+        args.num_heavy_hitter_tokens = args.num_window_length // 2
+
     if args.enable_h2o_generation:
     if args.enable_h2o_generation:
         config.num_heavy_hitter_tokens = args.num_heavy_hitter_tokens
         config.num_heavy_hitter_tokens = args.num_heavy_hitter_tokens
         config.num_window_length = args.num_window_length
         config.num_window_length = args.num_window_length

+ 4 - 1
recipes/experimental/long-context/H2O/run_summarization.py

@@ -32,7 +32,7 @@ if __name__ == '__main__':
     parser.add_argument("--model-name", type=str, default="")
     parser.add_argument("--model-name", type=str, default="")
 
 
     parser.add_argument("--enable_h2o_generation", action='store_true')
     parser.add_argument("--enable_h2o_generation", action='store_true')
-    parser.add_argument("--num_heavy_hitter_tokens", type=int, default=128)
+    parser.add_argument("--num_heavy_hitter_tokens", type=int, default=-1)
     parser.add_argument("--num_window_length", type=int, default=256)
     parser.add_argument("--num_window_length", type=int, default=256)
 
 
     parser.add_argument("--enable_position_rolling", action='store_true')
     parser.add_argument("--enable_position_rolling", action='store_true')
@@ -51,6 +51,9 @@ if __name__ == '__main__':
 
 
     config = AutoConfig.from_pretrained(model_name)
     config = AutoConfig.from_pretrained(model_name)
     tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
     tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
+    if args.num_heavy_hitter_tokens == -1:
+        print('not assign number of heavy hitter tokens, use half of the cache size: {}'.format(args.num_window_length // 2))
+        args.num_heavy_hitter_tokens = args.num_window_length // 2
 
 
     if args.enable_h2o_generation:
     if args.enable_h2o_generation:
         config.num_heavy_hitter_tokens = args.num_heavy_hitter_tokens
         config.num_heavy_hitter_tokens = args.num_heavy_hitter_tokens