d3d12: make panic free
This commit is contained in:
parent
f5fe3e37ef
commit
a3589cc794
8 changed files with 82 additions and 53 deletions
|
@ -25,6 +25,8 @@ bytemuck = { version = "1.12.3", features = ["derive"] }
|
||||||
array-init = "2.1.0"
|
array-init = "2.1.0"
|
||||||
bit-set = "0.5.3"
|
bit-set = "0.5.3"
|
||||||
|
|
||||||
|
rayon = "1.6.1"
|
||||||
|
|
||||||
[target.'cfg(windows)'.dependencies.windows]
|
[target.'cfg(windows)'.dependencies.windows]
|
||||||
version = "0.44.0"
|
version = "0.44.0"
|
||||||
features = [
|
features = [
|
||||||
|
|
|
@ -11,6 +11,7 @@ use windows::Win32::Graphics::Direct3D12::{
|
||||||
D3D12_DESCRIPTOR_HEAP_TYPE_RTV, D3D12_DESCRIPTOR_HEAP_TYPE_SAMPLER,
|
D3D12_DESCRIPTOR_HEAP_TYPE_RTV, D3D12_DESCRIPTOR_HEAP_TYPE_SAMPLER,
|
||||||
D3D12_GPU_DESCRIPTOR_HANDLE,
|
D3D12_GPU_DESCRIPTOR_HANDLE,
|
||||||
};
|
};
|
||||||
|
use crate::error::FilterChainError;
|
||||||
|
|
||||||
#[const_trait]
|
#[const_trait]
|
||||||
pub trait D3D12HeapType {
|
pub trait D3D12HeapType {
|
||||||
|
@ -132,6 +133,7 @@ impl<T: D3D12ShaderVisibleHeapType> AsRef<D3D12_GPU_DESCRIPTOR_HANDLE>
|
||||||
for D3D12DescriptorHeapSlotInner<T>
|
for D3D12DescriptorHeapSlotInner<T>
|
||||||
{
|
{
|
||||||
fn as_ref(&self) -> &D3D12_GPU_DESCRIPTOR_HANDLE {
|
fn as_ref(&self) -> &D3D12_GPU_DESCRIPTOR_HANDLE {
|
||||||
|
/// SAFETY: D3D12ShaderVisibleHeapType must have a GPU handle.
|
||||||
self.gpu_handle.as_ref().unwrap()
|
self.gpu_handle.as_ref().unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -298,7 +300,7 @@ impl<T> D3D12DescriptorHeap<T> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
todo!("error need to fail");
|
Err(FilterChainError::DescriptorHeapOverflow)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn alloc_range<const NUM_DESC: usize>(
|
pub fn alloc_range<const NUM_DESC: usize>(
|
||||||
|
|
|
@ -1,16 +1,48 @@
|
||||||
use std::error::Error;
|
use std::error::Error;
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
pub type Result<T> = std::result::Result<T, Box<dyn Error>>;
|
/// Cumulative error type for Direct3D12 filter chains.
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
pub enum FilterChainError {
|
||||||
|
#[error("invariant assumption about d3d11 did not hold. report this as an issue.")]
|
||||||
|
Direct3DOperationError(&'static str),
|
||||||
|
#[error("direct3d driver error")]
|
||||||
|
Direct3DError(#[from] windows::core::Error),
|
||||||
|
#[error("SPIRV reflection error")]
|
||||||
|
SpirvCrossReflectError(#[from] spirv_cross::ErrorCode),
|
||||||
|
#[error("shader preset parse error")]
|
||||||
|
ShaderPresetError(#[from] ParsePresetError),
|
||||||
|
#[error("shader preprocess error")]
|
||||||
|
ShaderPreprocessError(#[from] PreprocessError),
|
||||||
|
#[error("shader compile error")]
|
||||||
|
ShaderCompileError(#[from] ShaderCompileError),
|
||||||
|
#[error("shader reflect error")]
|
||||||
|
ShaderReflectError(#[from] ShaderReflectError),
|
||||||
|
#[error("lut loading error")]
|
||||||
|
LutLoadError(#[from] ImageError),
|
||||||
|
#[error("heap overflow")]
|
||||||
|
DescriptorHeapOverflow,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub type Result<T> = std::result::Result<T, FilterChainError>;
|
||||||
|
|
||||||
// todo: make this return error
|
// todo: make this return error
|
||||||
macro_rules! assume_d3d12_init {
|
macro_rules! assume_d3d12_init {
|
||||||
($value:ident, $call:literal) => {
|
($value:ident, $call:literal) => {
|
||||||
let $value = $value.expect($call);
|
let $value = $value.ok_or($crate::error::FilterChainError::Direct3DOperationError(
|
||||||
|
$call,
|
||||||
|
))?;
|
||||||
};
|
};
|
||||||
(mut $value:ident, $call:literal) => {
|
(mut $value:ident, $call:literal) => {
|
||||||
let mut $value = $value.expect($call);
|
let mut $value = $value.ok_or($crate::error::FilterChainError::Direct3DOperationError(
|
||||||
|
$call,
|
||||||
|
))?;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Macro for unwrapping result of a D3D function.
|
/// Macro for unwrapping result of a D3D function.
|
||||||
pub(crate) use assume_d3d12_init;
|
pub(crate) use assume_d3d12_init;
|
||||||
|
use librashader_preprocess::PreprocessError;
|
||||||
|
use librashader_presets::ParsePresetError;
|
||||||
|
use librashader_reflect::error::{ShaderCompileError, ShaderReflectError};
|
||||||
|
use librashader_runtime::image::ImageError;
|
||||||
|
|
|
@ -47,6 +47,7 @@ use windows::Win32::Graphics::Direct3D12::{
|
||||||
use windows::Win32::Graphics::Dxgi::Common::DXGI_FORMAT_UNKNOWN;
|
use windows::Win32::Graphics::Dxgi::Common::DXGI_FORMAT_UNKNOWN;
|
||||||
use windows::Win32::System::Threading::{CreateEventA, ResetEvent, WaitForSingleObject};
|
use windows::Win32::System::Threading::{CreateEventA, ResetEvent, WaitForSingleObject};
|
||||||
use windows::Win32::System::WindowsProgramming::INFINITE;
|
use windows::Win32::System::WindowsProgramming::INFINITE;
|
||||||
|
use crate::error::FilterChainError;
|
||||||
|
|
||||||
type DxilShaderPassMeta = ShaderPassArtifact<impl CompileReflectShader<DXIL, GlslangCompilation>>;
|
type DxilShaderPassMeta = ShaderPassArtifact<impl CompileReflectShader<DXIL, GlslangCompilation>>;
|
||||||
type HlslShaderPassMeta = ShaderPassArtifact<impl CompileReflectShader<HLSL, GlslangCompilation>>;
|
type HlslShaderPassMeta = ShaderPassArtifact<impl CompileReflectShader<HLSL, GlslangCompilation>>;
|
||||||
|
@ -112,20 +113,18 @@ impl FilterChainD3D12 {
|
||||||
let shader_copy = preset.shaders.clone();
|
let shader_copy = preset.shaders.clone();
|
||||||
|
|
||||||
let (passes, semantics) =
|
let (passes, semantics) =
|
||||||
DXIL::compile_preset_passes::<GlslangCompilation, Box<dyn Error>>(
|
DXIL::compile_preset_passes::<GlslangCompilation, FilterChainError>(
|
||||||
preset.shaders,
|
preset.shaders,
|
||||||
&preset.textures,
|
&preset.textures,
|
||||||
)
|
)?;
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let (hlsl_passes, _) = HLSL::compile_preset_passes::<GlslangCompilation, Box<dyn Error>>(
|
let (hlsl_passes, _) = HLSL::compile_preset_passes::<GlslangCompilation, FilterChainError>(
|
||||||
shader_copy,
|
shader_copy,
|
||||||
&preset.textures,
|
&preset.textures,
|
||||||
)
|
)?;
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let samplers = SamplerSet::new(device)?;
|
let samplers = SamplerSet::new(device)?;
|
||||||
let mipmap_gen = D3D12MipmapGen::new(device).unwrap();
|
let mipmap_gen = D3D12MipmapGen::new(device)?;
|
||||||
|
|
||||||
let draw_quad = DrawQuad::new(device)?;
|
let draw_quad = DrawQuad::new(device)?;
|
||||||
let mut staging_heap = D3D12DescriptorHeap::new(
|
let mut staging_heap = D3D12DescriptorHeap::new(
|
||||||
|
@ -138,8 +137,7 @@ impl FilterChainD3D12 {
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let luts =
|
let luts =
|
||||||
FilterChainD3D12::load_luts(device, &mut staging_heap, &preset.textures, &mipmap_gen)
|
FilterChainD3D12::load_luts(device, &mut staging_heap, &preset.textures, &mipmap_gen)?;
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let root_signature = D3D12RootSignature::new(device)?;
|
let root_signature = D3D12RootSignature::new(device)?;
|
||||||
|
|
||||||
|
@ -150,8 +148,7 @@ impl FilterChainD3D12 {
|
||||||
hlsl_passes,
|
hlsl_passes,
|
||||||
&semantics,
|
&semantics,
|
||||||
options.map_or(false, |o| o.force_hlsl_pipeline),
|
options.map_or(false, |o| o.force_hlsl_pipeline),
|
||||||
)
|
)?;
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
// initialize output framebuffers
|
// initialize output framebuffers
|
||||||
let mut output_framebuffers = Vec::new();
|
let mut output_framebuffers = Vec::new();
|
||||||
|
@ -305,22 +302,20 @@ impl FilterChainD3D12 {
|
||||||
|
|
||||||
// Wait until finished
|
// Wait until finished
|
||||||
if unsafe { fence.GetCompletedValue() } < 1 {
|
if unsafe { fence.GetCompletedValue() } < 1 {
|
||||||
unsafe { fence.SetEventOnCompletion(1, fence_event) }
|
unsafe { fence.SetEventOnCompletion(1, fence_event)? };
|
||||||
.ok()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
unsafe { WaitForSingleObject(fence_event, INFINITE) };
|
unsafe { WaitForSingleObject(fence_event, INFINITE) };
|
||||||
unsafe { ResetEvent(fence_event) };
|
unsafe { ResetEvent(fence_event) };
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.Reset(&command_pool, None).unwrap();
|
cmd.Reset(&command_pool, None)?;
|
||||||
|
|
||||||
let residuals = mipmap_gen.mipmapping_context(&cmd, &mut work_heap, |context| {
|
let residuals = mipmap_gen.mipmapping_context(&cmd, &mut work_heap, |context| {
|
||||||
for lut in luts.values() {
|
for lut in luts.values() {
|
||||||
lut.generate_mipmaps(context)?;
|
lut.generate_mipmaps(context)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok::<(), Box<dyn Error>>(())
|
Ok::<(), FilterChainError>(())
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
//
|
//
|
||||||
|
@ -329,9 +324,7 @@ impl FilterChainD3D12 {
|
||||||
queue.Signal(&fence, 2)?;
|
queue.Signal(&fence, 2)?;
|
||||||
//
|
//
|
||||||
if unsafe { fence.GetCompletedValue() } < 2 {
|
if unsafe { fence.GetCompletedValue() } < 2 {
|
||||||
unsafe { fence.SetEventOnCompletion(2, fence_event) }
|
unsafe { fence.SetEventOnCompletion(2, fence_event)? }
|
||||||
.ok()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
unsafe { WaitForSingleObject(fence_event, INFINITE) };
|
unsafe { WaitForSingleObject(fence_event, INFINITE) };
|
||||||
unsafe { CloseHandle(fence_event) };
|
unsafe { CloseHandle(fence_event) };
|
||||||
|
@ -600,24 +593,24 @@ impl FilterChainD3D12 {
|
||||||
source.filter = pass.config.filter;
|
source.filter = pass.config.filter;
|
||||||
source.wrap_mode = pass.config.wrap_mode;
|
source.wrap_mode = pass.config.wrap_mode;
|
||||||
|
|
||||||
if pass.config.mipmap_input && !self.disable_mipmaps {
|
// if pass.config.mipmap_input && !self.disable_mipmaps {
|
||||||
unsafe {
|
// unsafe {
|
||||||
// this is so bad.
|
// // this is so bad.
|
||||||
self.common.mipmap_gen.mipmapping_context(
|
// self.common.mipmap_gen.mipmapping_context(
|
||||||
cmd,
|
// cmd,
|
||||||
&mut self.mipmap_heap,
|
// &mut self.mipmap_heap,
|
||||||
|ctx| {
|
// |ctx| {
|
||||||
ctx.generate_mipmaps(
|
// ctx.generate_mipmaps(
|
||||||
&source.resource,
|
// &source.resource,
|
||||||
source.size().calculate_miplevels() as u16,
|
// source.size().calculate_miplevels() as u16,
|
||||||
source.size,
|
// source.size,
|
||||||
source.format,
|
// source.format,
|
||||||
)?;
|
// )?;
|
||||||
Ok::<(), Box<dyn Error>>(())
|
// Ok::<(), FilterChainError>(())
|
||||||
},
|
// },
|
||||||
)?;
|
// )?;
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
let target = &self.output_framebuffers[index];
|
let target = &self.output_framebuffers[index];
|
||||||
util::d3d12_resource_transition(
|
util::d3d12_resource_transition(
|
||||||
|
|
|
@ -20,6 +20,8 @@ use windows::Win32::Graphics::Direct3D12::{
|
||||||
D3D12_SHADER_VISIBILITY_ALL, D3D12_SHADER_VISIBILITY_PIXEL, D3D_ROOT_SIGNATURE_VERSION_1,
|
D3D12_SHADER_VISIBILITY_ALL, D3D12_SHADER_VISIBILITY_PIXEL, D3D_ROOT_SIGNATURE_VERSION_1,
|
||||||
};
|
};
|
||||||
use windows::Win32::Graphics::Dxgi::Common::{DXGI_FORMAT, DXGI_FORMAT_UNKNOWN, DXGI_SAMPLE_DESC};
|
use windows::Win32::Graphics::Dxgi::Common::{DXGI_FORMAT, DXGI_FORMAT_UNKNOWN, DXGI_SAMPLE_DESC};
|
||||||
|
use crate::error::assume_d3d12_init;
|
||||||
|
use crate::error::FilterChainError::Direct3DOperationError;
|
||||||
|
|
||||||
pub struct D3D12GraphicsPipeline {
|
pub struct D3D12GraphicsPipeline {
|
||||||
pub(crate) handle: ID3D12PipelineState,
|
pub(crate) handle: ID3D12PipelineState,
|
||||||
|
@ -108,8 +110,7 @@ impl D3D12RootSignature {
|
||||||
None,
|
None,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
// SAFETY: if D3D12SerializeRootSignature succeeds then blob is Some
|
assume_d3d12_init!(rs_blob, "D3D12SerializeRootSignature");
|
||||||
let rs_blob = rs_blob.unwrap();
|
|
||||||
let blob = std::slice::from_raw_parts(
|
let blob = std::slice::from_raw_parts(
|
||||||
rs_blob.GetBufferPointer().cast(),
|
rs_blob.GetBufferPointer().cast(),
|
||||||
rs_blob.GetBufferSize(),
|
rs_blob.GetBufferSize(),
|
||||||
|
@ -215,10 +216,10 @@ impl D3D12GraphicsPipeline {
|
||||||
render_format: DXGI_FORMAT,
|
render_format: DXGI_FORMAT,
|
||||||
) -> error::Result<D3D12GraphicsPipeline> {
|
) -> error::Result<D3D12GraphicsPipeline> {
|
||||||
if shader_assembly.vertex.requires_runtime_data() {
|
if shader_assembly.vertex.requires_runtime_data() {
|
||||||
panic!("vertex needs rt data??")
|
return Err(Direct3DOperationError("Compiled DXIL Vertex shader needs unexpected runtime data"))
|
||||||
}
|
}
|
||||||
if shader_assembly.fragment.requires_runtime_data() {
|
if shader_assembly.fragment.requires_runtime_data() {
|
||||||
panic!("fragment needs rt data??")
|
return Err(Direct3DOperationError("Compiled DXIL fragment shader needs unexpected runtime data"))
|
||||||
}
|
}
|
||||||
let vertex_dxil = util::dxc_validate_shader(library, validator, &shader_assembly.vertex)?;
|
let vertex_dxil = util::dxc_validate_shader(library, validator, &shader_assembly.vertex)?;
|
||||||
let fragment_dxil =
|
let fragment_dxil =
|
||||||
|
|
|
@ -27,9 +27,9 @@ mod tests {
|
||||||
fn triangle_d3d12() {
|
fn triangle_d3d12() {
|
||||||
let sample = hello_triangle::d3d12_hello_triangle::Sample::new(
|
let sample = hello_triangle::d3d12_hello_triangle::Sample::new(
|
||||||
// "../test/slang-shaders/crt/crt-lottes.slangp",
|
// "../test/slang-shaders/crt/crt-lottes.slangp",
|
||||||
// "../test/slang-shaders/bezel/Mega_Bezel/Presets/MBZ__0__SMOOTH-ADV.slangp",
|
"../test/slang-shaders/bezel/Mega_Bezel/Presets/MBZ__0__SMOOTH-ADV.slangp",
|
||||||
// "../test/slang-shaders/crt/crt-royale.slangp",
|
// "../test/slang-shaders/crt/crt-royale.slangp",
|
||||||
"../test/slang-shaders/vhs/VHSPro.slangp",
|
// "../test/slang-shaders/vhs/VHSPro.slangp",
|
||||||
&SampleCommandLine {
|
&SampleCommandLine {
|
||||||
use_warp_device: false,
|
use_warp_device: false,
|
||||||
},
|
},
|
||||||
|
|
|
@ -123,10 +123,10 @@ impl<'a> MipmapGenContext<'a> {
|
||||||
impl D3D12MipmapGen {
|
impl D3D12MipmapGen {
|
||||||
pub fn new(device: &ID3D12Device) -> error::Result<D3D12MipmapGen> {
|
pub fn new(device: &ID3D12Device) -> error::Result<D3D12MipmapGen> {
|
||||||
unsafe {
|
unsafe {
|
||||||
let blob = fxc_compile_shader(GENERATE_MIPS_SRC, b"main\0", b"cs_5_1\0").unwrap();
|
let blob = fxc_compile_shader(GENERATE_MIPS_SRC, b"main\0", b"cs_5_1\0")?;
|
||||||
let blob =
|
let blob =
|
||||||
std::slice::from_raw_parts(blob.GetBufferPointer().cast(), blob.GetBufferSize());
|
std::slice::from_raw_parts(blob.GetBufferPointer().cast(), blob.GetBufferSize());
|
||||||
let root_signature: ID3D12RootSignature = device.CreateRootSignature(0, blob).unwrap();
|
let root_signature: ID3D12RootSignature = device.CreateRootSignature(0, blob)?;
|
||||||
|
|
||||||
let desc = D3D12_COMPUTE_PIPELINE_STATE_DESC {
|
let desc = D3D12_COMPUTE_PIPELINE_STATE_DESC {
|
||||||
pRootSignature: windows::core::ManuallyDrop::new(&root_signature),
|
pRootSignature: windows::core::ManuallyDrop::new(&root_signature),
|
||||||
|
|
|
@ -197,7 +197,7 @@ pub fn dxc_compile_shader(
|
||||||
|
|
||||||
if let Ok(buf) = result.GetErrorBuffer() {
|
if let Ok(buf) = result.GetErrorBuffer() {
|
||||||
unsafe {
|
unsafe {
|
||||||
let buf: IDxcBlobUtf8 = buf.cast().unwrap();
|
let buf: IDxcBlobUtf8 = buf.cast()?;
|
||||||
let buf =
|
let buf =
|
||||||
std::slice::from_raw_parts(buf.GetBufferPointer().cast(), buf.GetBufferSize());
|
std::slice::from_raw_parts(buf.GetBufferPointer().cast(), buf.GetBufferSize());
|
||||||
let str = std::str::from_utf8_unchecked(buf);
|
let str = std::str::from_utf8_unchecked(buf);
|
||||||
|
@ -225,12 +225,11 @@ pub fn dxc_validate_shader(
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
let result = validator
|
let result = validator
|
||||||
.Validate(&blob, DxcValidatorFlags_InPlaceEdit)
|
.Validate(&blob, DxcValidatorFlags_InPlaceEdit)?;
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
if let Ok(buf) = result.GetErrorBuffer() {
|
if let Ok(buf) = result.GetErrorBuffer() {
|
||||||
unsafe {
|
unsafe {
|
||||||
let buf: IDxcBlobUtf8 = buf.cast().unwrap();
|
let buf: IDxcBlobUtf8 = buf.cast()?;
|
||||||
let buf =
|
let buf =
|
||||||
std::slice::from_raw_parts(buf.GetBufferPointer().cast(), buf.GetBufferSize());
|
std::slice::from_raw_parts(buf.GetBufferPointer().cast(), buf.GetBufferSize());
|
||||||
let str = std::str::from_utf8_unchecked(buf);
|
let str = std::str::from_utf8_unchecked(buf);
|
||||||
|
|
Loading…
Add table
Reference in a new issue