rt(d3d12): allow a pipeline to be available for multiple formats without recompilation
This commit is contained in:
parent
c57e502b78
commit
4285ad2bd1
3 changed files with 70 additions and 18 deletions
|
@ -327,7 +327,7 @@ impl FilterChainD3D12 {
|
|||
let mut staging_heap = unsafe {
|
||||
D3D12DescriptorHeap::new(
|
||||
device,
|
||||
// add one, because technically the input image doesn't need to count
|
||||
// add one, because technically the input image doesn't need to count
|
||||
(1 + MAX_BINDINGS_COUNT as usize) * shader_count
|
||||
+ MIPMAP_RESERVED_WORKHEAP_DESCRIPTORS
|
||||
+ lut_count,
|
||||
|
@ -790,7 +790,7 @@ impl FilterChainD3D12 {
|
|||
|
||||
let target = &self.output_framebuffers[index];
|
||||
|
||||
if pass.pipeline.format != target.format {
|
||||
if !pass.pipeline.has_format(target.format) {
|
||||
// eprintln!("recompiling final pipeline");
|
||||
pass.pipeline.recompile(
|
||||
target.format,
|
||||
|
@ -863,7 +863,7 @@ impl FilterChainD3D12 {
|
|||
if self.draw_last_pass_feedback {
|
||||
let feedback_target = &self.output_framebuffers[index];
|
||||
|
||||
if pass.pipeline.format != feedback_target.format {
|
||||
if !pass.pipeline.has_format(feedback_target.format) {
|
||||
// eprintln!("recompiling final pipeline");
|
||||
pass.pipeline.recompile(
|
||||
feedback_target.format,
|
||||
|
@ -902,7 +902,7 @@ impl FilterChainD3D12 {
|
|||
);
|
||||
}
|
||||
|
||||
if pass.pipeline.format != viewport.output.format {
|
||||
if !pass.pipeline.has_format(viewport.output.format) {
|
||||
// eprintln!("recompiling final pipeline");
|
||||
pass.pipeline.recompile(
|
||||
viewport.output.format,
|
||||
|
|
|
@ -145,7 +145,7 @@ impl FilterPass {
|
|||
vbo_type: QuadType,
|
||||
) -> error::Result<()> {
|
||||
unsafe {
|
||||
cmd.SetPipelineState(&self.pipeline.handle);
|
||||
cmd.SetPipelineState(self.pipeline.pipeline_state(output.output.format));
|
||||
}
|
||||
|
||||
self.build_semantics(
|
||||
|
|
|
@ -3,9 +3,11 @@ use crate::error::assume_d3d12_init;
|
|||
use crate::error::FilterChainError::Direct3DOperationError;
|
||||
use crate::{error, util};
|
||||
use librashader_cache::{cache_pipeline, cache_shader_object};
|
||||
use librashader_common::map::FastHashMap;
|
||||
use librashader_reflect::back::dxil::DxilObject;
|
||||
use librashader_reflect::back::hlsl::CrossHlslContext;
|
||||
use librashader_reflect::back::ShaderCompilerOutput;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::mem::ManuallyDrop;
|
||||
use std::ops::Deref;
|
||||
use widestring::u16cstr;
|
||||
|
@ -33,9 +35,18 @@ use windows::Win32::Graphics::Direct3D12::{
|
|||
};
|
||||
use windows::Win32::Graphics::Dxgi::Common::{DXGI_FORMAT, DXGI_FORMAT_UNKNOWN, DXGI_SAMPLE_DESC};
|
||||
|
||||
// bruh why does DXGI_FORMAT not impl hash
|
||||
#[repr(transparent)]
|
||||
#[derive(PartialEq, Eq)]
|
||||
struct HashDxgiFormat(DXGI_FORMAT);
|
||||
impl Hash for HashDxgiFormat {
|
||||
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||
self.0 .0.hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
pub struct D3D12GraphicsPipeline {
|
||||
pub(crate) handle: ID3D12PipelineState,
|
||||
pub(crate) format: DXGI_FORMAT,
|
||||
render_pipelines: FastHashMap<HashDxgiFormat, ID3D12PipelineState>,
|
||||
vertex: Vec<u8>,
|
||||
fragment: Vec<u8>,
|
||||
cache_disabled: bool,
|
||||
|
@ -149,14 +160,14 @@ impl D3D12RootSignature {
|
|||
}
|
||||
}
|
||||
impl D3D12GraphicsPipeline {
|
||||
pub fn new_from_blobs(
|
||||
fn make_pipeline_state(
|
||||
device: &ID3D12Device,
|
||||
vertex_dxil: IDxcBlob,
|
||||
fragment_dxil: IDxcBlob,
|
||||
vertex_dxil: &IDxcBlob,
|
||||
fragment_dxil: &IDxcBlob,
|
||||
root_signature: &D3D12RootSignature,
|
||||
render_format: DXGI_FORMAT,
|
||||
disable_cache: bool,
|
||||
) -> error::Result<D3D12GraphicsPipeline> {
|
||||
) -> error::Result<ID3D12PipelineState> {
|
||||
let input_element = DrawQuad::get_spirv_cross_vbo_desc();
|
||||
|
||||
let pipeline_state: ID3D12PipelineState = unsafe {
|
||||
|
@ -228,7 +239,7 @@ impl D3D12GraphicsPipeline {
|
|||
|
||||
let pipeline = cache_pipeline(
|
||||
"d3d12",
|
||||
&[&vertex_dxil, &fragment_dxil, &render_format.0],
|
||||
&[vertex_dxil, fragment_dxil, &render_format.0],
|
||||
|cached: Option<Vec<u8>>| {
|
||||
if let Some(cached) = cached {
|
||||
let pipeline_desc = D3D12_GRAPHICS_PIPELINE_STATE_DESC {
|
||||
|
@ -259,6 +270,38 @@ impl D3D12GraphicsPipeline {
|
|||
pipeline
|
||||
};
|
||||
|
||||
Ok(pipeline_state)
|
||||
}
|
||||
|
||||
pub fn pipeline_state(&self, format: DXGI_FORMAT) -> &ID3D12PipelineState {
|
||||
let Some(pipeline) = self
|
||||
.render_pipelines
|
||||
.get(&HashDxgiFormat(format))
|
||||
.or_else(|| self.render_pipelines.values().next())
|
||||
else {
|
||||
panic!("No available render pipeline found");
|
||||
};
|
||||
|
||||
pipeline
|
||||
}
|
||||
|
||||
pub fn new_from_blobs(
|
||||
device: &ID3D12Device,
|
||||
vertex_dxil: IDxcBlob,
|
||||
fragment_dxil: IDxcBlob,
|
||||
root_signature: &D3D12RootSignature,
|
||||
render_format: DXGI_FORMAT,
|
||||
disable_cache: bool,
|
||||
) -> error::Result<D3D12GraphicsPipeline> {
|
||||
let pipeline_state = Self::make_pipeline_state(
|
||||
device,
|
||||
&vertex_dxil,
|
||||
&fragment_dxil,
|
||||
root_signature,
|
||||
render_format,
|
||||
disable_cache,
|
||||
)?;
|
||||
|
||||
unsafe {
|
||||
let vertex = Vec::from(std::slice::from_raw_parts(
|
||||
vertex_dxil.GetBufferPointer().cast(),
|
||||
|
@ -268,9 +311,11 @@ impl D3D12GraphicsPipeline {
|
|||
fragment_dxil.GetBufferPointer().cast(),
|
||||
fragment_dxil.GetBufferSize(),
|
||||
));
|
||||
|
||||
let mut render_pipelines = FastHashMap::default();
|
||||
render_pipelines.insert(HashDxgiFormat(render_format), pipeline_state);
|
||||
Ok(D3D12GraphicsPipeline {
|
||||
handle: pipeline_state,
|
||||
format: render_format,
|
||||
render_pipelines,
|
||||
vertex,
|
||||
fragment,
|
||||
cache_disabled: disable_cache,
|
||||
|
@ -298,19 +343,26 @@ impl D3D12GraphicsPipeline {
|
|||
)?;
|
||||
(vertex, fragment)
|
||||
};
|
||||
let mut new_pipeline = Self::new_from_blobs(
|
||||
|
||||
let new_pipeline = Self::make_pipeline_state(
|
||||
device,
|
||||
vertex.cast()?,
|
||||
fragment.cast()?,
|
||||
&vertex.cast()?,
|
||||
&fragment.cast()?,
|
||||
root_sig,
|
||||
format,
|
||||
self.cache_disabled,
|
||||
)?;
|
||||
|
||||
std::mem::swap(self, &mut new_pipeline);
|
||||
self.render_pipelines
|
||||
.insert(HashDxgiFormat(format), new_pipeline);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn has_format(&self, format: DXGI_FORMAT) -> bool {
|
||||
self.render_pipelines.contains_key(&HashDxgiFormat(format))
|
||||
}
|
||||
|
||||
pub fn new_from_dxil(
|
||||
device: &ID3D12Device,
|
||||
library: &IDxcUtils,
|
||||
|
|
Loading…
Add table
Reference in a new issue