From f9d0aa078bb8dfcef42025535809a6613600d465 Mon Sep 17 00:00:00 2001 From: Raph Levien Date: Thu, 11 Nov 2021 11:48:58 -0800 Subject: [PATCH] Use DXIL shader compilation Integrate DXC for translating HLSL for use in DX12. This will work around FXC limitations and unlock the use of more advanced HLSL features such as subgroups. This hardcodes the use of DXIL, but it could be adapted (with a bit of effort) to choose between DXIL and HLSL at runtime. --- piet-gpu-hal/examples/shader/build.ninja | 5 ++++ piet-gpu-hal/src/dx12.rs | 21 +++++++++++--- piet-gpu-hal/src/dx12/wrappers.rs | 36 ++++++++++++++++++------ piet-gpu-hal/src/hub.rs | 4 +-- piet-gpu-hal/src/lib.rs | 2 +- piet-gpu-hal/src/macros.rs | 1 + piet-gpu-hal/src/mux.rs | 10 +++++-- tests/shader/build.ninja | 9 ++++++ tests/src/main.rs | 4 +++ tests/src/prefix.rs | 5 ++-- tests/src/test_result.rs | 4 ++- 11 files changed, 80 insertions(+), 21 deletions(-) diff --git a/piet-gpu-hal/examples/shader/build.ninja b/piet-gpu-hal/examples/shader/build.ninja index f1c6328..3b9cf3f 100644 --- a/piet-gpu-hal/examples/shader/build.ninja +++ b/piet-gpu-hal/examples/shader/build.ninja @@ -4,6 +4,7 @@ glslang_validator = glslangValidator spirv_cross = spirv-cross +dxc = dxc rule glsl command = $glslang_validator -V -o $out $in @@ -11,9 +12,13 @@ rule glsl rule hlsl command = $spirv_cross --hlsl $in --output $out +rule dxil + command = $dxc -T cs_6_0 $in -Fo $out + rule msl command = $spirv_cross --msl $in --output $out build gen/collatz.spv: glsl collatz.comp build gen/collatz.hlsl: hlsl gen/collatz.spv +build gen/collatz.dxil: dxil gen/collatz.hlsl build gen/collatz.msl: msl gen/collatz.spv diff --git a/piet-gpu-hal/src/dx12.rs b/piet-gpu-hal/src/dx12.rs index 0fb7dfd..66befa5 100644 --- a/piet-gpu-hal/src/dx12.rs +++ b/piet-gpu-hal/src/dx12.rs @@ -6,7 +6,9 @@ mod wrappers; use std::{cell::Cell, convert::{TryFrom, TryInto}, mem, ptr}; use winapi::shared::minwindef::TRUE; -use winapi::shared::{dxgi, dxgi1_2, dxgi1_3, dxgitype}; +use winapi::shared::{dxgi, dxgi1_2, dxgitype}; +#[allow(unused)] +use winapi::shared::dxgi1_3; // for error reporting in debug mode use winapi::um::d3d12; use raw_window_handle::{HasRawWindowHandle, RawWindowHandle}; @@ -236,8 +238,9 @@ impl crate::backend::Device for Dx12Device { type Sampler = (); - // Currently this is HLSL source, but we'll probably change it to IR. - type ShaderSource = str; + // Currently due to type inflexibility this is hardcoded to either HLSL or + // DXIL, but it would be nice to be able to handle both at runtime. + type ShaderSource = [u8]; fn create_buffer(&self, size: u64, usage: BufferUsage) -> Result { // TODO: consider supporting BufferUsage::QUERY_RESOLVE here rather than @@ -411,7 +414,7 @@ impl crate::backend::Device for Dx12Device { unsafe fn create_compute_pipeline( &self, - code: &str, + code: &Self::ShaderSource, bind_types: &[BindType], ) -> Result { if u32::try_from(bind_types.len()).is_err() { @@ -442,6 +445,11 @@ impl crate::backend::Device for Dx12Device { i = end; } + // We could always have ShaderSource as [u8] even when it's HLSL, and use the + // magic number to distinguish. In any case, for now it's hardcoded as one or + // the other. + /* + // HLSL code path #[cfg(debug_assertions)] let flags = winapi::um::d3dcompiler::D3DCOMPILE_DEBUG | winapi::um::d3dcompiler::D3DCOMPILE_SKIP_OPTIMIZATION; @@ -449,6 +457,11 @@ impl crate::backend::Device for Dx12Device { let flags = 0; let shader_blob = ShaderByteCode::compile(code, "cs_5_1", "main", flags)?; let shader = ShaderByteCode::from_blob(shader_blob); + */ + + // DXIL code path + let shader = ShaderByteCode::from_slice(code); + let mut root_parameter = d3d12::D3D12_ROOT_PARAMETER { ParameterType: d3d12::D3D12_ROOT_PARAMETER_TYPE_DESCRIPTOR_TABLE, ShaderVisibility: d3d12::D3D12_SHADER_VISIBILITY_ALL, diff --git a/piet-gpu-hal/src/dx12/wrappers.rs b/piet-gpu-hal/src/dx12/wrappers.rs index add0dda..dd834fa 100644 --- a/piet-gpu-hal/src/dx12/wrappers.rs +++ b/piet-gpu-hal/src/dx12/wrappers.rs @@ -196,7 +196,7 @@ impl Factory4 { error_if_failed_else_unit(self.0.EnumAdapters1(id, &mut adapter))?; let mut desc = mem::zeroed(); (*adapter).GetDesc(&mut desc); - println!("desc: {:?}", desc.Description); + //println!("desc: {:?}", desc.Description); Ok(Adapter1(ComPtr::from_raw(adapter))) } @@ -276,6 +276,7 @@ impl SwapChain3 { } impl Blob { + #[allow(unused)] pub unsafe fn print_to_console(blob: &Blob) { println!("==SHADER COMPILE MESSAGES=="); let message = { @@ -714,13 +715,13 @@ impl RootSignature { let hresult = d3d12::D3D12SerializeRootSignature(desc, version, &mut blob, &mut error_blob_ptr); - let error_blob = if error_blob_ptr.is_null() { - None - } else { - Some(Blob(ComPtr::from_raw(error_blob_ptr))) - }; #[cfg(debug_assertions)] { + let error_blob = if error_blob_ptr.is_null() { + None + } else { + Some(Blob(ComPtr::from_raw(error_blob_ptr))) + }; if let Some(error_blob) = &error_blob { Blob::print_to_console(error_blob); } @@ -736,6 +737,7 @@ impl ShaderByteCode { // `blob` may not be null. // TODO: this is not super elegant, maybe want to move the get // operations closer to where they're used. + #[allow(unused)] pub unsafe fn from_blob(blob: Blob) -> ShaderByteCode { ShaderByteCode { bytecode: d3d12::D3D12_SHADER_BYTECODE { @@ -749,6 +751,7 @@ impl ShaderByteCode { /// Compile a shader from raw HLSL. /// /// * `target`: example format: `ps_5_1`. + #[allow(unused)] pub unsafe fn compile( source: &str, target: &str, @@ -795,6 +798,24 @@ impl ShaderByteCode { Ok(Blob(ComPtr::from_raw(shader_blob_ptr))) } + + /// Create bytecode from a slice. + /// + /// # Safety + /// + /// This call elides the lifetime from the slice. The caller is responsible + /// for making sure the reference remains valid for the lifetime of this + /// object. + #[allow(unused)] + pub unsafe fn from_slice(bytecode: &[u8]) -> ShaderByteCode { + ShaderByteCode { + bytecode: d3d12::D3D12_SHADER_BYTECODE { + BytecodeLength: bytecode.len(), + pShaderBytecode: bytecode.as_ptr() as *const _, + }, + blob: None, + } + } } impl Fence { @@ -1073,9 +1094,8 @@ pub unsafe fn create_transition_resource_barrier( resource_barrier } +#[allow(unused)] pub unsafe fn enable_debug_layer() -> Result<(), Error> { - println!("enabling debug layer."); - let mut debug_controller: *mut d3d12sdklayers::ID3D12Debug1 = ptr::null_mut(); explain_error( d3d12::D3D12GetDebugInterface( diff --git a/piet-gpu-hal/src/hub.rs b/piet-gpu-hal/src/hub.rs index db6de2a..2acfee0 100644 --- a/piet-gpu-hal/src/hub.rs +++ b/piet-gpu-hal/src/hub.rs @@ -369,8 +369,8 @@ impl Session { } /// Choose shader code from the available choices. - pub fn choose_shader<'a>(&self, spv: &'a [u8], hlsl: &'a str, msl: &'a str) -> ShaderCode<'a> { - self.0.device.choose_shader(spv, hlsl, msl) + pub fn choose_shader<'a>(&self, spv: &'a [u8], hlsl: &'a str, dxil: &'a [u8], msl: &'a str) -> ShaderCode<'a> { + self.0.device.choose_shader(spv, hlsl, dxil, msl) } /// Report the backend type that was chosen. diff --git a/piet-gpu-hal/src/lib.rs b/piet-gpu-hal/src/lib.rs index d74bfb0..05e2394 100644 --- a/piet-gpu-hal/src/lib.rs +++ b/piet-gpu-hal/src/lib.rs @@ -51,7 +51,7 @@ bitflags! { } /// The GPU backend that was selected. -#[derive(Clone, Copy, PartialEq, Eq)] +#[derive(Clone, Copy, PartialEq, Eq, Debug)] pub enum BackendType { Vulkan, Dx12, diff --git a/piet-gpu-hal/src/macros.rs b/piet-gpu-hal/src/macros.rs index 38897a8..a4a441e 100644 --- a/piet-gpu-hal/src/macros.rs +++ b/piet-gpu-hal/src/macros.rs @@ -198,6 +198,7 @@ macro_rules! include_shader { $device.choose_shader( include_bytes!(concat!($path_base, ".spv")), include_str!(concat!($path_base, ".hlsl")), + include_bytes!(concat!($path_base, ".dxil")), include_str!(concat!($path_base, ".msl")), ) }; diff --git a/piet-gpu-hal/src/mux.rs b/piet-gpu-hal/src/mux.rs index d153478..a0ea28a 100644 --- a/piet-gpu-hal/src/mux.rs +++ b/piet-gpu-hal/src/mux.rs @@ -104,6 +104,8 @@ pub enum ShaderCode<'a> { Spv(&'a [u8]), /// HLSL (source) Hlsl(&'a str), + /// DXIL (DX12 intermediate language) + Dxil(&'a [u8]), /// Metal Shading Language (source) Msl(&'a str), } @@ -321,9 +323,10 @@ impl Device { } Device::Dx12(d) => { let shader_code = match code { - ShaderCode::Hlsl(hlsl) => hlsl, + //ShaderCode::Hlsl(hlsl) => hlsl, + ShaderCode::Dxil(dxil) => dxil, // Panic or return "incompatible shader" error here? - _ => panic!("DX12 backend requires shader code in HLSL format"), + _ => panic!("DX12 backend requires shader code in DXIL format"), }; d.create_compute_pipeline(shader_code, bind_types) .map(Pipeline::Dx12) @@ -475,11 +478,12 @@ impl Device { &self, _spv: &'a [u8], _hlsl: &'a str, + _dxil: &'a [u8], _msl: &'a str, ) -> ShaderCode<'a> { mux_match! { self; Device::Vk(_d) => ShaderCode::Spv(_spv), - Device::Dx12(_d) => ShaderCode::Hlsl(_hlsl), + Device::Dx12(_d) => ShaderCode::Dxil(_dxil), Device::Mtl(_d) => ShaderCode::Msl(_msl), } } diff --git a/tests/shader/build.ninja b/tests/shader/build.ninja index f4dc4ae..19297c9 100644 --- a/tests/shader/build.ninja +++ b/tests/shader/build.ninja @@ -4,6 +4,7 @@ glslang_validator = glslangValidator spirv_cross = spirv-cross +dxc = dxc # See https://github.com/KhronosGroup/SPIRV-Cross/issues/1248 for # why we set this. @@ -15,26 +16,34 @@ rule glsl rule hlsl command = $spirv_cross --hlsl $in --output $out +rule dxil + command = $dxc -T cs_6_0 $in -Fo $out + rule msl command = $spirv_cross --msl $in --output $out $msl_flags build gen/clear.spv: glsl clear.comp build gen/clear.hlsl: hlsl gen/clear.spv +build gen/clear.dxil: dxil gen/clear.hlsl build gen/clear.msl: msl gen/clear.spv build gen/prefix.spv: glsl prefix.comp build gen/prefix.hlsl: hlsl gen/prefix.spv +build gen/prefix.dxil: dxil gen/prefix.hlsl build gen/prefix.msl: msl gen/prefix.spv build gen/prefix_reduce.spv: glsl prefix_reduce.comp build gen/prefix_reduce.hlsl: hlsl gen/prefix_reduce.spv +build gen/prefix_reduce.dxil: dxil gen/prefix_reduce.hlsl build gen/prefix_reduce.msl: msl gen/prefix_reduce.spv build gen/prefix_root.spv: glsl prefix_scan.comp flags = -DROOT build gen/prefix_root.hlsl: hlsl gen/prefix_root.spv +build gen/prefix_root.dxil: dxil gen/prefix_root.hlsl build gen/prefix_root.msl: msl gen/prefix_root.spv build gen/prefix_scan.spv: glsl prefix_scan.comp build gen/prefix_scan.hlsl: hlsl gen/prefix_scan.spv +build gen/prefix_scan.dxil: dxil gen/prefix_scan.hlsl build gen/prefix_scan.msl: msl gen/prefix_scan.spv diff --git a/tests/src/main.rs b/tests/src/main.rs index 647e8db..adefa7f 100644 --- a/tests/src/main.rs +++ b/tests/src/main.rs @@ -80,6 +80,10 @@ fn main() { flags |= InstanceFlags::DX12; } let mut runner = Runner::new(flags); + if style == ReportStyle::Verbose { + // TODO: get adapter name in here too + println!("Backend: {:?}", runner.backend_type()); + } report(&clear::run_clear_test(&mut runner, &config)); if config.groups.matches("prefix") { report(&prefix::run_prefix_test(&mut runner, &config)); diff --git a/tests/src/prefix.rs b/tests/src/prefix.rs index b668fac..a2e52c3 100644 --- a/tests/src/prefix.rs +++ b/tests/src/prefix.rs @@ -53,11 +53,13 @@ struct PrefixBinding { pub unsafe fn run_prefix_test(runner: &mut Runner, config: &Config) -> TestResult { let mut result = TestResult::new("prefix sum, decoupled look-back"); + /* + // We're good if we're using DXC. if runner.backend_type() == BackendType::Dx12 { result.skip("Shader won't compile on FXC"); return result; } - // This will be configurable. + */ let n_elements: u64 = config.size.choose(1 << 12, 1 << 24, 1 << 25); let data: Vec = (0..n_elements as u32).collect(); let data_buf = runner @@ -68,7 +70,6 @@ pub unsafe fn run_prefix_test(runner: &mut Runner, config: &Config) -> TestResul let code = PrefixCode::new(runner); let stage = PrefixStage::new(runner, &code, n_elements); let binding = stage.bind(runner, &code, &data_buf, &out_buf.dev_buf); - // Also will be configurable of course. let n_iter = config.n_iter; let mut total_elapsed = 0.0; for i in 0..n_iter { diff --git a/tests/src/test_result.rs b/tests/src/test_result.rs index a223ff0..e582c63 100644 --- a/tests/src/test_result.rs +++ b/tests/src/test_result.rs @@ -27,10 +27,11 @@ pub struct TestResult { pub enum Status { Pass, Fail(String), + #[allow(unused)] Skipped(String), } -#[derive(Clone, Copy)] +#[derive(Clone, Copy, PartialEq, Eq)] pub enum ReportStyle { Short, Verbose, @@ -84,6 +85,7 @@ impl TestResult { self.status = Status::Fail(explanation.into()); } + #[allow(unused)] pub fn skip(&mut self, explanation: impl Into) { self.status = Status::Skipped(explanation.into()); }