|
@@ -217,36 +217,38 @@ def main(**kwargs):
|
|
|
)
|
|
|
|
|
|
#Initialize the optimizer and learning rate scheduler
|
|
|
- if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
|
|
|
- optimizer = AnyPrecisionAdamW(
|
|
|
- model.parameters(),
|
|
|
- lr=train_config.lr,
|
|
|
- momentum_dtype=torch.bfloat16,
|
|
|
- variance_dtype=torch.bfloat16,
|
|
|
- use_kahan_summation=False,
|
|
|
+ if not fsdp_config.optimizer_overlap:
|
|
|
+ if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
|
|
|
+ optimizer = AnyPrecisionAdamW(
|
|
|
+ model.parameters(),
|
|
|
+ lr=train_config.lr,
|
|
|
+ momentum_dtype=torch.bfloat16,
|
|
|
+ variance_dtype=torch.bfloat16,
|
|
|
+ use_kahan_summation=False,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ optimizer = optim.AdamW(
|
|
|
+ model.parameters(),
|
|
|
+ lr=train_config.lr,
|
|
|
+ weight_decay=0.0,
|
|
|
+ )
|
|
|
+ if fsdp_config.optimizer_overlap:
|
|
|
+ print(f"setting up optimizer overlap")
|
|
|
+ print("===============================")
|
|
|
+ optim_kwargs = {"lr": train_config.lr}
|
|
|
+ _apply_optimizer_in_backward(
|
|
|
+ optimizer_class=optim.AdamW,
|
|
|
+ params=model.parameters(),
|
|
|
+ optimizer_kwargs=optim_kwargs,
|
|
|
+ register_hook=False,
|
|
|
)
|
|
|
- else:
|
|
|
+ for p in model.parameters():
|
|
|
+ assert hasattr(p, "_in_backward_optimizers")
|
|
|
+ optim_kwargs = {"lr": train_config.lr, "weight_decay":0.0}
|
|
|
optimizer = optim.AdamW(
|
|
|
model.parameters(),
|
|
|
- lr=train_config.lr,
|
|
|
- weight_decay=0.0,
|
|
|
+ **optim_kwargs
|
|
|
)
|
|
|
- # if fsdp_config.optimizer_overlap:
|
|
|
- # print("we are hereeeeeeeee**************************************")
|
|
|
- # optim_kwargs = {"lr": train_config.lr}
|
|
|
- # _apply_optimizer_in_backward(
|
|
|
- # optimizer_class=optim.AdamW,
|
|
|
- # params=model.parameters(),
|
|
|
- # optimizer_kwargs=optim_kwargs,
|
|
|
- # register_hook=False,
|
|
|
- # )
|
|
|
- # for p in model.parameters():
|
|
|
- # assert hasattr(p, "_in_backward_optimizers")
|
|
|
- # optim_kwargs = {"lr": train_config.lr, "weight_decay":0.0}
|
|
|
- # optimizer = optim.AdamW(
|
|
|
- # model.parameters(),
|
|
|
- # **optim_kwargs
|
|
|
- # )
|
|
|
|
|
|
|
|
|
scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
|