BilinearUpscale.cpp 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. // The MIT License(MIT)
  2. //
  3. // Copyright(c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  4. //
  5. // Permission is hereby granted, free of charge, to any person obtaining a copy of
  6. // this software and associated documentation files(the "Software"), to deal in
  7. // the Software without restriction, including without limitation the rights to
  8. // use, copy, modify, merge, publish, distribute, sublicense, and / or sell copies of
  9. // the Software, and to permit persons to whom the Software is furnished to do so,
  10. // subject to the following conditions :
  11. //
  12. // The above copyright notice and this permission notice shall be included in all
  13. // copies or substantial portions of the Software.
  14. //
  15. // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  16. // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
  17. // FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE AUTHORS OR
  18. // COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
  19. // IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
  20. // CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
  21. #include "BilinearUpscale.h"
  22. #include "DXUtilities.h"
  23. #include "Utilities.h"
  24. #include <iostream>
  25. void BilinearUpdateConfig(BilinearUpscaleConfig& config,
  26. uint32_t inputViewportOriginX, uint32_t inputViewportOriginY,
  27. uint32_t inputViewportWidth, uint32_t inputViewportHeight,
  28. uint32_t inputTextureWidth, uint32_t inputTextureHeight,
  29. uint32_t outputViewportOriginX, uint32_t outputViewportOriginY,
  30. uint32_t outputViewportWidth, uint32_t outputViewportHeight,
  31. uint32_t outputTextureWidth, uint32_t outputTextureHeight)
  32. {
  33. config.kInputViewportHeight = inputViewportHeight;
  34. config.kInputViewportWidth = inputViewportWidth;
  35. config.kOutputViewportHeight = outputViewportHeight;
  36. config.kOutputViewportWidth = outputViewportWidth;
  37. config.kInputViewportOriginX = inputViewportOriginX;
  38. config.kInputViewportOriginY = inputViewportOriginY;
  39. config.kOutputViewportOriginX = outputViewportOriginX;
  40. config.kOutputViewportOriginY = outputViewportOriginY;
  41. config.kScaleX = inputTextureWidth / float(outputTextureWidth);
  42. config.kScaleY = inputTextureWidth / float(outputTextureWidth);
  43. config.kDstNormX = 1.f / outputTextureWidth;
  44. config.kDstNormY = 1.f / outputTextureHeight;
  45. config.kSrcNormX = 1.f / inputTextureWidth;
  46. config.kSrcNormY = 1.f / inputTextureHeight;
  47. }
  48. BilinearUpscale::BilinearUpscale(DeviceResources& deviceResources, const std::vector<std::string>& shaderPaths)
  49. : m_deviceResources(deviceResources)
  50. , m_outputWidth(0)
  51. , m_outputHeight(0)
  52. {
  53. std::string shaderName = "bilinearUpscale.hlsl";
  54. std::string shaderPath;
  55. for (auto& e : shaderPaths)
  56. {
  57. if (std::filesystem::exists(e + "/" + shaderName))
  58. {
  59. shaderPath = e + "/" + shaderName;
  60. break;
  61. }
  62. }
  63. if (shaderPath.empty())
  64. throw std::runtime_error("Shader file not found" + shaderName);
  65. ComPtr<IDxcLibrary> library;
  66. DX::ThrowIfFailed(DxcCreateInstance(CLSID_DxcLibrary, __uuidof(IDxcLibrary), &library));
  67. ComPtr<IDxcCompiler> compiler;
  68. DX::ThrowIfFailed(DxcCreateInstance(CLSID_DxcCompiler, __uuidof(IDxcCompiler), &compiler));
  69. std::wstring wShaderFilename = widen(shaderPath);
  70. uint32_t codePage = CP_UTF8;
  71. ComPtr<IDxcBlobEncoding> sourceBlob;
  72. DX::ThrowIfFailed(library->CreateBlobFromFile(wShaderFilename.c_str(), &codePage, &sourceBlob));
  73. constexpr uint32_t nDefines = 2;
  74. std::wstring wBlockWidth = widen(toStr(m_BlockWidth));
  75. std::wstring wBlockHeight = widen(toStr(m_BlockHeight));
  76. DxcDefine defines[nDefines] = {
  77. {L"BLOCK_WIDTH", wBlockWidth.c_str()},
  78. {L"BLOCK_HEIGHT", wBlockHeight.c_str()},
  79. };
  80. ComPtr<IDxcOperationResult> result;
  81. HRESULT hr = compiler->Compile(sourceBlob.Get(), wShaderFilename.c_str(), L"main", L"cs_6_2", nullptr, 0, defines, nDefines, nullptr, &result);
  82. if (SUCCEEDED(hr))
  83. {
  84. result->GetStatus(&hr);
  85. }
  86. if (FAILED(hr))
  87. {
  88. if (result)
  89. {
  90. ComPtr<IDxcBlobEncoding> errorsBlob;
  91. hr = result->GetErrorBuffer(&errorsBlob);
  92. if (SUCCEEDED(hr) && errorsBlob)
  93. {
  94. wprintf(L"Compilation failed with errors:\n%hs\n", (const char*)errorsBlob->GetBufferPointer());
  95. }
  96. }
  97. DX::ThrowIfFailed(hr);
  98. }
  99. ComPtr<IDxcBlob> computeShaderBlob;
  100. result->GetResult(&computeShaderBlob);
  101. m_deviceResources.CreateBuffer(sizeof(BilinearUpscaleConfig), D3D12_HEAP_TYPE_UPLOAD, D3D12_RESOURCE_STATE_GENERIC_READ, &m_stagingBuffer);
  102. m_deviceResources.CreateBuffer(sizeof(BilinearUpscaleConfig), D3D12_HEAP_TYPE_DEFAULT, D3D12_RESOURCE_STATE_COMMON, &m_constatBuffer);
  103. // Define root table layout
  104. constexpr uint32_t nParams = 4;
  105. CD3DX12_DESCRIPTOR_RANGE descriptorRange[nParams] = {};
  106. descriptorRange[0] = CD3DX12_DESCRIPTOR_RANGE(D3D12_DESCRIPTOR_RANGE_TYPE_CBV, 1, 0);
  107. descriptorRange[1] = CD3DX12_DESCRIPTOR_RANGE(D3D12_DESCRIPTOR_RANGE_TYPE_SAMPLER, 1, 0);
  108. descriptorRange[2] = CD3DX12_DESCRIPTOR_RANGE(D3D12_DESCRIPTOR_RANGE_TYPE_SRV, 1, 0);
  109. descriptorRange[3] = CD3DX12_DESCRIPTOR_RANGE(D3D12_DESCRIPTOR_RANGE_TYPE_UAV, 1, 0);
  110. CD3DX12_ROOT_PARAMETER m_rootParams[nParams] = {};
  111. m_rootParams[0].InitAsDescriptorTable(1, &descriptorRange[0]);
  112. m_rootParams[1].InitAsDescriptorTable(1, &descriptorRange[1]);
  113. m_rootParams[2].InitAsDescriptorTable(1, &descriptorRange[2]);
  114. m_rootParams[3].InitAsDescriptorTable(1, &descriptorRange[3]);
  115. D3D12_ROOT_SIGNATURE_DESC rootSignatureDesc;
  116. rootSignatureDesc.NumParameters = nParams;
  117. rootSignatureDesc.pParameters = m_rootParams;
  118. rootSignatureDesc.NumStaticSamplers = 0;
  119. rootSignatureDesc.pStaticSamplers = nullptr;
  120. rootSignatureDesc.Flags = D3D12_ROOT_SIGNATURE_FLAG_NONE;
  121. ComPtr<ID3DBlob> serializedSignature;
  122. DX::ThrowIfFailed(D3D12SerializeRootSignature(&rootSignatureDesc, D3D_ROOT_SIGNATURE_VERSION_1, &serializedSignature, nullptr));
  123. // Create the root signature
  124. DX::ThrowIfFailed(
  125. m_deviceResources.device()->CreateRootSignature(
  126. 0,
  127. serializedSignature->GetBufferPointer(),
  128. serializedSignature->GetBufferSize(),
  129. __uuidof(ID3D12RootSignature),
  130. &m_computeRootSignature));
  131. m_computeRootSignature->SetName(L"BilinearUpscale");
  132. D3D12_COMPUTE_PIPELINE_STATE_DESC descComputePSO = {};
  133. descComputePSO.pRootSignature = m_computeRootSignature.Get();
  134. descComputePSO.CS.pShaderBytecode = computeShaderBlob->GetBufferPointer();
  135. descComputePSO.CS.BytecodeLength = computeShaderBlob->GetBufferSize();
  136. DX::ThrowIfFailed(
  137. m_deviceResources.device()->CreateComputePipelineState(&descComputePSO, __uuidof(ID3D12PipelineState), &m_computePSO));
  138. m_computePSO->SetName(L"BilinearUpscale Compute PSO");
  139. }
  140. void BilinearUpscale::update(uint32_t inputWidth, uint32_t inputHeight, uint32_t outputWidth, uint32_t outputHeight)
  141. {
  142. BilinearUpdateConfig(m_config, 0, 0, inputWidth, inputHeight, inputWidth, inputHeight, 0, 0, outputWidth, outputHeight, outputWidth, outputHeight);
  143. m_deviceResources.UploadBufferData(&m_config, sizeof(BilinearUpscaleConfig), m_constatBuffer.Get(), m_stagingBuffer.Get());
  144. m_outputWidth = outputWidth;
  145. m_outputHeight = outputHeight;
  146. }