|
@@ -188,7 +188,7 @@ def main(**kwargs):
|
|
|
device_id=device_id,
|
|
|
limit_all_gathers=True,
|
|
|
sync_module_states=train_config.low_cpu_fsdp,
|
|
|
- param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
|
|
|
+ param_init_fn=(lambda module: module.to_empty(device=torch.device("cuda"), recurse=False))
|
|
|
if train_config.low_cpu_fsdp and rank != 0 else None,
|
|
|
)
|
|
|
if fsdp_config.fsdp_activation_checkpointing:
|