|
@@ -160,7 +160,7 @@ def main(**kwargs):
|
|
|
device_id=torch.cuda.current_device(),
|
|
|
limit_all_gathers=True,
|
|
|
sync_module_states=train_config.low_cpu_fsdp,
|
|
|
- use_orig_params = True if optimizer_in_backward_available else False,
|
|
|
+ use_orig_params = optimizer_in_backward_available,
|
|
|
param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
|
|
|
if train_config.low_cpu_fsdp and rank != 0 else None,
|
|
|
)
|
|
@@ -217,21 +217,22 @@ def main(**kwargs):
|
|
|
)
|
|
|
|
|
|
# Initialize the optimizer and learning rate scheduler
|
|
|
+ optim_kwargs = {"lr": train_config.lr, "weight_decay":train_config.weight_decay}
|
|
|
if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
|
|
|
if optimizer_in_backward_available:
|
|
|
print(f"setting up optimizer overlap")
|
|
|
_apply_optimizer_in_backward(
|
|
|
optimizer_class=AnyPrecisionAdamW,
|
|
|
params=model.parameters(),
|
|
|
+ optimizer_kwargs = optim_kwargs,
|
|
|
register_hook=False,
|
|
|
)
|
|
|
optimizer = AnyPrecisionAdamW(
|
|
|
model.parameters(),
|
|
|
- lr=train_config.lr,
|
|
|
momentum_dtype=torch.bfloat16,
|
|
|
variance_dtype=torch.bfloat16,
|
|
|
use_kahan_summation=False,
|
|
|
- weight_decay=train_config.weight_decay,
|
|
|
+ **optim_kwargs,
|
|
|
)
|
|
|
else:
|
|
|
if optimizer_in_backward_available:
|
|
@@ -239,15 +240,14 @@ def main(**kwargs):
|
|
|
_apply_optimizer_in_backward(
|
|
|
optimizer_class=optim.AdamW,
|
|
|
params=model.parameters(),
|
|
|
- lr=train_config.lr,
|
|
|
+ optimizer_kwargs = optim_kwargs,
|
|
|
register_hook=False,
|
|
|
)
|
|
|
for p in model.parameters():
|
|
|
assert hasattr(p, "_in_backward_optimizers")
|
|
|
optimizer = optim.AdamW(
|
|
|
model.parameters(),
|
|
|
- lr=train_config.lr,
|
|
|
- weight_decay=train_config.weight_decay,
|
|
|
+ **optim_kwargs,
|
|
|
)
|
|
|
|
|
|
scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
|