d3d12: make panic free

This commit is contained in:
chyyran 2023-02-05 19:58:51 -05:00
parent f5fe3e37ef
commit a3589cc794
8 changed files with 82 additions and 53 deletions

View file

@ -25,6 +25,8 @@ bytemuck = { version = "1.12.3", features = ["derive"] }
array-init = "2.1.0" array-init = "2.1.0"
bit-set = "0.5.3" bit-set = "0.5.3"
rayon = "1.6.1"
[target.'cfg(windows)'.dependencies.windows] [target.'cfg(windows)'.dependencies.windows]
version = "0.44.0" version = "0.44.0"
features = [ features = [

View file

@ -11,6 +11,7 @@ use windows::Win32::Graphics::Direct3D12::{
D3D12_DESCRIPTOR_HEAP_TYPE_RTV, D3D12_DESCRIPTOR_HEAP_TYPE_SAMPLER, D3D12_DESCRIPTOR_HEAP_TYPE_RTV, D3D12_DESCRIPTOR_HEAP_TYPE_SAMPLER,
D3D12_GPU_DESCRIPTOR_HANDLE, D3D12_GPU_DESCRIPTOR_HANDLE,
}; };
use crate::error::FilterChainError;
#[const_trait] #[const_trait]
pub trait D3D12HeapType { pub trait D3D12HeapType {
@ -132,6 +133,7 @@ impl<T: D3D12ShaderVisibleHeapType> AsRef<D3D12_GPU_DESCRIPTOR_HANDLE>
for D3D12DescriptorHeapSlotInner<T> for D3D12DescriptorHeapSlotInner<T>
{ {
fn as_ref(&self) -> &D3D12_GPU_DESCRIPTOR_HANDLE { fn as_ref(&self) -> &D3D12_GPU_DESCRIPTOR_HANDLE {
/// SAFETY: D3D12ShaderVisibleHeapType must have a GPU handle.
self.gpu_handle.as_ref().unwrap() self.gpu_handle.as_ref().unwrap()
} }
} }
@ -298,7 +300,7 @@ impl<T> D3D12DescriptorHeap<T> {
} }
} }
todo!("error need to fail"); Err(FilterChainError::DescriptorHeapOverflow)
} }
pub fn alloc_range<const NUM_DESC: usize>( pub fn alloc_range<const NUM_DESC: usize>(

View file

@ -1,16 +1,48 @@
use std::error::Error; use std::error::Error;
use thiserror::Error;
pub type Result<T> = std::result::Result<T, Box<dyn Error>>; /// Cumulative error type for Direct3D12 filter chains.
#[derive(Error, Debug)]
pub enum FilterChainError {
#[error("invariant assumption about d3d11 did not hold. report this as an issue.")]
Direct3DOperationError(&'static str),
#[error("direct3d driver error")]
Direct3DError(#[from] windows::core::Error),
#[error("SPIRV reflection error")]
SpirvCrossReflectError(#[from] spirv_cross::ErrorCode),
#[error("shader preset parse error")]
ShaderPresetError(#[from] ParsePresetError),
#[error("shader preprocess error")]
ShaderPreprocessError(#[from] PreprocessError),
#[error("shader compile error")]
ShaderCompileError(#[from] ShaderCompileError),
#[error("shader reflect error")]
ShaderReflectError(#[from] ShaderReflectError),
#[error("lut loading error")]
LutLoadError(#[from] ImageError),
#[error("heap overflow")]
DescriptorHeapOverflow,
}
pub type Result<T> = std::result::Result<T, FilterChainError>;
// todo: make this return error // todo: make this return error
macro_rules! assume_d3d12_init { macro_rules! assume_d3d12_init {
($value:ident, $call:literal) => { ($value:ident, $call:literal) => {
let $value = $value.expect($call); let $value = $value.ok_or($crate::error::FilterChainError::Direct3DOperationError(
$call,
))?;
}; };
(mut $value:ident, $call:literal) => { (mut $value:ident, $call:literal) => {
let mut $value = $value.expect($call); let mut $value = $value.ok_or($crate::error::FilterChainError::Direct3DOperationError(
$call,
))?;
}; };
} }
/// Macro for unwrapping result of a D3D function. /// Macro for unwrapping result of a D3D function.
pub(crate) use assume_d3d12_init; pub(crate) use assume_d3d12_init;
use librashader_preprocess::PreprocessError;
use librashader_presets::ParsePresetError;
use librashader_reflect::error::{ShaderCompileError, ShaderReflectError};
use librashader_runtime::image::ImageError;

View file

@ -47,6 +47,7 @@ use windows::Win32::Graphics::Direct3D12::{
use windows::Win32::Graphics::Dxgi::Common::DXGI_FORMAT_UNKNOWN; use windows::Win32::Graphics::Dxgi::Common::DXGI_FORMAT_UNKNOWN;
use windows::Win32::System::Threading::{CreateEventA, ResetEvent, WaitForSingleObject}; use windows::Win32::System::Threading::{CreateEventA, ResetEvent, WaitForSingleObject};
use windows::Win32::System::WindowsProgramming::INFINITE; use windows::Win32::System::WindowsProgramming::INFINITE;
use crate::error::FilterChainError;
type DxilShaderPassMeta = ShaderPassArtifact<impl CompileReflectShader<DXIL, GlslangCompilation>>; type DxilShaderPassMeta = ShaderPassArtifact<impl CompileReflectShader<DXIL, GlslangCompilation>>;
type HlslShaderPassMeta = ShaderPassArtifact<impl CompileReflectShader<HLSL, GlslangCompilation>>; type HlslShaderPassMeta = ShaderPassArtifact<impl CompileReflectShader<HLSL, GlslangCompilation>>;
@ -112,20 +113,18 @@ impl FilterChainD3D12 {
let shader_copy = preset.shaders.clone(); let shader_copy = preset.shaders.clone();
let (passes, semantics) = let (passes, semantics) =
DXIL::compile_preset_passes::<GlslangCompilation, Box<dyn Error>>( DXIL::compile_preset_passes::<GlslangCompilation, FilterChainError>(
preset.shaders, preset.shaders,
&preset.textures, &preset.textures,
) )?;
.unwrap();
let (hlsl_passes, _) = HLSL::compile_preset_passes::<GlslangCompilation, Box<dyn Error>>( let (hlsl_passes, _) = HLSL::compile_preset_passes::<GlslangCompilation, FilterChainError>(
shader_copy, shader_copy,
&preset.textures, &preset.textures,
) )?;
.unwrap();
let samplers = SamplerSet::new(device)?; let samplers = SamplerSet::new(device)?;
let mipmap_gen = D3D12MipmapGen::new(device).unwrap(); let mipmap_gen = D3D12MipmapGen::new(device)?;
let draw_quad = DrawQuad::new(device)?; let draw_quad = DrawQuad::new(device)?;
let mut staging_heap = D3D12DescriptorHeap::new( let mut staging_heap = D3D12DescriptorHeap::new(
@ -138,8 +137,7 @@ impl FilterChainD3D12 {
)?; )?;
let luts = let luts =
FilterChainD3D12::load_luts(device, &mut staging_heap, &preset.textures, &mipmap_gen) FilterChainD3D12::load_luts(device, &mut staging_heap, &preset.textures, &mipmap_gen)?;
.unwrap();
let root_signature = D3D12RootSignature::new(device)?; let root_signature = D3D12RootSignature::new(device)?;
@ -150,8 +148,7 @@ impl FilterChainD3D12 {
hlsl_passes, hlsl_passes,
&semantics, &semantics,
options.map_or(false, |o| o.force_hlsl_pipeline), options.map_or(false, |o| o.force_hlsl_pipeline),
) )?;
.unwrap();
// initialize output framebuffers // initialize output framebuffers
let mut output_framebuffers = Vec::new(); let mut output_framebuffers = Vec::new();
@ -305,22 +302,20 @@ impl FilterChainD3D12 {
// Wait until finished // Wait until finished
if unsafe { fence.GetCompletedValue() } < 1 { if unsafe { fence.GetCompletedValue() } < 1 {
unsafe { fence.SetEventOnCompletion(1, fence_event) } unsafe { fence.SetEventOnCompletion(1, fence_event)? };
.ok()
.unwrap();
unsafe { WaitForSingleObject(fence_event, INFINITE) }; unsafe { WaitForSingleObject(fence_event, INFINITE) };
unsafe { ResetEvent(fence_event) }; unsafe { ResetEvent(fence_event) };
} }
cmd.Reset(&command_pool, None).unwrap(); cmd.Reset(&command_pool, None)?;
let residuals = mipmap_gen.mipmapping_context(&cmd, &mut work_heap, |context| { let residuals = mipmap_gen.mipmapping_context(&cmd, &mut work_heap, |context| {
for lut in luts.values() { for lut in luts.values() {
lut.generate_mipmaps(context)?; lut.generate_mipmaps(context)?;
} }
Ok::<(), Box<dyn Error>>(()) Ok::<(), FilterChainError>(())
})?; })?;
// //
@ -329,9 +324,7 @@ impl FilterChainD3D12 {
queue.Signal(&fence, 2)?; queue.Signal(&fence, 2)?;
// //
if unsafe { fence.GetCompletedValue() } < 2 { if unsafe { fence.GetCompletedValue() } < 2 {
unsafe { fence.SetEventOnCompletion(2, fence_event) } unsafe { fence.SetEventOnCompletion(2, fence_event)? }
.ok()
.unwrap();
unsafe { WaitForSingleObject(fence_event, INFINITE) }; unsafe { WaitForSingleObject(fence_event, INFINITE) };
unsafe { CloseHandle(fence_event) }; unsafe { CloseHandle(fence_event) };
@ -600,24 +593,24 @@ impl FilterChainD3D12 {
source.filter = pass.config.filter; source.filter = pass.config.filter;
source.wrap_mode = pass.config.wrap_mode; source.wrap_mode = pass.config.wrap_mode;
if pass.config.mipmap_input && !self.disable_mipmaps { // if pass.config.mipmap_input && !self.disable_mipmaps {
unsafe { // unsafe {
// this is so bad. // // this is so bad.
self.common.mipmap_gen.mipmapping_context( // self.common.mipmap_gen.mipmapping_context(
cmd, // cmd,
&mut self.mipmap_heap, // &mut self.mipmap_heap,
|ctx| { // |ctx| {
ctx.generate_mipmaps( // ctx.generate_mipmaps(
&source.resource, // &source.resource,
source.size().calculate_miplevels() as u16, // source.size().calculate_miplevels() as u16,
source.size, // source.size,
source.format, // source.format,
)?; // )?;
Ok::<(), Box<dyn Error>>(()) // Ok::<(), FilterChainError>(())
}, // },
)?; // )?;
} // }
} // }
let target = &self.output_framebuffers[index]; let target = &self.output_framebuffers[index];
util::d3d12_resource_transition( util::d3d12_resource_transition(

View file

@ -20,6 +20,8 @@ use windows::Win32::Graphics::Direct3D12::{
D3D12_SHADER_VISIBILITY_ALL, D3D12_SHADER_VISIBILITY_PIXEL, D3D_ROOT_SIGNATURE_VERSION_1, D3D12_SHADER_VISIBILITY_ALL, D3D12_SHADER_VISIBILITY_PIXEL, D3D_ROOT_SIGNATURE_VERSION_1,
}; };
use windows::Win32::Graphics::Dxgi::Common::{DXGI_FORMAT, DXGI_FORMAT_UNKNOWN, DXGI_SAMPLE_DESC}; use windows::Win32::Graphics::Dxgi::Common::{DXGI_FORMAT, DXGI_FORMAT_UNKNOWN, DXGI_SAMPLE_DESC};
use crate::error::assume_d3d12_init;
use crate::error::FilterChainError::Direct3DOperationError;
pub struct D3D12GraphicsPipeline { pub struct D3D12GraphicsPipeline {
pub(crate) handle: ID3D12PipelineState, pub(crate) handle: ID3D12PipelineState,
@ -108,8 +110,7 @@ impl D3D12RootSignature {
None, None,
)?; )?;
// SAFETY: if D3D12SerializeRootSignature succeeds then blob is Some assume_d3d12_init!(rs_blob, "D3D12SerializeRootSignature");
let rs_blob = rs_blob.unwrap();
let blob = std::slice::from_raw_parts( let blob = std::slice::from_raw_parts(
rs_blob.GetBufferPointer().cast(), rs_blob.GetBufferPointer().cast(),
rs_blob.GetBufferSize(), rs_blob.GetBufferSize(),
@ -215,10 +216,10 @@ impl D3D12GraphicsPipeline {
render_format: DXGI_FORMAT, render_format: DXGI_FORMAT,
) -> error::Result<D3D12GraphicsPipeline> { ) -> error::Result<D3D12GraphicsPipeline> {
if shader_assembly.vertex.requires_runtime_data() { if shader_assembly.vertex.requires_runtime_data() {
panic!("vertex needs rt data??") return Err(Direct3DOperationError("Compiled DXIL Vertex shader needs unexpected runtime data"))
} }
if shader_assembly.fragment.requires_runtime_data() { if shader_assembly.fragment.requires_runtime_data() {
panic!("fragment needs rt data??") return Err(Direct3DOperationError("Compiled DXIL fragment shader needs unexpected runtime data"))
} }
let vertex_dxil = util::dxc_validate_shader(library, validator, &shader_assembly.vertex)?; let vertex_dxil = util::dxc_validate_shader(library, validator, &shader_assembly.vertex)?;
let fragment_dxil = let fragment_dxil =

View file

@ -27,9 +27,9 @@ mod tests {
fn triangle_d3d12() { fn triangle_d3d12() {
let sample = hello_triangle::d3d12_hello_triangle::Sample::new( let sample = hello_triangle::d3d12_hello_triangle::Sample::new(
// "../test/slang-shaders/crt/crt-lottes.slangp", // "../test/slang-shaders/crt/crt-lottes.slangp",
// "../test/slang-shaders/bezel/Mega_Bezel/Presets/MBZ__0__SMOOTH-ADV.slangp", "../test/slang-shaders/bezel/Mega_Bezel/Presets/MBZ__0__SMOOTH-ADV.slangp",
// "../test/slang-shaders/crt/crt-royale.slangp", // "../test/slang-shaders/crt/crt-royale.slangp",
"../test/slang-shaders/vhs/VHSPro.slangp", // "../test/slang-shaders/vhs/VHSPro.slangp",
&SampleCommandLine { &SampleCommandLine {
use_warp_device: false, use_warp_device: false,
}, },

View file

@ -123,10 +123,10 @@ impl<'a> MipmapGenContext<'a> {
impl D3D12MipmapGen { impl D3D12MipmapGen {
pub fn new(device: &ID3D12Device) -> error::Result<D3D12MipmapGen> { pub fn new(device: &ID3D12Device) -> error::Result<D3D12MipmapGen> {
unsafe { unsafe {
let blob = fxc_compile_shader(GENERATE_MIPS_SRC, b"main\0", b"cs_5_1\0").unwrap(); let blob = fxc_compile_shader(GENERATE_MIPS_SRC, b"main\0", b"cs_5_1\0")?;
let blob = let blob =
std::slice::from_raw_parts(blob.GetBufferPointer().cast(), blob.GetBufferSize()); std::slice::from_raw_parts(blob.GetBufferPointer().cast(), blob.GetBufferSize());
let root_signature: ID3D12RootSignature = device.CreateRootSignature(0, blob).unwrap(); let root_signature: ID3D12RootSignature = device.CreateRootSignature(0, blob)?;
let desc = D3D12_COMPUTE_PIPELINE_STATE_DESC { let desc = D3D12_COMPUTE_PIPELINE_STATE_DESC {
pRootSignature: windows::core::ManuallyDrop::new(&root_signature), pRootSignature: windows::core::ManuallyDrop::new(&root_signature),

View file

@ -197,7 +197,7 @@ pub fn dxc_compile_shader(
if let Ok(buf) = result.GetErrorBuffer() { if let Ok(buf) = result.GetErrorBuffer() {
unsafe { unsafe {
let buf: IDxcBlobUtf8 = buf.cast().unwrap(); let buf: IDxcBlobUtf8 = buf.cast()?;
let buf = let buf =
std::slice::from_raw_parts(buf.GetBufferPointer().cast(), buf.GetBufferSize()); std::slice::from_raw_parts(buf.GetBufferPointer().cast(), buf.GetBufferSize());
let str = std::str::from_utf8_unchecked(buf); let str = std::str::from_utf8_unchecked(buf);
@ -225,12 +225,11 @@ pub fn dxc_validate_shader(
unsafe { unsafe {
let result = validator let result = validator
.Validate(&blob, DxcValidatorFlags_InPlaceEdit) .Validate(&blob, DxcValidatorFlags_InPlaceEdit)?;
.unwrap();
if let Ok(buf) = result.GetErrorBuffer() { if let Ok(buf) = result.GetErrorBuffer() {
unsafe { unsafe {
let buf: IDxcBlobUtf8 = buf.cast().unwrap(); let buf: IDxcBlobUtf8 = buf.cast()?;
let buf = let buf =
std::slice::from_raw_parts(buf.GetBufferPointer().cast(), buf.GetBufferSize()); std::slice::from_raw_parts(buf.GetBufferPointer().cast(), buf.GetBufferSize());
let str = std::str::from_utf8_unchecked(buf); let str = std::str::from_utf8_unchecked(buf);