|
@@ -149,7 +149,7 @@ def main(**kwargs):
|
|
|
wandb_run.config.update(peft_config)
|
|
|
model.print_trainable_parameters()
|
|
|
|
|
|
- hsdp_device_mesh = None
|
|
|
+ hsdp_device_mesh_plan = None
|
|
|
if fsdp_config.hsdp and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD:
|
|
|
hsdp_device_mesh_plan = hsdp_device_mesh(replica_group_size=fsdp_config.replica_group_size, sharding_group_size=fsdp_config.sharding_group_size)
|
|
|
print("HSDP device mesh is ready")
|