|
@@ -59,7 +59,7 @@ if __name__ == '__main__':
|
|
|
config.hh_size = args.num_heavy_hitter_tokens
|
|
|
config.recent_size = args.num_local_windows
|
|
|
config.enable_position_rolling = args.enable_position_rolling
|
|
|
- model = H2OLlamaForCausalLM.from_pretrained(config)
|
|
|
+ model = H2OLlamaForCausalLM.from_pretrained(model_name, config=config)
|
|
|
else:
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_name)
|
|
|
|