cross_entropy.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import torch
  16. from .initialize import get_tensor_model_parallel_group
  17. from .initialize import get_tensor_model_parallel_rank
  18. from .initialize import get_tensor_model_parallel_world_size
  19. from .utils import VocabUtility
  20. class _VocabParallelCrossEntropy(torch.autograd.Function):
  21. @staticmethod
  22. def forward(ctx, vocab_parallel_logits, target):
  23. # Maximum value along vocab dimension across all GPUs.
  24. logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
  25. torch.distributed.all_reduce(logits_max,
  26. op=torch.distributed.ReduceOp.MAX,
  27. group=get_tensor_model_parallel_group())
  28. # Subtract the maximum value.
  29. vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
  30. # Get the partition's vocab indecies
  31. get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
  32. partition_vocab_size = vocab_parallel_logits.size()[-1]
  33. rank = get_tensor_model_parallel_rank()
  34. world_size = get_tensor_model_parallel_world_size()
  35. vocab_start_index, vocab_end_index = get_vocab_range(
  36. partition_vocab_size, rank, world_size)
  37. # Create a mask of valid vocab ids (1 means it needs to be masked).
  38. target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
  39. masked_target = target.clone() - vocab_start_index
  40. masked_target[target_mask] = 0
  41. # Get predicted-logits = logits[target].
  42. # For Simplicity, we convert logits to a 2-D tensor with size
  43. # [*, partition-vocab-size] and target to a 1-D tensor of size [*].
  44. logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
  45. masked_target_1d = masked_target.view(-1)
  46. arange_1d = torch.arange(start=0, end=logits_2d.size()[0],
  47. device=logits_2d.device)
  48. predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
  49. predicted_logits_1d = predicted_logits_1d.clone().contiguous()
  50. predicted_logits = predicted_logits_1d.view_as(target)
  51. predicted_logits[target_mask] = 0.0
  52. # All reduce is needed to get the chunks from other GPUs.
  53. torch.distributed.all_reduce(predicted_logits,
  54. op=torch.distributed.ReduceOp.SUM,
  55. group=get_tensor_model_parallel_group())
  56. # Sum of exponential of logits along vocab dimension across all GPUs.
  57. exp_logits = vocab_parallel_logits
  58. torch.exp(vocab_parallel_logits, out=exp_logits)
  59. sum_exp_logits = exp_logits.sum(dim=-1)
  60. torch.distributed.all_reduce(sum_exp_logits,
  61. op=torch.distributed.ReduceOp.SUM,
  62. group=get_tensor_model_parallel_group())
  63. # Loss = log(sum(exp(logits))) - predicted-logit.
  64. loss = torch.log(sum_exp_logits) - predicted_logits
  65. # Store softmax, target-mask and masked-target for backward pass.
  66. exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
  67. ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
  68. return loss
  69. @staticmethod
  70. def backward(ctx, grad_output):
  71. # Retreive tensors from the forward path.
  72. softmax, target_mask, masked_target_1d = ctx.saved_tensors
  73. # All the inputs have softmax as thier gradient.
  74. grad_input = softmax
  75. # For simplicity, work with the 2D gradient.
  76. partition_vocab_size = softmax.size()[-1]
  77. grad_2d = grad_input.view(-1, partition_vocab_size)
  78. # Add the gradient from matching classes.
  79. arange_1d = torch.arange(start=0, end=grad_2d.size()[0],
  80. device=grad_2d.device)
  81. grad_2d[arange_1d, masked_target_1d] -= (
  82. 1.0 - target_mask.view(-1).float())
  83. # Finally elementwise multiplication with the output gradients.
  84. grad_input.mul_(grad_output.unsqueeze(dim=-1))
  85. return grad_input, None
  86. def vocab_parallel_cross_entropy(vocab_parallel_logits, target):
  87. """Helper function for the cross entropy."""
  88. return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target)