Bladeren bron

fixing comments

Hamid Shojanazeri 1 jaar geleden
bovenliggende
commit
ea5d0d47db
2 gewijzigde bestanden met toevoegingen van 18 en 17 verwijderingen
  1. 1 1
      README.md
  2. 17 16
      src/llama_recipes/finetuning.py

+ 1 - 1
README.md

@@ -153,7 +153,7 @@ torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --u
 
 ## FSDP optimizer overlap
 
-setting `optimizer_overlap` in [fsdp_config](./src/llama_recipes/configs/fsdp.py) enable optimizer overlap that lowers the GPU memory footprint during the training. The main idea here is the fusion of gradient calculation and parameter update in the one step. This is a new feature in FSDP and is only available on PyTorch nightly binaries for versions before 2.1.0.
+setting `optimizer_overlap` in [fsdp_config](./src/llama_recipes/configs/fsdp.py) enable optimizer overlap that lowers the GPU memory footprint during the training. The main idea here is the fusion of gradient calculation and parameter update in the one step. This is a new feature in FSDP and is only available from PyTorch 2.1.0 onward.
 
 ```bash
 torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --use_peft --peft_method lora --output_dir Path/to/save/PEFT/model --optimizer_overlap

+ 17 - 16
src/llama_recipes/finetuning.py

@@ -218,6 +218,13 @@ def main(**kwargs):
 
     # Initialize the optimizer and learning rate scheduler
     if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
+        if optimizer_in_backward_available:
+            print(f"setting up optimizer overlap")
+            _apply_optimizer_in_backward(
+            optimizer_class=AnyPrecisionAdamW,
+            params=model.parameters(),
+            register_hook=False,
+        )
         optimizer = AnyPrecisionAdamW(
             model.parameters(),
             lr=train_config.lr,
@@ -226,23 +233,17 @@ def main(**kwargs):
             use_kahan_summation=False,
             weight_decay=train_config.weight_decay,
         )
-    elif optimizer_in_backward_available:
-        print(f"setting up optimizer overlap")
-        optim_kwargs = {"lr": train_config.lr}
-        _apply_optimizer_in_backward(
-            optimizer_class=optim.AdamW,
-            params=model.parameters(),
-            optimizer_kwargs=optim_kwargs,
-            register_hook=False,
-        )
-        for p in model.parameters():
-            assert hasattr(p, "_in_backward_optimizers")
-        optim_kwargs = {"lr": train_config.lr, "weight_decay":0.0}
-        optimizer = optim.AdamW(
-            model.parameters(),
-            **optim_kwargs
-        )
     else:
+        if optimizer_in_backward_available:
+            print(f"setting up optimizer overlap")
+            _apply_optimizer_in_backward(
+                optimizer_class=optim.AdamW,
+                params=model.parameters(),
+                lr=train_config.lr,
+                register_hook=False,
+            )
+            for p in model.parameters():
+                assert hasattr(p, "_in_backward_optimizers")
         optimizer = optim.AdamW(
             model.parameters(),
             lr=train_config.lr,