diff --git a/librashader-runtime-d3d12/src/filter_chain.rs b/librashader-runtime-d3d12/src/filter_chain.rs index a2a208f..60c7705 100644 --- a/librashader-runtime-d3d12/src/filter_chain.rs +++ b/librashader-runtime-d3d12/src/filter_chain.rs @@ -327,7 +327,7 @@ impl FilterChainD3D12 { let mut staging_heap = unsafe { D3D12DescriptorHeap::new( device, - // add one, because technically the input image doesn't need to count + // add one, because technically the input image doesn't need to count (1 + MAX_BINDINGS_COUNT as usize) * shader_count + MIPMAP_RESERVED_WORKHEAP_DESCRIPTORS + lut_count, @@ -790,7 +790,7 @@ impl FilterChainD3D12 { let target = &self.output_framebuffers[index]; - if pass.pipeline.format != target.format { + if !pass.pipeline.has_format(target.format) { // eprintln!("recompiling final pipeline"); pass.pipeline.recompile( target.format, @@ -863,7 +863,7 @@ impl FilterChainD3D12 { if self.draw_last_pass_feedback { let feedback_target = &self.output_framebuffers[index]; - if pass.pipeline.format != feedback_target.format { + if !pass.pipeline.has_format(feedback_target.format) { // eprintln!("recompiling final pipeline"); pass.pipeline.recompile( feedback_target.format, @@ -902,7 +902,7 @@ impl FilterChainD3D12 { ); } - if pass.pipeline.format != viewport.output.format { + if !pass.pipeline.has_format(viewport.output.format) { // eprintln!("recompiling final pipeline"); pass.pipeline.recompile( viewport.output.format, diff --git a/librashader-runtime-d3d12/src/filter_pass.rs b/librashader-runtime-d3d12/src/filter_pass.rs index 70a5309..d76b926 100644 --- a/librashader-runtime-d3d12/src/filter_pass.rs +++ b/librashader-runtime-d3d12/src/filter_pass.rs @@ -145,7 +145,7 @@ impl FilterPass { vbo_type: QuadType, ) -> error::Result<()> { unsafe { - cmd.SetPipelineState(&self.pipeline.handle); + cmd.SetPipelineState(self.pipeline.pipeline_state(output.output.format)); } self.build_semantics( diff --git a/librashader-runtime-d3d12/src/graphics_pipeline.rs b/librashader-runtime-d3d12/src/graphics_pipeline.rs index 849d899..42770fe 100644 --- a/librashader-runtime-d3d12/src/graphics_pipeline.rs +++ b/librashader-runtime-d3d12/src/graphics_pipeline.rs @@ -3,9 +3,11 @@ use crate::error::assume_d3d12_init; use crate::error::FilterChainError::Direct3DOperationError; use crate::{error, util}; use librashader_cache::{cache_pipeline, cache_shader_object}; +use librashader_common::map::FastHashMap; use librashader_reflect::back::dxil::DxilObject; use librashader_reflect::back::hlsl::CrossHlslContext; use librashader_reflect::back::ShaderCompilerOutput; +use std::hash::{Hash, Hasher}; use std::mem::ManuallyDrop; use std::ops::Deref; use widestring::u16cstr; @@ -33,9 +35,18 @@ use windows::Win32::Graphics::Direct3D12::{ }; use windows::Win32::Graphics::Dxgi::Common::{DXGI_FORMAT, DXGI_FORMAT_UNKNOWN, DXGI_SAMPLE_DESC}; +// bruh why does DXGI_FORMAT not impl hash +#[repr(transparent)] +#[derive(PartialEq, Eq)] +struct HashDxgiFormat(DXGI_FORMAT); +impl Hash for HashDxgiFormat { + fn hash(&self, state: &mut H) { + self.0 .0.hash(state); + } +} + pub struct D3D12GraphicsPipeline { - pub(crate) handle: ID3D12PipelineState, - pub(crate) format: DXGI_FORMAT, + render_pipelines: FastHashMap, vertex: Vec, fragment: Vec, cache_disabled: bool, @@ -149,14 +160,14 @@ impl D3D12RootSignature { } } impl D3D12GraphicsPipeline { - pub fn new_from_blobs( + fn make_pipeline_state( device: &ID3D12Device, - vertex_dxil: IDxcBlob, - fragment_dxil: IDxcBlob, + vertex_dxil: &IDxcBlob, + fragment_dxil: &IDxcBlob, root_signature: &D3D12RootSignature, render_format: DXGI_FORMAT, disable_cache: bool, - ) -> error::Result { + ) -> error::Result { let input_element = DrawQuad::get_spirv_cross_vbo_desc(); let pipeline_state: ID3D12PipelineState = unsafe { @@ -228,7 +239,7 @@ impl D3D12GraphicsPipeline { let pipeline = cache_pipeline( "d3d12", - &[&vertex_dxil, &fragment_dxil, &render_format.0], + &[vertex_dxil, fragment_dxil, &render_format.0], |cached: Option>| { if let Some(cached) = cached { let pipeline_desc = D3D12_GRAPHICS_PIPELINE_STATE_DESC { @@ -259,6 +270,38 @@ impl D3D12GraphicsPipeline { pipeline }; + Ok(pipeline_state) + } + + pub fn pipeline_state(&self, format: DXGI_FORMAT) -> &ID3D12PipelineState { + let Some(pipeline) = self + .render_pipelines + .get(&HashDxgiFormat(format)) + .or_else(|| self.render_pipelines.values().next()) + else { + panic!("No available render pipeline found"); + }; + + pipeline + } + + pub fn new_from_blobs( + device: &ID3D12Device, + vertex_dxil: IDxcBlob, + fragment_dxil: IDxcBlob, + root_signature: &D3D12RootSignature, + render_format: DXGI_FORMAT, + disable_cache: bool, + ) -> error::Result { + let pipeline_state = Self::make_pipeline_state( + device, + &vertex_dxil, + &fragment_dxil, + root_signature, + render_format, + disable_cache, + )?; + unsafe { let vertex = Vec::from(std::slice::from_raw_parts( vertex_dxil.GetBufferPointer().cast(), @@ -268,9 +311,11 @@ impl D3D12GraphicsPipeline { fragment_dxil.GetBufferPointer().cast(), fragment_dxil.GetBufferSize(), )); + + let mut render_pipelines = FastHashMap::default(); + render_pipelines.insert(HashDxgiFormat(render_format), pipeline_state); Ok(D3D12GraphicsPipeline { - handle: pipeline_state, - format: render_format, + render_pipelines, vertex, fragment, cache_disabled: disable_cache, @@ -298,19 +343,26 @@ impl D3D12GraphicsPipeline { )?; (vertex, fragment) }; - let mut new_pipeline = Self::new_from_blobs( + + let new_pipeline = Self::make_pipeline_state( device, - vertex.cast()?, - fragment.cast()?, + &vertex.cast()?, + &fragment.cast()?, root_sig, format, self.cache_disabled, )?; - std::mem::swap(self, &mut new_pipeline); + self.render_pipelines + .insert(HashDxgiFormat(format), new_pipeline); + Ok(()) } + pub fn has_format(&self, format: DXGI_FORMAT) -> bool { + self.render_pipelines.contains_key(&HashDxgiFormat(format)) + } + pub fn new_from_dxil( device: &ID3D12Device, library: &IDxcUtils,