layer_norm_cuda.cpp 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. /* coding=utf-8
  2. * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. /*This code is copied fron NVIDIA apex:
  17. * https://github.com/NVIDIA/apex
  18. * with minor changes. */
  19. #include <torch/extension.h>
  20. #include <vector>
  21. #include <cassert>
  22. #include "compat.h"
  23. namespace {
  24. void compute_n1_n2(
  25. at::Tensor input,
  26. at::IntArrayRef normalized_shape,
  27. int& n1,
  28. int& n2) {
  29. int idiff = input.ndimension() - normalized_shape.size();
  30. n2 = 1;
  31. for (int i = 0; i < (int)normalized_shape.size(); ++i) {
  32. assert( input.sizes()[i+idiff] == normalized_shape[i] );
  33. n2 *= normalized_shape[i];
  34. }
  35. n1 = 1;
  36. for (int i = 0; i < idiff; ++i) {
  37. n1 *= input.sizes()[i];
  38. }
  39. }
  40. void check_args(
  41. at::IntArrayRef normalized_shape,
  42. at::Tensor gamma,
  43. at::Tensor beta
  44. )
  45. {
  46. TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));
  47. TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape));
  48. }
  49. void check_args(
  50. at::Tensor input,
  51. at::IntArrayRef normalized_shape,
  52. int& n1,
  53. int& n2
  54. )
  55. {
  56. int64_t normalized_ndim = normalized_shape.size();
  57. if (normalized_ndim < 1) {
  58. std::stringstream ss;
  59. ss << "Expected normalized_shape to be at least 1-dimensional, i.e., "
  60. << "containing at least one element, but got normalized_shape="
  61. << normalized_shape;
  62. throw std::runtime_error(ss.str());
  63. }
  64. auto input_shape = input.sizes();
  65. auto input_ndim = input.dim();
  66. if (input_ndim < normalized_ndim ||
  67. !input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) {
  68. std::stringstream ss;
  69. ss << "Given normalized_shape=" << normalized_shape
  70. << ", expected input with shape [*";
  71. for (auto size : normalized_shape) {
  72. ss << ", " << size;
  73. }
  74. ss << "], but got input of size" << input_shape;
  75. throw std::runtime_error(ss.str());
  76. }
  77. compute_n1_n2(input,normalized_shape,n1,n2);
  78. }
  79. void check_args(
  80. at::Tensor input,
  81. at::IntArrayRef normalized_shape,
  82. at::Tensor gamma,
  83. at::Tensor beta,
  84. int& n1,
  85. int& n2
  86. )
  87. {
  88. check_args(input,normalized_shape,n1,n2);
  89. check_args(normalized_shape,gamma,beta);
  90. }
  91. }
  92. void cuda_layer_norm(
  93. at::Tensor* output,
  94. at::Tensor* mean,
  95. at::Tensor* invvar,
  96. at::Tensor* input,
  97. int n1,
  98. int n2,
  99. at::IntArrayRef normalized_shape,
  100. at::Tensor* gamma,
  101. at::Tensor* beta,
  102. double epsilon);
  103. #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
  104. #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
  105. #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
  106. std::vector<at::Tensor> layer_norm_affine(
  107. at::Tensor input,
  108. at::IntArrayRef normalized_shape,
  109. at::Tensor gamma,
  110. at::Tensor beta,
  111. double epsilon) {
  112. CHECK_INPUT(input);
  113. CHECK_INPUT(gamma);
  114. CHECK_INPUT(beta);
  115. int n1, n2;
  116. check_args(input, normalized_shape, gamma, beta, n1, n2);
  117. at::Tensor output = at::empty_like(
  118. input, gamma.options().dtype(gamma.scalar_type()));
  119. at::Tensor mean = at::empty(
  120. {n1}, input.options().dtype(at::ScalarType::Float));
  121. at::Tensor invvar = at::empty_like(mean);
  122. cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2,
  123. normalized_shape, &gamma, &beta, epsilon);
  124. return {output, mean, invvar};
  125. }
  126. void cuda_layer_norm_gradient(
  127. at::Tensor* dout,
  128. at::Tensor* mean,
  129. at::Tensor* invvar,
  130. at::Tensor* input,
  131. int n1,
  132. int n2,
  133. at::IntArrayRef normalized_shape,
  134. at::Tensor* gamma,
  135. at::Tensor* beta,
  136. double epsilon,
  137. at::Tensor* grad_input,
  138. at::Tensor* grad_gamma,
  139. at::Tensor* grad_beta
  140. );
  141. std::vector<at::Tensor> layer_norm_gradient_affine(
  142. at::Tensor dout,
  143. at::Tensor mean,
  144. at::Tensor invvar,
  145. at::Tensor input,
  146. at::IntArrayRef normalized_shape,
  147. at::Tensor gamma,
  148. at::Tensor beta,
  149. double epsilon) {
  150. CHECK_INPUT(dout);
  151. CHECK_INPUT(mean);
  152. CHECK_INPUT(invvar);
  153. CHECK_INPUT(input);
  154. CHECK_INPUT(gamma);
  155. CHECK_INPUT(beta);
  156. int n1, n2;
  157. check_args(input, normalized_shape, gamma, beta, n1, n2);
  158. at::Tensor grad_input = at::empty_like(input);
  159. at::Tensor grad_gamma = at::empty_like(gamma);
  160. at::Tensor grad_beta = at::empty_like(beta);
  161. cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2,
  162. normalized_shape, &gamma, &beta, epsilon,
  163. &grad_input, &grad_gamma, &grad_beta);
  164. return {grad_input, grad_gamma, grad_beta};
  165. }
  166. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  167. m.def("forward_affine", &layer_norm_affine,
  168. "LayerNorm forward (CUDA)");
  169. m.def("backward_affine", &layer_norm_gradient_affine,
  170. "LayerNorm backward (CUDA)");
  171. }