123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109 |
- # coding=utf-8
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from commons import set_random_seed
- from commons import IdentityLayer
- from commons import print_separator
- from commons import initialize_distributed
- from mpu.cross_entropy import vocab_parallel_cross_entropy
- import mpu
- import torch.nn.functional as F
- import torch
- import random
- import sys
- sys.path.append("../..")
- def torch_cross_entropy(batch_size, seq_length, vocab_size,
- logits_scale, seed):
- set_random_seed(seed)
- identity = IdentityLayer((batch_size, seq_length, vocab_size),
- scale=logits_scale).cuda()
- logits = identity()
- target = torch.cuda.LongTensor(
- size=(batch_size, seq_length)).random_(0, vocab_size)
- loss = F.cross_entropy(logits.view(-1, logits.size()[-1]),
- target.view(-1),
- reduction='none').view_as(target).mean()
- loss.backward()
- return loss, identity.weight.grad
- def mpu_cross_entropy(batch_size, seq_length, vocab_size,
- logits_scale, seed):
- set_random_seed(seed)
- identity = IdentityLayer((batch_size, seq_length, vocab_size),
- scale=logits_scale).cuda()
- logits = identity()
- logits_parallel = mpu.scatter_to_tensor_model_parallel_region(logits)
- target = torch.cuda.LongTensor(
- size=(batch_size, seq_length)).random_(0, vocab_size)
- loss = vocab_parallel_cross_entropy(logits_parallel, target).mean()
- loss.backward()
- return loss, identity.weight.grad
- def test_cross_entropy(tensor_model_parallel_size):
- if torch.distributed.get_rank() == 0:
- print('> testing cross entropy with model parallel size {} ...'.
- format(tensor_model_parallel_size))
- mpu.initialize_model_parallel(tensor_model_parallel_size)
- tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
- batch_size = 13
- seq_length = 17
- vocab_size_per_partition = 11
- logits_scale = 1000.0
- vocab_size = vocab_size_per_partition * tensor_model_parallel_size
- seed = 1234
- loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length,
- vocab_size, logits_scale,
- seed)
- loss_mpu, grad_mpu = mpu_cross_entropy(batch_size, seq_length,
- vocab_size, logits_scale,
- seed)
- error = loss_torch.sub_(loss_mpu).abs().max()
- print(' max error in loss on global rank {}: {}'.format(
- torch.distributed.get_rank(), error))
- assert error < 1.0e-6
- error = grad_torch.sub_(grad_mpu).abs().max()
- print(' max error in grad on global rank {}: {}'.format(
- torch.distributed.get_rank(), error))
- assert error < 1.0e-6
- # Reset groups
- mpu.destroy_tensor_model_parallel()
- torch.distributed.barrier()
- if torch.distributed.get_rank() == 0:
- print('>> passed the test :-)')
- if __name__ == '__main__':
- initialize_distributed()
- world_size = torch.distributed.get_world_size()
- tensor_model_parallel_size = 1
- while tensor_model_parallel_size <= world_size:
- print_separator('test cross entropy')
- test_cross_entropy(tensor_model_parallel_size)
- tensor_model_parallel_size *= 2
|