main.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Main tasks functionality."""
  16. import os
  17. import sys
  18. sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
  19. os.path.pardir)))
  20. from megatron import get_args
  21. from megatron.initialize import initialize_megatron
  22. def get_tasks_args(parser):
  23. """Provide extra arguments required for tasks."""
  24. group = parser.add_argument_group(title='tasks')
  25. group.add_argument('--task', type=str, required=True,
  26. help='Task name.')
  27. group.add_argument('--epochs', type=int, default=None,
  28. help='Number of finetunning epochs. Zero results in '
  29. 'evaluation only.')
  30. group.add_argument('--pretrained-checkpoint', type=str, default=None,
  31. help='Pretrained checkpoint used for finetunning.')
  32. group.add_argument('--keep-last', action='store_true',
  33. help='Keep the last batch (maybe incomplete) in'
  34. 'the data loader')
  35. group.add_argument('--train-data', nargs='+', default=None,
  36. help='Whitespace separated paths or corpora names '
  37. 'for training.')
  38. group.add_argument('--valid-data', nargs='*', default=None,
  39. help='path(s) to the validation data.')
  40. group.add_argument('--overlapping-eval', type=int, default=32,
  41. help='Sliding window for overlapping evaluation.')
  42. group.add_argument('--strict-lambada', action='store_true',
  43. help='Use more difficult formulation of lambada.')
  44. # Retriever args
  45. group.add_argument('--qa-data-dev', type=str, default=None,
  46. help='Path to the QA dataset dev file.')
  47. group.add_argument('--qa-data-test', type=str, default=None,
  48. help='Path to the QA dataset test file.')
  49. # Faiss arguments for retriever
  50. group.add_argument('--faiss-use-gpu', action='store_true',
  51. help='Whether create the FaissMIPSIndex on GPU')
  52. group.add_argument('--faiss-match', type=str, default='string', \
  53. choices=['regex', 'string'], help="Answer matching '\
  54. 'logic type")
  55. group.add_argument('--faiss-topk-retrievals', type=int, default=100,
  56. help='Number of blocks to use as top-k during retrieval')
  57. # finetune for retriever
  58. group.add_argument('--eval-micro-batch-size', type=int, default=None,
  59. help='Eval Batch size per model instance (local batch '
  60. 'size). Global batch size is local batch size '
  61. 'times data parallel size.')
  62. group.add_argument('--train-with-neg', action='store_true',
  63. help='Whether to use negative examples during model '
  64. 'training')
  65. group.add_argument('--train-hard-neg', type=int, default=0,
  66. help='Number of hard negative exmaples to use during '
  67. 'training')
  68. # parameters for Av.rank validation method
  69. # Following options/arguments have been taken directly from DPR codebase
  70. group.add_argument('--val-av-rank-hard-neg', type=int, default=30,
  71. help='Av.rank validation: how many hard negatives to'
  72. ' take from each question pool')
  73. group.add_argument('--val-av-rank-other-neg', type=int, default=30,
  74. help='Av.rank validation: how many other negatives to'
  75. ' take from each question pool')
  76. return parser
  77. if __name__ == '__main__':
  78. initialize_megatron(extra_args_provider=get_tasks_args)
  79. args = get_args()
  80. if args.num_layers_per_virtual_pipeline_stage is not None:
  81. print("Interleaved pipeline schedule is not yet supported for downstream tasks.")
  82. exit()
  83. if args.task == 'RACE':
  84. from race.finetune import main
  85. elif args.task in ['MNLI', 'QQP']:
  86. from glue.finetune import main
  87. elif args.task in ['LAMBADA', 'WIKITEXT103']:
  88. from zeroshot_gpt.evaluate import main
  89. elif args.task in ['ICT-ZEROSHOT-NQ', 'RETRIEVER-EVAL']:
  90. from orqa.evaluate_orqa import main
  91. elif args.task in ['RET-FINETUNE-NQ']:
  92. from orqa.supervised.finetune import main
  93. else:
  94. raise NotImplementedError('Task {} is not implemented.'.format(
  95. args.task))
  96. main()