Skip to content

Commit

Permalink
Moved more of common logic to DeviceContextBase::SetPipelineState
Browse files Browse the repository at this point in the history
  • Loading branch information
TheMostDiligent committed Jan 9, 2025
1 parent 9014dee commit 41e6218
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 81 deletions.
35 changes: 27 additions & 8 deletions Graphics/GraphicsEngine/include/DeviceContextBase.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2019-2024 Diligent Graphics LLC
* Copyright 2019-2025 Diligent Graphics LLC
* Copyright 2015-2019 Egor Yusov
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -465,7 +465,7 @@ class DeviceContextBase : public ObjectBase<typename EngineImplTraits::DeviceCon

inline bool SetStencilRef(Uint32 StencilRef, int Dummy);

inline void SetPipelineState(RefCntAutoPtr<PipelineStateImplType> pPipelineState, int /*Dummy*/);
inline bool SetPipelineState(IPipelineState* pPipelineState, const INTERFACE_ID& IID_PSOImpl);

/// Clears all cached resources
inline void ClearStateCache();
Expand Down Expand Up @@ -766,17 +766,36 @@ inline void DeviceContextBase<ImplementationTraits>::SetVertexBuffers(
}

template <typename ImplementationTraits>
inline void DeviceContextBase<ImplementationTraits>::SetPipelineState(
RefCntAutoPtr<PipelineStateImplType> pPipelineState,
int /*Dummy*/)
inline bool DeviceContextBase<ImplementationTraits>::SetPipelineState(
IPipelineState* pPipelineState,
const INTERFACE_ID& IID_PSOImpl)
{
if (pPipelineState == nullptr)
{
DEV_ERROR("Pipeline state must not be null");
return false;
}

DVP_CHECK_QUEUE_TYPE_COMPATIBILITY(COMMAND_QUEUE_TYPE_COMPUTE, "SetPipelineState");

DEV_CHECK_ERR((pPipelineState->GetDesc().ImmediateContextMask & (Uint64{1} << GetExecutionCtxId())) != 0,
"PSO '", pPipelineState->GetDesc().Name, "' can't be used in device context '", m_Desc.Name, "'.");
DEV_CHECK_ERR(pPipelineState->GetStatus() == PIPELINE_STATE_STATUS_READY, "PSO '", pPipelineState->GetDesc().Name, "' is not ready. Use GetStatus() to check the pipeline status.");

m_pPipelineState = std::move(pPipelineState);
// Check that the PSO is ready before querying the implementation.
DEV_CHECK_ERR(pPipelineState->GetStatus() == PIPELINE_STATE_STATUS_READY, "PSO '", pPipelineState->GetDesc().Name,
"' is not ready. Use GetStatus() to check the pipeline status.");

// Note that pPipelineStateImpl may not be the same as pPipelineState (for example, if pPipelineState
// is a reloadable pipeline).
RefCntAutoPtr<PipelineStateImplType> pPipelineStateImpl{pPipelineState, IID_PSOImpl};
VERIFY(pPipelineStateImpl != nullptr, "Unknown pipeline state object implementation");
if (PipelineStateImplType::IsSameObject(m_pPipelineState, pPipelineStateImpl))
return false;

m_pPipelineState = std::move(pPipelineStateImpl);
++m_Stats.CommandCounters.SetPipelineState;

return true;
}

template <typename ImplementationTraits>
Expand Down Expand Up @@ -2200,7 +2219,7 @@ inline Uint32 GetPrimitiveCount(PRIMITIVE_TOPOLOGY Topology, Uint32 Elements)
UNEXPECTED("Undefined primitive topology");
return 0;

// clang-format off
// clang-format off
case PRIMITIVE_TOPOLOGY_TRIANGLE_LIST: return Elements / 3;
case PRIMITIVE_TOPOLOGY_TRIANGLE_STRIP: return (std::max)(Elements, 2u) - 2;
case PRIMITIVE_TOPOLOGY_POINT_LIST: return Elements;
Expand Down
19 changes: 8 additions & 11 deletions Graphics/GraphicsEngineD3D11/src/DeviceContextD3D11Impl.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2019-2024 Diligent Graphics LLC
* Copyright 2019-2025 Diligent Graphics LLC
* Copyright 2015-2019 Egor Yusov
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -80,16 +80,13 @@ void DeviceContextD3D11Impl::Begin(Uint32 ImmediateContextId)

void DeviceContextD3D11Impl::SetPipelineState(IPipelineState* pPipelineState)
{
RefCntAutoPtr<PipelineStateD3D11Impl> pPipelineStateD3D11{pPipelineState, PipelineStateD3D11Impl::IID_InternalImpl};
VERIFY(pPipelineState == nullptr || pPipelineStateD3D11 != nullptr, "Unknown pipeline state object implementation");
if (PipelineStateD3D11Impl::IsSameObject(m_pPipelineState, pPipelineStateD3D11))
if (!TDeviceContextBase::SetPipelineState(pPipelineState, PipelineStateD3D11Impl::IID_InternalImpl))
return;

TDeviceContextBase::SetPipelineState(std::move(pPipelineStateD3D11), 0 /*Dummy*/);
const auto& Desc = m_pPipelineState->GetDesc();
const PipelineStateDesc& Desc = m_pPipelineState->GetDesc();
if (Desc.PipelineType == PIPELINE_TYPE_COMPUTE)
{
auto* pd3d11CS = m_pPipelineState->GetD3D11ComputeShader();
ID3D11ComputeShader* pd3d11CS = m_pPipelineState->GetD3D11ComputeShader();
if (pd3d11CS == nullptr)
{
LOG_ERROR("Compute shader is not set in the pipeline");
Expand Down Expand Up @@ -118,13 +115,13 @@ void DeviceContextD3D11Impl::SetPipelineState(IPipelineState* pPipelineState)
COMMIT_SHADER(DS, DomainShader);
#undef COMMIT_SHADER

const auto& GraphicsPipeline = m_pPipelineState->GetGraphicsPipelineDesc();
const GraphicsPipelineDesc& GraphicsPipeline = m_pPipelineState->GetGraphicsPipelineDesc();

m_pd3d11DeviceContext->OMSetBlendState(m_pPipelineState->GetD3D11BlendState(), m_BlendFactors, GraphicsPipeline.SampleMask);
m_pd3d11DeviceContext->RSSetState(m_pPipelineState->GetD3D11RasterizerState());
m_pd3d11DeviceContext->OMSetDepthStencilState(m_pPipelineState->GetD3D11DepthStencilState(), m_StencilRef);

auto* pd3d11InputLayout = m_pPipelineState->GetD3D11InputLayout();
ID3D11InputLayout* pd3d11InputLayout = m_pPipelineState->GetD3D11InputLayout();
// It is safe to perform raw pointer comparison as the device context
// keeps bound input layout alive
if (m_CommittedD3D11InputLayout != pd3d11InputLayout)
Expand All @@ -133,7 +130,7 @@ void DeviceContextD3D11Impl::SetPipelineState(IPipelineState* pPipelineState)
m_CommittedD3D11InputLayout = pd3d11InputLayout;
}

auto PrimTopology = GraphicsPipeline.PrimitiveTopology;
PRIMITIVE_TOPOLOGY PrimTopology = GraphicsPipeline.PrimitiveTopology;
if (m_CommittedPrimitiveTopology != PrimTopology)
{
m_CommittedPrimitiveTopology = PrimTopology;
Expand All @@ -149,7 +146,7 @@ void DeviceContextD3D11Impl::SetPipelineState(IPipelineState* pPipelineState)
Uint32 DvpCompatibleSRBCount = 0;
PrepareCommittedResources(m_BindInfo, DvpCompatibleSRBCount);

const auto ActiveStages = m_pPipelineState->GetActiveShaderStages();
const SHADER_TYPE ActiveStages = m_pPipelineState->GetActiveShaderStages();
if (m_BindInfo.ActiveStages != ActiveStages)
{
m_BindInfo.ActiveStages = ActiveStages;
Expand Down
40 changes: 19 additions & 21 deletions Graphics/GraphicsEngineD3D12/src/DeviceContextD3D12Impl.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2019-2024 Diligent Graphics LLC
* Copyright 2019-2025 Diligent Graphics LLC
* Copyright 2015-2019 Egor Yusov
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -264,16 +264,15 @@ void DeviceContextD3D12Impl::Begin(Uint32 ImmediateContextId)

void DeviceContextD3D12Impl::SetPipelineState(IPipelineState* pPipelineState)
{
RefCntAutoPtr<PipelineStateD3D12Impl> pPipelineStateD3D12{pPipelineState, PipelineStateD3D12Impl::IID_InternalImpl};
VERIFY(pPipelineState == nullptr || pPipelineStateD3D12 != nullptr, "Unknown pipeline state object implementation");
if (PipelineStateD3D12Impl::IsSameObject(m_pPipelineState, pPipelineStateD3D12))
RefCntAutoPtr<PipelineStateD3D12Impl> pOldPipeline = m_pPipelineState;
if (!TDeviceContextBase::SetPipelineState(pPipelineState, PipelineStateD3D12Impl::IID_InternalImpl))
return;

const auto& PSODesc = pPipelineStateD3D12->GetDesc();
const PipelineStateDesc& PSODesc = m_pPipelineState->GetDesc();

bool CommitStates = false;
bool CommitScissor = false;
if (!m_pPipelineState)
if (!pOldPipeline)
{
// If no pipeline state is bound, we are working with the fresh command
// list. We have to commit the states set in the context that are not
Expand All @@ -282,22 +281,21 @@ void DeviceContextD3D12Impl::SetPipelineState(IPipelineState* pPipelineState)
}
else
{
const auto& OldPSODesc = m_pPipelineState->GetDesc();
const PipelineStateDesc& OldPSODesc = pOldPipeline->GetDesc();
// Commit all graphics states when switching from compute pipeline
// This is necessary because if the command list had been flushed
// and the first PSO set on the command list was a compute pipeline,
// the states would otherwise never be committed (since m_pPipelineState != nullptr)
CommitStates = !OldPSODesc.IsAnyGraphicsPipeline();
// We also need to update scissor rect if ScissorEnable state has changed
if (OldPSODesc.IsAnyGraphicsPipeline() && PSODesc.IsAnyGraphicsPipeline())
CommitScissor = m_pPipelineState->GetGraphicsPipelineDesc().RasterizerDesc.ScissorEnable != pPipelineStateD3D12->GetGraphicsPipelineDesc().RasterizerDesc.ScissorEnable;
CommitScissor = pOldPipeline->GetGraphicsPipelineDesc().RasterizerDesc.ScissorEnable != m_pPipelineState->GetGraphicsPipelineDesc().RasterizerDesc.ScissorEnable;
pOldPipeline.Release();
}

TDeviceContextBase::SetPipelineState(std::move(pPipelineStateD3D12), 0 /*Dummy*/);

auto& CmdCtx = GetCmdContext();
auto& RootInfo = GetRootTableInfo(PSODesc.PipelineType);
auto* pd3d12RootSig = m_pPipelineState->GetD3D12RootSignature();
CommandContext& CmdCtx = GetCmdContext();
RootTableInfo& RootInfo = GetRootTableInfo(PSODesc.PipelineType);
ID3D12RootSignature* pd3d12RootSig = m_pPipelineState->GetD3D12RootSignature();

if (RootInfo.pd3d12RootSig != pd3d12RootSig)
{
Expand All @@ -316,15 +314,15 @@ void DeviceContextD3D12Impl::SetPipelineState(IPipelineState* pPipelineState)
case PIPELINE_TYPE_GRAPHICS:
case PIPELINE_TYPE_MESH:
{
auto& GraphicsPipeline = m_pPipelineState->GetGraphicsPipelineDesc();
auto& GraphicsCtx = CmdCtx.AsGraphicsContext();
auto* pd3d12PSO = m_pPipelineState->GetD3D12PipelineState();
const GraphicsPipelineDesc& GraphicsPipeline = m_pPipelineState->GetGraphicsPipelineDesc();
GraphicsContext& GraphicsCtx = CmdCtx.AsGraphicsContext();
ID3D12PipelineState* pd3d12PSO = m_pPipelineState->GetD3D12PipelineState();
GraphicsCtx.SetPipelineState(pd3d12PSO);
GraphicsCtx.SetGraphicsRootSignature(pd3d12RootSig);

if (PSODesc.PipelineType == PIPELINE_TYPE_GRAPHICS)
{
auto D3D12Topology = TopologyToD3D12Topology(GraphicsPipeline.PrimitiveTopology);
D3D12_PRIMITIVE_TOPOLOGY D3D12Topology = TopologyToD3D12Topology(GraphicsPipeline.PrimitiveTopology);
GraphicsCtx.SetPrimitiveTopology(D3D12Topology);
}

Expand All @@ -347,16 +345,16 @@ void DeviceContextD3D12Impl::SetPipelineState(IPipelineState* pPipelineState)
}
case PIPELINE_TYPE_COMPUTE:
{
auto* pd3d12PSO = m_pPipelineState->GetD3D12PipelineState();
auto& CompCtx = CmdCtx.AsComputeContext();
ID3D12PipelineState* pd3d12PSO = m_pPipelineState->GetD3D12PipelineState();
ComputeContext& CompCtx = CmdCtx.AsComputeContext();
CompCtx.SetPipelineState(pd3d12PSO);
CompCtx.SetComputeRootSignature(pd3d12RootSig);
break;
}
case PIPELINE_TYPE_RAY_TRACING:
{
auto* pd3d12SO = m_pPipelineState->GetD3D12StateObject();
auto& RTCtx = CmdCtx.AsGraphicsContext4();
ID3D12StateObject* pd3d12SO = m_pPipelineState->GetD3D12StateObject();
GraphicsContext4& RTCtx = CmdCtx.AsGraphicsContext4();
RTCtx.SetRayTracingPipelineState(pd3d12SO);
RTCtx.SetComputeRootSignature(pd3d12RootSig);
break;
Expand Down
24 changes: 9 additions & 15 deletions Graphics/GraphicsEngineOpenGL/src/DeviceContextGLImpl.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2019-2024 Diligent Graphics LLC
* Copyright 2019-2025 Diligent Graphics LLC
* Copyright 2015-2019 Egor Yusov
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -87,25 +87,19 @@ void DeviceContextGLImpl::Begin(Uint32 ImmediateContextId)

void DeviceContextGLImpl::SetPipelineState(IPipelineState* pPipelineState)
{
VERIFY_EXPR(pPipelineState != nullptr);

RefCntAutoPtr<PipelineStateGLImpl> pPipelineStateGLImpl{pPipelineState, PipelineStateGLImpl::IID_InternalImpl};
VERIFY(pPipelineState == nullptr || pPipelineStateGLImpl != nullptr, "Unknown pipeline state object implementation");
if (PipelineStateGLImpl::IsSameObject(m_pPipelineState, pPipelineStateGLImpl))
if (!TDeviceContextBase::SetPipelineState(pPipelineState, PipelineStateGLImpl::IID_InternalImpl))
return;

TDeviceContextBase::SetPipelineState(std::move(pPipelineStateGLImpl), 0 /*Dummy*/);

const auto& Desc = m_pPipelineState->GetDesc();
const PipelineStateDesc& Desc = m_pPipelineState->GetDesc();
if (Desc.PipelineType == PIPELINE_TYPE_COMPUTE)
{
}
else if (Desc.PipelineType == PIPELINE_TYPE_GRAPHICS)
{
const auto& GraphicsPipeline = m_pPipelineState->GetGraphicsPipelineDesc();
const GraphicsPipelineDesc& GraphicsPipeline = m_pPipelineState->GetGraphicsPipelineDesc();
// Set rasterizer state
{
const auto& RasterizerDesc = GraphicsPipeline.RasterizerDesc;
const RasterizerStateDesc& RasterizerDesc = GraphicsPipeline.RasterizerDesc;

m_ContextState.SetFillMode(RasterizerDesc.FillMode);
m_ContextState.SetCullMode(RasterizerDesc.CullMode);
Expand All @@ -126,13 +120,13 @@ void DeviceContextGLImpl::SetPipelineState(IPipelineState* pPipelineState)

// Set blend state
{
const auto& BSDsc = GraphicsPipeline.BlendDesc;
const BlendStateDesc& BSDsc = GraphicsPipeline.BlendDesc;
m_ContextState.SetBlendState(BSDsc, GraphicsPipeline.SampleMask);
}

// Set depth-stencil state
{
const auto& DepthStencilDesc = GraphicsPipeline.DepthStencilDesc;
const DepthStencilStateDesc& DepthStencilDesc = GraphicsPipeline.DepthStencilDesc;

m_ContextState.EnableDepthTest(DepthStencilDesc.DepthEnable);
m_ContextState.EnableDepthWrites(DepthStencilDesc.DepthWriteEnable);
Expand All @@ -141,13 +135,13 @@ void DeviceContextGLImpl::SetPipelineState(IPipelineState* pPipelineState)
m_ContextState.SetStencilWriteMask(DepthStencilDesc.StencilWriteMask);

{
const auto& FrontFace = DepthStencilDesc.FrontFace;
const StencilOpDesc& FrontFace = DepthStencilDesc.FrontFace;
m_ContextState.SetStencilFunc(GL_FRONT, FrontFace.StencilFunc, m_StencilRef, DepthStencilDesc.StencilReadMask);
m_ContextState.SetStencilOp(GL_FRONT, FrontFace.StencilFailOp, FrontFace.StencilDepthFailOp, FrontFace.StencilPassOp);
}

{
const auto& BackFace = DepthStencilDesc.BackFace;
const StencilOpDesc& BackFace = DepthStencilDesc.BackFace;
m_ContextState.SetStencilFunc(GL_BACK, BackFace.StencilFunc, m_StencilRef, DepthStencilDesc.StencilReadMask);
m_ContextState.SetStencilOp(GL_BACK, BackFace.StencilFailOp, BackFace.StencilDepthFailOp, BackFace.StencilPassOp);
}
Expand Down
Loading

0 comments on commit 41e6218

Please sign in to comment.