type_shim.h 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. /* coding=utf-8
  2. * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <ATen/ATen.h>
  17. #include "compat.h"
  18. #define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
  19. switch(TYPE) \
  20. { \
  21. case at::ScalarType::Half: \
  22. { \
  23. using scalar_t = at::Half; \
  24. __VA_ARGS__; \
  25. break; \
  26. } \
  27. case at::ScalarType::BFloat16: \
  28. { \
  29. using scalar_t = at::BFloat16; \
  30. __VA_ARGS__; \
  31. break; \
  32. } \
  33. default: \
  34. AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
  35. }
  36. #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
  37. switch(TYPEIN) \
  38. { \
  39. case at::ScalarType::Float: \
  40. { \
  41. using scalar_t_in = float; \
  42. switch(TYPEOUT) \
  43. { \
  44. case at::ScalarType::Float: \
  45. { \
  46. using scalar_t_out = float; \
  47. __VA_ARGS__; \
  48. break; \
  49. } \
  50. case at::ScalarType::Half: \
  51. { \
  52. using scalar_t_out = at::Half; \
  53. __VA_ARGS__; \
  54. break; \
  55. } \
  56. case at::ScalarType::BFloat16: \
  57. { \
  58. using scalar_t_out = at::BFloat16; \
  59. __VA_ARGS__; \
  60. break; \
  61. } \
  62. default: \
  63. AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
  64. } \
  65. break; \
  66. } \
  67. case at::ScalarType::Half: \
  68. { \
  69. using scalar_t_in = at::Half; \
  70. using scalar_t_out = at::Half; \
  71. __VA_ARGS__; \
  72. break; \
  73. } \
  74. case at::ScalarType::BFloat16: \
  75. { \
  76. using scalar_t_in = at::BFloat16; \
  77. using scalar_t_out = at::BFloat16; \
  78. __VA_ARGS__; \
  79. break; \
  80. } \
  81. default: \
  82. AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
  83. }