initialize.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Model and data parallel groups."""
  16. import torch
  17. from .utils import ensure_divisibility
  18. # Intra-layer model parallel group that the current rank belongs to.
  19. _TENSOR_MODEL_PARALLEL_GROUP = None
  20. # Inter-layer model parallel group that the current rank belongs to.
  21. _PIPELINE_MODEL_PARALLEL_GROUP = None
  22. # Model parallel group (both intra- and pipeline) that the current rank belongs to.
  23. _MODEL_PARALLEL_GROUP = None
  24. # Embedding group.
  25. _EMBEDDING_GROUP = None
  26. # Data parallel group that the current rank belongs to.
  27. _DATA_PARALLEL_GROUP = None
  28. _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
  29. _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
  30. # These values enable us to change the mpu sizes on the fly.
  31. _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
  32. _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
  33. _MPU_TENSOR_MODEL_PARALLEL_RANK = None
  34. _MPU_PIPELINE_MODEL_PARALLEL_RANK = None
  35. # A list of global ranks for each pipeline group to ease calculation of the source
  36. # rank when broadcasting from the first or last pipeline stage
  37. _PIPELINE_GLOBAL_RANKS = None
  38. def is_unitialized():
  39. """Useful for code segments that may be accessed with or without mpu initialization"""
  40. return _DATA_PARALLEL_GROUP is None
  41. def initialize_model_parallel(tensor_model_parallel_size_=1,
  42. pipeline_model_parallel_size_=1,
  43. virtual_pipeline_model_parallel_size_=None):
  44. """
  45. Initialize model data parallel groups.
  46. Arguments:
  47. tensor_model_parallel_size: number of GPUs used to parallelize model tensor.
  48. pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline.
  49. Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
  50. use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
  51. the model pipeline. The present function will
  52. create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
  53. and 8 data-parallel groups as:
  54. 8 data_parallel groups:
  55. [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
  56. 8 tensor model-parallel groups:
  57. [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
  58. 4 pipeline model-parallel groups:
  59. [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
  60. Note that for efficiency, the caller should make sure adjacent ranks
  61. are on the same DGX box. For example if we are using 2 DGX-1 boxes
  62. with a total of 16 GPUs, rank 0 to 7 belong to the first box and
  63. ranks 8 to 15 belong to the second box.
  64. """
  65. if torch.distributed.get_rank() == 0:
  66. print('> initializing tensor model parallel with size {}'.format(
  67. tensor_model_parallel_size_))
  68. print('> initializing pipeline model parallel with size {}'.format(
  69. pipeline_model_parallel_size_))
  70. # Get world size and rank. Ensure some consistencies.
  71. assert torch.distributed.is_initialized()
  72. world_size = torch.distributed.get_world_size()
  73. tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size)
  74. pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size)
  75. ensure_divisibility(world_size,
  76. tensor_model_parallel_size * pipeline_model_parallel_size)
  77. data_parallel_size = world_size // (tensor_model_parallel_size *
  78. pipeline_model_parallel_size)
  79. num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
  80. num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
  81. num_data_parallel_groups = world_size // data_parallel_size
  82. if virtual_pipeline_model_parallel_size_ is not None:
  83. global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
  84. global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
  85. _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
  86. _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size_
  87. rank = torch.distributed.get_rank()
  88. # Build the data-parallel groups.
  89. global _DATA_PARALLEL_GROUP
  90. assert _DATA_PARALLEL_GROUP is None, \
  91. 'data parallel group is already initialized'
  92. all_data_parallel_group_ranks = []
  93. for i in range(pipeline_model_parallel_size):
  94. start_rank = i * num_pipeline_model_parallel_groups
  95. end_rank = (i + 1) * num_pipeline_model_parallel_groups
  96. for j in range(tensor_model_parallel_size):
  97. ranks = range(start_rank + j, end_rank,
  98. tensor_model_parallel_size)
  99. all_data_parallel_group_ranks.append(list(ranks))
  100. group = torch.distributed.new_group(ranks)
  101. if rank in ranks:
  102. _DATA_PARALLEL_GROUP = group
  103. # Build the model-parallel groups.
  104. global _MODEL_PARALLEL_GROUP
  105. assert _MODEL_PARALLEL_GROUP is None, \
  106. 'model parallel group is already initialized'
  107. for i in range(data_parallel_size):
  108. ranks = [data_parallel_group_ranks[i]
  109. for data_parallel_group_ranks in all_data_parallel_group_ranks]
  110. group = torch.distributed.new_group(ranks)
  111. if rank in ranks:
  112. _MODEL_PARALLEL_GROUP = group
  113. # Build the tensor model-parallel groups.
  114. global _TENSOR_MODEL_PARALLEL_GROUP
  115. assert _TENSOR_MODEL_PARALLEL_GROUP is None, \
  116. 'tensor model parallel group is already initialized'
  117. for i in range(num_tensor_model_parallel_groups):
  118. ranks = range(i * tensor_model_parallel_size,
  119. (i + 1) * tensor_model_parallel_size)
  120. group = torch.distributed.new_group(ranks)
  121. if rank in ranks:
  122. _TENSOR_MODEL_PARALLEL_GROUP = group
  123. # Build the pipeline model-parallel groups and embedding groups
  124. # (first and last rank in each pipeline model-parallel group).
  125. global _PIPELINE_MODEL_PARALLEL_GROUP
  126. global _PIPELINE_GLOBAL_RANKS
  127. assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \
  128. 'pipeline model parallel group is already initialized'
  129. global _EMBEDDING_GROUP
  130. assert _EMBEDDING_GROUP is None, \
  131. 'embedding group is already initialized'
  132. for i in range(num_pipeline_model_parallel_groups):
  133. ranks = range(i, world_size,
  134. num_pipeline_model_parallel_groups)
  135. group = torch.distributed.new_group(ranks)
  136. if rank in ranks:
  137. _PIPELINE_MODEL_PARALLEL_GROUP = group
  138. _PIPELINE_GLOBAL_RANKS = ranks
  139. # Setup embedding group (to exchange gradients between
  140. # first and last stages).
  141. if len(ranks) > 1:
  142. embedding_ranks = [ranks[0], ranks[-1]]
  143. else:
  144. embedding_ranks = ranks
  145. group = torch.distributed.new_group(embedding_ranks)
  146. if rank in embedding_ranks:
  147. _EMBEDDING_GROUP = group
  148. def model_parallel_is_initialized():
  149. """Check if model and data parallel groups are initialized."""
  150. if _TENSOR_MODEL_PARALLEL_GROUP is None or \
  151. _PIPELINE_MODEL_PARALLEL_GROUP is None or \
  152. _DATA_PARALLEL_GROUP is None:
  153. return False
  154. return True
  155. def get_model_parallel_group():
  156. """Get the model parallel group the caller rank belongs to."""
  157. assert _MODEL_PARALLEL_GROUP is not None, \
  158. 'model parallel group is not initialized'
  159. return _MODEL_PARALLEL_GROUP
  160. def get_tensor_model_parallel_group():
  161. """Get the tensor model parallel group the caller rank belongs to."""
  162. assert _TENSOR_MODEL_PARALLEL_GROUP is not None, \
  163. 'intra_layer_model parallel group is not initialized'
  164. return _TENSOR_MODEL_PARALLEL_GROUP
  165. def get_pipeline_model_parallel_group():
  166. """Get the pipeline model parallel group the caller rank belongs to."""
  167. assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, \
  168. 'pipeline_model parallel group is not initialized'
  169. return _PIPELINE_MODEL_PARALLEL_GROUP
  170. def get_data_parallel_group():
  171. """Get the data parallel group the caller rank belongs to."""
  172. assert _DATA_PARALLEL_GROUP is not None, \
  173. 'data parallel group is not initialized'
  174. return _DATA_PARALLEL_GROUP
  175. def get_embedding_group():
  176. """Get the embedding group the caller rank belongs to."""
  177. assert _EMBEDDING_GROUP is not None, \
  178. 'embedding group is not initialized'
  179. return _EMBEDDING_GROUP
  180. def set_tensor_model_parallel_world_size(world_size):
  181. """Set the tensor model parallel size"""
  182. global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
  183. _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size
  184. def set_pipeline_model_parallel_world_size(world_size):
  185. """Set the pipeline model parallel size"""
  186. global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
  187. _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
  188. def get_tensor_model_parallel_world_size():
  189. """Return world size for the tensor model parallel group."""
  190. global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
  191. if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None:
  192. return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
  193. return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
  194. def get_pipeline_model_parallel_world_size():
  195. """Return world size for the pipeline model parallel group."""
  196. global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
  197. if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None:
  198. return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
  199. return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group())
  200. def set_tensor_model_parallel_rank(rank):
  201. """Set tensor model parallel rank."""
  202. global _MPU_TENSOR_MODEL_PARALLEL_RANK
  203. _MPU_TENSOR_MODEL_PARALLEL_RANK = rank
  204. def set_pipeline_model_parallel_rank(rank):
  205. """Set pipeline model parallel rank."""
  206. global _MPU_PIPELINE_MODEL_PARALLEL_RANK
  207. _MPU_PIPELINE_MODEL_PARALLEL_RANK = rank
  208. def get_tensor_model_parallel_rank():
  209. """Return my rank for the tensor model parallel group."""
  210. global _MPU_TENSOR_MODEL_PARALLEL_RANK
  211. if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None:
  212. return _MPU_TENSOR_MODEL_PARALLEL_RANK
  213. return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
  214. def get_pipeline_model_parallel_rank():
  215. """Return my rank for the pipeline model parallel group."""
  216. global _MPU_PIPELINE_MODEL_PARALLEL_RANK
  217. if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None:
  218. return _MPU_PIPELINE_MODEL_PARALLEL_RANK
  219. return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
  220. def is_pipeline_first_stage(ignore_virtual=False):
  221. """Return True if in the first pipeline model-parallel stage, False otherwise."""
  222. if not ignore_virtual:
  223. if get_virtual_pipeline_model_parallel_world_size() is not None and \
  224. get_virtual_pipeline_model_parallel_rank() != 0:
  225. return False
  226. return get_pipeline_model_parallel_rank() == 0
  227. def is_pipeline_last_stage(ignore_virtual=False):
  228. """Return True if in the last pipeline model-parallel stage, False otherwise."""
  229. if not ignore_virtual:
  230. virtual_pipeline_model_parallel_world_size = \
  231. get_virtual_pipeline_model_parallel_world_size()
  232. if virtual_pipeline_model_parallel_world_size is not None and \
  233. get_virtual_pipeline_model_parallel_rank() != (
  234. virtual_pipeline_model_parallel_world_size - 1):
  235. return False
  236. return get_pipeline_model_parallel_rank() == (
  237. get_pipeline_model_parallel_world_size() - 1)
  238. def get_virtual_pipeline_model_parallel_rank():
  239. """Return the virtual pipeline-parallel rank."""
  240. global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
  241. return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
  242. def set_virtual_pipeline_model_parallel_rank(rank):
  243. """Set the virtual pipeline-parallel rank."""
  244. global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
  245. _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank
  246. def get_virtual_pipeline_model_parallel_world_size():
  247. """Return the virtual pipeline-parallel world size."""
  248. global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
  249. return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
  250. def get_tensor_model_parallel_src_rank():
  251. """Calculate the global rank corresponding to the first local rank
  252. in the tensor model parallel group."""
  253. global_rank = torch.distributed.get_rank()
  254. local_world_size = get_tensor_model_parallel_world_size()
  255. return (global_rank // local_world_size) * local_world_size
  256. def get_pipeline_model_parallel_first_rank():
  257. assert _PIPELINE_GLOBAL_RANKS is not None, \
  258. "Pipeline parallel group is not initialized"
  259. return _PIPELINE_GLOBAL_RANKS[0]
  260. def get_pipeline_model_parallel_last_rank():
  261. assert _PIPELINE_GLOBAL_RANKS is not None, \
  262. "Pipeline parallel group is not initialized"
  263. last_rank_local = get_pipeline_model_parallel_world_size() - 1
  264. return _PIPELINE_GLOBAL_RANKS[last_rank_local]
  265. def get_pipeline_model_parallel_next_rank():
  266. assert _PIPELINE_GLOBAL_RANKS is not None, \
  267. "Pipeline parallel group is not initialized"
  268. rank_in_pipeline = get_pipeline_model_parallel_rank()
  269. world_size = get_pipeline_model_parallel_world_size()
  270. return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
  271. def get_pipeline_model_parallel_prev_rank():
  272. assert _PIPELINE_GLOBAL_RANKS is not None, \
  273. "Pipeline parallel group is not initialized"
  274. rank_in_pipeline = get_pipeline_model_parallel_rank()
  275. world_size = get_pipeline_model_parallel_world_size()
  276. return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
  277. def get_data_parallel_world_size():
  278. """Return world size for the data parallel group."""
  279. return torch.distributed.get_world_size(group=get_data_parallel_group())
  280. def get_data_parallel_rank():
  281. """Return my rank for the data parallel group."""
  282. return torch.distributed.get_rank(group=get_data_parallel_group())
  283. def destroy_model_parallel():
  284. """Set the groups to none."""
  285. global _TENSOR_MODEL_PARALLEL_GROUP
  286. _TENSOR_MODEL_PARALLEL_GROUP = None
  287. global _PIPELINE_MODEL_PARALLEL_GROUP
  288. _PIPELINE_MODEL_PARALLEL_GROUP = None
  289. global _DATA_PARALLEL_GROUP
  290. _DATA_PARALLEL_GROUP = None