|
@@ -20,13 +20,17 @@ def load_model(model_name, quantization, use_fast_kernels, **kwargs):
|
|
|
bnb_config = quant_config.create_bnb_config(quantization)
|
|
|
|
|
|
print(f"use_fast_kernels{use_fast_kernels}")
|
|
|
+
|
|
|
+ kwargs = {}
|
|
|
+ if bnb_config:
|
|
|
+ kwargs["quantization_config"]=bnb_config,
|
|
|
+ kwargs["device_map"]="auto"
|
|
|
+ kwargs["low_cpu_mem_usage"]=True
|
|
|
+ kwargs["attn_implementation"]="sdpa" if use_fast_kernels else None
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
model_name,
|
|
|
return_dict=True,
|
|
|
- quantization_config=bnb_config,
|
|
|
- device_map="auto",
|
|
|
- low_cpu_mem_usage=True,
|
|
|
- attn_implementation="sdpa" if use_fast_kernels else None,
|
|
|
+ **kwargs,
|
|
|
)
|
|
|
return model
|
|
|
|