|
@@ -217,35 +217,36 @@ def main(**kwargs):
|
|
|
)
|
|
|
|
|
|
# Initialize the optimizer and learning rate scheduler
|
|
|
- if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
|
|
|
- optimizer = AnyPrecisionAdamW(
|
|
|
- model.parameters(),
|
|
|
- lr=train_config.lr,
|
|
|
- momentum_dtype=torch.bfloat16,
|
|
|
- variance_dtype=torch.bfloat16,
|
|
|
- use_kahan_summation=False,
|
|
|
- )
|
|
|
- else:
|
|
|
- optimizer = optim.AdamW(
|
|
|
- model.parameters(),
|
|
|
- lr=train_config.lr,
|
|
|
- weight_decay=0.0,
|
|
|
- )
|
|
|
- # if fsdp_config.optimizer_overlap:
|
|
|
- # optim_kwargs = {"lr": train_config.lr}
|
|
|
- # _apply_optimizer_in_backward(
|
|
|
- # optimizer_class=optim.AdamW,
|
|
|
- # params=model.parameters(),
|
|
|
- # optimizer_kwargs=optim_kwargs,
|
|
|
- # register_hook=False,
|
|
|
+ # if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
|
|
|
+ # optimizer = AnyPrecisionAdamW(
|
|
|
+ # model.parameters(),
|
|
|
+ # lr=train_config.lr,
|
|
|
+ # momentum_dtype=torch.bfloat16,
|
|
|
+ # variance_dtype=torch.bfloat16,
|
|
|
+ # use_kahan_summation=False,
|
|
|
# )
|
|
|
- # for p in model.parameters():
|
|
|
- # assert hasattr(p, "_in_backward_optimizers")
|
|
|
- # optim_kwargs = {"lr": train_config.lr, "weight_decay":0.0}
|
|
|
+ # else:
|
|
|
# optimizer = optim.AdamW(
|
|
|
# model.parameters(),
|
|
|
- # **optim_kwargs
|
|
|
+ # lr=train_config.lr,
|
|
|
+ # weight_decay=0.0,
|
|
|
# )
|
|
|
+ if fsdp_config.optimizer_overlap:
|
|
|
+ print("we are hereeeeeeeee**************************************")
|
|
|
+ optim_kwargs = {"lr": train_config.lr}
|
|
|
+ _apply_optimizer_in_backward(
|
|
|
+ optimizer_class=optim.AdamW,
|
|
|
+ params=model.parameters(),
|
|
|
+ optimizer_kwargs=optim_kwargs,
|
|
|
+ register_hook=False,
|
|
|
+ )
|
|
|
+ for p in model.parameters():
|
|
|
+ assert hasattr(p, "_in_backward_optimizers")
|
|
|
+ optim_kwargs = {"lr": train_config.lr, "weight_decay":0.0}
|
|
|
+ optimizer = optim.AdamW(
|
|
|
+ model.parameters(),
|
|
|
+ **optim_kwargs
|
|
|
+ )
|
|
|
|
|
|
|
|
|
scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
|