|
@@ -34,7 +34,7 @@ if __name__ == '__main__':
|
|
|
|
|
|
parser.add_argument("--enable_h2o_generation", action='store_true')
|
|
parser.add_argument("--enable_h2o_generation", action='store_true')
|
|
parser.add_argument("--num_heavy_hitter_tokens", type=int, default=256)
|
|
parser.add_argument("--num_heavy_hitter_tokens", type=int, default=256)
|
|
- parser.add_argument("--num_local_windows", type=int, default=256)
|
|
|
|
|
|
+ parser.add_argument("--num_window_length", type=int, default=1024)
|
|
|
|
|
|
parser.add_argument("--enable_position_rolling", action='store_true')
|
|
parser.add_argument("--enable_position_rolling", action='store_true')
|
|
|
|
|
|
@@ -56,8 +56,8 @@ if __name__ == '__main__':
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
|
|
|
|
|
|
if args.enable_h2o_generation:
|
|
if args.enable_h2o_generation:
|
|
- config.hh_size = args.num_heavy_hitter_tokens
|
|
|
|
- config.recent_size = args.num_local_windows
|
|
|
|
|
|
+ config.num_heavy_hitter_tokens = args.num_heavy_hitter_tokens
|
|
|
|
+ config.num_window_length = args.num_window_length
|
|
config.enable_position_rolling = args.enable_position_rolling
|
|
config.enable_position_rolling = args.enable_position_rolling
|
|
model = H2OLlamaForCausalLM.from_pretrained(model_name, config=config)
|
|
model = H2OLlamaForCausalLM.from_pretrained(model_name, config=config)
|
|
else:
|
|
else:
|