|
@@ -107,7 +107,7 @@ def main(**kwargs):
|
|
|
use_double_quant=True,
|
|
|
quant_storage=torch.bfloat16
|
|
|
) if train_config.quantization else None
|
|
|
- bnb_config = quant_config.create_bnb_config()
|
|
|
+ bnb_config = quant_config.create_bnb_config() if quant_config else None
|
|
|
|
|
|
# Load the pre-trained model and setup its configuration
|
|
|
use_cache = False if train_config.enable_fsdp else None
|