rt(d3d12): allow a pipeline to be available for multiple formats without recompilation

This commit is contained in:
chyyran 2024-09-29 00:15:02 -04:00 committed by Ronny Chan
parent c57e502b78
commit 4285ad2bd1
3 changed files with 70 additions and 18 deletions

View file

@ -327,7 +327,7 @@ impl FilterChainD3D12 {
let mut staging_heap = unsafe {
D3D12DescriptorHeap::new(
device,
// add one, because technically the input image doesn't need to count
// add one, because technically the input image doesn't need to count
(1 + MAX_BINDINGS_COUNT as usize) * shader_count
+ MIPMAP_RESERVED_WORKHEAP_DESCRIPTORS
+ lut_count,
@ -790,7 +790,7 @@ impl FilterChainD3D12 {
let target = &self.output_framebuffers[index];
if pass.pipeline.format != target.format {
if !pass.pipeline.has_format(target.format) {
// eprintln!("recompiling final pipeline");
pass.pipeline.recompile(
target.format,
@ -863,7 +863,7 @@ impl FilterChainD3D12 {
if self.draw_last_pass_feedback {
let feedback_target = &self.output_framebuffers[index];
if pass.pipeline.format != feedback_target.format {
if !pass.pipeline.has_format(feedback_target.format) {
// eprintln!("recompiling final pipeline");
pass.pipeline.recompile(
feedback_target.format,
@ -902,7 +902,7 @@ impl FilterChainD3D12 {
);
}
if pass.pipeline.format != viewport.output.format {
if !pass.pipeline.has_format(viewport.output.format) {
// eprintln!("recompiling final pipeline");
pass.pipeline.recompile(
viewport.output.format,

View file

@ -145,7 +145,7 @@ impl FilterPass {
vbo_type: QuadType,
) -> error::Result<()> {
unsafe {
cmd.SetPipelineState(&self.pipeline.handle);
cmd.SetPipelineState(self.pipeline.pipeline_state(output.output.format));
}
self.build_semantics(

View file

@ -3,9 +3,11 @@ use crate::error::assume_d3d12_init;
use crate::error::FilterChainError::Direct3DOperationError;
use crate::{error, util};
use librashader_cache::{cache_pipeline, cache_shader_object};
use librashader_common::map::FastHashMap;
use librashader_reflect::back::dxil::DxilObject;
use librashader_reflect::back::hlsl::CrossHlslContext;
use librashader_reflect::back::ShaderCompilerOutput;
use std::hash::{Hash, Hasher};
use std::mem::ManuallyDrop;
use std::ops::Deref;
use widestring::u16cstr;
@ -33,9 +35,18 @@ use windows::Win32::Graphics::Direct3D12::{
};
use windows::Win32::Graphics::Dxgi::Common::{DXGI_FORMAT, DXGI_FORMAT_UNKNOWN, DXGI_SAMPLE_DESC};
// bruh why does DXGI_FORMAT not impl hash
#[repr(transparent)]
#[derive(PartialEq, Eq)]
struct HashDxgiFormat(DXGI_FORMAT);
impl Hash for HashDxgiFormat {
fn hash<H: Hasher>(&self, state: &mut H) {
self.0 .0.hash(state);
}
}
pub struct D3D12GraphicsPipeline {
pub(crate) handle: ID3D12PipelineState,
pub(crate) format: DXGI_FORMAT,
render_pipelines: FastHashMap<HashDxgiFormat, ID3D12PipelineState>,
vertex: Vec<u8>,
fragment: Vec<u8>,
cache_disabled: bool,
@ -149,14 +160,14 @@ impl D3D12RootSignature {
}
}
impl D3D12GraphicsPipeline {
pub fn new_from_blobs(
fn make_pipeline_state(
device: &ID3D12Device,
vertex_dxil: IDxcBlob,
fragment_dxil: IDxcBlob,
vertex_dxil: &IDxcBlob,
fragment_dxil: &IDxcBlob,
root_signature: &D3D12RootSignature,
render_format: DXGI_FORMAT,
disable_cache: bool,
) -> error::Result<D3D12GraphicsPipeline> {
) -> error::Result<ID3D12PipelineState> {
let input_element = DrawQuad::get_spirv_cross_vbo_desc();
let pipeline_state: ID3D12PipelineState = unsafe {
@ -228,7 +239,7 @@ impl D3D12GraphicsPipeline {
let pipeline = cache_pipeline(
"d3d12",
&[&vertex_dxil, &fragment_dxil, &render_format.0],
&[vertex_dxil, fragment_dxil, &render_format.0],
|cached: Option<Vec<u8>>| {
if let Some(cached) = cached {
let pipeline_desc = D3D12_GRAPHICS_PIPELINE_STATE_DESC {
@ -259,6 +270,38 @@ impl D3D12GraphicsPipeline {
pipeline
};
Ok(pipeline_state)
}
pub fn pipeline_state(&self, format: DXGI_FORMAT) -> &ID3D12PipelineState {
let Some(pipeline) = self
.render_pipelines
.get(&HashDxgiFormat(format))
.or_else(|| self.render_pipelines.values().next())
else {
panic!("No available render pipeline found");
};
pipeline
}
pub fn new_from_blobs(
device: &ID3D12Device,
vertex_dxil: IDxcBlob,
fragment_dxil: IDxcBlob,
root_signature: &D3D12RootSignature,
render_format: DXGI_FORMAT,
disable_cache: bool,
) -> error::Result<D3D12GraphicsPipeline> {
let pipeline_state = Self::make_pipeline_state(
device,
&vertex_dxil,
&fragment_dxil,
root_signature,
render_format,
disable_cache,
)?;
unsafe {
let vertex = Vec::from(std::slice::from_raw_parts(
vertex_dxil.GetBufferPointer().cast(),
@ -268,9 +311,11 @@ impl D3D12GraphicsPipeline {
fragment_dxil.GetBufferPointer().cast(),
fragment_dxil.GetBufferSize(),
));
let mut render_pipelines = FastHashMap::default();
render_pipelines.insert(HashDxgiFormat(render_format), pipeline_state);
Ok(D3D12GraphicsPipeline {
handle: pipeline_state,
format: render_format,
render_pipelines,
vertex,
fragment,
cache_disabled: disable_cache,
@ -298,19 +343,26 @@ impl D3D12GraphicsPipeline {
)?;
(vertex, fragment)
};
let mut new_pipeline = Self::new_from_blobs(
let new_pipeline = Self::make_pipeline_state(
device,
vertex.cast()?,
fragment.cast()?,
&vertex.cast()?,
&fragment.cast()?,
root_sig,
format,
self.cache_disabled,
)?;
std::mem::swap(self, &mut new_pipeline);
self.render_pipelines
.insert(HashDxgiFormat(format), new_pipeline);
Ok(())
}
pub fn has_format(&self, format: DXGI_FORMAT) -> bool {
self.render_pipelines.contains_key(&HashDxgiFormat(format))
}
pub fn new_from_dxil(
device: &ID3D12Device,
library: &IDxcUtils,