Pārlūkot izejas kodu

Fix param_init_fn: move if-statement out of lambda

Pavel Belevich 11 mēneši atpakaļ
vecāks
revīzija
fb2e802cef
1 mainītis faili ar 1 papildinājumiem un 1 dzēšanām
  1. 1 1
      src/llama_recipes/finetuning.py

+ 1 - 1
src/llama_recipes/finetuning.py

@@ -188,7 +188,7 @@ def main(**kwargs):
             device_id=device_id,
             limit_all_gathers=True,
             sync_module_states=train_config.low_cpu_fsdp,
-            param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
+            param_init_fn=(lambda module: module.to_empty(device=torch.device("cuda"), recurse=False))
             if train_config.low_cpu_fsdp and rank != 0 else None,
         )
         if fsdp_config.fsdp_activation_checkpointing: