NVSharpen.cpp 3.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. // The MIT License(MIT)
  2. //
  3. // Copyright(c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
  4. //
  5. // Permission is hereby granted, free of charge, to any person obtaining a copy of
  6. // this software and associated documentation files(the "Software"), to deal in
  7. // the Software without restriction, including without limitation the rights to
  8. // use, copy, modify, merge, publish, distribute, sublicense, and / or sell copies of
  9. // the Software, and to permit persons to whom the Software is furnished to do so,
  10. // subject to the following conditions :
  11. //
  12. // The above copyright notice and this permission notice shall be included in all
  13. // copies or substantial portions of the Software.
  14. //
  15. // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  16. // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
  17. // FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE AUTHORS OR
  18. // COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
  19. // IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
  20. // CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
  21. #include "NVSharpen.h"
  22. #include <dxgi1_4.h>
  23. #include <d3d11_3.h>
  24. #include <d3dcompiler.h>
  25. #include <iostream>
  26. #include "DXUtilities.h"
  27. #include "DeviceResources.h"
  28. #include "Utilities.h"
  29. NVSharpen::NVSharpen(DeviceResources& deviceResources, const std::vector<std::string>& shaderPaths)
  30. : m_deviceResources(deviceResources)
  31. , m_outputWidth(1)
  32. , m_outputHeight(1)
  33. {
  34. NISOptimizer opt(false, NISGPUArchitecture::NVIDIA_Generic);
  35. m_blockWidth = opt.GetOptimalBlockWidth();
  36. m_blockHeight = opt.GetOptimalBlockHeight();
  37. uint32_t threadGroupSize = opt.GetOptimalThreadGroupSize();
  38. DX::Defines defines;
  39. defines.add("NIS_SCALER", false);
  40. defines.add("NIS_HDR_MODE", uint32_t(NISHDRMode::None));
  41. defines.add("NIS_BLOCK_WIDTH", m_blockWidth);
  42. defines.add("NIS_BLOCK_HEIGHT", m_blockHeight);
  43. defines.add("NIS_THREAD_GROUP_SIZE", threadGroupSize);
  44. std::string shaderName = "NIS_Main.hlsl";
  45. std::string shaderFolder;
  46. for (auto& e : shaderPaths)
  47. {
  48. if (std::filesystem::exists(e + "/" + shaderName))
  49. {
  50. shaderFolder = e;
  51. break;
  52. }
  53. }
  54. if (shaderFolder.empty())
  55. throw std::runtime_error("Shader file not found" + shaderName);
  56. std::wstring wShaderFilename = widen(shaderFolder + "/" + "NIS_Main.hlsl");
  57. DX::IncludeHeader includeHeader({ shaderFolder });
  58. DX::CompileComputeShader(m_deviceResources.device(),
  59. wShaderFilename.c_str(),
  60. "main",
  61. &m_computeShader,
  62. defines.get(),
  63. &includeHeader);
  64. const int rowPitch = kFilterSize * 4;
  65. const int imageSize = rowPitch * kPhaseCount;
  66. m_deviceResources.createLinearClampSampler(&m_LinearClampSampler);
  67. m_deviceResources.createConstBuffer(&m_config, sizeof(NISConfig), &m_csBuffer);
  68. }
  69. void NVSharpen::update(float sharpness, uint32_t inputWidth, uint32_t inputHeight)
  70. {
  71. NVSharpenUpdateConfig(m_config, sharpness,
  72. 0, 0, inputWidth, inputHeight, inputWidth, inputHeight,
  73. 0, 0, NISHDRMode::None);
  74. m_deviceResources.updateConstBuffer(&m_config, sizeof(NISConfig), m_csBuffer.Get());
  75. m_outputWidth = inputWidth;
  76. m_outputHeight = inputHeight;
  77. }
  78. void NVSharpen::dispatch(ID3D11ShaderResourceView* const* input, ID3D11UnorderedAccessView* const* output)
  79. {
  80. auto context = m_deviceResources.context();
  81. context->CSSetShaderResources(0, 1, input);
  82. context->CSSetUnorderedAccessViews(0, 1, output, nullptr);
  83. context->CSSetSamplers(0, 1, m_LinearClampSampler.GetAddressOf());
  84. context->CSSetConstantBuffers(0, 1, m_csBuffer.GetAddressOf());
  85. context->CSSetShader(m_computeShader.Get(), nullptr, 0);
  86. context->Dispatch(UINT(std::ceil(m_outputWidth / float(m_blockWidth))), UINT(std::ceil(m_outputHeight / float(m_blockHeight))), 1);
  87. }