|
@@ -30,7 +30,7 @@ if __name__ == '__main__':
|
|
parser.add_argument("--model-name", type=str, default="")
|
|
parser.add_argument("--model-name", type=str, default="")
|
|
|
|
|
|
parser.add_argument("--enable_h2o_generation", action='store_true')
|
|
parser.add_argument("--enable_h2o_generation", action='store_true')
|
|
- parser.add_argument("--num_heavy_hitter_tokens", type=int, default=128)
|
|
|
|
|
|
+ parser.add_argument("--num_heavy_hitter_tokens", type=int, default=-1)
|
|
parser.add_argument("--num_window_length", type=int, default=256)
|
|
parser.add_argument("--num_window_length", type=int, default=256)
|
|
parser.add_argument("--num_chunk_size", type=int, default=2048)
|
|
parser.add_argument("--num_chunk_size", type=int, default=2048)
|
|
|
|
|
|
@@ -53,6 +53,10 @@ if __name__ == '__main__':
|
|
config = AutoConfig.from_pretrained(model_name)
|
|
config = AutoConfig.from_pretrained(model_name)
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
|
|
|
|
|
|
|
+ if args.num_heavy_hitter_tokens == -1:
|
|
|
|
+ print('not assign number of heavy hitter tokens, use half of the cache size: {}'.format(args.num_window_length // 2))
|
|
|
|
+ args.num_heavy_hitter_tokens = args.num_window_length // 2
|
|
|
|
+
|
|
if args.enable_h2o_generation:
|
|
if args.enable_h2o_generation:
|
|
config.num_heavy_hitter_tokens = args.num_heavy_hitter_tokens
|
|
config.num_heavy_hitter_tokens = args.num_heavy_hitter_tokens
|
|
config.num_window_length = args.num_window_length
|
|
config.num_window_length = args.num_window_length
|