scaled_masked_softmax.cpp 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. /* coding=utf-8
  2. * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <cuda_fp16.h>
  17. #include <torch/extension.h>
  18. #include <vector>
  19. namespace multihead_attn {
  20. namespace fused_softmax {
  21. namespace scaled_masked_softmax {
  22. torch::Tensor fwd_cuda(
  23. torch::Tensor const& input,
  24. torch::Tensor const& mask,
  25. float scale_factor);
  26. torch::Tensor bwd_cuda(
  27. torch::Tensor const& output_grads,
  28. torch::Tensor const& softmax_results,
  29. float scale_factor);
  30. torch::Tensor fwd(
  31. torch::Tensor const& input,
  32. torch::Tensor const& mask,
  33. float scale_factor) {
  34. AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
  35. AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
  36. (input.scalar_type() == at::ScalarType::BFloat16),
  37. "Only fp16 and bf16 are supported");
  38. AT_ASSERTM(mask.dim() == 4, "expected 4D tensor");
  39. return fwd_cuda(input, mask, scale_factor);
  40. }
  41. torch::Tensor bwd(
  42. torch::Tensor const& output_grads,
  43. torch::Tensor const& softmax_results,
  44. float scale_factor) {
  45. AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
  46. AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
  47. AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
  48. (output_grads.scalar_type() == at::ScalarType::BFloat16),
  49. "Only fp16 and bf16 are supported");
  50. AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
  51. (softmax_results.scalar_type() == at::ScalarType::BFloat16),
  52. "Only fp16 and bf16 are supported");
  53. return bwd_cuda(output_grads, softmax_results, scale_factor);
  54. }
  55. } // end namespace scaled_masked_softmax
  56. } // end namespace fused_softmax
  57. } // end namespace multihead_attn
  58. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  59. m.def("forward",
  60. &multihead_attn::fused_softmax::scaled_masked_softmax::fwd,
  61. "Self Multihead Attention scaled, time masked softmax -- Forward.");
  62. m.def("backward",
  63. &multihead_attn::fused_softmax::scaled_masked_softmax::bwd,
  64. "Self Multihead Attention scaled, time masked softmax -- Backward.");
  65. }