use crate::draw_quad::DrawQuad; use crate::error::assume_d3d12_init; use crate::error::FilterChainError::Direct3DOperationError; use crate::{error, util}; use librashader_cache::{cache_pipeline, cache_shader_object}; use librashader_reflect::back::dxil::DxilObject; use librashader_reflect::back::hlsl::CrossHlslContext; use librashader_reflect::back::ShaderCompilerOutput; use std::mem::ManuallyDrop; use std::ops::Deref; use widestring::u16cstr; use windows::core::ComInterface; use windows::Win32::Foundation::BOOL; use windows::Win32::Graphics::Direct3D::Dxc::{ CLSID_DxcLibrary, DxcCreateInstance, IDxcBlob, IDxcCompiler, IDxcUtils, IDxcValidator, DXC_CP, }; use windows::Win32::Graphics::Direct3D12::{ D3D12SerializeVersionedRootSignature, ID3D12Device, ID3D12PipelineState, ID3D12RootSignature, D3D12_BLEND_DESC, D3D12_BLEND_INV_SRC_ALPHA, D3D12_BLEND_OP_ADD, D3D12_BLEND_SRC_ALPHA, D3D12_CACHED_PIPELINE_STATE, D3D12_COLOR_WRITE_ENABLE_ALL, D3D12_CULL_MODE_NONE, D3D12_DESCRIPTOR_RANGE1, D3D12_DESCRIPTOR_RANGE_FLAGS, D3D12_DESCRIPTOR_RANGE_FLAG_DATA_VOLATILE, D3D12_DESCRIPTOR_RANGE_FLAG_DESCRIPTORS_VOLATILE, D3D12_DESCRIPTOR_RANGE_TYPE_SAMPLER, D3D12_DESCRIPTOR_RANGE_TYPE_SRV, D3D12_FILL_MODE_SOLID, D3D12_GRAPHICS_PIPELINE_STATE_DESC, D3D12_INPUT_LAYOUT_DESC, D3D12_LOGIC_OP_NOOP, D3D12_PRIMITIVE_TOPOLOGY_TYPE_TRIANGLE, D3D12_RASTERIZER_DESC, D3D12_RENDER_TARGET_BLEND_DESC, D3D12_ROOT_DESCRIPTOR1, D3D12_ROOT_DESCRIPTOR_FLAG_NONE, D3D12_ROOT_DESCRIPTOR_TABLE1, D3D12_ROOT_PARAMETER1, D3D12_ROOT_PARAMETER1_0, D3D12_ROOT_PARAMETER_TYPE_CBV, D3D12_ROOT_PARAMETER_TYPE_DESCRIPTOR_TABLE, D3D12_ROOT_SIGNATURE_DESC1, D3D12_ROOT_SIGNATURE_FLAG_ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT, D3D12_SHADER_BYTECODE, D3D12_SHADER_VISIBILITY_ALL, D3D12_SHADER_VISIBILITY_PIXEL, D3D12_VERSIONED_ROOT_SIGNATURE_DESC, D3D12_VERSIONED_ROOT_SIGNATURE_DESC_0, D3D_ROOT_SIGNATURE_VERSION_1_1, }; use windows::Win32::Graphics::Dxgi::Common::{DXGI_FORMAT, DXGI_FORMAT_UNKNOWN, DXGI_SAMPLE_DESC}; pub struct D3D12GraphicsPipeline { pub(crate) handle: ID3D12PipelineState, pub(crate) format: DXGI_FORMAT, vertex: Vec, fragment: Vec, cache_disabled: bool, } const D3D12_SLANG_ROOT_PARAMETERS: &[D3D12_ROOT_PARAMETER1; 4] = &[ // srvs D3D12_ROOT_PARAMETER1 { ParameterType: D3D12_ROOT_PARAMETER_TYPE_DESCRIPTOR_TABLE, Anonymous: D3D12_ROOT_PARAMETER1_0 { DescriptorTable: D3D12_ROOT_DESCRIPTOR_TABLE1 { NumDescriptorRanges: 1, pDescriptorRanges: &D3D12_DESCRIPTOR_RANGE1 { RangeType: D3D12_DESCRIPTOR_RANGE_TYPE_SRV, NumDescriptors: 16, BaseShaderRegister: 0, RegisterSpace: 0, Flags: D3D12_DESCRIPTOR_RANGE_FLAGS( D3D12_DESCRIPTOR_RANGE_FLAG_DATA_VOLATILE.0 | D3D12_DESCRIPTOR_RANGE_FLAG_DESCRIPTORS_VOLATILE.0, ), OffsetInDescriptorsFromTableStart: 0, }, }, }, ShaderVisibility: D3D12_SHADER_VISIBILITY_PIXEL, }, // samplers D3D12_ROOT_PARAMETER1 { ParameterType: D3D12_ROOT_PARAMETER_TYPE_DESCRIPTOR_TABLE, Anonymous: D3D12_ROOT_PARAMETER1_0 { DescriptorTable: D3D12_ROOT_DESCRIPTOR_TABLE1 { NumDescriptorRanges: 1, pDescriptorRanges: &D3D12_DESCRIPTOR_RANGE1 { RangeType: D3D12_DESCRIPTOR_RANGE_TYPE_SAMPLER, NumDescriptors: 16, BaseShaderRegister: 0, RegisterSpace: 0, Flags: D3D12_DESCRIPTOR_RANGE_FLAG_DESCRIPTORS_VOLATILE, OffsetInDescriptorsFromTableStart: 0, }, }, }, ShaderVisibility: D3D12_SHADER_VISIBILITY_PIXEL, }, // UBO D3D12_ROOT_PARAMETER1 { ParameterType: D3D12_ROOT_PARAMETER_TYPE_CBV, Anonymous: D3D12_ROOT_PARAMETER1_0 { Descriptor: D3D12_ROOT_DESCRIPTOR1 { ShaderRegister: 0, RegisterSpace: 0, Flags: D3D12_ROOT_DESCRIPTOR_FLAG_NONE, }, }, ShaderVisibility: D3D12_SHADER_VISIBILITY_ALL, }, // push D3D12_ROOT_PARAMETER1 { ParameterType: D3D12_ROOT_PARAMETER_TYPE_CBV, Anonymous: D3D12_ROOT_PARAMETER1_0 { Descriptor: D3D12_ROOT_DESCRIPTOR1 { ShaderRegister: 1, RegisterSpace: 0, Flags: D3D12_ROOT_DESCRIPTOR_FLAG_NONE, }, }, ShaderVisibility: D3D12_SHADER_VISIBILITY_ALL, }, ]; const D3D12_SLANG_VERSIONED_ROOT_SIGNATURE: &D3D12_VERSIONED_ROOT_SIGNATURE_DESC = &D3D12_VERSIONED_ROOT_SIGNATURE_DESC { Version: D3D_ROOT_SIGNATURE_VERSION_1_1, Anonymous: D3D12_VERSIONED_ROOT_SIGNATURE_DESC_0 { Desc_1_1: D3D12_ROOT_SIGNATURE_DESC1 { NumParameters: D3D12_SLANG_ROOT_PARAMETERS.len() as u32, pParameters: D3D12_SLANG_ROOT_PARAMETERS.as_ptr(), NumStaticSamplers: 0, pStaticSamplers: std::ptr::null(), Flags: D3D12_ROOT_SIGNATURE_FLAG_ALLOW_INPUT_ASSEMBLER_INPUT_LAYOUT, }, }, }; pub struct D3D12RootSignature { pub(crate) handle: ID3D12RootSignature, } impl D3D12RootSignature { pub fn new(device: &ID3D12Device) -> error::Result { let signature = unsafe { let mut rs_blob = None; D3D12SerializeVersionedRootSignature( D3D12_SLANG_VERSIONED_ROOT_SIGNATURE, &mut rs_blob, None, )?; assume_d3d12_init!(rs_blob, "D3D12SerializeVersionedRootSignature"); let blob = std::slice::from_raw_parts( rs_blob.GetBufferPointer().cast(), rs_blob.GetBufferSize(), ); let root_signature: ID3D12RootSignature = device.CreateRootSignature(0, blob)?; root_signature }; Ok(D3D12RootSignature { handle: signature }) } } impl D3D12GraphicsPipeline { pub fn new_from_blobs( device: &ID3D12Device, vertex_dxil: IDxcBlob, fragment_dxil: IDxcBlob, root_signature: &D3D12RootSignature, render_format: DXGI_FORMAT, disable_cache: bool, ) -> error::Result { let input_element = DrawQuad::get_spirv_cross_vbo_desc(); let pipeline_state: ID3D12PipelineState = unsafe { let pipeline_desc = D3D12_GRAPHICS_PIPELINE_STATE_DESC { pRootSignature: ManuallyDrop::new(Some(root_signature.handle.clone())), VS: D3D12_SHADER_BYTECODE { pShaderBytecode: vertex_dxil.GetBufferPointer(), BytecodeLength: vertex_dxil.GetBufferSize(), }, PS: D3D12_SHADER_BYTECODE { pShaderBytecode: fragment_dxil.GetBufferPointer(), BytecodeLength: fragment_dxil.GetBufferSize(), }, StreamOutput: Default::default(), BlendState: D3D12_BLEND_DESC { RenderTarget: [ D3D12_RENDER_TARGET_BLEND_DESC { BlendEnable: BOOL::from(false), LogicOpEnable: BOOL::from(false), SrcBlend: D3D12_BLEND_SRC_ALPHA, DestBlend: D3D12_BLEND_INV_SRC_ALPHA, BlendOp: D3D12_BLEND_OP_ADD, SrcBlendAlpha: D3D12_BLEND_SRC_ALPHA, DestBlendAlpha: D3D12_BLEND_INV_SRC_ALPHA, BlendOpAlpha: D3D12_BLEND_OP_ADD, LogicOp: D3D12_LOGIC_OP_NOOP, RenderTargetWriteMask: D3D12_COLOR_WRITE_ENABLE_ALL.0 as u8, }, Default::default(), Default::default(), Default::default(), Default::default(), Default::default(), Default::default(), Default::default(), ], ..Default::default() }, SampleMask: u32::MAX, RasterizerState: D3D12_RASTERIZER_DESC { FillMode: D3D12_FILL_MODE_SOLID, CullMode: D3D12_CULL_MODE_NONE, ..Default::default() }, DepthStencilState: Default::default(), InputLayout: D3D12_INPUT_LAYOUT_DESC { pInputElementDescs: input_element.as_ptr(), NumElements: input_element.len() as u32, }, PrimitiveTopologyType: D3D12_PRIMITIVE_TOPOLOGY_TYPE_TRIANGLE, NumRenderTargets: 1, RTVFormats: [ render_format, DXGI_FORMAT_UNKNOWN, DXGI_FORMAT_UNKNOWN, DXGI_FORMAT_UNKNOWN, DXGI_FORMAT_UNKNOWN, DXGI_FORMAT_UNKNOWN, DXGI_FORMAT_UNKNOWN, DXGI_FORMAT_UNKNOWN, ], SampleDesc: DXGI_SAMPLE_DESC { Count: 1, Quality: 0, }, NodeMask: 0, ..Default::default() }; let pipeline = cache_pipeline( "d3d12", &[&vertex_dxil, &fragment_dxil, &render_format.0], |cached: Option>| { if let Some(cached) = cached { let pipeline_desc = D3D12_GRAPHICS_PIPELINE_STATE_DESC { CachedPSO: D3D12_CACHED_PIPELINE_STATE { pCachedBlob: cached.as_ptr().cast(), CachedBlobSizeInBytes: cached.len(), }, pRootSignature: ManuallyDrop::new(Some(root_signature.handle.clone())), ..pipeline_desc }; let pipeline_state = device.CreateGraphicsPipelineState(&pipeline_desc); drop(ManuallyDrop::into_inner(pipeline_desc.pRootSignature)); pipeline_state } else { device.CreateGraphicsPipelineState(&pipeline_desc) } }, |pso: &ID3D12PipelineState| { let cached_pso = pso.GetCachedBlob()?; Ok(cached_pso) }, disable_cache, )?; // cleanup handle drop(ManuallyDrop::into_inner(pipeline_desc.pRootSignature)); pipeline }; unsafe { let vertex = Vec::from(std::slice::from_raw_parts( vertex_dxil.GetBufferPointer().cast(), vertex_dxil.GetBufferSize(), )); let fragment = Vec::from(std::slice::from_raw_parts( fragment_dxil.GetBufferPointer().cast(), fragment_dxil.GetBufferSize(), )); Ok(D3D12GraphicsPipeline { handle: pipeline_state, format: render_format, vertex, fragment, cache_disabled: disable_cache, }) } } pub fn recompile( &mut self, format: DXGI_FORMAT, root_sig: &D3D12RootSignature, device: &ID3D12Device, ) -> error::Result<()> { let (vertex, fragment) = unsafe { let library: IDxcUtils = DxcCreateInstance(&CLSID_DxcLibrary)?; let vertex = library.CreateBlobFromPinned( self.vertex.as_ptr().cast(), self.vertex.len() as u32, DXC_CP(0), )?; let fragment = library.CreateBlobFromPinned( self.fragment.as_ptr().cast(), self.fragment.len() as u32, DXC_CP(0), )?; (vertex, fragment) }; let mut new_pipeline = Self::new_from_blobs( device, vertex.cast()?, fragment.cast()?, root_sig, format, self.cache_disabled, )?; std::mem::swap(self, &mut new_pipeline); Ok(()) } pub fn new_from_dxil( device: &ID3D12Device, library: &IDxcUtils, validator: &IDxcValidator, shader_assembly: &ShaderCompilerOutput, root_signature: &D3D12RootSignature, render_format: DXGI_FORMAT, disable_cache: bool, ) -> error::Result { if shader_assembly.vertex.requires_runtime_data() { return Err(Direct3DOperationError( "Compiled DXIL Vertex shader needs unexpected runtime data", )); } if shader_assembly.fragment.requires_runtime_data() { return Err(Direct3DOperationError( "Compiled DXIL fragment shader needs unexpected runtime data", )); } let vertex_dxil = cache_shader_object( "dxil", &[shader_assembly.vertex.deref()], |&[source]| util::dxc_validate_shader(library, validator, source), |f| Ok(f), disable_cache, )?; let fragment_dxil = cache_shader_object( "dxil", &[shader_assembly.fragment.deref()], |&[source]| util::dxc_validate_shader(library, validator, source), |f| Ok(f), disable_cache, )?; Self::new_from_blobs( device, vertex_dxil, fragment_dxil, root_signature, render_format, disable_cache, ) } pub fn new_from_hlsl( device: &ID3D12Device, library: &IDxcUtils, dxc: &IDxcCompiler, shader_assembly: &ShaderCompilerOutput, root_signature: &D3D12RootSignature, render_format: DXGI_FORMAT, disable_cache: bool, ) -> error::Result { let vertex_dxil = cache_shader_object( "dxil", &[shader_assembly.vertex.as_bytes()], |&[source]| util::dxc_compile_shader(library, dxc, source, u16cstr!("vs_6_0")), |f| Ok(f), disable_cache, )?; let fragment_dxil = cache_shader_object( "dxil", &[shader_assembly.fragment.as_bytes()], |&[source]| util::dxc_compile_shader(library, dxc, source, u16cstr!("ps_6_0")), |f| Ok(f), disable_cache, )?; Self::new_from_blobs( device, vertex_dxil, fragment_dxil, root_signature, render_format, disable_cache, ) } }