fsdp_utils.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import torch
  5. import torch.cuda.nccl as nccl
  6. import torch.distributed as dist
  7. from llama_cookbook.policies import fpSixteen,bfSixteen, get_llama_wrapper
  8. from torch.distributed._tensor.device_mesh import init_device_mesh
  9. def fsdp_auto_wrap_policy(model, transformer_layer_names):
  10. import functools
  11. from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
  12. def lambda_policy_fn(module):
  13. if (
  14. len(list(module.named_children())) == 0
  15. and getattr(module, "weight", None) is not None
  16. and module.weight.requires_grad
  17. ):
  18. return True
  19. return False
  20. lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
  21. transformer_wrap_policy = functools.partial(
  22. transformer_auto_wrap_policy,
  23. transformer_layer_cls=set(transformer_layer_names)
  24. )
  25. auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])
  26. return auto_wrap_policy
  27. def hsdp_device_mesh(replica_group_size, sharding_group_size, device=None):
  28. """
  29. Initializes a device mesh for use with Hybrid Sharding strategy in FSDP (HSDP) training.
  30. This function requires explicit sizes for replica and sharding groups to accommodate models
  31. whose GPU fit is unknown, providing flexibility in distributed training setups.
  32. Args:
  33. replica_group_size (int): The size of each replica group. Must be provided to ensure
  34. the model fits within the available resources.
  35. sharding_group_size (int): The size of each sharding group that the model can fit. Must be provided to
  36. ensure the correct distribution of model parameters.
  37. device (str, optional): The device to use (e.g., "cuda:0"). If None, defaults to "cuda"
  38. with the local rank as the device index.
  39. Returns:
  40. A device mesh object compatible with FSDP.
  41. Raises:
  42. ValueError: If replica_group_size or sharding_group_size are not provided, or if the
  43. world size is not evenly divisible by the sharding group size.
  44. RuntimeError: If a valid device mesh cannot be created.
  45. Usage:
  46. If your model fits on 4 GPUS, and you have 3 nodes of 8 GPUs, then:
  47. Sharding_Group_Size = 4
  48. Replica_Groups_Size = (24 total gpus, 4 per sharding group) = 6 Replica Groups
  49. >>> device_mesh = initialize_device_mesh(replica_group_size, sharding_group_size)
  50. >>> sharded_model = FSDP(model, device_mesh=device_mesh, ...)
  51. """
  52. if replica_group_size is None or sharding_group_size is None:
  53. raise ValueError("Both replica_group_size and sharding_group_size must be provided.")
  54. local_rank = int(os.getenv("LOCAL_RANK", "0"))
  55. world_size = int(os.getenv("WORLD_SIZE", "1"))
  56. device = device or f"cuda"
  57. if world_size % sharding_group_size != 0:
  58. raise ValueError(f"World size {world_size} is not evenly divisible by "
  59. f"sharding group size {sharding_group_size}.")
  60. if (world_size // sharding_group_size) % replica_group_size != 0:
  61. raise ValueError(f"The calculated number of replica groups is not evenly divisible by "
  62. f"replica_group_size {replica_group_size}.")
  63. device_mesh = init_device_mesh(device, (replica_group_size, sharding_group_size))
  64. if device_mesh is None:
  65. raise RuntimeError("Failed to create a valid device mesh.")
  66. return device_mesh
  67. def get_policies(cfg, rank):
  68. """Get the policies for mixed precision and fsdp wrapping"""
  69. verify_bfloat_support = ((
  70. torch.version.cuda
  71. and torch.cuda.is_bf16_supported()
  72. and torch.version.cuda >= "11.0"
  73. and dist.is_nccl_available()
  74. and nccl.version() >= (2, 10)
  75. ) or
  76. (is_xpu_available()))
  77. mixed_precision_policy = None
  78. wrapping_policy = None
  79. # Mixed precision
  80. if cfg.mixed_precision:
  81. bf16_ready = verify_bfloat_support
  82. if bf16_ready and not cfg.use_fp16:
  83. mixed_precision_policy = bfSixteen
  84. if rank == 0:
  85. print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
  86. elif cfg.use_fp16:
  87. mixed_precision_policy = fpSixteen
  88. if rank == 0:
  89. print(f"FP16 enabled")
  90. else:
  91. print(f"bFloat16 support not present. Using FP32, and not mixed precision")
  92. wrapping_policy = get_llama_wrapper()
  93. return mixed_precision_policy, wrapping_policy