merge_mp_partitions.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Merge model parallel partitions."""
  16. import os
  17. import re
  18. import sys
  19. sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
  20. os.path.pardir)))
  21. import torch
  22. from megatron import mpu
  23. from megatron.checkpointing import load_checkpoint, save_checkpoint
  24. from megatron.checkpointing import ensure_directory_exists
  25. from megatron.checkpointing import get_checkpoint_name
  26. from megatron.checkpointing import get_checkpoint_version
  27. from megatron.checkpointing import get_checkpoint_tracker_filename
  28. from megatron.global_vars import set_global_variables, get_args
  29. from megatron.global_vars import rebuild_tokenizer
  30. def split_into_partitions(tensor, num_partitions, partition_dim, stride):
  31. per_partition_size = mpu.utils.divide(tensor.size(partition_dim),
  32. num_partitions)
  33. per_partition_per_stride_size = mpu.utils.divide(per_partition_size, stride)
  34. partitions_list = torch.split(tensor,
  35. per_partition_per_stride_size,
  36. dim=partition_dim)
  37. partitions = []
  38. for i in range(num_partitions):
  39. partition = torch.cat(partitions_list[i::num_partitions],
  40. dim=partition_dim)
  41. partitions.append(partition)
  42. return partitions
  43. def merge_partitions(merged, partitions, partition_dim, stride):
  44. # Number and size of each partition.
  45. num_partitions = len(partitions)
  46. per_partition_size = None
  47. for partition in partitions:
  48. if per_partition_size is None:
  49. per_partition_size = partition.size(partition_dim)
  50. else:
  51. assert per_partition_size == partition.size(partition_dim)
  52. def concat_partitions(partitions_):
  53. with torch.no_grad():
  54. if (per_partition_size * num_partitions) == merged.size(
  55. partition_dim):
  56. torch.cat(partitions_, dim=partition_dim, out=merged)
  57. else:
  58. print(' ***WARNING*** sizes do not match. Will cut '
  59. 'the merged partitions by {} along dimension {} '
  60. 'to reduce the size from {} to {} ...'.format(
  61. (per_partition_size * num_partitions) - \
  62. merged.size(partition_dim), partition_dim,
  63. per_partition_size * num_partitions,
  64. merged.size(partition_dim)))
  65. merged_ = torch.cat(partitions_, dim=partition_dim)
  66. merged_split = torch.split(merged_, merged.size(partition_dim),
  67. dim=partition_dim)
  68. merged_ = merged_split[0]
  69. assert merged_.size(partition_dim) == merged.size(partition_dim)
  70. merged.data.copy_(merged_.data)
  71. # If stride is 1, then do simple concatination.
  72. if stride == 1:
  73. concat_partitions(partitions)
  74. return
  75. # For none unity strides, first split based on stride and then group.
  76. per_partition_per_stride_size = mpu.utils.divide(per_partition_size, stride)
  77. # Chunk and build a list.
  78. chunks = None
  79. for i, partition in enumerate(partitions):
  80. chunk = torch.split(partition,
  81. per_partition_per_stride_size,
  82. dim=partition_dim)
  83. if chunks is None:
  84. chunks = [0]*(num_partitions*len(chunk))
  85. chunks[i::num_partitions] = chunk
  86. # Concatinate.
  87. concat_partitions(chunks)
  88. return
  89. def get_model(model_type):
  90. if model_type == 'BERT':
  91. from pretrain_bert import model_provider
  92. elif model_type == 'GPT':
  93. from pretrain_gpt import model_provider
  94. elif model_type == 'RACE':
  95. from tasks.race.finetune import model_provider
  96. elif model_type == ['MNLI', 'QQP']:
  97. num_classes = 2
  98. if model_type == 'MNLI':
  99. num_classes = 3
  100. from megatron.model.classification import Classification
  101. def model_provider():
  102. return Classification(num_classes=num_classes, num_tokentypes=2)
  103. else:
  104. raise Exception('unrecognized model type: {}'.format(model_type))
  105. model = model_provider()
  106. model = model.half()
  107. return model
  108. def get_parallel_checkpoint_name(path):
  109. tracker_filename = get_checkpoint_tracker_filename(path)
  110. iteration = 0
  111. with open(tracker_filename, 'r') as f:
  112. metastring = f.read().strip()
  113. iteration = int(metastring)
  114. assert iteration > 0
  115. checkpoint_name = get_checkpoint_name(path, iteration)
  116. return checkpoint_name, iteration
  117. def test_split_merge():
  118. print('testing split and merge ...')
  119. #[QKV.ROW-COL]
  120. tensor = torch.FloatTensor([[1.11, 1.12, 1.13, 1.14, 1.15],
  121. [1.21, 1.22, 1.23, 1.24, 1.25],
  122. [1.31, 1.32, 1.33, 1.34, 1.35],
  123. [1.41, 1.42, 1.43, 1.44, 1.45],
  124. [2.11, 2.12, 2.13, 2.14, 2.15],
  125. [2.21, 2.22, 2.23, 2.24, 2.25],
  126. [2.31, 2.32, 2.33, 2.34, 2.35],
  127. [2.41, 2.42, 2.43, 2.44, 2.45],
  128. [3.11, 3.12, 3.13, 3.14, 3.15],
  129. [3.21, 3.22, 3.23, 3.24, 3.25],
  130. [3.31, 3.32, 3.33, 3.34, 3.35],
  131. [3.41, 3.42, 3.43, 3.44, 3.45]])
  132. num_partitions = 2
  133. partition_dim = 0
  134. stride = 3
  135. partitions = split_into_partitions(tensor, num_partitions,
  136. partition_dim, stride)
  137. merged = torch.zeros_like(tensor)
  138. merge_partitions(merged, partitions, partition_dim, stride)
  139. max_error = (merged - tensor).abs().max()
  140. print(' > max error (should be zero): {}'.format(max_error))
  141. def get_mp_merge_args(parser):
  142. """Provide extra arguments required for merging."""
  143. group = parser.add_argument_group(title='mp merge')
  144. group.add_argument('--model-type', type=str, required=True,
  145. choices=['BERT', 'GPT', 'RACE', 'MNLI', 'QQP'],
  146. help='Type of the mdoel.')
  147. group.add_argument('--target-pipeline-model-parallel-size', type=int, default=1,
  148. help='Degree of pipeline model parallelism in output model.')
  149. return parser
  150. def main():
  151. # Arguments do sanity checks on the world size, but we don't care,
  152. # so trick it into thinking we are plenty of processes
  153. os.environ["WORLD_SIZE"] = f'{2**31}'
  154. # Args
  155. set_global_variables(extra_args_provider=get_mp_merge_args,
  156. args_defaults = {'use_cpu_initialization': True,
  157. 'micro_batch_size': 1,
  158. 'no_load_optim': True,
  159. 'no_load_rng': True,
  160. 'no_save_optim': True,
  161. 'no_save_rng': True,
  162. 'save_interval': 1})
  163. args = get_args()
  164. if args.pipeline_model_parallel_size > 1:
  165. print("Checkpoints with pipeline model parallelism are not currently supported.")
  166. exit()
  167. model_type = args.model_type
  168. orig_tensor_model_parallel_size = args.tensor_model_parallel_size
  169. args.tensor_model_parallel_size = 1
  170. tokenizer = rebuild_tokenizer(args)
  171. print('\n merging model parallel partitions ...')
  172. print(' > number of partitions: {}'.format(orig_tensor_model_parallel_size))
  173. print(' > checkpoint path: {}'.format(args.load))
  174. print(' > model parameters:')
  175. print(' number of tokens ................ {} '.format(
  176. tokenizer.vocab_size))
  177. print(' number of layers ................ {}'.format(args.num_layers))
  178. print(' hidden size ..................... {}'.format(args.hidden_size))
  179. print(' number of attention heads ....... {}'.format(
  180. args.num_attention_heads))
  181. print(' maximum position embeddings ..... {}'.format(
  182. args.max_position_embeddings))
  183. # Full model.
  184. print('> building the full model ...')
  185. mpu.initialize.set_tensor_model_parallel_world_size(1)
  186. mpu.initialize.set_tensor_model_parallel_rank(0)
  187. mpu.initialize.set_pipeline_model_parallel_world_size(1)
  188. mpu.initialize.set_pipeline_model_parallel_rank(0)
  189. merged_model = get_model(model_type)
  190. # Build and load partitions.
  191. partitions = []
  192. iteration = 0
  193. args.tensor_model_parallel_size = orig_tensor_model_parallel_size
  194. tokenizer = rebuild_tokenizer(args)
  195. mpu.initialize.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
  196. for rank in range(args.tensor_model_parallel_size):
  197. # Reset these since load_checkpoint asserts they are 0, but we are loading
  198. # multiple checkpoints in the same process and they get set each time
  199. args.consumed_train_samples = 0
  200. args.consumed_valid_samples = 0
  201. mpu.initialize.set_tensor_model_parallel_rank(rank)
  202. checkpoint_name, iteration = get_parallel_checkpoint_name(args.load)
  203. model_ = get_model(model_type)
  204. print(f'> loading {checkpoint_name} ...')
  205. load_checkpoint(model_, None, None)
  206. print(f'> checkpoint version {get_checkpoint_version()}')
  207. partitions.append(model_)
  208. # Parameter generators so we can loop through them semiltaneouly.
  209. merged_params_gen = merged_model.named_parameters()
  210. partitions_params_gen = [partition.named_parameters()
  211. for partition in partitions]
  212. while True:
  213. try:
  214. # Get the params and check names.
  215. name, merged_param = next(merged_params_gen)
  216. print(' > working on {} ...'.format(name))
  217. print(' merged type: {}, size: {}'.format(
  218. merged_param.dtype, list(merged_param.size())))
  219. partitions_param = []
  220. for rank, partition_params_gen in enumerate(partitions_params_gen):
  221. partition_name, partition_param = next(partition_params_gen)
  222. assert partition_name == name
  223. partitions_param.append(partition_param)
  224. print(' partition {} type: {}, size: {}'.format(
  225. rank, partition_param.dtype, list(partition_param.size())))
  226. # For the non-parallel parameters, simply copy the rank 0 values.
  227. if not hasattr(merged_param, 'tensor_model_parallel'):
  228. print(' none-parallel parameter, simple copy from rank 0')
  229. with torch.no_grad():
  230. merged_param.data.copy_(partitions_param[0].data)
  231. # For parallel parameters, merge the values
  232. else:
  233. dim = merged_param.partition_dim
  234. stride = merged_param.partition_stride
  235. print(f' parallel parameter merge with stride {stride} along '
  236. f'dimention {dim}')
  237. merge_partitions(merged_param,
  238. partitions_param,
  239. dim,
  240. stride)
  241. except StopIteration:
  242. break
  243. partitions = []
  244. args.tensor_model_parallel_size = 1
  245. args.pipeline_model_parallel_size = args.target_pipeline_model_parallel_size
  246. assert args.num_layers % args.pipeline_model_parallel_size == 0, \
  247. 'num_layers must be divisible by target pipeline model parallel size'
  248. layers_per_part = args.num_layers // args.pipeline_model_parallel_size
  249. tokenizer = rebuild_tokenizer(args)
  250. mpu.initialize.set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
  251. mpu.initialize.set_tensor_model_parallel_rank(0)
  252. mpu.initialize.set_pipeline_model_parallel_world_size(args.pipeline_model_parallel_size)
  253. # regex to parse out layer number from param name
  254. layer_re = re.compile('layers\.([0-9]+)')
  255. if args.pipeline_model_parallel_size > 1:
  256. merged_params = {}
  257. for name, merged_param in merged_model.named_parameters():
  258. merged_params[name] = merged_param
  259. for rank in range(args.pipeline_model_parallel_size):
  260. mpu.initialize.set_pipeline_model_parallel_rank(rank)
  261. model = get_model(model_type)
  262. def update_layer_num(m):
  263. # TODO! This assumes no interleaved pipeline execution
  264. layer = int(m.group(1))
  265. layer += rank * layers_per_part
  266. return f'layers.{layer}'
  267. for dst_name, partition_param in model.named_parameters():
  268. if dst_name == "word_embeddings.weight":
  269. # See comment in MegatronModule.initialize_word_embeddings()
  270. src_name = "language_model.embedding.word_embeddings.weight"
  271. else:
  272. # Translate destination layer number (0-N for each partition)
  273. # to source layer number (single-model layer number)
  274. src_name = re.sub(layer_re, update_layer_num, dst_name)
  275. print(f" > copying {src_name} to {dst_name} in rank {rank}'s model")
  276. partition_param.data.copy_(merged_params[src_name].data)
  277. partitions.append(model)
  278. else:
  279. partitions = [merged_model]
  280. for rank, model in enumerate(partitions):
  281. mpu.initialize.set_pipeline_model_parallel_rank(rank)
  282. print(f"> saving rank {rank}'s model")
  283. save_checkpoint(iteration, model, None, None)
  284. print('done :-)')
  285. if __name__ == '__main__':
  286. main()