1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071 |
- # coding=utf-8
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import torch
- def ensure_divisibility(numerator, denominator):
- """Ensure that numerator is divisible by the denominator."""
- assert numerator % denominator == 0, '{} is not divisible by {}'.format(
- numerator, denominator)
- def divide(numerator, denominator):
- """Ensure that numerator is divisible by the denominator and return
- the division value."""
- ensure_divisibility(numerator, denominator)
- return numerator // denominator
- def split_tensor_along_last_dim(tensor, num_partitions,
- contiguous_split_chunks=False):
- """Split a tensor along its last dimension.
- Arguments:
- tensor: input tensor.
- num_partitions: number of partitions to split the tensor
- contiguous_split_chunks: If True, make each chunk contiguous
- in memory.
- """
- # Get the size and dimension.
- last_dim = tensor.dim() - 1
- last_dim_size = divide(tensor.size()[last_dim], num_partitions)
- # Split.
- tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
- # Note: torch.split does not create contiguous tensors by default.
- if contiguous_split_chunks:
- return tuple(chunk.contiguous() for chunk in tensor_list)
- return tensor_list
- class VocabUtility:
- """Split the vocabulary into `world_size` chunks amd return the
- first and last index of the vocabulary belonging to the `rank`
- partition: Note that indecies in [fist, last)"""
- @staticmethod
- def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
- rank, world_size):
- index_f = rank * per_partition_vocab_size
- index_l = index_f + per_partition_vocab_size
- return index_f, index_l
- @staticmethod
- def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
- per_partition_vocab_size = divide(global_vocab_size, world_size)
- return VocabUtility.vocab_range_from_per_partition_vocab_size(
- per_partition_vocab_size, rank, world_size)
|