|
@@ -64,7 +64,16 @@ def main(**kwargs):
|
|
|
torch.cuda.set_device(local_rank)
|
|
|
clear_gpu_cache(local_rank)
|
|
|
setup_environ_flags(rank)
|
|
|
-
|
|
|
+
|
|
|
+ #import _apply_optimizer_in_backward for FSDP optimizer overlap
|
|
|
+ optimizer_in_backward_available = False
|
|
|
+ if fsdp_config.optimizer_overlap:
|
|
|
+ try:
|
|
|
+ from torch.distributed.optim import _apply_optimizer_in_backward
|
|
|
+ optimizer_in_backward_available = True
|
|
|
+ except ImportError:
|
|
|
+ print("The required module for optimizer overlap in 'torch.distributed.optim' is not available.")
|
|
|
+
|
|
|
# Load the pre-trained model and setup its configuration
|
|
|
use_cache = False if train_config.enable_fsdp else None
|
|
|
if train_config.enable_fsdp and train_config.low_cpu_fsdp:
|
|
@@ -151,6 +160,7 @@ def main(**kwargs):
|
|
|
device_id=torch.cuda.current_device(),
|
|
|
limit_all_gathers=True,
|
|
|
sync_module_states=train_config.low_cpu_fsdp,
|
|
|
+ use_orig_params = True if optimizer_in_backward_available else False,
|
|
|
param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
|
|
|
if train_config.low_cpu_fsdp and rank != 0 else None,
|
|
|
)
|
|
@@ -228,12 +238,29 @@ def main(**kwargs):
|
|
|
use_kahan_summation=False,
|
|
|
weight_decay=train_config.weight_decay,
|
|
|
)
|
|
|
+ elif optimizer_in_backward_available:
|
|
|
+ print(f"setting up optimizer overlap")
|
|
|
+ optim_kwargs = {"lr": train_config.lr}
|
|
|
+ _apply_optimizer_in_backward(
|
|
|
+ optimizer_class=optim.AdamW,
|
|
|
+ params=model.parameters(),
|
|
|
+ optimizer_kwargs=optim_kwargs,
|
|
|
+ register_hook=False,
|
|
|
+ )
|
|
|
+ for p in model.parameters():
|
|
|
+ assert hasattr(p, "_in_backward_optimizers")
|
|
|
+ optim_kwargs = {"lr": train_config.lr, "weight_decay":0.0}
|
|
|
+ optimizer = optim.AdamW(
|
|
|
+ model.parameters(),
|
|
|
+ **optim_kwargs
|
|
|
+ )
|
|
|
else:
|
|
|
optimizer = optim.AdamW(
|
|
|
model.parameters(),
|
|
|
lr=train_config.lr,
|
|
|
weight_decay=train_config.weight_decay,
|
|
|
)
|
|
|
+
|
|
|
scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
|
|
|
|
|
|
# Start the training process
|