Преглед изворни кода

gradient_checkpointing_enable()

Kai Wu пре 7 месеци
родитељ
комит
50dff0b78e
1 измењених фајлова са 1 додато и 1 уклоњено
  1. 1 1
      src/llama_recipes/finetuning.py

+ 1 - 1
src/llama_recipes/finetuning.py

@@ -208,7 +208,7 @@ def main(**kwargs):
         )
         if fsdp_config.fsdp_activation_checkpointing:            
             model.enable_input_require_grads()
-            #model.gradient_checkpointing_enable()
+            model.gradient_checkpointing_enable()
             apply_fsdp_checkpointing(model)                      
     elif not train_config.quantization and not train_config.enable_fsdp:
         if is_xpu_available():