|
@@ -218,6 +218,13 @@ def main(**kwargs):
|
|
|
|
|
|
# Initialize the optimizer and learning rate scheduler
|
|
|
if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
|
|
|
+ if optimizer_in_backward_available:
|
|
|
+ print(f"setting up optimizer overlap")
|
|
|
+ _apply_optimizer_in_backward(
|
|
|
+ optimizer_class=AnyPrecisionAdamW,
|
|
|
+ params=model.parameters(),
|
|
|
+ register_hook=False,
|
|
|
+ )
|
|
|
optimizer = AnyPrecisionAdamW(
|
|
|
model.parameters(),
|
|
|
lr=train_config.lr,
|
|
@@ -226,23 +233,17 @@ def main(**kwargs):
|
|
|
use_kahan_summation=False,
|
|
|
weight_decay=train_config.weight_decay,
|
|
|
)
|
|
|
- elif optimizer_in_backward_available:
|
|
|
- print(f"setting up optimizer overlap")
|
|
|
- optim_kwargs = {"lr": train_config.lr}
|
|
|
- _apply_optimizer_in_backward(
|
|
|
- optimizer_class=optim.AdamW,
|
|
|
- params=model.parameters(),
|
|
|
- optimizer_kwargs=optim_kwargs,
|
|
|
- register_hook=False,
|
|
|
- )
|
|
|
- for p in model.parameters():
|
|
|
- assert hasattr(p, "_in_backward_optimizers")
|
|
|
- optim_kwargs = {"lr": train_config.lr, "weight_decay":0.0}
|
|
|
- optimizer = optim.AdamW(
|
|
|
- model.parameters(),
|
|
|
- **optim_kwargs
|
|
|
- )
|
|
|
else:
|
|
|
+ if optimizer_in_backward_available:
|
|
|
+ print(f"setting up optimizer overlap")
|
|
|
+ _apply_optimizer_in_backward(
|
|
|
+ optimizer_class=optim.AdamW,
|
|
|
+ params=model.parameters(),
|
|
|
+ lr=train_config.lr,
|
|
|
+ register_hook=False,
|
|
|
+ )
|
|
|
+ for p in model.parameters():
|
|
|
+ assert hasattr(p, "_in_backward_optimizers")
|
|
|
optimizer = optim.AdamW(
|
|
|
model.parameters(),
|
|
|
lr=train_config.lr,
|