checkpointing.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Input/output checkpointing."""
  16. import os
  17. import random
  18. import sys
  19. import numpy as np
  20. import torch
  21. from megatron import (get_args,
  22. mpu,
  23. print_rank_0,
  24. update_num_microbatches,
  25. utils)
  26. _CHECKPOINT_VERSION = None
  27. def set_checkpoint_version(value):
  28. global _CHECKPOINT_VERSION
  29. if _CHECKPOINT_VERSION is not None:
  30. assert _CHECKPOINT_VERSION == value, \
  31. "checkpoint versions do not match"
  32. _CHECKPOINT_VERSION = value
  33. def get_checkpoint_version():
  34. global _CHECKPOINT_VERSION
  35. return _CHECKPOINT_VERSION
  36. def check_checkpoint_args(checkpoint_args):
  37. """Ensure fixed arguments for a model are the same for the input
  38. arguments and the one retrieved from checkpoint."""
  39. args = get_args()
  40. def _compare(arg_name, old_arg_name=None):
  41. if old_arg_name is not None:
  42. checkpoint_value = getattr(checkpoint_args, old_arg_name)
  43. else:
  44. checkpoint_value = getattr(checkpoint_args, arg_name)
  45. args_value = getattr(args, arg_name)
  46. error_message = '{} value from checkpoint ({}) is not equal to the ' \
  47. 'input argument value ({}).'.format(
  48. arg_name, checkpoint_value, args_value)
  49. assert checkpoint_value == args_value, error_message
  50. _compare('num_layers')
  51. _compare('hidden_size')
  52. _compare('num_attention_heads')
  53. if args.vocab_file:
  54. _compare('max_position_embeddings')
  55. _compare('make_vocab_size_divisible_by')
  56. _compare('padded_vocab_size')
  57. _compare('tokenizer_type')
  58. if get_checkpoint_version() < 3.0:
  59. _compare('tensor_model_parallel_size',
  60. old_arg_name='model_parallel_size')
  61. if get_checkpoint_version() >= 3.0:
  62. _compare('tensor_model_parallel_size')
  63. _compare('pipeline_model_parallel_size')
  64. def ensure_directory_exists(filename):
  65. """Build filename's path if it does not already exists."""
  66. dirname = os.path.dirname(filename)
  67. if not os.path.exists(dirname):
  68. os.makedirs(dirname)
  69. def get_checkpoint_name(checkpoints_path, iteration,
  70. release=False):
  71. """A unified checkpoint name."""
  72. if release:
  73. directory = 'release'
  74. else:
  75. directory = 'iter_{:07d}'.format(iteration)
  76. # Use both the tensor and pipeline MP rank.
  77. if mpu.get_pipeline_model_parallel_world_size() == 1:
  78. return os.path.join(checkpoints_path, directory,
  79. 'mp_rank_{:02d}'.format(
  80. mpu.get_tensor_model_parallel_rank()),
  81. 'model_optim_rng.pt')
  82. return os.path.join(checkpoints_path, directory,
  83. 'mp_rank_{:02d}_{:03d}'.format(
  84. mpu.get_tensor_model_parallel_rank(),
  85. mpu.get_pipeline_model_parallel_rank()),
  86. 'model_optim_rng.pt')
  87. def get_checkpoint_tracker_filename(checkpoints_path):
  88. """Tracker file rescords the latest chckpoint during
  89. training to restart from."""
  90. return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')
  91. def save_checkpoint(iteration, model, optimizer, lr_scheduler):
  92. """Save a model checkpoint."""
  93. args = get_args()
  94. # Only rank zero of the data parallel writes to the disk.
  95. model = utils.unwrap_model(model)
  96. print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
  97. iteration, args.save))
  98. if not torch.distributed.is_initialized() or mpu.get_data_parallel_rank() == 0:
  99. # Arguments, iteration, and model.
  100. state_dict = {}
  101. state_dict['args'] = args
  102. state_dict['checkpoint_version'] = 3.0
  103. state_dict['iteration'] = iteration
  104. if len(model) == 1:
  105. state_dict['model'] = model[0].state_dict_for_save_checkpoint()
  106. else:
  107. for i in range(len(model)):
  108. mpu.set_virtual_pipeline_model_parallel_rank(i)
  109. state_dict['model%d' % i] = model[i].state_dict_for_save_checkpoint()
  110. # Optimizer stuff.
  111. if not args.no_save_optim:
  112. if optimizer is not None:
  113. state_dict['optimizer'] = optimizer.state_dict()
  114. if lr_scheduler is not None:
  115. state_dict['lr_scheduler'] = lr_scheduler.state_dict()
  116. # RNG states.
  117. if not args.no_save_rng:
  118. state_dict['random_rng_state'] = random.getstate()
  119. state_dict['np_rng_state'] = np.random.get_state()
  120. state_dict['torch_rng_state'] = torch.get_rng_state()
  121. state_dict['cuda_rng_state'] = torch.cuda.get_rng_state()
  122. state_dict['rng_tracker_states'] \
  123. = mpu.get_cuda_rng_tracker().get_states()
  124. # Save.
  125. checkpoint_name = get_checkpoint_name(args.save, iteration)
  126. ensure_directory_exists(checkpoint_name)
  127. torch.save(state_dict, checkpoint_name)
  128. # Wait so everyone is done (necessary)
  129. if torch.distributed.is_initialized():
  130. torch.distributed.barrier()
  131. print_rank_0(' successfully saved checkpoint at iteration {:7d} to {}'.format(
  132. iteration, args.save))
  133. # And update the latest iteration
  134. if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
  135. tracker_filename = get_checkpoint_tracker_filename(args.save)
  136. with open(tracker_filename, 'w') as f:
  137. f.write(str(iteration))
  138. # Wait so everyone is done (not necessary)
  139. if torch.distributed.is_initialized():
  140. torch.distributed.barrier()
  141. def _transpose_first_dim(t, num_splits, num_splits_first, model):
  142. input_shape = t.size()
  143. # We use a self_attention module but the values extracted aren't
  144. # specific to self attention so should work for cross attention as well
  145. while hasattr(model, 'module'):
  146. model = model.module
  147. attention_module = model.language_model.encoder.layers[0].self_attention
  148. hidden_size_per_attention_head = attention_module.hidden_size_per_attention_head
  149. num_attention_heads_per_partition = attention_module.num_attention_heads_per_partition
  150. if num_splits_first:
  151. """[num_splits * np * hn, h]
  152. -->(view) [num_splits, np, hn, h]
  153. -->(tranpose) [np, num_splits, hn, h]
  154. -->(view) [np * num_splits * hn, h] """
  155. intermediate_shape = \
  156. (num_splits, num_attention_heads_per_partition,
  157. hidden_size_per_attention_head) + input_shape[1:]
  158. t = t.view(*intermediate_shape)
  159. t = t.transpose(0, 1).contiguous()
  160. else:
  161. """[np * hn * num_splits, h]
  162. -->(view) [np, hn, num_splits, h]
  163. -->(tranpose) [np, num_splits, hn, h]
  164. -->(view) [np * num_splits * hn, h] """
  165. intermediate_shape = \
  166. (num_attention_heads_per_partition,
  167. hidden_size_per_attention_head, num_splits) +\
  168. input_shape[1:]
  169. t = t.view(*intermediate_shape)
  170. t = t.transpose(1, 2).contiguous()
  171. t = t.view(*input_shape)
  172. return t
  173. def fix_query_key_value_ordering(model, checkpoint_version):
  174. """Fix up query/key/value matrix ordering if checkpoint
  175. version is smaller than 2.0
  176. """
  177. if checkpoint_version < 2.0:
  178. if isinstance(model, list):
  179. assert len(model)==1
  180. model = model[0]
  181. for name, param in model.named_parameters():
  182. if name.endswith(('.query_key_value.weight', '.query_key_value.bias')):
  183. if checkpoint_version == 0:
  184. fixed_param = _transpose_first_dim(param.data, 3, True, model)
  185. elif checkpoint_version == 1.0:
  186. fixed_param = _transpose_first_dim(param.data, 3, False, model)
  187. else:
  188. print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
  189. sys.exit()
  190. param.data.copy_(fixed_param)
  191. if name.endswith(('.key_value.weight', '.key_value.bias')):
  192. if checkpoint_version == 0:
  193. fixed_param = _transpose_first_dim(param.data, 2, True, model)
  194. elif checkpoint_version == 1.0:
  195. fixed_param = _transpose_first_dim(param.data, 2, False, model)
  196. else:
  197. print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
  198. sys.exit()
  199. param.data.copy_(fixed_param)
  200. print_rank_0(" succesfully fixed query-key-values ordering for"
  201. " checkpoint version {}".format(checkpoint_version))
  202. def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True):
  203. """Load a model checkpoint and return the iteration.
  204. strict (bool): whether to strictly enforce that the keys in
  205. :attr:`state_dict` of the checkpoint match the names of
  206. parameters and buffers in model.
  207. """
  208. args = get_args()
  209. load_dir = getattr(args, load_arg)
  210. model = utils.unwrap_model(model)
  211. # Read the tracker file and set the iteration.
  212. tracker_filename = get_checkpoint_tracker_filename(load_dir)
  213. # If no tracker file, return iretation zero.
  214. if not os.path.isfile(tracker_filename):
  215. print_rank_0('WARNING: could not find the metadata file {} '.format(
  216. tracker_filename))
  217. print_rank_0(' will not load any checkpoints and will start from '
  218. 'random')
  219. return 0
  220. # Otherwise, read the tracker file and either set the iteration or
  221. # mark it as a release checkpoint.
  222. iteration = 0
  223. release = False
  224. with open(tracker_filename, 'r') as f:
  225. metastring = f.read().strip()
  226. try:
  227. iteration = int(metastring)
  228. except ValueError:
  229. release = metastring == 'release'
  230. if not release:
  231. print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format(
  232. tracker_filename))
  233. sys.exit()
  234. assert iteration > 0 or release, 'error parsing metadata file {}'.format(
  235. tracker_filename)
  236. # Checkpoint.
  237. checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
  238. print_rank_0(f' loading checkpoint from {args.load} at iteration {iteration}')
  239. # Load the checkpoint.
  240. try:
  241. state_dict = torch.load(checkpoint_name, map_location='cpu')
  242. except ModuleNotFoundError:
  243. from megatron.fp16_deprecated import loss_scaler
  244. # For backward compatibility.
  245. print_rank_0(' > deserializing using the old code structure ...')
  246. sys.modules['fp16.loss_scaler'] = sys.modules[
  247. 'megatron.fp16_deprecated.loss_scaler']
  248. sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
  249. 'megatron.fp16_deprecated.loss_scaler']
  250. state_dict = torch.load(checkpoint_name, map_location='cpu')
  251. sys.modules.pop('fp16.loss_scaler', None)
  252. sys.modules.pop('megatron.fp16.loss_scaler', None)
  253. except BaseException as e:
  254. print_rank_0('could not load the checkpoint')
  255. print_rank_0(e)
  256. sys.exit()
  257. # set checkpoint version
  258. set_checkpoint_version(state_dict.get('checkpoint_version', 0))
  259. # Set iteration.
  260. if args.finetune or release:
  261. iteration = 0
  262. else:
  263. try:
  264. iteration = state_dict['iteration']
  265. except KeyError:
  266. try: # Backward compatible with older checkpoints
  267. iteration = state_dict['total_iters']
  268. except KeyError:
  269. print_rank_0('A metadata file exists but unable to load '
  270. 'iteration from checkpoint {}, exiting'.format(
  271. checkpoint_name))
  272. sys.exit()
  273. # Check arguments.
  274. assert args.consumed_train_samples == 0
  275. assert args.consumed_valid_samples == 0
  276. if 'args' in state_dict:
  277. checkpoint_args = state_dict['args']
  278. check_checkpoint_args(checkpoint_args)
  279. args.consumed_train_samples = getattr(checkpoint_args,
  280. 'consumed_train_samples', 0)
  281. update_num_microbatches(consumed_samples=args.consumed_train_samples)
  282. args.consumed_valid_samples = getattr(checkpoint_args,
  283. 'consumed_valid_samples', 0)
  284. else:
  285. print_rank_0('could not find arguments in the checkpoint ...')
  286. # Model.
  287. if len(model) == 1:
  288. model[0].load_state_dict(state_dict['model'], strict=strict)
  289. else:
  290. for i in range(len(model)):
  291. mpu.set_virtual_pipeline_model_parallel_rank(i)
  292. model[i].load_state_dict(state_dict['model%d' % i], strict=strict)
  293. # Fix up query/key/value matrix ordering if needed
  294. checkpoint_version = get_checkpoint_version()
  295. print_rank_0(f' checkpoint version {checkpoint_version}')
  296. fix_query_key_value_ordering(model, checkpoint_version)
  297. # Optimizer.
  298. if not release and not args.finetune and not args.no_load_optim:
  299. try:
  300. if optimizer is not None:
  301. optimizer.load_state_dict(state_dict['optimizer'])
  302. if lr_scheduler is not None:
  303. lr_scheduler.load_state_dict(state_dict['lr_scheduler'])
  304. except KeyError:
  305. print_rank_0('Unable to load optimizer from checkpoint {}. '
  306. 'Specify --no-load-optim or --finetune to prevent '
  307. 'attempting to load the optimizer state, '
  308. 'exiting ...'.format(checkpoint_name))
  309. sys.exit()
  310. # rng states.
  311. if not release and not args.finetune and not args.no_load_rng:
  312. try:
  313. random.setstate(state_dict['random_rng_state'])
  314. np.random.set_state(state_dict['np_rng_state'])
  315. torch.set_rng_state(state_dict['torch_rng_state'])
  316. torch.cuda.set_rng_state(state_dict['cuda_rng_state'])
  317. # Check for empty states array
  318. if not state_dict['rng_tracker_states']:
  319. raise KeyError
  320. mpu.get_cuda_rng_tracker().set_states(
  321. state_dict['rng_tracker_states'])
  322. except KeyError:
  323. print_rank_0('Unable to load rng state from checkpoint {}. '
  324. 'Specify --no-load-rng or --finetune to prevent '
  325. 'attempting to load the rng state, '
  326. 'exiting ...'.format(checkpoint_name))
  327. sys.exit()
  328. # Some utilities want to load a checkpoint without distributed being initialized
  329. if torch.distributed.is_initialized():
  330. torch.distributed.barrier()
  331. print_rank_0(f' successfully loaded checkpoint from {args.load} '
  332. f'at iteration {iteration}')
  333. return iteration
  334. def load_biencoder_checkpoint(model, only_query_model=False,
  335. only_context_model=False, custom_load_path=None):
  336. """
  337. selectively load retrieval models for indexing/retrieving
  338. from saved checkpoints
  339. """
  340. args = get_args()
  341. model = utils.unwrap_model(model)
  342. load_path = custom_load_path if custom_load_path is not None else args.load
  343. tracker_filename = get_checkpoint_tracker_filename(load_path)
  344. with open(tracker_filename, 'r') as f:
  345. iteration = int(f.read().strip())
  346. checkpoint_name = get_checkpoint_name(load_path, iteration, False)
  347. if mpu.get_data_parallel_rank() == 0:
  348. print('global rank {} is loading checkpoint {}'.format(
  349. torch.distributed.get_rank(), checkpoint_name))
  350. state_dict = torch.load(checkpoint_name, map_location='cpu')
  351. ret_state_dict = state_dict['model']
  352. if only_query_model:
  353. ret_state_dict.pop('context_model')
  354. if only_context_model:
  355. ret_state_dict.pop('query_model')
  356. assert len(model) == 1
  357. model[0].load_state_dict(ret_state_dict)
  358. torch.distributed.barrier()
  359. if mpu.get_data_parallel_rank() == 0:
  360. print(' successfully loaded {}'.format(checkpoint_name))
  361. return model