瀏覽代碼

fix test env

Hamid Shojanazeri 1 年之前
父節點
當前提交
14a81f4e24
共有 2 個文件被更改,包括 28 次插入26 次删除
  1. 26 25
      llama_finetuning.py
  2. 2 1
      utils/train_utils.py

+ 26 - 25
llama_finetuning.py

@@ -217,35 +217,36 @@ def main(**kwargs):
         )
         
     # Initialize the optimizer and learning rate scheduler
-    if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
-        optimizer = AnyPrecisionAdamW(
-            model.parameters(),
-            lr=train_config.lr,
-            momentum_dtype=torch.bfloat16,
-            variance_dtype=torch.bfloat16,
-            use_kahan_summation=False,
-        )
-    else:
-        optimizer = optim.AdamW(
-            model.parameters(),
-            lr=train_config.lr,
-            weight_decay=0.0,
-        )
-    # if fsdp_config.optimizer_overlap:
-    #     optim_kwargs = {"lr": train_config.lr}
-    #     _apply_optimizer_in_backward(
-    #         optimizer_class=optim.AdamW,
-    #         params=model.parameters(),
-    #         optimizer_kwargs=optim_kwargs,
-    #         register_hook=False,
+    # if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
+    #     optimizer = AnyPrecisionAdamW(
+    #         model.parameters(),
+    #         lr=train_config.lr,
+    #         momentum_dtype=torch.bfloat16,
+    #         variance_dtype=torch.bfloat16,
+    #         use_kahan_summation=False,
     #     )
-    #     for p in model.parameters():
-    #         assert hasattr(p, "_in_backward_optimizers")
-    #     optim_kwargs = {"lr": train_config.lr, "weight_decay":0.0}
+    # else:
     #     optimizer = optim.AdamW(
     #         model.parameters(),
-    #         **optim_kwargs
+    #         lr=train_config.lr,
+    #         weight_decay=0.0,
     #     )
+    if fsdp_config.optimizer_overlap:
+        print("we are hereeeeeeeee**************************************")
+        optim_kwargs = {"lr": train_config.lr}
+        _apply_optimizer_in_backward(
+            optimizer_class=optim.AdamW,
+            params=model.parameters(),
+            optimizer_kwargs=optim_kwargs,
+            register_hook=False,
+        )
+        for p in model.parameters():
+            assert hasattr(p, "_in_backward_optimizers")
+        optim_kwargs = {"lr": train_config.lr, "weight_decay":0.0}
+        optimizer = optim.AdamW(
+            model.parameters(),
+            **optim_kwargs
+        )
         
         
     scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)

+ 2 - 1
utils/train_utils.py

@@ -116,7 +116,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
             dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
             world_size = int(os.environ["WORLD_SIZE"])
         train_epoch_loss = total_loss / len(train_dataloader)
-        train_epoch_loss = train_epoch_loss/world_size
+        if train_config.enable_fsdp:
+            train_epoch_loss = train_epoch_loss/world_size
         train_perplexity = torch.exp(train_epoch_loss)
         
         train_prep.append(train_perplexity)