utils.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import torch
  16. def ensure_divisibility(numerator, denominator):
  17. """Ensure that numerator is divisible by the denominator."""
  18. assert numerator % denominator == 0, '{} is not divisible by {}'.format(
  19. numerator, denominator)
  20. def divide(numerator, denominator):
  21. """Ensure that numerator is divisible by the denominator and return
  22. the division value."""
  23. ensure_divisibility(numerator, denominator)
  24. return numerator // denominator
  25. def split_tensor_along_last_dim(tensor, num_partitions,
  26. contiguous_split_chunks=False):
  27. """Split a tensor along its last dimension.
  28. Arguments:
  29. tensor: input tensor.
  30. num_partitions: number of partitions to split the tensor
  31. contiguous_split_chunks: If True, make each chunk contiguous
  32. in memory.
  33. """
  34. # Get the size and dimension.
  35. last_dim = tensor.dim() - 1
  36. last_dim_size = divide(tensor.size()[last_dim], num_partitions)
  37. # Split.
  38. tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
  39. # Note: torch.split does not create contiguous tensors by default.
  40. if contiguous_split_chunks:
  41. return tuple(chunk.contiguous() for chunk in tensor_list)
  42. return tensor_list
  43. class VocabUtility:
  44. """Split the vocabulary into `world_size` chunks amd return the
  45. first and last index of the vocabulary belonging to the `rank`
  46. partition: Note that indecies in [fist, last)"""
  47. @staticmethod
  48. def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
  49. rank, world_size):
  50. index_f = rank * per_partition_vocab_size
  51. index_l = index_f + per_partition_vocab_size
  52. return index_f, index_l
  53. @staticmethod
  54. def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
  55. per_partition_vocab_size = divide(global_vocab_size, world_size)
  56. return VocabUtility.vocab_range_from_per_partition_vocab_size(
  57. per_partition_vocab_size, rank, world_size)