rt(d3d12): allow a pipeline to be available for multiple formats without recompilation

This commit is contained in:
chyyran 2024-09-29 00:15:02 -04:00 committed by Ronny Chan
parent c57e502b78
commit 4285ad2bd1
3 changed files with 70 additions and 18 deletions

View file

@ -790,7 +790,7 @@ impl FilterChainD3D12 {
let target = &self.output_framebuffers[index]; let target = &self.output_framebuffers[index];
if pass.pipeline.format != target.format { if !pass.pipeline.has_format(target.format) {
// eprintln!("recompiling final pipeline"); // eprintln!("recompiling final pipeline");
pass.pipeline.recompile( pass.pipeline.recompile(
target.format, target.format,
@ -863,7 +863,7 @@ impl FilterChainD3D12 {
if self.draw_last_pass_feedback { if self.draw_last_pass_feedback {
let feedback_target = &self.output_framebuffers[index]; let feedback_target = &self.output_framebuffers[index];
if pass.pipeline.format != feedback_target.format { if !pass.pipeline.has_format(feedback_target.format) {
// eprintln!("recompiling final pipeline"); // eprintln!("recompiling final pipeline");
pass.pipeline.recompile( pass.pipeline.recompile(
feedback_target.format, feedback_target.format,
@ -902,7 +902,7 @@ impl FilterChainD3D12 {
); );
} }
if pass.pipeline.format != viewport.output.format { if !pass.pipeline.has_format(viewport.output.format) {
// eprintln!("recompiling final pipeline"); // eprintln!("recompiling final pipeline");
pass.pipeline.recompile( pass.pipeline.recompile(
viewport.output.format, viewport.output.format,

View file

@ -145,7 +145,7 @@ impl FilterPass {
vbo_type: QuadType, vbo_type: QuadType,
) -> error::Result<()> { ) -> error::Result<()> {
unsafe { unsafe {
cmd.SetPipelineState(&self.pipeline.handle); cmd.SetPipelineState(self.pipeline.pipeline_state(output.output.format));
} }
self.build_semantics( self.build_semantics(

View file

@ -3,9 +3,11 @@ use crate::error::assume_d3d12_init;
use crate::error::FilterChainError::Direct3DOperationError; use crate::error::FilterChainError::Direct3DOperationError;
use crate::{error, util}; use crate::{error, util};
use librashader_cache::{cache_pipeline, cache_shader_object}; use librashader_cache::{cache_pipeline, cache_shader_object};
use librashader_common::map::FastHashMap;
use librashader_reflect::back::dxil::DxilObject; use librashader_reflect::back::dxil::DxilObject;
use librashader_reflect::back::hlsl::CrossHlslContext; use librashader_reflect::back::hlsl::CrossHlslContext;
use librashader_reflect::back::ShaderCompilerOutput; use librashader_reflect::back::ShaderCompilerOutput;
use std::hash::{Hash, Hasher};
use std::mem::ManuallyDrop; use std::mem::ManuallyDrop;
use std::ops::Deref; use std::ops::Deref;
use widestring::u16cstr; use widestring::u16cstr;
@ -33,9 +35,18 @@ use windows::Win32::Graphics::Direct3D12::{
}; };
use windows::Win32::Graphics::Dxgi::Common::{DXGI_FORMAT, DXGI_FORMAT_UNKNOWN, DXGI_SAMPLE_DESC}; use windows::Win32::Graphics::Dxgi::Common::{DXGI_FORMAT, DXGI_FORMAT_UNKNOWN, DXGI_SAMPLE_DESC};
// bruh why does DXGI_FORMAT not impl hash
#[repr(transparent)]
#[derive(PartialEq, Eq)]
struct HashDxgiFormat(DXGI_FORMAT);
impl Hash for HashDxgiFormat {
fn hash<H: Hasher>(&self, state: &mut H) {
self.0 .0.hash(state);
}
}
pub struct D3D12GraphicsPipeline { pub struct D3D12GraphicsPipeline {
pub(crate) handle: ID3D12PipelineState, render_pipelines: FastHashMap<HashDxgiFormat, ID3D12PipelineState>,
pub(crate) format: DXGI_FORMAT,
vertex: Vec<u8>, vertex: Vec<u8>,
fragment: Vec<u8>, fragment: Vec<u8>,
cache_disabled: bool, cache_disabled: bool,
@ -149,14 +160,14 @@ impl D3D12RootSignature {
} }
} }
impl D3D12GraphicsPipeline { impl D3D12GraphicsPipeline {
pub fn new_from_blobs( fn make_pipeline_state(
device: &ID3D12Device, device: &ID3D12Device,
vertex_dxil: IDxcBlob, vertex_dxil: &IDxcBlob,
fragment_dxil: IDxcBlob, fragment_dxil: &IDxcBlob,
root_signature: &D3D12RootSignature, root_signature: &D3D12RootSignature,
render_format: DXGI_FORMAT, render_format: DXGI_FORMAT,
disable_cache: bool, disable_cache: bool,
) -> error::Result<D3D12GraphicsPipeline> { ) -> error::Result<ID3D12PipelineState> {
let input_element = DrawQuad::get_spirv_cross_vbo_desc(); let input_element = DrawQuad::get_spirv_cross_vbo_desc();
let pipeline_state: ID3D12PipelineState = unsafe { let pipeline_state: ID3D12PipelineState = unsafe {
@ -228,7 +239,7 @@ impl D3D12GraphicsPipeline {
let pipeline = cache_pipeline( let pipeline = cache_pipeline(
"d3d12", "d3d12",
&[&vertex_dxil, &fragment_dxil, &render_format.0], &[vertex_dxil, fragment_dxil, &render_format.0],
|cached: Option<Vec<u8>>| { |cached: Option<Vec<u8>>| {
if let Some(cached) = cached { if let Some(cached) = cached {
let pipeline_desc = D3D12_GRAPHICS_PIPELINE_STATE_DESC { let pipeline_desc = D3D12_GRAPHICS_PIPELINE_STATE_DESC {
@ -259,6 +270,38 @@ impl D3D12GraphicsPipeline {
pipeline pipeline
}; };
Ok(pipeline_state)
}
pub fn pipeline_state(&self, format: DXGI_FORMAT) -> &ID3D12PipelineState {
let Some(pipeline) = self
.render_pipelines
.get(&HashDxgiFormat(format))
.or_else(|| self.render_pipelines.values().next())
else {
panic!("No available render pipeline found");
};
pipeline
}
pub fn new_from_blobs(
device: &ID3D12Device,
vertex_dxil: IDxcBlob,
fragment_dxil: IDxcBlob,
root_signature: &D3D12RootSignature,
render_format: DXGI_FORMAT,
disable_cache: bool,
) -> error::Result<D3D12GraphicsPipeline> {
let pipeline_state = Self::make_pipeline_state(
device,
&vertex_dxil,
&fragment_dxil,
root_signature,
render_format,
disable_cache,
)?;
unsafe { unsafe {
let vertex = Vec::from(std::slice::from_raw_parts( let vertex = Vec::from(std::slice::from_raw_parts(
vertex_dxil.GetBufferPointer().cast(), vertex_dxil.GetBufferPointer().cast(),
@ -268,9 +311,11 @@ impl D3D12GraphicsPipeline {
fragment_dxil.GetBufferPointer().cast(), fragment_dxil.GetBufferPointer().cast(),
fragment_dxil.GetBufferSize(), fragment_dxil.GetBufferSize(),
)); ));
let mut render_pipelines = FastHashMap::default();
render_pipelines.insert(HashDxgiFormat(render_format), pipeline_state);
Ok(D3D12GraphicsPipeline { Ok(D3D12GraphicsPipeline {
handle: pipeline_state, render_pipelines,
format: render_format,
vertex, vertex,
fragment, fragment,
cache_disabled: disable_cache, cache_disabled: disable_cache,
@ -298,19 +343,26 @@ impl D3D12GraphicsPipeline {
)?; )?;
(vertex, fragment) (vertex, fragment)
}; };
let mut new_pipeline = Self::new_from_blobs(
let new_pipeline = Self::make_pipeline_state(
device, device,
vertex.cast()?, &vertex.cast()?,
fragment.cast()?, &fragment.cast()?,
root_sig, root_sig,
format, format,
self.cache_disabled, self.cache_disabled,
)?; )?;
std::mem::swap(self, &mut new_pipeline); self.render_pipelines
.insert(HashDxgiFormat(format), new_pipeline);
Ok(()) Ok(())
} }
pub fn has_format(&self, format: DXGI_FORMAT) -> bool {
self.render_pipelines.contains_key(&HashDxgiFormat(format))
}
pub fn new_from_dxil( pub fn new_from_dxil(
device: &ID3D12Device, device: &ID3D12Device,
library: &IDxcUtils, library: &IDxcUtils,