浏览代码

fixing the logic

Hamid Shojanazeri 1 年之前
父节点
当前提交
c63428da3a
共有 1 个文件被更改,包括 6 次插入6 次删除
  1. 6 6
      src/llama_recipes/finetuning.py

+ 6 - 6
src/llama_recipes/finetuning.py

@@ -160,7 +160,7 @@ def main(**kwargs):
             device_id=torch.cuda.current_device(),
             limit_all_gathers=True,
             sync_module_states=train_config.low_cpu_fsdp,
-            use_orig_params = True if optimizer_in_backward_available else False,
+            use_orig_params = optimizer_in_backward_available,
             param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
             if train_config.low_cpu_fsdp and rank != 0 else None,
         )
@@ -217,21 +217,22 @@ def main(**kwargs):
         )
 
     # Initialize the optimizer and learning rate scheduler
+    optim_kwargs = {"lr": train_config.lr, "weight_decay":train_config.weight_decay}
     if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
         if optimizer_in_backward_available:
             print(f"setting up optimizer overlap")
             _apply_optimizer_in_backward(
             optimizer_class=AnyPrecisionAdamW,
             params=model.parameters(),
+            optimizer_kwargs = optim_kwargs,
             register_hook=False,
         )
         optimizer = AnyPrecisionAdamW(
             model.parameters(),
-            lr=train_config.lr,
             momentum_dtype=torch.bfloat16,
             variance_dtype=torch.bfloat16,
             use_kahan_summation=False,
-            weight_decay=train_config.weight_decay,
+            **optim_kwargs,
         )
     else:
         if optimizer_in_backward_available:
@@ -239,15 +240,14 @@ def main(**kwargs):
             _apply_optimizer_in_backward(
                 optimizer_class=optim.AdamW,
                 params=model.parameters(),
-                lr=train_config.lr,
+                optimizer_kwargs = optim_kwargs,
                 register_hook=False,
             )
             for p in model.parameters():
                 assert hasattr(p, "_in_backward_optimizers")
         optimizer = optim.AdamW(
             model.parameters(),
-            lr=train_config.lr,
-            weight_decay=train_config.weight_decay,
+            **optim_kwargs,
         )
    
     scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)