瀏覽代碼

add the condition to guard overlap behidn a flag

Hamid Shojanazeri 1 年之前
父節點
當前提交
a3210058d2
共有 1 個文件被更改,包括 28 次插入26 次删除
  1. 28 26
      llama_finetuning.py

+ 28 - 26
llama_finetuning.py

@@ -217,36 +217,38 @@ def main(**kwargs):
         )
         
     #Initialize the optimizer and learning rate scheduler
-    if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
-        optimizer = AnyPrecisionAdamW(
-            model.parameters(),
-            lr=train_config.lr,
-            momentum_dtype=torch.bfloat16,
-            variance_dtype=torch.bfloat16,
-            use_kahan_summation=False,
+    if not fsdp_config.optimizer_overlap:
+        if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
+            optimizer = AnyPrecisionAdamW(
+                model.parameters(),
+                lr=train_config.lr,
+                momentum_dtype=torch.bfloat16,
+                variance_dtype=torch.bfloat16,
+                use_kahan_summation=False,
+            )
+        else:
+            optimizer = optim.AdamW(
+                model.parameters(),
+                lr=train_config.lr,
+                weight_decay=0.0,
+            )
+    if fsdp_config.optimizer_overlap:
+        print(f"setting up optimizer overlap")
+        print("===============================")
+        optim_kwargs = {"lr": train_config.lr}
+        _apply_optimizer_in_backward(
+            optimizer_class=optim.AdamW,
+            params=model.parameters(),
+            optimizer_kwargs=optim_kwargs,
+            register_hook=False,
         )
-    else:
+        for p in model.parameters():
+            assert hasattr(p, "_in_backward_optimizers")
+        optim_kwargs = {"lr": train_config.lr, "weight_decay":0.0}
         optimizer = optim.AdamW(
             model.parameters(),
-            lr=train_config.lr,
-            weight_decay=0.0,
+            **optim_kwargs
         )
-    # if fsdp_config.optimizer_overlap:
-    #     print("we are hereeeeeeeee**************************************")
-    #     optim_kwargs = {"lr": train_config.lr}
-    #     _apply_optimizer_in_backward(
-    #         optimizer_class=optim.AdamW,
-    #         params=model.parameters(),
-    #         optimizer_kwargs=optim_kwargs,
-    #         register_hook=False,
-    #     )
-    #     for p in model.parameters():
-    #         assert hasattr(p, "_in_backward_optimizers")
-    #     optim_kwargs = {"lr": train_config.lr, "weight_decay":0.0}
-    #     optimizer = optim.AdamW(
-    #         model.parameters(),
-    #         **optim_kwargs
-    #     )
         
         
     scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)