layer_norm_cuda_kernel.cu 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830
  1. /* coding=utf-8
  2. * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. /*This code is copied fron NVIDIA apex:
  17. * https://github.com/NVIDIA/apex
  18. * with minor changes. */
  19. #include "ATen/ATen.h"
  20. #include "ATen/AccumulateType.h"
  21. #include "ATen/cuda/CUDAContext.h"
  22. #include <THC/THCDeviceUtils.cuh>
  23. #include <cuda.h>
  24. #include <cuda_runtime.h>
  25. #include "type_shim.h"
  26. template<typename U> __device__
  27. void cuWelfordOnlineSum(
  28. const U curr,
  29. U& mu,
  30. U& sigma2,
  31. U& count)
  32. {
  33. count = count + U(1);
  34. U delta = curr - mu;
  35. U lmean = mu + delta / count;
  36. mu = lmean;
  37. U delta2 = curr - lmean;
  38. sigma2 = sigma2 + delta * delta2;
  39. }
  40. template<typename U> __device__
  41. void cuChanOnlineSum(
  42. const U muB,
  43. const U sigma2B,
  44. const U countB,
  45. U& mu,
  46. U& sigma2,
  47. U& count)
  48. {
  49. U delta = muB - mu;
  50. U nA = count;
  51. U nB = countB;
  52. count = count + countB;
  53. U nX = count;
  54. if (nX > U(0)) {
  55. nA = nA / nX;
  56. nB = nB / nX;
  57. mu = nA*mu + nB*muB;
  58. sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX;
  59. } else {
  60. mu = U(0);
  61. sigma2 = U(0);
  62. }
  63. }
  64. template<typename T, typename U> __device__
  65. void cuWelfordMuSigma2(
  66. const T* __restrict__ vals,
  67. const int n1,
  68. const int n2,
  69. const int i1,
  70. U& mu,
  71. U& sigma2,
  72. U* buf)
  73. {
  74. // Assumptions:
  75. // 1) blockDim.x == warpSize
  76. // 2) Tensor is contiguous
  77. // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
  78. //
  79. // compute variance and mean over n2
  80. U count = U(0);
  81. mu= U(0);
  82. sigma2 = U(0);
  83. if (i1 < n1) {
  84. // one warp normalizes one n1 index,
  85. // synchronization is implicit
  86. // initialize with standard Welford algorithm
  87. const int numx = blockDim.x * blockDim.y;
  88. const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
  89. const T* lvals = vals + i1*n2;
  90. int l = 4*thrx;
  91. for (; l+3 < n2; l+=4*numx) {
  92. for (int k = 0; k < 4; ++k) {
  93. U curr = static_cast<U>(lvals[l+k]);
  94. cuWelfordOnlineSum<U>(curr,mu,sigma2,count);
  95. }
  96. }
  97. for (; l < n2; ++l) {
  98. U curr = static_cast<U>(lvals[l]);
  99. cuWelfordOnlineSum<U>(curr,mu,sigma2,count);
  100. }
  101. // intra-warp reductions
  102. for (int l = 0; l <= 4; ++l) {
  103. int srcLaneB = (threadIdx.x+(1<<l))&31;
  104. U muB = WARP_SHFL(mu, srcLaneB);
  105. U countB = WARP_SHFL(count, srcLaneB);
  106. U sigma2B = WARP_SHFL(sigma2, srcLaneB);
  107. cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);
  108. }
  109. // threadIdx.x == 0 has correct values for each warp
  110. // inter-warp reductions
  111. if (blockDim.y > 1) {
  112. U* ubuf = (U*)buf;
  113. U* ibuf = (U*)(ubuf + blockDim.y);
  114. for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
  115. // upper half of warps write to shared
  116. if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {
  117. const int wrt_y = threadIdx.y - offset;
  118. ubuf[2*wrt_y] = mu;
  119. ubuf[2*wrt_y+1] = sigma2;
  120. ibuf[wrt_y] = count;
  121. }
  122. __syncthreads();
  123. // lower half merges
  124. if (threadIdx.x == 0 && threadIdx.y < offset) {
  125. U muB = ubuf[2*threadIdx.y];
  126. U sigma2B = ubuf[2*threadIdx.y+1];
  127. U countB = ibuf[threadIdx.y];
  128. cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);
  129. }
  130. __syncthreads();
  131. }
  132. // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
  133. if (threadIdx.x == 0 && threadIdx.y == 0) {
  134. ubuf[0] = mu;
  135. ubuf[1] = sigma2;
  136. }
  137. __syncthreads();
  138. mu = ubuf[0];
  139. sigma2 = ubuf[1]/U(n2);
  140. // don't care about final value of count, we know count == n2
  141. } else {
  142. mu = WARP_SHFL(mu, 0);
  143. sigma2 = WARP_SHFL(sigma2/U(n2), 0);
  144. }
  145. }
  146. }
  147. template<> __device__
  148. void cuWelfordMuSigma2(
  149. const at::Half* __restrict__ vals,
  150. const int n1,
  151. const int n2,
  152. const int i1,
  153. float& mu,
  154. float& sigma2,
  155. float* buf)
  156. {
  157. // Assumptions:
  158. // 1) blockDim.x == warpSize
  159. // 2) Tensor is contiguous
  160. // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
  161. //
  162. // compute variance and mean over n2
  163. float count = 0.0f;
  164. mu= float(0);
  165. sigma2 = float(0);
  166. if (i1 < n1) {
  167. // one warp normalizes one n1 index,
  168. // synchronization is implicit
  169. // initialize with standard Welford algorithm
  170. const int numx = blockDim.x * blockDim.y;
  171. const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
  172. const at::Half* lvals = vals + i1*n2;
  173. int l = 8*thrx;
  174. if ((((size_t)lvals)&3) != 0) {
  175. // 16 bit alignment
  176. // first thread consumes first point
  177. if (thrx == 0) {
  178. float curr = static_cast<float>(lvals[0]);
  179. cuWelfordOnlineSum(curr,mu,sigma2,count);
  180. }
  181. ++l;
  182. }
  183. // at this point, lvals[l] are 32 bit aligned for all threads.
  184. for (; l+7 < n2; l+=8*numx) {
  185. for (int k = 0; k < 8; k+=2) {
  186. float2 curr = __half22float2(*((__half2*)(lvals+l+k)));
  187. cuWelfordOnlineSum(curr.x,mu,sigma2,count);
  188. cuWelfordOnlineSum(curr.y,mu,sigma2,count);
  189. }
  190. }
  191. for (; l < n2; ++l) {
  192. float curr = static_cast<float>(lvals[l]);
  193. cuWelfordOnlineSum(curr,mu,sigma2,count);
  194. }
  195. // intra-warp reductions
  196. for (int l = 0; l <= 4; ++l) {
  197. int srcLaneB = (threadIdx.x+(1<<l))&31;
  198. float muB = WARP_SHFL(mu, srcLaneB);
  199. float countB = WARP_SHFL(count, srcLaneB);
  200. float sigma2B = WARP_SHFL(sigma2, srcLaneB);
  201. cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);
  202. }
  203. // threadIdx.x == 0 has correct values for each warp
  204. // inter-warp reductions
  205. if (blockDim.y > 1) {
  206. float* ubuf = (float*)buf;
  207. float* ibuf = (float*)(ubuf + blockDim.y);
  208. for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
  209. // upper half of warps write to shared
  210. if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {
  211. const int wrt_y = threadIdx.y - offset;
  212. ubuf[2*wrt_y] = mu;
  213. ubuf[2*wrt_y+1] = sigma2;
  214. ibuf[wrt_y] = count;
  215. }
  216. __syncthreads();
  217. // lower half merges
  218. if (threadIdx.x == 0 && threadIdx.y < offset) {
  219. float muB = ubuf[2*threadIdx.y];
  220. float sigma2B = ubuf[2*threadIdx.y+1];
  221. float countB = ibuf[threadIdx.y];
  222. cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);
  223. }
  224. __syncthreads();
  225. }
  226. // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
  227. if (threadIdx.x == 0 && threadIdx.y == 0) {
  228. ubuf[0] = mu;
  229. ubuf[1] = sigma2;
  230. }
  231. __syncthreads();
  232. mu = ubuf[0];
  233. sigma2 = ubuf[1]/float(n2);
  234. // don't care about final value of count, we know count == n2
  235. } else {
  236. mu = WARP_SHFL(mu, 0);
  237. sigma2 = WARP_SHFL(sigma2/float(n2), 0);
  238. }
  239. }
  240. }
  241. template<typename U> U rsqrt(U v) {
  242. return U(1) / sqrt(v);
  243. }
  244. template<> float rsqrt(float v) {
  245. return rsqrtf(v);
  246. }
  247. template<> double rsqrt(double v) {
  248. return rsqrt(v);
  249. }
  250. namespace {
  251. // This is the un-specialized struct. Note that we prevent instantiation of this
  252. // struct by putting an undefined symbol in the function body so it won't compile.
  253. // template <typename T>
  254. // struct SharedMemory
  255. // {
  256. // // Ensure that we won't compile any un-specialized types
  257. // __device__ T *getPointer()
  258. // {
  259. // extern __device__ void error(void);
  260. // error();
  261. // return NULL;
  262. // }
  263. // };
  264. // https://github.com/NVIDIA/apex/issues/246
  265. template <typename T>
  266. struct SharedMemory;
  267. template <>
  268. struct SharedMemory <float>
  269. {
  270. __device__ float *getPointer()
  271. {
  272. extern __shared__ float s_float[];
  273. return s_float;
  274. }
  275. };
  276. }
  277. template<typename T, typename U, typename V> __global__
  278. void cuApplyLayerNorm(
  279. V* __restrict__ output_vals,
  280. U* __restrict__ mean,
  281. U* __restrict__ invvar,
  282. const T* __restrict__ vals,
  283. const int n1,
  284. const int n2,
  285. const U epsilon,
  286. const V* __restrict__ gamma,
  287. const V* __restrict__ beta
  288. )
  289. {
  290. // Assumptions:
  291. // 1) blockDim.x == warpSize
  292. // 2) Tensors are contiguous
  293. //
  294. for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
  295. SharedMemory<U> shared;
  296. U* buf = shared.getPointer();
  297. U mu,sigma2;
  298. cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf);
  299. const T* lvals = vals + i1*n2;
  300. V* ovals = output_vals + i1*n2;
  301. U c_invvar = rsqrt(sigma2 + epsilon);
  302. const int numx = blockDim.x * blockDim.y;
  303. const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
  304. if (gamma != NULL && beta != NULL) {
  305. for (int i = thrx; i < n2; i+=numx) {
  306. U curr = static_cast<U>(lvals[i]);
  307. ovals[i] = gamma[i] * static_cast<V>(c_invvar * (curr - mu)) + beta[i];
  308. }
  309. } else {
  310. for (int i = thrx; i < n2; i+=numx) {
  311. U curr = static_cast<U>(lvals[i]);
  312. ovals[i] = static_cast<V>(c_invvar * (curr - mu));
  313. }
  314. }
  315. if (threadIdx.x == 0 && threadIdx.y == 0) {
  316. mean[i1] = mu;
  317. invvar[i1] = c_invvar;
  318. }
  319. }
  320. }
  321. template<typename T, typename U, typename V> __device__
  322. void cuLoadWriteStridedInputs(
  323. const int i1_block,
  324. const int thr_load_row_off,
  325. const int thr_load_col_off,
  326. const int i2_off,
  327. const int row_stride,
  328. U* warp_buf1,
  329. U* warp_buf2,
  330. const T* input,
  331. const V* dout,
  332. const int i1_end,
  333. const int n2,
  334. const U* __restrict__ mean,
  335. const U* __restrict__ invvar
  336. )
  337. {
  338. int i1 = i1_block+thr_load_row_off;
  339. if (i1 < i1_end) {
  340. U curr_mean = mean[i1];
  341. U curr_invvar = invvar[i1];
  342. for (int k = 0; k < blockDim.y; ++k) {
  343. int i2 = i2_off + k;
  344. int load_idx = i1*n2+i2;
  345. int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
  346. if (i2<n2) {
  347. U curr_input = static_cast<U>(input[load_idx]);
  348. U curr_dout = static_cast<U>(dout[load_idx]);
  349. warp_buf1[write_idx] = curr_dout;
  350. warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar;
  351. } else {
  352. warp_buf1[write_idx] = U(0);
  353. warp_buf2[write_idx] = U(0);
  354. }
  355. }
  356. } else {
  357. for (int k = 0; k < blockDim.y; ++k) {
  358. int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
  359. warp_buf1[write_idx] = U(0);
  360. warp_buf2[write_idx] = U(0);
  361. }
  362. }
  363. }
  364. template<typename T, typename U, typename V> __device__
  365. void cuLoadAddStridedInputs(
  366. const int i1_block,
  367. const int thr_load_row_off,
  368. const int thr_load_col_off,
  369. const int i2_off,
  370. const int row_stride,
  371. U* warp_buf1,
  372. U* warp_buf2,
  373. const T* input,
  374. const V* dout,
  375. const int i1_end,
  376. const int n2,
  377. const U* __restrict__ mean,
  378. const U* __restrict__ invvar
  379. )
  380. {
  381. int i1 = i1_block+thr_load_row_off;
  382. if (i1 < i1_end) {
  383. U curr_mean = mean[i1];
  384. U curr_invvar = invvar[i1];
  385. for (int k = 0; k < blockDim.y; ++k) {
  386. int i2 = i2_off + k;
  387. int load_idx = i1*n2+i2;
  388. int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
  389. if (i2<n2) {
  390. U curr_input = static_cast<U>(input[load_idx]);
  391. U curr_dout = static_cast<U>(dout[load_idx]);
  392. warp_buf1[write_idx] += curr_dout;
  393. warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar;
  394. }
  395. }
  396. }
  397. }
  398. template<typename T, typename U, typename V> __global__
  399. void cuComputePartGradGammaBeta(
  400. const V* __restrict__ dout,
  401. const T* __restrict__ input,
  402. const int n1,
  403. const int n2,
  404. const U* __restrict__ mean,
  405. const U* __restrict__ invvar,
  406. U epsilon,
  407. U* part_grad_gamma,
  408. U* part_grad_beta)
  409. {
  410. const int numsegs_n1 = (n1+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y);
  411. const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;
  412. const int i1_beg = blockIdx.y * segs_per_block * blockDim.y*blockDim.y;
  413. const int i1_beg_plus_one = (blockIdx.y+1) * segs_per_block * blockDim.y*blockDim.y;
  414. const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;
  415. const int row_stride = blockDim.x+1;
  416. const int thr_load_col_off = (threadIdx.x*blockDim.y)&(blockDim.x-1);
  417. const int thr_load_row_off = (threadIdx.x*blockDim.y)/blockDim.x + threadIdx.y*blockDim.y;
  418. const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;
  419. SharedMemory<U> shared;
  420. U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements
  421. U* warp_buf1 = (U*)buf;
  422. U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;
  423. // compute partial sums from strided inputs
  424. // do this to increase number of loads in flight
  425. cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);
  426. for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) {
  427. cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);
  428. }
  429. __syncthreads();
  430. // inter-warp reductions
  431. // sum within each warp
  432. U acc1 = U(0);
  433. U acc2 = U(0);
  434. for (int k = 0; k < blockDim.y; ++k) {
  435. int row1 = threadIdx.y + k*blockDim.y;
  436. int idx1 = row1*row_stride + threadIdx.x;
  437. acc1 += warp_buf1[idx1];
  438. acc2 += warp_buf2[idx1];
  439. }
  440. warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1;
  441. warp_buf2[threadIdx.y*row_stride+threadIdx.x] = acc2;
  442. __syncthreads();
  443. // sum all warps
  444. for (int offset = blockDim.y/2; offset > 1; offset /= 2) {
  445. if (threadIdx.y < offset) {
  446. int row1 = threadIdx.y;
  447. int row2 = threadIdx.y + offset;
  448. int idx1 = row1*row_stride + threadIdx.x;
  449. int idx2 = row2*row_stride + threadIdx.x;
  450. warp_buf1[idx1] += warp_buf1[idx2];
  451. warp_buf2[idx1] += warp_buf2[idx2];
  452. }
  453. __syncthreads();
  454. }
  455. int i2 = blockIdx.x * blockDim.x + threadIdx.x;
  456. if (threadIdx.y == 0 && i2 < n2) {
  457. int row1 = threadIdx.y;
  458. int row2 = threadIdx.y + 1;
  459. int idx1 = row1*row_stride + threadIdx.x;
  460. int idx2 = row2*row_stride + threadIdx.x;
  461. part_grad_beta[blockIdx.y*n2+i2] = warp_buf1[idx1] + warp_buf1[idx2];
  462. part_grad_gamma[blockIdx.y*n2+i2] = warp_buf2[idx1] + warp_buf2[idx2];
  463. }
  464. }
  465. template<typename U, typename V> __global__
  466. void cuComputeGradGammaBeta(
  467. const U* part_grad_gamma,
  468. const U* part_grad_beta,
  469. const int part_size,
  470. const int n1,
  471. const int n2,
  472. V* grad_gamma,
  473. V* grad_beta)
  474. {
  475. // sum partial gradients for gamma and beta
  476. SharedMemory<U> shared;
  477. U* buf = shared.getPointer();
  478. int i2 = blockIdx.x * blockDim.x + threadIdx.x;
  479. if (i2 < n2) {
  480. // each warp does sequential reductions until reduced part_size is num_warps
  481. int num_warp_reductions = part_size / blockDim.y;
  482. U sum_gamma = U(0);
  483. U sum_beta = U(0);
  484. const U* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
  485. const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
  486. for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) {
  487. sum_gamma += part_grad_gamma_ptr[warp_offset*n2];
  488. sum_beta += part_grad_beta_ptr[warp_offset*n2];
  489. }
  490. // inter-warp reductions
  491. const int nbsize3 = blockDim.x * blockDim.y / 2;
  492. for (int offset = blockDim.y/2; offset >= 1; offset /= 2) {
  493. // top half write to shared memory
  494. if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
  495. const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
  496. buf[write_idx] = sum_gamma;
  497. buf[write_idx+nbsize3] = sum_beta;
  498. }
  499. __syncthreads();
  500. // bottom half sums
  501. if (threadIdx.y < offset) {
  502. const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;
  503. sum_gamma += buf[read_idx];
  504. sum_beta += buf[read_idx+nbsize3];
  505. }
  506. __syncthreads();
  507. }
  508. // write out fully summed gradients
  509. if (threadIdx.y == 0) {
  510. grad_gamma[i2] = sum_gamma;
  511. grad_beta[i2] = sum_beta;
  512. }
  513. }
  514. }
  515. template<typename T, typename U, typename V> __global__
  516. void cuComputeGradInput(
  517. const V* __restrict__ dout,
  518. const T* __restrict__ input,
  519. const int n1,
  520. const int n2,
  521. const U* __restrict__ mean,
  522. const U* __restrict__ invvar,
  523. U epsilon,
  524. const V* gamma,
  525. T* grad_input)
  526. {
  527. for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
  528. U sum_loss1 = U(0);
  529. U sum_loss2 = U(0);
  530. const U c_mean = mean[i1];
  531. const U c_invvar = invvar[i1];
  532. const T* k_input = input + i1*n2;
  533. const V* k_dout = dout + i1*n2;
  534. const int numx = blockDim.x * blockDim.y;
  535. const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
  536. if (gamma != NULL) {
  537. int l = 4*thrx;
  538. for (; l+3 < n2; l+=4*numx) {
  539. for (int k = 0; k < 4; ++k) {
  540. const U c_h = static_cast<U>(k_input[l+k]);
  541. const U c_loss = static_cast<U>(k_dout[l+k]);
  542. sum_loss1 += c_loss * gamma[l+k];
  543. sum_loss2 += c_loss * gamma[l+k] * (c_h - c_mean) * c_invvar;
  544. }
  545. }
  546. for (; l < n2; ++l) {
  547. const U c_h = static_cast<U>(k_input[l]);
  548. const U c_loss = static_cast<U>(k_dout[l]);
  549. sum_loss1 += c_loss * gamma[l];
  550. sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar;
  551. }
  552. } else {
  553. int l = 4*thrx;
  554. for (; l+3 < n2; l+=4*numx) {
  555. for (int k = 0; k < 4; ++k) {
  556. const U c_h = static_cast<U>(k_input[l+k]);
  557. const U c_loss = static_cast<U>(k_dout[l+k]);
  558. sum_loss1 += c_loss;
  559. sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
  560. }
  561. }
  562. for (; l < n2; ++l) {
  563. const U c_h = static_cast<U>(k_input[l]);
  564. const U c_loss = static_cast<U>(k_dout[l]);
  565. sum_loss1 += c_loss;
  566. sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
  567. }
  568. }
  569. // intra-warp reductions
  570. for (int mask = blockDim.x/2; mask > 0; mask /= 2) {
  571. sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask);
  572. sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask);
  573. }
  574. // inter-warp reductions
  575. if (blockDim.y > 1) {
  576. SharedMemory<U> shared;
  577. U* buf = shared.getPointer();
  578. for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
  579. // upper half of warps write to shared
  580. if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
  581. const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
  582. buf[2*wrt_i] = sum_loss1;
  583. buf[2*wrt_i+1] = sum_loss2;
  584. }
  585. __syncthreads();
  586. // lower half merges
  587. if (threadIdx.y < offset) {
  588. const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
  589. sum_loss1 += buf[2*read_i];
  590. sum_loss2 += buf[2*read_i+1];
  591. }
  592. __syncthreads();
  593. }
  594. if (threadIdx.y == 0) {
  595. buf[2*threadIdx.x] = sum_loss1;
  596. buf[2*threadIdx.x+1] = sum_loss2;
  597. }
  598. __syncthreads();
  599. if (threadIdx.y !=0) {
  600. sum_loss1 = buf[2*threadIdx.x];
  601. sum_loss2 = buf[2*threadIdx.x+1];
  602. }
  603. }
  604. // all threads now have the two sums over l
  605. U fH = (U)n2;
  606. U term1 = (U(1) / fH) * c_invvar;
  607. T* k_grad_input = grad_input + i1*n2;
  608. if (gamma != NULL) {
  609. for (int l = thrx; l < n2; l+=numx) {
  610. const U c_h = static_cast<U>(k_input[l]);
  611. const U c_loss = static_cast<U>(k_dout[l]);
  612. U f_grad_input = fH * c_loss * gamma[l];
  613. f_grad_input -= sum_loss1;
  614. f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
  615. f_grad_input *= term1;
  616. k_grad_input[l] = static_cast<T>(f_grad_input);
  617. }
  618. } else {
  619. for (int l = thrx; l < n2; l+=numx) {
  620. const U c_h = static_cast<U>(k_input[l]);
  621. const U c_loss = static_cast<U>(k_dout[l]);
  622. U f_grad_input = fH * c_loss;
  623. f_grad_input -= sum_loss1;
  624. f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
  625. f_grad_input *= term1;
  626. k_grad_input[l] = static_cast<T>(f_grad_input);
  627. }
  628. }
  629. }
  630. }
  631. template<typename T, typename U, typename V>
  632. void HostApplyLayerNorm(
  633. V* output,
  634. U* mean,
  635. U* invvar,
  636. const T* input,
  637. int n1,
  638. int n2,
  639. double epsilon,
  640. const V* gamma,
  641. const V* beta
  642. )
  643. {
  644. auto stream = at::cuda::getCurrentCUDAStream().stream();
  645. const dim3 threads(32,4,1);
  646. const uint64_t maxGridY =
  647. at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
  648. const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
  649. int nshared =
  650. threads.y > 1 ?
  651. threads.y*sizeof(U)+(threads.y/2)*sizeof(U) :
  652. 0;
  653. cuApplyLayerNorm<<<blocks, threads, nshared, stream>>>(
  654. output,
  655. mean,
  656. invvar,
  657. input,
  658. n1,n2,
  659. U(epsilon),
  660. gamma,beta);
  661. }
  662. void cuda_layer_norm(
  663. at::Tensor* output,
  664. at::Tensor* mean,
  665. at::Tensor* invvar,
  666. at::Tensor* input,
  667. int n1,
  668. int n2,
  669. #ifdef VERSION_GE_1_1
  670. at::IntArrayRef normalized_shape,
  671. #else
  672. at::IntList normalized_shape,
  673. #endif
  674. at::Tensor* gamma,
  675. at::Tensor* beta,
  676. double epsilon)
  677. {
  678. using namespace at;
  679. DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
  680. input->scalar_type(), output->scalar_type(), "cuda_layer_norm_kernel",
  681. HostApplyLayerNorm(
  682. output->DATA_PTR<scalar_t_out>(),
  683. mean->DATA_PTR<float>(),
  684. invvar->DATA_PTR<float>(),
  685. input->DATA_PTR<scalar_t_in>(),
  686. n1,n2,
  687. epsilon,
  688. gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
  689. beta != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL);
  690. )
  691. }
  692. template<typename T, typename U, typename V>
  693. void HostLayerNormGradient(
  694. const V* dout,
  695. const U* mean,
  696. const U* invvar,
  697. at::Tensor* input,
  698. int n1,
  699. int n2,
  700. const V* gamma,
  701. const V* beta,
  702. double epsilon,
  703. T* grad_input,
  704. V* grad_gamma,
  705. V* grad_beta
  706. )
  707. {
  708. auto stream = at::cuda::getCurrentCUDAStream().stream();
  709. if (gamma != NULL && beta != NULL) {
  710. // compute grad_gamma(j) and grad_beta(j)
  711. const int part_size = 16;
  712. const dim3 threads2(32,4,1);
  713. const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1);
  714. const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y *
  715. (threads2.x + 1);
  716. const int nshared2_b = threads2.x * threads2.y * sizeof(U);
  717. const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
  718. at::Tensor part_grad_gamma = at::empty(
  719. {part_size,n2}, input->options().dtype(at::ScalarType::Float));
  720. at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
  721. cuComputePartGradGammaBeta<<<blocks2, threads2, nshared2, stream>>>(
  722. dout,
  723. input->DATA_PTR<T>(),
  724. n1,n2,
  725. mean,
  726. invvar,
  727. U(epsilon),
  728. part_grad_gamma.DATA_PTR<U>(),
  729. part_grad_beta.DATA_PTR<U>());
  730. const dim3 threads3(32,8,1);
  731. const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1);
  732. const int nshared3 = threads3.x * threads3.y * sizeof(U);
  733. cuComputeGradGammaBeta<<<blocks3, threads3, nshared3, stream>>>(
  734. part_grad_gamma.DATA_PTR<U>(),
  735. part_grad_beta.DATA_PTR<U>(),
  736. part_size,
  737. n1,n2,
  738. grad_gamma,
  739. grad_beta);
  740. }
  741. // compute grad_input
  742. const uint64_t maxGridY =
  743. at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
  744. const dim3 blocks1(1, std::min((uint64_t)n1, maxGridY), 1);
  745. const dim3 threads1(32,4,1);
  746. int nshared =
  747. threads1.y > 1 ?
  748. threads1.y*threads1.x*sizeof(U) :
  749. 0;
  750. cuComputeGradInput<<<blocks1, threads1, nshared, stream>>>(
  751. dout,
  752. input->DATA_PTR<T>(),
  753. n1,n2,
  754. mean,
  755. invvar,
  756. U(epsilon),
  757. gamma,
  758. grad_input);
  759. }
  760. void cuda_layer_norm_gradient(
  761. at::Tensor* dout,
  762. at::Tensor* mean,
  763. at::Tensor* invvar,
  764. at::Tensor* input,
  765. int n1,
  766. int n2,
  767. #ifdef VERSION_GE_1_1
  768. at::IntArrayRef normalized_shape,
  769. #else
  770. at::IntList normalized_shape,
  771. #endif
  772. at::Tensor* gamma,
  773. at::Tensor* beta,
  774. double epsilon,
  775. at::Tensor* grad_input,
  776. at::Tensor* grad_gamma,
  777. at::Tensor* grad_beta)
  778. {
  779. using namespace at;
  780. DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
  781. input->scalar_type(), gamma->scalar_type(),
  782. "cuda_layer_norm_gradient_kernel",
  783. HostLayerNormGradient(
  784. dout->DATA_PTR<scalar_t_out>(),
  785. mean->DATA_PTR<float>(),
  786. invvar->DATA_PTR<float>(),
  787. input,
  788. n1,n2,
  789. // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
  790. // if gamma Tensor is NULL on input.
  791. gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
  792. gamma != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL,
  793. epsilon,
  794. grad_input->DATA_PTR<scalar_t_in>(),
  795. gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL,
  796. gamma != NULL ? grad_beta->DATA_PTR<scalar_t_out>() : NULL);
  797. )
  798. }