Pārlūkot izejas kodu

fixing comments

Hamid Shojanazeri 1 gadu atpakaļ
vecāks
revīzija
ea5d0d47db
2 mainītis faili ar 18 papildinājumiem un 17 dzēšanām
  1. 1 1
      README.md
  2. 17 16
      src/llama_recipes/finetuning.py

+ 1 - 1
README.md

@@ -153,7 +153,7 @@ torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --u
 
 ## FSDP optimizer overlap
 
-setting `optimizer_overlap` in [fsdp_config](./src/llama_recipes/configs/fsdp.py) enable optimizer overlap that lowers the GPU memory footprint during the training. The main idea here is the fusion of gradient calculation and parameter update in the one step. This is a new feature in FSDP and is only available on PyTorch nightly binaries for versions before 2.1.0.
+setting `optimizer_overlap` in [fsdp_config](./src/llama_recipes/configs/fsdp.py) enable optimizer overlap that lowers the GPU memory footprint during the training. The main idea here is the fusion of gradient calculation and parameter update in the one step. This is a new feature in FSDP and is only available from PyTorch 2.1.0 onward.
 
 ```bash
 torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --use_peft --peft_method lora --output_dir Path/to/save/PEFT/model --optimizer_overlap

+ 17 - 16
src/llama_recipes/finetuning.py

@@ -218,6 +218,13 @@ def main(**kwargs):
 
     # Initialize the optimizer and learning rate scheduler
     if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
+        if optimizer_in_backward_available:
+            print(f"setting up optimizer overlap")
+            _apply_optimizer_in_backward(
+            optimizer_class=AnyPrecisionAdamW,
+            params=model.parameters(),
+            register_hook=False,
+        )
         optimizer = AnyPrecisionAdamW(
             model.parameters(),
             lr=train_config.lr,
@@ -226,23 +233,17 @@ def main(**kwargs):
             use_kahan_summation=False,
             weight_decay=train_config.weight_decay,
         )
-    elif optimizer_in_backward_available:
-        print(f"setting up optimizer overlap")
-        optim_kwargs = {"lr": train_config.lr}
-        _apply_optimizer_in_backward(
-            optimizer_class=optim.AdamW,
-            params=model.parameters(),
-            optimizer_kwargs=optim_kwargs,
-            register_hook=False,
-        )
-        for p in model.parameters():
-            assert hasattr(p, "_in_backward_optimizers")
-        optim_kwargs = {"lr": train_config.lr, "weight_decay":0.0}
-        optimizer = optim.AdamW(
-            model.parameters(),
-            **optim_kwargs
-        )
     else:
+        if optimizer_in_backward_available:
+            print(f"setting up optimizer overlap")
+            _apply_optimizer_in_backward(
+                optimizer_class=optim.AdamW,
+                params=model.parameters(),
+                lr=train_config.lr,
+                register_hook=False,
+            )
+            for p in model.parameters():
+                assert hasattr(p, "_in_backward_optimizers")
         optimizer = optim.AdamW(
             model.parameters(),
             lr=train_config.lr,