arguments.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Megatron arguments."""
  16. import argparse
  17. import os
  18. import torch
  19. def parse_args(extra_args_provider=None, defaults={},
  20. ignore_unknown_args=False):
  21. """Parse all arguments."""
  22. parser = argparse.ArgumentParser(description='Megatron-LM Arguments',
  23. allow_abbrev=False)
  24. # Standard arguments.
  25. parser = _add_network_size_args(parser)
  26. parser = _add_regularization_args(parser)
  27. parser = _add_training_args(parser)
  28. parser = _add_initialization_args(parser)
  29. parser = _add_learning_rate_args(parser)
  30. parser = _add_checkpointing_args(parser)
  31. parser = _add_mixed_precision_args(parser)
  32. parser = _add_distributed_args(parser)
  33. parser = _add_validation_args(parser)
  34. parser = _add_data_args(parser)
  35. parser = _add_autoresume_args(parser)
  36. parser = _add_biencoder_args(parser)
  37. parser = _add_vit_args(parser)
  38. parser = _add_logging_args(parser)
  39. # Custom arguments.
  40. if extra_args_provider is not None:
  41. parser = extra_args_provider(parser)
  42. # Parse.
  43. if ignore_unknown_args:
  44. args, _ = parser.parse_known_args()
  45. else:
  46. args = parser.parse_args()
  47. # Distributed args.
  48. args.rank = int(os.getenv('RANK', '0'))
  49. args.world_size = int(os.getenv("WORLD_SIZE", '1'))
  50. # Tensor model parallel size.
  51. args.tensor_model_parallel_size = min(
  52. args.tensor_model_parallel_size, args.world_size)
  53. assert args.world_size % args.tensor_model_parallel_size == 0, 'world size'\
  54. ' ({}) is not divisible by tensor model parallel size ({})'.format(
  55. args.world_size, args.tensor_model_parallel_size)
  56. # Pipeline model parallel size.
  57. args.pipeline_model_parallel_size = min(
  58. args.pipeline_model_parallel_size,
  59. (args.world_size // args.tensor_model_parallel_size))
  60. # Checks.
  61. model_parallel_size = args.pipeline_model_parallel_size * \
  62. args.tensor_model_parallel_size
  63. assert args.world_size % model_parallel_size == 0, 'world size is not'\
  64. ' divisible by tensor parallel size ({}) times pipeline parallel ' \
  65. 'size ({})'.format(args.world_size, args.tensor_model_parallel_size,
  66. args.pipeline_model_parallel_size)
  67. args.data_parallel_size = args.world_size // model_parallel_size
  68. if args.rank == 0:
  69. print('using world size: {}, data-parallel-size: {}, '
  70. 'tensor-model-parallel size: {}, '
  71. 'pipeline-model-parallel size: {} '.format(
  72. args.world_size, args.data_parallel_size,
  73. args.tensor_model_parallel_size,
  74. args.pipeline_model_parallel_size), flush=True)
  75. # Deprecated arguments
  76. assert args.batch_size is None, '--batch-size argument is no longer ' \
  77. 'valid, use --micro-batch-size instead'
  78. del args.batch_size
  79. assert args.warmup is None, '--warmup argument is no longer valid, use ' \
  80. '--lr-warmup-fraction instead'
  81. del args.warmup
  82. assert args.model_parallel_size is None, '--model-parallel-size is no ' \
  83. 'longer valid, use --tensor-model-parallel-size instead'
  84. del args.model_parallel_size
  85. # Set input defaults.
  86. for key in defaults:
  87. # For default to be valid, it should not be provided in the
  88. # arguments that are passed to the program. We check this by
  89. # ensuring the arg is set to None.
  90. if getattr(args, key) is not None:
  91. if args.rank == 0:
  92. print('WARNING: overriding default arguments for {key}:{v} \
  93. with {key}:{v2}'.format(key=key, v=defaults[key],
  94. v2=getattr(args, key)),
  95. flush=True)
  96. else:
  97. setattr(args, key, defaults[key])
  98. # Batch size.
  99. assert args.micro_batch_size is not None
  100. assert args.micro_batch_size > 0
  101. if args.global_batch_size is None:
  102. args.global_batch_size = args.micro_batch_size * args.data_parallel_size
  103. if args.rank == 0:
  104. print('setting global batch size to {}'.format(
  105. args.global_batch_size), flush=True)
  106. assert args.global_batch_size > 0
  107. if args.num_layers_per_virtual_pipeline_stage is not None:
  108. assert args.pipeline_model_parallel_size > 2, \
  109. 'pipeline-model-parallel size should be greater than 2 with ' \
  110. 'interleaved schedule'
  111. assert args.num_layers % args.num_layers_per_virtual_pipeline_stage == 0, \
  112. 'number of layers is not divisible by number of layers per virtual ' \
  113. 'pipeline stage'
  114. args.virtual_pipeline_model_parallel_size = \
  115. (args.num_layers // args.pipeline_model_parallel_size) // \
  116. args.num_layers_per_virtual_pipeline_stage
  117. else:
  118. args.virtual_pipeline_model_parallel_size = None
  119. # Parameters dtype.
  120. args.params_dtype = torch.float
  121. if args.fp16:
  122. assert not args.bf16
  123. args.params_dtype = torch.half
  124. if args.bf16:
  125. assert not args.fp16
  126. args.params_dtype = torch.bfloat16
  127. # bfloat16 requires gradient accumulation and all-reduce to
  128. # be done in fp32.
  129. if not args.accumulate_allreduce_grads_in_fp32:
  130. args.accumulate_allreduce_grads_in_fp32 = True
  131. if args.rank == 0:
  132. print('accumulate and all-reduce gradients in fp32 for '
  133. 'bfloat16 data type.', flush=True)
  134. if args.rank == 0:
  135. print('using {} for parameters ...'.format(args.params_dtype),
  136. flush=True)
  137. # If we do accumulation and all-reduces in fp32, we need to have
  138. # local DDP and we should set the use-contiguous-buffers-in-ddp.
  139. if args.accumulate_allreduce_grads_in_fp32:
  140. assert args.DDP_impl == 'local'
  141. args.use_contiguous_buffers_in_ddp = True
  142. if args.dataloader_type is None:
  143. args.dataloader_type = 'single'
  144. # Consumed tokens.
  145. args.consumed_train_samples = 0
  146. args.consumed_valid_samples = 0
  147. # Iteration-based training.
  148. if args.train_iters:
  149. # If we use iteration-based training, make sure the
  150. # sample-based options are off.
  151. assert args.train_samples is None, \
  152. 'expected iteration-based training'
  153. assert args.lr_decay_samples is None, \
  154. 'expected iteration-based learning rate decay'
  155. assert args.lr_warmup_samples == 0, \
  156. 'expected iteration-based learning rate warmup'
  157. assert args.rampup_batch_size is None, \
  158. 'expected no batch-size rampup for iteration-based training'
  159. if args.lr_warmup_fraction is not None:
  160. assert args.lr_warmup_iters == 0, \
  161. 'can only specify one of lr-warmup-fraction and lr-warmup-iters'
  162. # Sample-based training.
  163. if args.train_samples:
  164. # If we use sample-based training, make sure the
  165. # iteration-based options are off.
  166. assert args.train_iters is None, \
  167. 'expected sample-based training'
  168. assert args.lr_decay_iters is None, \
  169. 'expected sample-based learning rate decay'
  170. assert args.lr_warmup_iters == 0, \
  171. 'expected sample-based learnig rate warmup'
  172. if args.lr_warmup_fraction is not None:
  173. assert args.lr_warmup_samples == 0, \
  174. 'can only specify one of lr-warmup-fraction ' \
  175. 'and lr-warmup-samples'
  176. # Check required arguments.
  177. required_args = ['num_layers', 'hidden_size', 'num_attention_heads',
  178. 'max_position_embeddings']
  179. for req_arg in required_args:
  180. _check_arg_is_not_none(args, req_arg)
  181. # Checks.
  182. if args.ffn_hidden_size is None:
  183. args.ffn_hidden_size = 4 * args.hidden_size
  184. if args.kv_channels is None:
  185. assert args.hidden_size % args.num_attention_heads == 0
  186. args.kv_channels = args.hidden_size // args.num_attention_heads
  187. if args.seq_length is not None:
  188. assert args.encoder_seq_length is None
  189. args.encoder_seq_length = args.seq_length
  190. else:
  191. assert args.encoder_seq_length is not None
  192. args.seq_length = args.encoder_seq_length
  193. if args.seq_length is not None:
  194. assert args.max_position_embeddings >= args.seq_length
  195. if args.decoder_seq_length is not None:
  196. assert args.max_position_embeddings >= args.decoder_seq_length
  197. if args.lr is not None:
  198. assert args.min_lr <= args.lr
  199. if args.save is not None:
  200. assert args.save_interval is not None
  201. # Mixed precision checks.
  202. if args.fp16_lm_cross_entropy:
  203. assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'
  204. if args.fp32_residual_connection:
  205. assert args.fp16 or args.bf16, \
  206. 'residual connection in fp32 only supported when using fp16 or bf16.'
  207. # Activation checkpointing.
  208. if args.distribute_checkpointed_activations:
  209. assert args.checkpoint_activations, \
  210. 'for distribute-checkpointed-activations to work you '\
  211. 'need to enable checkpoint-activations'
  212. _print_args(args)
  213. return args
  214. def _print_args(args):
  215. """Print arguments."""
  216. if args.rank == 0:
  217. print('------------------------ arguments ------------------------',
  218. flush=True)
  219. str_list = []
  220. for arg in vars(args):
  221. dots = '.' * (48 - len(arg))
  222. str_list.append(' {} {} {}'.format(arg, dots, getattr(args, arg)))
  223. for arg in sorted(str_list, key=lambda x: x.lower()):
  224. print(arg, flush=True)
  225. print('-------------------- end of arguments ---------------------',
  226. flush=True)
  227. def _check_arg_is_not_none(args, arg):
  228. assert getattr(args, arg) is not None, '{} argument is None'.format(arg)
  229. def _add_network_size_args(parser):
  230. group = parser.add_argument_group(title='network size')
  231. group.add_argument('--num-layers', type=int, default=None,
  232. help='Number of transformer layers.')
  233. group.add_argument('--hidden-size', type=int, default=None,
  234. help='Tansformer hidden size.')
  235. group.add_argument('--ffn-hidden-size', type=int, default=None,
  236. help='Transformer Feed-Forward Network hidden size. '
  237. 'This is set to 4*hidden-size if not provided')
  238. group.add_argument('--num-attention-heads', type=int, default=None,
  239. help='Number of transformer attention heads.')
  240. group.add_argument('--kv-channels', type=int, default=None,
  241. help='Projection weights dimension in multi-head '
  242. 'attention. This is set to '
  243. ' args.hidden_size // args.num_attention_heads '
  244. 'if not provided.')
  245. group.add_argument('--max-position-embeddings', type=int, default=None,
  246. help='Maximum number of position embeddings to use. '
  247. 'This is the size of position embedding.')
  248. group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
  249. help='Pad the vocab size to be divisible by this value.'
  250. 'This is added for computational efficieny reasons.')
  251. group.add_argument('--layernorm-epsilon', type=float, default=1e-5,
  252. help='Layer norm epsilon.')
  253. group.add_argument('--apply-residual-connection-post-layernorm',
  254. action='store_true',
  255. help='If set, use original BERT residula connection '
  256. 'ordering.')
  257. group.add_argument('--openai-gelu', action='store_true',
  258. help='Use OpenAIs GeLU implementation. This option'
  259. 'should not be used unless for backward compatibility'
  260. 'reasons.')
  261. group.add_argument('--onnx-safe', type=bool, required=False,
  262. help='Use workarounds for known problems with '
  263. 'Torch ONNX exporter')
  264. group.add_argument('--bert-no-binary-head', action='store_false',
  265. help='Disable BERT binary head.',
  266. dest='bert_binary_head')
  267. return parser
  268. def _add_logging_args(parser):
  269. group = parser.add_argument_group(title='logging')
  270. group.add_argument('--log-params-norm', action='store_true',
  271. help='If set, calculate and log parameters norm.')
  272. group.add_argument('--log-num-zeros-in-grad', action='store_true',
  273. help='If set, calculate and log the number of zeros in gradient.')
  274. group.add_argument('--tensorboard-log-interval', type=int, default=1,
  275. help='Report to tensorboard interval.')
  276. group.add_argument('--tensorboard-queue-size', type=int, default=1000,
  277. help='Size of the tensorboard queue for pending events '
  278. 'and summaries before one of the ‘add’ calls forces a '
  279. 'flush to disk.')
  280. group.add_argument('--log-timers-to-tensorboard', action='store_true',
  281. help='If set, write timers to tensorboard.')
  282. group.add_argument('--log-batch-size-to-tensorboard', action='store_true',
  283. help='If set, write batch-size to tensorboard.')
  284. group.add_argument('--no-log-learnig-rate-to-tensorboard',
  285. action='store_false',
  286. help='Disable learning rate logging to tensorboard.',
  287. dest='log_learning_rate_to_tensorboard')
  288. group.add_argument('--no-log-loss-scale-to-tensorboard',
  289. action='store_false',
  290. help='Disable loss-scale logging to tensorboard.',
  291. dest='log_loss_scale_to_tensorboard')
  292. group.add_argument('--log-validation-ppl-to-tensorboard',
  293. action='store_true',
  294. help='If set, write validation perplexity to '
  295. 'tensorboard.')
  296. return parser
  297. def _add_regularization_args(parser):
  298. group = parser.add_argument_group(title='regularization')
  299. group.add_argument('--attention-dropout', type=float, default=0.1,
  300. help='Post attention dropout probability.')
  301. group.add_argument('--hidden-dropout', type=float, default=0.1,
  302. help='Dropout probability for hidden state transformer.')
  303. group.add_argument('--weight-decay', type=float, default=0.01,
  304. help='Weight decay coefficient for L2 regularization.')
  305. group.add_argument('--clip-grad', type=float, default=1.0,
  306. help='Gradient clipping based on global L2 norm.')
  307. group.add_argument('--adam-beta1', type=float, default=0.9,
  308. help='First coefficient for computing running averages '
  309. 'of gradient and its square')
  310. group.add_argument('--adam-beta2', type=float, default=0.999,
  311. help='Second coefficient for computing running averages '
  312. 'of gradient and its square')
  313. group.add_argument('--adam-eps', type=float, default=1e-08,
  314. help='Term added to the denominator to improve'
  315. 'numerical stability')
  316. group.add_argument('--sgd-momentum', type=float, default=0.9,
  317. help='Momentum factor for sgd')
  318. return parser
  319. def _add_training_args(parser):
  320. group = parser.add_argument_group(title='training')
  321. group.add_argument('--micro-batch-size', type=int, default=None,
  322. help='Batch size per model instance (local batch size). '
  323. 'Global batch size is local batch size times data '
  324. 'parallel size times number of micro batches.')
  325. group.add_argument('--batch-size', type=int, default=None,
  326. help='Old batch size parameter, do not use. '
  327. 'Use --micro-batch-size instead')
  328. group.add_argument('--global-batch-size', type=int, default=None,
  329. help='Training batch size. If set, it should be a '
  330. 'multiple of micro-batch-size times data-parallel-size. '
  331. 'If this value is None, then '
  332. 'use micro-batch-size * data-parallel-size as the '
  333. 'global batch size. This choice will result in 1 for '
  334. 'number of micro-batches.')
  335. group.add_argument('--rampup-batch-size', nargs='*', default=None,
  336. help='Batch size ramp up with the following values:'
  337. ' --rampup-batch-size <start batch size> '
  338. ' <batch size incerement> '
  339. ' <ramp-up samples> '
  340. 'For example:'
  341. ' --rampup-batch-size 16 8 300000 \ '
  342. ' --global-batch-size 1024'
  343. 'will start with global batch size 16 and over '
  344. ' (1024 - 16) / 8 = 126 intervals will increase'
  345. 'the batch size linearly to 1024. In each interval'
  346. 'we will use approximately 300000 / 126 = 2380 samples.')
  347. group.add_argument('--checkpoint-activations', action='store_true',
  348. help='Checkpoint activation to allow for training '
  349. 'with larger models, sequences, and batch sizes.')
  350. group.add_argument('--distribute-checkpointed-activations',
  351. action='store_true',
  352. help='If set, distribute checkpointed activations '
  353. 'across model parallel group.')
  354. group.add_argument('--checkpoint-num-layers', type=int, default=1,
  355. help='chunk size (number of layers) for checkpointing.')
  356. group.add_argument('--train-iters', type=int, default=None,
  357. help='Total number of iterations to train over all '
  358. 'training runs. Note that either train-iters or '
  359. 'train-samples should be provided.')
  360. group.add_argument('--train-samples', type=int, default=None,
  361. help='Total number of samples to train over all '
  362. 'training runs. Note that either train-iters or '
  363. 'train-samples should be provided.')
  364. group.add_argument('--log-interval', type=int, default=100,
  365. help='Report loss and timing interval.')
  366. group.add_argument('--exit-interval', type=int, default=None,
  367. help='Exit the program after the iteration is divisible '
  368. 'by this value.')
  369. group.add_argument('--exit-duration-in-mins', type=int, default=None,
  370. help='Exit the program after this many minutes.')
  371. group.add_argument('--tensorboard-dir', type=str, default=None,
  372. help='Write TensorBoard logs to this directory.')
  373. group.add_argument('--no-masked-softmax-fusion',
  374. action='store_false',
  375. help='Disable fusion of query_key_value scaling, '
  376. 'masking, and softmax.',
  377. dest='masked_softmax_fusion')
  378. group.add_argument('--no-bias-gelu-fusion', action='store_false',
  379. help='Disable bias and gelu fusion.',
  380. dest='bias_gelu_fusion')
  381. group.add_argument('--no-bias-dropout-fusion', action='store_false',
  382. help='Disable bias and dropout fusion.',
  383. dest='bias_dropout_fusion')
  384. group.add_argument('--optimizer', type=str, default='adam',
  385. choices=['adam', 'sgd'],
  386. help='Optimizer function')
  387. group.add_argument('--dataloader-type', type=str, default=None,
  388. choices=['single', 'cyclic'],
  389. help='Single pass vs multiple pass data loader')
  390. return parser
  391. def _add_initialization_args(parser):
  392. group = parser.add_argument_group(title='initialization')
  393. group.add_argument('--seed', type=int, default=1234,
  394. help='Random seed used for python, numpy, '
  395. 'pytorch, and cuda.')
  396. group.add_argument('--init-method-std', type=float, default=0.02,
  397. help='Standard deviation of the zero mean normal '
  398. 'distribution used for weight initialization.')
  399. group.add_argument('--init-method-xavier-uniform', action='store_true',
  400. help='Enable Xavier uniform parameter initialization')
  401. return parser
  402. def _add_learning_rate_args(parser):
  403. group = parser.add_argument_group(title='learning rate')
  404. group.add_argument('--lr', type=float, default=None,
  405. help='Initial learning rate. Depending on decay style '
  406. 'and initial warmup, the learing rate at each '
  407. 'iteration would be different.')
  408. group.add_argument('--lr-decay-style', type=str, default='linear',
  409. choices=['constant', 'linear', 'cosine'],
  410. help='Learning rate decay function.')
  411. group.add_argument('--lr-decay-iters', type=int, default=None,
  412. help='number of iterations to decay learning rate over,'
  413. ' If None defaults to `--train-iters`')
  414. group.add_argument('--lr-decay-samples', type=int, default=None,
  415. help='number of samples to decay learning rate over,'
  416. ' If None defaults to `--train-samples`')
  417. group.add_argument('--lr-warmup-fraction', type=float, default=None,
  418. help='fraction of lr-warmup-(iters/samples) to use '
  419. 'for warmup (as a float)')
  420. group.add_argument('--lr-warmup-iters', type=int, default=0,
  421. help='number of iterations to linearly warmup '
  422. 'learning rate over.')
  423. group.add_argument('--lr-warmup-samples', type=int, default=0,
  424. help='number of samples to linearly warmup '
  425. 'learning rate over.')
  426. group.add_argument('--warmup', type=int, default=None,
  427. help='Old lr warmup argument, do not use. Use one of the'
  428. '--lr-warmup-* arguments above')
  429. group.add_argument('--min-lr', type=float, default=0.0,
  430. help='Minumum value for learning rate. The scheduler'
  431. 'clip values below this threshold.')
  432. group.add_argument('--override-lr-scheduler', action='store_true',
  433. help='Reset the values of the scheduler (learning rate,'
  434. 'warmup iterations, minimum learning rate, maximum '
  435. 'number of iterations, and decay style from input '
  436. 'arguments and ignore values from checkpoints. Note'
  437. 'that all the above values will be reset.')
  438. group.add_argument('--use-checkpoint-lr-scheduler', action='store_true',
  439. help='Use checkpoint to set the values of the scheduler '
  440. '(learning rate, warmup iterations, minimum learning '
  441. 'rate, maximum number of iterations, and decay style '
  442. 'from checkpoint and ignore input arguments.')
  443. return parser
  444. def _add_checkpointing_args(parser):
  445. group = parser.add_argument_group(title='checkpointing')
  446. group.add_argument('--save', type=str, default=None,
  447. help='Output directory to save checkpoints to.')
  448. group.add_argument('--save-interval', type=int, default=None,
  449. help='Number of iterations between checkpoint saves.')
  450. group.add_argument('--no-save-optim', action='store_true', default=None,
  451. help='Do not save current optimizer.')
  452. group.add_argument('--no-save-rng', action='store_true', default=None,
  453. help='Do not save current rng state.')
  454. group.add_argument('--load', type=str, default=None,
  455. help='Directory containing a model checkpoint.')
  456. group.add_argument('--no-load-optim', action='store_true', default=None,
  457. help='Do not load optimizer when loading checkpoint.')
  458. group.add_argument('--no-load-rng', action='store_true', default=None,
  459. help='Do not load rng state when loading checkpoint.')
  460. group.add_argument('--finetune', action='store_true',
  461. help='Load model for finetuning. Do not load optimizer '
  462. 'or rng state from checkpoint and set iteration to 0. '
  463. 'Assumed when loading a release checkpoint.')
  464. return parser
  465. def _add_mixed_precision_args(parser):
  466. group = parser.add_argument_group(title='mixed precision')
  467. group.add_argument('--fp16', action='store_true',
  468. help='Run model in fp16 mode.')
  469. group.add_argument('--bf16', action='store_true',
  470. help='Run model in bfloat16 mode.')
  471. group.add_argument('--loss-scale', type=float, default=None,
  472. help='Static loss scaling, positive power of 2 '
  473. 'values can improve fp16 convergence. If None, dynamic'
  474. 'loss scaling is used.')
  475. group.add_argument('--initial-loss-scale', type=float, default=2**32,
  476. help='Initial loss-scale for dynamic loss scaling.')
  477. group.add_argument('--min-loss-scale', type=float, default=1.0,
  478. help='Minimum loss scale for dynamic loss scale.')
  479. group.add_argument('--loss-scale-window', type=float, default=1000,
  480. help='Window over which to raise/lower dynamic scale.')
  481. group.add_argument('--hysteresis', type=int, default=2,
  482. help='hysteresis for dynamic loss scaling')
  483. group.add_argument('--fp32-residual-connection', action='store_true',
  484. help='Move residual connections to fp32.')
  485. group.add_argument('--no-query-key-layer-scaling', action='store_false',
  486. help='Do not scale Q * K^T by 1 / layer-number.',
  487. dest='apply_query_key_layer_scaling')
  488. group.add_argument('--attention-softmax-in-fp32', action='store_true',
  489. help='Run attention masking and softmax in fp32. '
  490. 'This flag is ignored unless '
  491. '--no-query-key-layer-scaling is specified.')
  492. group.add_argument('--accumulate-allreduce-grads-in-fp32',
  493. action='store_true',
  494. help='Gradient accumulation and all-reduce in fp32.')
  495. group.add_argument('--fp16-lm-cross-entropy', action='store_true',
  496. help='Move the cross entropy unreduced loss calculation'
  497. 'for lm head to fp16.')
  498. return parser
  499. def _add_distributed_args(parser):
  500. group = parser.add_argument_group(title='distributed')
  501. group.add_argument('--tensor-model-parallel-size', type=int, default=1,
  502. help='Degree of tensor model parallelism.')
  503. group.add_argument('--pipeline-model-parallel-size', type=int, default=1,
  504. help='Degree of pipeline model parallelism.')
  505. group.add_argument('--model-parallel-size', type=int, default=None,
  506. help='Old model parallel argument, do not use. Use '
  507. '--tensor-model-parallel-size instead.')
  508. group.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None,
  509. help='Number of layers per virtual pipeline stage')
  510. group.add_argument('--distributed-backend', default='nccl',
  511. choices=['nccl', 'gloo'],
  512. help='Which backend to use for distributed training.')
  513. group.add_argument('--DDP-impl', default='local',
  514. choices=['local', 'torch'],
  515. help='which DistributedDataParallel implementation '
  516. 'to use.')
  517. group.add_argument('--use-contiguous-buffers-in-ddp', action='store_true',
  518. help='If set, use contiguous buffer in DDP. Note that '
  519. 'this option only works woth local DDP.' )
  520. group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false',
  521. help='Use scatter/gather to optimize communication of tensors in pipeline',
  522. dest='scatter_gather_tensors_in_pipeline')
  523. group.add_argument('--local_rank', type=int, default=None,
  524. help='local rank passed from distributed launcher.')
  525. group.add_argument('--lazy-mpu-init', type=bool, required=False,
  526. help='If set to True, initialize_megatron() '
  527. 'skips DDP initialization and returns function to '
  528. 'complete it instead.Also turns on '
  529. '--use-cpu-initialization flag. This is for '
  530. 'external DDP manager.' )
  531. group.add_argument('--use-cpu-initialization', action='store_true',
  532. default=None, help='If set, affine parallel weights '
  533. 'initialization uses CPU' )
  534. return parser
  535. def _add_validation_args(parser):
  536. group = parser.add_argument_group(title='validation')
  537. group.add_argument('--eval-iters', type=int, default=100,
  538. help='Number of iterations to run for evaluation'
  539. 'validation/test for.')
  540. group.add_argument('--eval-interval', type=int, default=1000,
  541. help='Interval between running evaluation on '
  542. 'validation set.')
  543. return parser
  544. def _add_data_args(parser):
  545. group = parser.add_argument_group(title='data and dataloader')
  546. group.add_argument('--data-path', nargs='*', default=None,
  547. help='Path to the training dataset. Accepted format:'
  548. '1) a single data path, 2) multiple datasets in the'
  549. 'form: dataset1-weight dataset1-path dataset2-weight '
  550. 'dataset2-path ...')
  551. group.add_argument('--split', type=str, default='969, 30, 1',
  552. help='Comma-separated list of proportions for training,'
  553. ' validation, and test split. For example the split '
  554. '`90,5,5` will use 90%% of data for training, 5%% for '
  555. 'validation and 5%% for test.')
  556. group.add_argument('--vocab-file', type=str, default=None,
  557. help='Path to the vocab file.')
  558. group.add_argument('--merge-file', type=str, default=None,
  559. help='Path to the BPE merge file.')
  560. group.add_argument('--vocab-extra-ids', type=int, default=0,
  561. help='Number of additional vocabulary tokens. '
  562. 'They are used for span masking in the T5 model')
  563. group.add_argument('--seq-length', type=int, default=None,
  564. help='Maximum sequence length to process.')
  565. group.add_argument('--encoder-seq-length', type=int, default=None,
  566. help='Maximum encoder sequence length to process.'
  567. 'This should be exclusive of --seq-length')
  568. group.add_argument('--decoder-seq-length', type=int, default=None,
  569. help="Maximum decoder sequence length to process.")
  570. group.add_argument('--retriever-seq-length', type=int, default=256,
  571. help='Maximum sequence length for the biencoder model '
  572. ' for retriever')
  573. group.add_argument('--sample-rate', type=float, default=1.0,
  574. help='sample rate for training data. Supposed to be 0 '
  575. ' < sample_rate < 1')
  576. group.add_argument('--mask-prob', type=float, default=0.15,
  577. help='Probability of replacing a token with mask.')
  578. group.add_argument('--short-seq-prob', type=float, default=0.1,
  579. help='Probability of producing a short sequence.')
  580. group.add_argument('--mmap-warmup', action='store_true',
  581. help='Warm up mmap files.')
  582. group.add_argument('--num-workers', type=int, default=2,
  583. help="Dataloader number of workers.")
  584. group.add_argument('--tokenizer-type', type=str,
  585. default=None,
  586. choices=['BertWordPieceLowerCase',
  587. 'BertWordPieceCase',
  588. 'GPT2BPETokenizer'],
  589. help='What type of tokenizer to use.')
  590. group.add_argument('--data-impl', type=str, default='infer',
  591. choices=['lazy', 'cached', 'mmap', 'infer'],
  592. help='Implementation of indexed datasets.')
  593. group.add_argument('--reset-position-ids', action='store_true',
  594. help='Reset posistion ids after end-of-document token.')
  595. group.add_argument('--reset-attention-mask', action='store_true',
  596. help='Reset self attention maske after '
  597. 'end-of-document token.')
  598. group.add_argument('--eod-mask-loss', action='store_true',
  599. help='Mask loss for the end of document tokens.')
  600. return parser
  601. def _add_autoresume_args(parser):
  602. group = parser.add_argument_group(title='autoresume')
  603. group.add_argument('--adlr-autoresume', action='store_true',
  604. help='Enable autoresume on adlr cluster.')
  605. group.add_argument('--adlr-autoresume-interval', type=int, default=1000,
  606. help='Intervals over which check for autoresume'
  607. 'termination signal')
  608. return parser
  609. def _add_biencoder_args(parser):
  610. group = parser.add_argument_group(title='biencoder')
  611. # network size
  612. group.add_argument('--ict-head-size', type=int, default=None,
  613. help='Size of block embeddings to be used in ICT and '
  614. 'REALM (paper default: 128)')
  615. group.add_argument('--biencoder-projection-dim', type=int, default=0,
  616. help='Size of projection head used in biencoder (paper'
  617. ' default: 128)')
  618. group.add_argument('--biencoder-shared-query-context-model', action='store_true',
  619. help='Whether to share the parameters of the query '
  620. 'and context models or not')
  621. # checkpointing
  622. group.add_argument('--ict-load', type=str, default=None,
  623. help='Directory containing an ICTBertModel checkpoint')
  624. group.add_argument('--bert-load', type=str, default=None,
  625. help='Directory containing an BertModel checkpoint '
  626. '(needed to start ICT and REALM)')
  627. # data
  628. group.add_argument('--titles-data-path', type=str, default=None,
  629. help='Path to titles dataset used for ICT')
  630. group.add_argument('--query-in-block-prob', type=float, default=0.1,
  631. help='Probability of keeping query in block for '
  632. 'ICT dataset')
  633. group.add_argument('--use-one-sent-docs', action='store_true',
  634. help='Whether to use one sentence documents in ICT')
  635. group.add_argument('--evidence-data-path', type=str, default=None,
  636. help='Path to Wikipedia Evidence frm DPR paper')
  637. # training
  638. group.add_argument('--retriever-report-topk-accuracies', nargs='+', type=int,
  639. default=[], help="Which top-k accuracies to report "
  640. "(e.g. '1 5 20')")
  641. group.add_argument('--retriever-score-scaling', action='store_true',
  642. help='Whether to scale retriever scores by inverse '
  643. 'square root of hidden size')
  644. # faiss index
  645. group.add_argument('--block-data-path', type=str, default=None,
  646. help='Where to save/load BlockData to/from')
  647. group.add_argument('--embedding-path', type=str, default=None,
  648. help='Where to save/load Open-Retrieval Embedding'
  649. ' data to/from')
  650. # indexer
  651. group.add_argument('--indexer-batch-size', type=int, default=128,
  652. help='How large of batches to use when doing indexing '
  653. 'jobs')
  654. group.add_argument('--indexer-log-interval', type=int, default=1000,
  655. help='After how many batches should the indexer '
  656. 'report progress')
  657. return parser
  658. def _add_vit_args(parser):
  659. group = parser.add_argument_group(title="vit")
  660. group.add_argument('--num-classes', type=int, default=1000,
  661. help='num of classes in vision classificaiton task')
  662. group.add_argument('--img-dim', type=int, default=224,
  663. help='Image size for vision classification task')
  664. group.add_argument('--num-channels', type=int, default=3,
  665. help='Number of channels in input image data')
  666. group.add_argument('--patch-dim', type=int, default=16,
  667. help='patch dimension used in vit')
  668. return parser