main.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Main tasks functionality."""
  16. import os
  17. import sys
  18. sys.path.append(
  19. os.path.abspath(
  20. os.path.join(
  21. os.path.join(os.path.dirname(__file__), os.path.pardir),
  22. os.path.pardir,
  23. )
  24. )
  25. )
  26. from megatron import get_args
  27. from megatron.initialize import initialize_megatron
  28. from classification import main
  29. def get_tasks_args(parser):
  30. """Provide extra arguments required for tasks."""
  31. group = parser.add_argument_group(title="tasks")
  32. group.add_argument(
  33. "--epochs",
  34. type=int,
  35. default=None,
  36. help="Number of finetunning epochs. Zero results in "
  37. "evaluation only.",
  38. )
  39. group.add_argument(
  40. "--pretrained-checkpoint",
  41. type=str,
  42. default=None,
  43. help="Pretrained checkpoint used for finetunning.",
  44. )
  45. group.add_argument(
  46. "--keep-last",
  47. action="store_true",
  48. help="Keep the last batch (maybe incomplete) in" "the data loader",
  49. )
  50. return parser
  51. if __name__ == "__main__":
  52. initialize_megatron(extra_args_provider=get_tasks_args)
  53. args = get_args()
  54. main()