DXUtilities.h 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. // The MIT License(MIT)
  2. //
  3. // Copyright(c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  4. //
  5. // Permission is hereby granted, free of charge, to any person obtaining a copy of
  6. // this software and associated documentation files(the "Software"), to deal in
  7. // the Software without restriction, including without limitation the rights to
  8. // use, copy, modify, merge, publish, distribute, sublicense, and / or sell copies of
  9. // the Software, and to permit persons to whom the Software is furnished to do so,
  10. // subject to the following conditions :
  11. //
  12. // The above copyright notice and this permission notice shall be included in all
  13. // copies or substantial portions of the Software.
  14. //
  15. // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  16. // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
  17. // FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE AUTHORS OR
  18. // COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
  19. // IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
  20. // CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
  21. #pragma once
  22. #include <stdio.h>
  23. #include <tchar.h>
  24. #include <dxgi1_4.h>
  25. #include <d3d11.h>
  26. #include <d3dcompiler.h>
  27. #include <wrl.h>
  28. #include <algorithm>
  29. #include <iostream>
  30. #include <fstream>
  31. #include <string>
  32. #include <vector>
  33. namespace DX
  34. {
  35. using namespace Microsoft::WRL;
  36. inline LPTSTR GetErrorDescription(HRESULT hr, WCHAR* buffer, size_t size)
  37. {
  38. if (FACILITY_WINDOWS == HRESULT_FACILITY(hr))
  39. hr = HRESULT_CODE(hr);
  40. FormatMessage(FORMAT_MESSAGE_FROM_SYSTEM,
  41. nullptr, hr, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), buffer, DWORD(size), nullptr);
  42. return buffer;
  43. }
  44. inline void ThrowIfFailed(HRESULT hr)
  45. {
  46. if (FAILED(hr))
  47. {
  48. // Set a breakpoint on this line to catch Win32 API errors.
  49. const size_t size = 1024;
  50. WCHAR buffer[size];
  51. GetErrorDescription(hr, buffer, size);
  52. char str[size];
  53. size_t i = 0;
  54. wcstombs_s(&i, str, 1024, buffer, 1024);
  55. throw std::runtime_error(str);
  56. }
  57. }
  58. inline void CompileComputeShader(ID3D11Device* device,
  59. LPCWSTR pFileName,
  60. LPCSTR pEntryPoint,
  61. ID3D11ComputeShader** csShader,
  62. const D3D_SHADER_MACRO* pDefines = nullptr,
  63. ID3DInclude* pInclude = nullptr,
  64. LPCSTR pTarget = "cs_5_0")
  65. {
  66. ComPtr<ID3DBlob> csBlob;
  67. ComPtr<ID3DBlob> cdErrorBlob = nullptr;
  68. HRESULT hr = D3DCompileFromFile(pFileName, pDefines, pInclude, pEntryPoint, pTarget, 0, 0, &csBlob, &cdErrorBlob);
  69. if (FAILED(hr)) {
  70. if (cdErrorBlob) {
  71. OutputDebugStringA((char*)cdErrorBlob->GetBufferPointer());
  72. }
  73. DX::ThrowIfFailed(hr);
  74. }
  75. DX::ThrowIfFailed(device->CreateComputeShader(csBlob->GetBufferPointer(), csBlob->GetBufferSize(), nullptr, csShader));
  76. }
  77. inline void CompileComputeShader(ID3D11Device* device,
  78. LPCVOID pSrcData,
  79. size_t SrcDataSize,
  80. LPCSTR pEntrypoint,
  81. ID3D11ComputeShader** csShader,
  82. const D3D_SHADER_MACRO* pDefines = nullptr,
  83. LPCSTR pTarget = "cs_5_0")
  84. {
  85. ComPtr<ID3DBlob> csBlob;
  86. ComPtr<ID3DBlob> cdErrorBlob = nullptr;
  87. HRESULT hr = D3DCompile(pSrcData, SrcDataSize, nullptr, pDefines, nullptr, pEntrypoint, pTarget, 0, 0, &csBlob, &cdErrorBlob);
  88. if (FAILED(hr)) {
  89. if (cdErrorBlob) {
  90. OutputDebugStringA((char*)cdErrorBlob->GetBufferPointer());
  91. }
  92. DX::ThrowIfFailed(hr);
  93. }
  94. DX::ThrowIfFailed(device->CreateComputeShader(csBlob->GetBufferPointer(), csBlob->GetBufferSize(), nullptr, csShader));
  95. }
  96. struct IncludeHeader : ID3DInclude {
  97. IncludeHeader(const std::vector<std::string>& includePath)
  98. : m_includePath(includePath)
  99. , m_idx(0) {}
  100. HRESULT Open(
  101. D3D_INCLUDE_TYPE IncludeType,
  102. LPCSTR pFileName,
  103. LPCVOID pParentData,
  104. LPCVOID* ppData,
  105. UINT* pBytes
  106. ) {
  107. m_data.push_back("");
  108. std::ifstream t;
  109. size_t i = 0;
  110. while (!t.is_open() && i < m_includePath.size()) {
  111. t.open(m_includePath[i] + "/" + pFileName);
  112. i++;
  113. }
  114. if (!t.is_open())
  115. throw std::runtime_error("Error opening D3DCompileFromFile include header");
  116. t.seekg(0, std::ios::end);
  117. size_t size = t.tellg();
  118. m_data[m_idx].resize(size);
  119. t.seekg(0, std::ios::beg);
  120. t.read(m_data[m_idx].data(), size);
  121. m_data[m_idx].erase(std::remove(m_data[m_idx].begin(), m_data[m_idx].end(), '\0'), m_data[m_idx].end());
  122. *ppData = m_data[m_idx].data();
  123. *pBytes = UINT(m_data[m_idx].size());
  124. m_idx++;
  125. return S_OK;
  126. }
  127. HRESULT Close(LPCVOID pData) {
  128. return S_OK;
  129. }
  130. std::vector<std::string> m_data;
  131. std::vector<std::string> m_includePath;
  132. size_t m_idx;
  133. };
  134. class Defines {
  135. public:
  136. template<typename T>
  137. void add(const std::string& define, const T& val) {
  138. m_definesVector.push_back({ define, toStr(val) });
  139. }
  140. D3D_SHADER_MACRO* get() {
  141. m_defines = std::make_unique<D3D_SHADER_MACRO[]>(m_definesVector.size() + 1);
  142. for (size_t i = 0; i < m_definesVector.size(); ++i)
  143. m_defines[i] = { m_definesVector[i].first.c_str(), m_definesVector[i].second.c_str() };
  144. m_defines[m_definesVector.size()] = { nullptr, nullptr };
  145. return m_defines.get();
  146. }
  147. private:
  148. std::vector<std::pair<std::string, std::string>> m_definesVector;
  149. std::unique_ptr<D3D_SHADER_MACRO[]> m_defines;
  150. };
  151. }