|
@@ -208,7 +208,7 @@ def main(**kwargs):
|
|
|
)
|
|
|
if fsdp_config.fsdp_activation_checkpointing:
|
|
|
model.enable_input_require_grads()
|
|
|
- #model.gradient_checkpointing_enable()
|
|
|
+ model.gradient_checkpointing_enable()
|
|
|
apply_fsdp_checkpointing(model)
|
|
|
elif not train_config.quantization and not train_config.enable_fsdp:
|
|
|
if is_xpu_available():
|