NVSharpen.cpp 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. // The MIT License(MIT)
  2. //
  3. // Copyright(c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  4. //
  5. // Permission is hereby granted, free of charge, to any person obtaining a copy of
  6. // this software and associated documentation files(the "Software"), to deal in
  7. // the Software without restriction, including without limitation the rights to
  8. // use, copy, modify, merge, publish, distribute, sublicense, and / or sell copies of
  9. // the Software, and to permit persons to whom the Software is furnished to do so,
  10. // subject to the following conditions :
  11. //
  12. // The above copyright notice and this permission notice shall be included in all
  13. // copies or substantial portions of the Software.
  14. //
  15. // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  16. // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
  17. // FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE AUTHORS OR
  18. // COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
  19. // IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
  20. // CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
  21. #include "NVSharpen.h"
  22. #include <iostream>
  23. #include "DXUtilities.h"
  24. #include "DeviceResources.h"
  25. #include "Utilities.h"
  26. NVSharpen::NVSharpen(DeviceResources& deviceResources, const std::vector<std::string>& shaderPaths)
  27. : m_deviceResources(deviceResources)
  28. , m_outputWidth(1)
  29. , m_outputHeight(1)
  30. {
  31. std::string shaderName = "NIS_Main.hlsl";
  32. std::string shaderPath;
  33. for (auto& e : shaderPaths)
  34. {
  35. if (std::filesystem::exists(e + "/" + shaderName))
  36. {
  37. shaderPath = e + "/" + shaderName;
  38. break;
  39. }
  40. }
  41. if (shaderPath.empty())
  42. throw std::runtime_error("Shader file not found" + shaderName);
  43. NISOptimizer opt(false, NISGPUArchitecture::NVIDIA_Generic);
  44. m_blockWidth = opt.GetOptimalBlockWidth();
  45. m_blockHeight = opt.GetOptimalBlockHeight();
  46. uint32_t threadGroupSize = opt.GetOptimalThreadGroupSize();
  47. std::wstring wNIS_BLOCK_WIDTH = widen(toStr(m_blockWidth));
  48. std::wstring wNIS_BLOCK_HEIGHT = widen(toStr(m_blockHeight));
  49. std::wstring wNIS_THREAD_GROUP_SIZE = widen(toStr(threadGroupSize));
  50. std::wstring wNIS_HDR_MODE = widen(toStr(uint32_t(NISHDRMode::None)));
  51. std::vector<DxcDefine> defines {
  52. {L"NIS_SCALER", L"0"},
  53. {L"NIS_HDR_MODE", wNIS_HDR_MODE.c_str()},
  54. {L"NIS_USE_HALF_PRECISION", L"1"},
  55. {L"NIS_HLSL_6_2", L"1"},
  56. {L"NIS_BLOCK_WIDTH", wNIS_BLOCK_WIDTH.c_str()},
  57. {L"NIS_BLOCK_HEIGHT", wNIS_BLOCK_WIDTH.c_str()},
  58. {L"NIS_THREAD_GROUP_SIZE", wNIS_BLOCK_HEIGHT.c_str()},
  59. };
  60. ComPtr<IDxcLibrary> library;
  61. DX::ThrowIfFailed(DxcCreateInstance(CLSID_DxcLibrary, __uuidof(IDxcLibrary), &library));
  62. ComPtr<IDxcCompiler> compiler;
  63. DX::ThrowIfFailed(DxcCreateInstance(CLSID_DxcCompiler, __uuidof(IDxcCompiler), &compiler));
  64. std::wstring wShaderFilename = widen(shaderPath);
  65. uint32_t codePage = CP_UTF8;
  66. ComPtr<IDxcBlobEncoding> sourceBlob;
  67. DX::ThrowIfFailed(library->CreateBlobFromFile(wShaderFilename.c_str(), &codePage, &sourceBlob));
  68. ComPtr<IDxcIncludeHandler> includeHandler;
  69. library->CreateIncludeHandler(&includeHandler);
  70. std::vector<LPCWSTR> args{ L"-O3", L"-enable-16bit-types" };
  71. ComPtr<IDxcOperationResult> result;
  72. HRESULT hr = compiler->Compile(sourceBlob.Get(), wShaderFilename.c_str(), L"main", L"cs_6_2", args.data(), uint32_t(args.size()),
  73. defines.data(), uint32_t(defines.size()), includeHandler.Get(), &result);
  74. if (SUCCEEDED(hr))
  75. result->GetStatus(&hr);
  76. if (FAILED(hr))
  77. {
  78. if (result)
  79. {
  80. ComPtr<IDxcBlobEncoding> errorsBlob;
  81. hr = result->GetErrorBuffer(&errorsBlob);
  82. if (SUCCEEDED(hr) && errorsBlob)
  83. {
  84. wprintf(L"Compilation failed with errors:\n%hs\n", (const char*)errorsBlob->GetBufferPointer());
  85. }
  86. }
  87. DX::ThrowIfFailed(hr);
  88. }
  89. ComPtr<IDxcBlob> computeShaderBlob;
  90. result->GetResult(&computeShaderBlob);
  91. m_deviceResources.CreateBuffer(sizeof(NISConfig), D3D12_HEAP_TYPE_UPLOAD, D3D12_RESOURCE_STATE_GENERIC_READ, &m_stagingBuffer);
  92. m_deviceResources.CreateBuffer(sizeof(NISConfig), D3D12_HEAP_TYPE_DEFAULT, D3D12_RESOURCE_STATE_COMMON, &m_constatBuffer);
  93. constexpr uint32_t nParams = 4;
  94. CD3DX12_DESCRIPTOR_RANGE descriptorRange[nParams] = {};
  95. descriptorRange[0] = CD3DX12_DESCRIPTOR_RANGE(D3D12_DESCRIPTOR_RANGE_TYPE_CBV, 1, 0);
  96. descriptorRange[1] = CD3DX12_DESCRIPTOR_RANGE(D3D12_DESCRIPTOR_RANGE_TYPE_SAMPLER, 1, 0);
  97. descriptorRange[2] = CD3DX12_DESCRIPTOR_RANGE(D3D12_DESCRIPTOR_RANGE_TYPE_SRV, 1, 0);
  98. descriptorRange[3] = CD3DX12_DESCRIPTOR_RANGE(D3D12_DESCRIPTOR_RANGE_TYPE_UAV, 1, 0);
  99. CD3DX12_ROOT_PARAMETER m_rootParams[nParams] = {};
  100. m_rootParams[0].InitAsDescriptorTable(1, &descriptorRange[0]);
  101. m_rootParams[1].InitAsDescriptorTable(1, &descriptorRange[1]);
  102. m_rootParams[2].InitAsDescriptorTable(1, &descriptorRange[2]);
  103. m_rootParams[3].InitAsDescriptorTable(1, &descriptorRange[3]);
  104. D3D12_ROOT_SIGNATURE_DESC rootSignatureDesc;
  105. rootSignatureDesc.NumParameters = nParams;
  106. rootSignatureDesc.pParameters = m_rootParams;
  107. rootSignatureDesc.NumStaticSamplers = 0;
  108. rootSignatureDesc.pStaticSamplers = nullptr;
  109. rootSignatureDesc.Flags = D3D12_ROOT_SIGNATURE_FLAG_NONE;
  110. ComPtr<ID3DBlob> serializedSignature;
  111. DX::ThrowIfFailed(D3D12SerializeRootSignature(&rootSignatureDesc, D3D_ROOT_SIGNATURE_VERSION_1, &serializedSignature, nullptr));
  112. // Create the root signature
  113. DX::ThrowIfFailed(m_deviceResources.device()->CreateRootSignature(0, serializedSignature->GetBufferPointer(),serializedSignature->GetBufferSize(),
  114. __uuidof(ID3D12RootSignature),&m_computeRootSignature));
  115. m_computeRootSignature->SetName(L"NVSharpen");
  116. // Create compute pipeline state
  117. D3D12_COMPUTE_PIPELINE_STATE_DESC descComputePSO = {};
  118. descComputePSO.pRootSignature = m_computeRootSignature.Get();
  119. descComputePSO.CS.pShaderBytecode = computeShaderBlob->GetBufferPointer();
  120. descComputePSO.CS.BytecodeLength = computeShaderBlob->GetBufferSize();
  121. DX::ThrowIfFailed(m_deviceResources.device()->CreateComputePipelineState(&descComputePSO, __uuidof(ID3D12PipelineState), &m_computePSO));
  122. m_computePSO->SetName(L"NVSharpen Compute PSO");
  123. }
  124. void NVSharpen::update(float sharpness, uint32_t inputWidth, uint32_t inputHeight)
  125. {
  126. NVSharpenUpdateConfig(m_config, sharpness,
  127. 0, 0, inputWidth, inputHeight, inputWidth, inputHeight,
  128. 0, 0, NISHDRMode::None);
  129. m_deviceResources.UploadBufferData(&m_config, sizeof(NISConfig), m_constatBuffer.Get(), m_stagingBuffer.Get());
  130. m_outputWidth = inputWidth;
  131. m_outputHeight = inputHeight;
  132. }