__init__.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import torch
  16. from .package_info import (
  17. __description__,
  18. __contact_names__,
  19. __url__,
  20. __download_url__,
  21. __keywords__,
  22. __license__,
  23. __package_name__,
  24. __version__,
  25. )
  26. from .global_vars import get_args
  27. from .global_vars import get_current_global_batch_size
  28. from .global_vars import get_num_microbatches
  29. from .global_vars import update_num_microbatches
  30. from .global_vars import get_tokenizer
  31. from .global_vars import get_tensorboard_writer
  32. from .global_vars import get_adlr_autoresume
  33. from .global_vars import get_timers
  34. from .initialize import initialize_megatron
  35. def print_rank_0(message):
  36. """If distributed is initialized, print only on rank 0."""
  37. if torch.distributed.is_initialized():
  38. if torch.distributed.get_rank() == 0:
  39. print(message, flush=True)
  40. else:
  41. print(message, flush=True)
  42. def is_last_rank():
  43. return torch.distributed.get_rank() == (
  44. torch.distributed.get_world_size() - 1)
  45. def print_rank_last(message):
  46. """If distributed is initialized, print only on last rank."""
  47. if torch.distributed.is_initialized():
  48. if is_last_rank():
  49. print(message, flush=True)
  50. else:
  51. print(message, flush=True)