d3d12: make panic free

This commit is contained in:
chyyran 2023-02-05 19:58:51 -05:00
parent f5fe3e37ef
commit a3589cc794
8 changed files with 82 additions and 53 deletions

View file

@ -25,6 +25,8 @@ bytemuck = { version = "1.12.3", features = ["derive"] }
array-init = "2.1.0"
bit-set = "0.5.3"
rayon = "1.6.1"
[target.'cfg(windows)'.dependencies.windows]
version = "0.44.0"
features = [

View file

@ -11,6 +11,7 @@ use windows::Win32::Graphics::Direct3D12::{
D3D12_DESCRIPTOR_HEAP_TYPE_RTV, D3D12_DESCRIPTOR_HEAP_TYPE_SAMPLER,
D3D12_GPU_DESCRIPTOR_HANDLE,
};
use crate::error::FilterChainError;
#[const_trait]
pub trait D3D12HeapType {
@ -132,6 +133,7 @@ impl<T: D3D12ShaderVisibleHeapType> AsRef<D3D12_GPU_DESCRIPTOR_HANDLE>
for D3D12DescriptorHeapSlotInner<T>
{
fn as_ref(&self) -> &D3D12_GPU_DESCRIPTOR_HANDLE {
/// SAFETY: D3D12ShaderVisibleHeapType must have a GPU handle.
self.gpu_handle.as_ref().unwrap()
}
}
@ -298,7 +300,7 @@ impl<T> D3D12DescriptorHeap<T> {
}
}
todo!("error need to fail");
Err(FilterChainError::DescriptorHeapOverflow)
}
pub fn alloc_range<const NUM_DESC: usize>(

View file

@ -1,16 +1,48 @@
use std::error::Error;
use thiserror::Error;
pub type Result<T> = std::result::Result<T, Box<dyn Error>>;
/// Cumulative error type for Direct3D12 filter chains.
#[derive(Error, Debug)]
pub enum FilterChainError {
#[error("invariant assumption about d3d11 did not hold. report this as an issue.")]
Direct3DOperationError(&'static str),
#[error("direct3d driver error")]
Direct3DError(#[from] windows::core::Error),
#[error("SPIRV reflection error")]
SpirvCrossReflectError(#[from] spirv_cross::ErrorCode),
#[error("shader preset parse error")]
ShaderPresetError(#[from] ParsePresetError),
#[error("shader preprocess error")]
ShaderPreprocessError(#[from] PreprocessError),
#[error("shader compile error")]
ShaderCompileError(#[from] ShaderCompileError),
#[error("shader reflect error")]
ShaderReflectError(#[from] ShaderReflectError),
#[error("lut loading error")]
LutLoadError(#[from] ImageError),
#[error("heap overflow")]
DescriptorHeapOverflow,
}
pub type Result<T> = std::result::Result<T, FilterChainError>;
// todo: make this return error
macro_rules! assume_d3d12_init {
($value:ident, $call:literal) => {
let $value = $value.expect($call);
let $value = $value.ok_or($crate::error::FilterChainError::Direct3DOperationError(
$call,
))?;
};
(mut $value:ident, $call:literal) => {
let mut $value = $value.expect($call);
let mut $value = $value.ok_or($crate::error::FilterChainError::Direct3DOperationError(
$call,
))?;
};
}
/// Macro for unwrapping result of a D3D function.
pub(crate) use assume_d3d12_init;
use librashader_preprocess::PreprocessError;
use librashader_presets::ParsePresetError;
use librashader_reflect::error::{ShaderCompileError, ShaderReflectError};
use librashader_runtime::image::ImageError;

View file

@ -47,6 +47,7 @@ use windows::Win32::Graphics::Direct3D12::{
use windows::Win32::Graphics::Dxgi::Common::DXGI_FORMAT_UNKNOWN;
use windows::Win32::System::Threading::{CreateEventA, ResetEvent, WaitForSingleObject};
use windows::Win32::System::WindowsProgramming::INFINITE;
use crate::error::FilterChainError;
type DxilShaderPassMeta = ShaderPassArtifact<impl CompileReflectShader<DXIL, GlslangCompilation>>;
type HlslShaderPassMeta = ShaderPassArtifact<impl CompileReflectShader<HLSL, GlslangCompilation>>;
@ -112,20 +113,18 @@ impl FilterChainD3D12 {
let shader_copy = preset.shaders.clone();
let (passes, semantics) =
DXIL::compile_preset_passes::<GlslangCompilation, Box<dyn Error>>(
DXIL::compile_preset_passes::<GlslangCompilation, FilterChainError>(
preset.shaders,
&preset.textures,
)
.unwrap();
)?;
let (hlsl_passes, _) = HLSL::compile_preset_passes::<GlslangCompilation, Box<dyn Error>>(
let (hlsl_passes, _) = HLSL::compile_preset_passes::<GlslangCompilation, FilterChainError>(
shader_copy,
&preset.textures,
)
.unwrap();
)?;
let samplers = SamplerSet::new(device)?;
let mipmap_gen = D3D12MipmapGen::new(device).unwrap();
let mipmap_gen = D3D12MipmapGen::new(device)?;
let draw_quad = DrawQuad::new(device)?;
let mut staging_heap = D3D12DescriptorHeap::new(
@ -138,8 +137,7 @@ impl FilterChainD3D12 {
)?;
let luts =
FilterChainD3D12::load_luts(device, &mut staging_heap, &preset.textures, &mipmap_gen)
.unwrap();
FilterChainD3D12::load_luts(device, &mut staging_heap, &preset.textures, &mipmap_gen)?;
let root_signature = D3D12RootSignature::new(device)?;
@ -150,8 +148,7 @@ impl FilterChainD3D12 {
hlsl_passes,
&semantics,
options.map_or(false, |o| o.force_hlsl_pipeline),
)
.unwrap();
)?;
// initialize output framebuffers
let mut output_framebuffers = Vec::new();
@ -305,22 +302,20 @@ impl FilterChainD3D12 {
// Wait until finished
if unsafe { fence.GetCompletedValue() } < 1 {
unsafe { fence.SetEventOnCompletion(1, fence_event) }
.ok()
.unwrap();
unsafe { fence.SetEventOnCompletion(1, fence_event)? };
unsafe { WaitForSingleObject(fence_event, INFINITE) };
unsafe { ResetEvent(fence_event) };
}
cmd.Reset(&command_pool, None).unwrap();
cmd.Reset(&command_pool, None)?;
let residuals = mipmap_gen.mipmapping_context(&cmd, &mut work_heap, |context| {
for lut in luts.values() {
lut.generate_mipmaps(context)?;
}
Ok::<(), Box<dyn Error>>(())
Ok::<(), FilterChainError>(())
})?;
//
@ -329,9 +324,7 @@ impl FilterChainD3D12 {
queue.Signal(&fence, 2)?;
//
if unsafe { fence.GetCompletedValue() } < 2 {
unsafe { fence.SetEventOnCompletion(2, fence_event) }
.ok()
.unwrap();
unsafe { fence.SetEventOnCompletion(2, fence_event)? }
unsafe { WaitForSingleObject(fence_event, INFINITE) };
unsafe { CloseHandle(fence_event) };
@ -600,24 +593,24 @@ impl FilterChainD3D12 {
source.filter = pass.config.filter;
source.wrap_mode = pass.config.wrap_mode;
if pass.config.mipmap_input && !self.disable_mipmaps {
unsafe {
// this is so bad.
self.common.mipmap_gen.mipmapping_context(
cmd,
&mut self.mipmap_heap,
|ctx| {
ctx.generate_mipmaps(
&source.resource,
source.size().calculate_miplevels() as u16,
source.size,
source.format,
)?;
Ok::<(), Box<dyn Error>>(())
},
)?;
}
}
// if pass.config.mipmap_input && !self.disable_mipmaps {
// unsafe {
// // this is so bad.
// self.common.mipmap_gen.mipmapping_context(
// cmd,
// &mut self.mipmap_heap,
// |ctx| {
// ctx.generate_mipmaps(
// &source.resource,
// source.size().calculate_miplevels() as u16,
// source.size,
// source.format,
// )?;
// Ok::<(), FilterChainError>(())
// },
// )?;
// }
// }
let target = &self.output_framebuffers[index];
util::d3d12_resource_transition(

View file

@ -20,6 +20,8 @@ use windows::Win32::Graphics::Direct3D12::{
D3D12_SHADER_VISIBILITY_ALL, D3D12_SHADER_VISIBILITY_PIXEL, D3D_ROOT_SIGNATURE_VERSION_1,
};
use windows::Win32::Graphics::Dxgi::Common::{DXGI_FORMAT, DXGI_FORMAT_UNKNOWN, DXGI_SAMPLE_DESC};
use crate::error::assume_d3d12_init;
use crate::error::FilterChainError::Direct3DOperationError;
pub struct D3D12GraphicsPipeline {
pub(crate) handle: ID3D12PipelineState,
@ -108,8 +110,7 @@ impl D3D12RootSignature {
None,
)?;
// SAFETY: if D3D12SerializeRootSignature succeeds then blob is Some
let rs_blob = rs_blob.unwrap();
assume_d3d12_init!(rs_blob, "D3D12SerializeRootSignature");
let blob = std::slice::from_raw_parts(
rs_blob.GetBufferPointer().cast(),
rs_blob.GetBufferSize(),
@ -215,10 +216,10 @@ impl D3D12GraphicsPipeline {
render_format: DXGI_FORMAT,
) -> error::Result<D3D12GraphicsPipeline> {
if shader_assembly.vertex.requires_runtime_data() {
panic!("vertex needs rt data??")
return Err(Direct3DOperationError("Compiled DXIL Vertex shader needs unexpected runtime data"))
}
if shader_assembly.fragment.requires_runtime_data() {
panic!("fragment needs rt data??")
return Err(Direct3DOperationError("Compiled DXIL fragment shader needs unexpected runtime data"))
}
let vertex_dxil = util::dxc_validate_shader(library, validator, &shader_assembly.vertex)?;
let fragment_dxil =

View file

@ -27,9 +27,9 @@ mod tests {
fn triangle_d3d12() {
let sample = hello_triangle::d3d12_hello_triangle::Sample::new(
// "../test/slang-shaders/crt/crt-lottes.slangp",
// "../test/slang-shaders/bezel/Mega_Bezel/Presets/MBZ__0__SMOOTH-ADV.slangp",
"../test/slang-shaders/bezel/Mega_Bezel/Presets/MBZ__0__SMOOTH-ADV.slangp",
// "../test/slang-shaders/crt/crt-royale.slangp",
"../test/slang-shaders/vhs/VHSPro.slangp",
// "../test/slang-shaders/vhs/VHSPro.slangp",
&SampleCommandLine {
use_warp_device: false,
},

View file

@ -123,10 +123,10 @@ impl<'a> MipmapGenContext<'a> {
impl D3D12MipmapGen {
pub fn new(device: &ID3D12Device) -> error::Result<D3D12MipmapGen> {
unsafe {
let blob = fxc_compile_shader(GENERATE_MIPS_SRC, b"main\0", b"cs_5_1\0").unwrap();
let blob = fxc_compile_shader(GENERATE_MIPS_SRC, b"main\0", b"cs_5_1\0")?;
let blob =
std::slice::from_raw_parts(blob.GetBufferPointer().cast(), blob.GetBufferSize());
let root_signature: ID3D12RootSignature = device.CreateRootSignature(0, blob).unwrap();
let root_signature: ID3D12RootSignature = device.CreateRootSignature(0, blob)?;
let desc = D3D12_COMPUTE_PIPELINE_STATE_DESC {
pRootSignature: windows::core::ManuallyDrop::new(&root_signature),

View file

@ -197,7 +197,7 @@ pub fn dxc_compile_shader(
if let Ok(buf) = result.GetErrorBuffer() {
unsafe {
let buf: IDxcBlobUtf8 = buf.cast().unwrap();
let buf: IDxcBlobUtf8 = buf.cast()?;
let buf =
std::slice::from_raw_parts(buf.GetBufferPointer().cast(), buf.GetBufferSize());
let str = std::str::from_utf8_unchecked(buf);
@ -225,12 +225,11 @@ pub fn dxc_validate_shader(
unsafe {
let result = validator
.Validate(&blob, DxcValidatorFlags_InPlaceEdit)
.unwrap();
.Validate(&blob, DxcValidatorFlags_InPlaceEdit)?;
if let Ok(buf) = result.GetErrorBuffer() {
unsafe {
let buf: IDxcBlobUtf8 = buf.cast().unwrap();
let buf: IDxcBlobUtf8 = buf.cast()?;
let buf =
std::slice::from_raw_parts(buf.GetBufferPointer().cast(), buf.GetBufferSize());
let str = std::str::from_utf8_unchecked(buf);