diff --git a/piet-gpu-hal/examples/shader/build.ninja b/piet-gpu-hal/examples/shader/build.ninja index f1c6328..3b9cf3f 100644 --- a/piet-gpu-hal/examples/shader/build.ninja +++ b/piet-gpu-hal/examples/shader/build.ninja @@ -4,6 +4,7 @@ glslang_validator = glslangValidator spirv_cross = spirv-cross +dxc = dxc rule glsl command = $glslang_validator -V -o $out $in @@ -11,9 +12,13 @@ rule glsl rule hlsl command = $spirv_cross --hlsl $in --output $out +rule dxil + command = $dxc -T cs_6_0 $in -Fo $out + rule msl command = $spirv_cross --msl $in --output $out build gen/collatz.spv: glsl collatz.comp build gen/collatz.hlsl: hlsl gen/collatz.spv +build gen/collatz.dxil: dxil gen/collatz.hlsl build gen/collatz.msl: msl gen/collatz.spv diff --git a/piet-gpu-hal/src/dx12.rs b/piet-gpu-hal/src/dx12.rs index 0fb7dfd..66befa5 100644 --- a/piet-gpu-hal/src/dx12.rs +++ b/piet-gpu-hal/src/dx12.rs @@ -6,7 +6,9 @@ mod wrappers; use std::{cell::Cell, convert::{TryFrom, TryInto}, mem, ptr}; use winapi::shared::minwindef::TRUE; -use winapi::shared::{dxgi, dxgi1_2, dxgi1_3, dxgitype}; +use winapi::shared::{dxgi, dxgi1_2, dxgitype}; +#[allow(unused)] +use winapi::shared::dxgi1_3; // for error reporting in debug mode use winapi::um::d3d12; use raw_window_handle::{HasRawWindowHandle, RawWindowHandle}; @@ -236,8 +238,9 @@ impl crate::backend::Device for Dx12Device { type Sampler = (); - // Currently this is HLSL source, but we'll probably change it to IR. - type ShaderSource = str; + // Currently due to type inflexibility this is hardcoded to either HLSL or + // DXIL, but it would be nice to be able to handle both at runtime. + type ShaderSource = [u8]; fn create_buffer(&self, size: u64, usage: BufferUsage) -> Result { // TODO: consider supporting BufferUsage::QUERY_RESOLVE here rather than @@ -411,7 +414,7 @@ impl crate::backend::Device for Dx12Device { unsafe fn create_compute_pipeline( &self, - code: &str, + code: &Self::ShaderSource, bind_types: &[BindType], ) -> Result { if u32::try_from(bind_types.len()).is_err() { @@ -442,6 +445,11 @@ impl crate::backend::Device for Dx12Device { i = end; } + // We could always have ShaderSource as [u8] even when it's HLSL, and use the + // magic number to distinguish. In any case, for now it's hardcoded as one or + // the other. + /* + // HLSL code path #[cfg(debug_assertions)] let flags = winapi::um::d3dcompiler::D3DCOMPILE_DEBUG | winapi::um::d3dcompiler::D3DCOMPILE_SKIP_OPTIMIZATION; @@ -449,6 +457,11 @@ impl crate::backend::Device for Dx12Device { let flags = 0; let shader_blob = ShaderByteCode::compile(code, "cs_5_1", "main", flags)?; let shader = ShaderByteCode::from_blob(shader_blob); + */ + + // DXIL code path + let shader = ShaderByteCode::from_slice(code); + let mut root_parameter = d3d12::D3D12_ROOT_PARAMETER { ParameterType: d3d12::D3D12_ROOT_PARAMETER_TYPE_DESCRIPTOR_TABLE, ShaderVisibility: d3d12::D3D12_SHADER_VISIBILITY_ALL, diff --git a/piet-gpu-hal/src/dx12/wrappers.rs b/piet-gpu-hal/src/dx12/wrappers.rs index add0dda..dd834fa 100644 --- a/piet-gpu-hal/src/dx12/wrappers.rs +++ b/piet-gpu-hal/src/dx12/wrappers.rs @@ -196,7 +196,7 @@ impl Factory4 { error_if_failed_else_unit(self.0.EnumAdapters1(id, &mut adapter))?; let mut desc = mem::zeroed(); (*adapter).GetDesc(&mut desc); - println!("desc: {:?}", desc.Description); + //println!("desc: {:?}", desc.Description); Ok(Adapter1(ComPtr::from_raw(adapter))) } @@ -276,6 +276,7 @@ impl SwapChain3 { } impl Blob { + #[allow(unused)] pub unsafe fn print_to_console(blob: &Blob) { println!("==SHADER COMPILE MESSAGES=="); let message = { @@ -714,13 +715,13 @@ impl RootSignature { let hresult = d3d12::D3D12SerializeRootSignature(desc, version, &mut blob, &mut error_blob_ptr); - let error_blob = if error_blob_ptr.is_null() { - None - } else { - Some(Blob(ComPtr::from_raw(error_blob_ptr))) - }; #[cfg(debug_assertions)] { + let error_blob = if error_blob_ptr.is_null() { + None + } else { + Some(Blob(ComPtr::from_raw(error_blob_ptr))) + }; if let Some(error_blob) = &error_blob { Blob::print_to_console(error_blob); } @@ -736,6 +737,7 @@ impl ShaderByteCode { // `blob` may not be null. // TODO: this is not super elegant, maybe want to move the get // operations closer to where they're used. + #[allow(unused)] pub unsafe fn from_blob(blob: Blob) -> ShaderByteCode { ShaderByteCode { bytecode: d3d12::D3D12_SHADER_BYTECODE { @@ -749,6 +751,7 @@ impl ShaderByteCode { /// Compile a shader from raw HLSL. /// /// * `target`: example format: `ps_5_1`. + #[allow(unused)] pub unsafe fn compile( source: &str, target: &str, @@ -795,6 +798,24 @@ impl ShaderByteCode { Ok(Blob(ComPtr::from_raw(shader_blob_ptr))) } + + /// Create bytecode from a slice. + /// + /// # Safety + /// + /// This call elides the lifetime from the slice. The caller is responsible + /// for making sure the reference remains valid for the lifetime of this + /// object. + #[allow(unused)] + pub unsafe fn from_slice(bytecode: &[u8]) -> ShaderByteCode { + ShaderByteCode { + bytecode: d3d12::D3D12_SHADER_BYTECODE { + BytecodeLength: bytecode.len(), + pShaderBytecode: bytecode.as_ptr() as *const _, + }, + blob: None, + } + } } impl Fence { @@ -1073,9 +1094,8 @@ pub unsafe fn create_transition_resource_barrier( resource_barrier } +#[allow(unused)] pub unsafe fn enable_debug_layer() -> Result<(), Error> { - println!("enabling debug layer."); - let mut debug_controller: *mut d3d12sdklayers::ID3D12Debug1 = ptr::null_mut(); explain_error( d3d12::D3D12GetDebugInterface( diff --git a/piet-gpu-hal/src/hub.rs b/piet-gpu-hal/src/hub.rs index db6de2a..2acfee0 100644 --- a/piet-gpu-hal/src/hub.rs +++ b/piet-gpu-hal/src/hub.rs @@ -369,8 +369,8 @@ impl Session { } /// Choose shader code from the available choices. - pub fn choose_shader<'a>(&self, spv: &'a [u8], hlsl: &'a str, msl: &'a str) -> ShaderCode<'a> { - self.0.device.choose_shader(spv, hlsl, msl) + pub fn choose_shader<'a>(&self, spv: &'a [u8], hlsl: &'a str, dxil: &'a [u8], msl: &'a str) -> ShaderCode<'a> { + self.0.device.choose_shader(spv, hlsl, dxil, msl) } /// Report the backend type that was chosen. diff --git a/piet-gpu-hal/src/lib.rs b/piet-gpu-hal/src/lib.rs index d74bfb0..05e2394 100644 --- a/piet-gpu-hal/src/lib.rs +++ b/piet-gpu-hal/src/lib.rs @@ -51,7 +51,7 @@ bitflags! { } /// The GPU backend that was selected. -#[derive(Clone, Copy, PartialEq, Eq)] +#[derive(Clone, Copy, PartialEq, Eq, Debug)] pub enum BackendType { Vulkan, Dx12, diff --git a/piet-gpu-hal/src/macros.rs b/piet-gpu-hal/src/macros.rs index 38897a8..a4a441e 100644 --- a/piet-gpu-hal/src/macros.rs +++ b/piet-gpu-hal/src/macros.rs @@ -198,6 +198,7 @@ macro_rules! include_shader { $device.choose_shader( include_bytes!(concat!($path_base, ".spv")), include_str!(concat!($path_base, ".hlsl")), + include_bytes!(concat!($path_base, ".dxil")), include_str!(concat!($path_base, ".msl")), ) }; diff --git a/piet-gpu-hal/src/mux.rs b/piet-gpu-hal/src/mux.rs index d153478..a0ea28a 100644 --- a/piet-gpu-hal/src/mux.rs +++ b/piet-gpu-hal/src/mux.rs @@ -104,6 +104,8 @@ pub enum ShaderCode<'a> { Spv(&'a [u8]), /// HLSL (source) Hlsl(&'a str), + /// DXIL (DX12 intermediate language) + Dxil(&'a [u8]), /// Metal Shading Language (source) Msl(&'a str), } @@ -321,9 +323,10 @@ impl Device { } Device::Dx12(d) => { let shader_code = match code { - ShaderCode::Hlsl(hlsl) => hlsl, + //ShaderCode::Hlsl(hlsl) => hlsl, + ShaderCode::Dxil(dxil) => dxil, // Panic or return "incompatible shader" error here? - _ => panic!("DX12 backend requires shader code in HLSL format"), + _ => panic!("DX12 backend requires shader code in DXIL format"), }; d.create_compute_pipeline(shader_code, bind_types) .map(Pipeline::Dx12) @@ -475,11 +478,12 @@ impl Device { &self, _spv: &'a [u8], _hlsl: &'a str, + _dxil: &'a [u8], _msl: &'a str, ) -> ShaderCode<'a> { mux_match! { self; Device::Vk(_d) => ShaderCode::Spv(_spv), - Device::Dx12(_d) => ShaderCode::Hlsl(_hlsl), + Device::Dx12(_d) => ShaderCode::Dxil(_dxil), Device::Mtl(_d) => ShaderCode::Msl(_msl), } } diff --git a/tests/shader/build.ninja b/tests/shader/build.ninja index f4dc4ae..19297c9 100644 --- a/tests/shader/build.ninja +++ b/tests/shader/build.ninja @@ -4,6 +4,7 @@ glslang_validator = glslangValidator spirv_cross = spirv-cross +dxc = dxc # See https://github.com/KhronosGroup/SPIRV-Cross/issues/1248 for # why we set this. @@ -15,26 +16,34 @@ rule glsl rule hlsl command = $spirv_cross --hlsl $in --output $out +rule dxil + command = $dxc -T cs_6_0 $in -Fo $out + rule msl command = $spirv_cross --msl $in --output $out $msl_flags build gen/clear.spv: glsl clear.comp build gen/clear.hlsl: hlsl gen/clear.spv +build gen/clear.dxil: dxil gen/clear.hlsl build gen/clear.msl: msl gen/clear.spv build gen/prefix.spv: glsl prefix.comp build gen/prefix.hlsl: hlsl gen/prefix.spv +build gen/prefix.dxil: dxil gen/prefix.hlsl build gen/prefix.msl: msl gen/prefix.spv build gen/prefix_reduce.spv: glsl prefix_reduce.comp build gen/prefix_reduce.hlsl: hlsl gen/prefix_reduce.spv +build gen/prefix_reduce.dxil: dxil gen/prefix_reduce.hlsl build gen/prefix_reduce.msl: msl gen/prefix_reduce.spv build gen/prefix_root.spv: glsl prefix_scan.comp flags = -DROOT build gen/prefix_root.hlsl: hlsl gen/prefix_root.spv +build gen/prefix_root.dxil: dxil gen/prefix_root.hlsl build gen/prefix_root.msl: msl gen/prefix_root.spv build gen/prefix_scan.spv: glsl prefix_scan.comp build gen/prefix_scan.hlsl: hlsl gen/prefix_scan.spv +build gen/prefix_scan.dxil: dxil gen/prefix_scan.hlsl build gen/prefix_scan.msl: msl gen/prefix_scan.spv diff --git a/tests/src/main.rs b/tests/src/main.rs index 647e8db..adefa7f 100644 --- a/tests/src/main.rs +++ b/tests/src/main.rs @@ -80,6 +80,10 @@ fn main() { flags |= InstanceFlags::DX12; } let mut runner = Runner::new(flags); + if style == ReportStyle::Verbose { + // TODO: get adapter name in here too + println!("Backend: {:?}", runner.backend_type()); + } report(&clear::run_clear_test(&mut runner, &config)); if config.groups.matches("prefix") { report(&prefix::run_prefix_test(&mut runner, &config)); diff --git a/tests/src/prefix.rs b/tests/src/prefix.rs index b668fac..a2e52c3 100644 --- a/tests/src/prefix.rs +++ b/tests/src/prefix.rs @@ -53,11 +53,13 @@ struct PrefixBinding { pub unsafe fn run_prefix_test(runner: &mut Runner, config: &Config) -> TestResult { let mut result = TestResult::new("prefix sum, decoupled look-back"); + /* + // We're good if we're using DXC. if runner.backend_type() == BackendType::Dx12 { result.skip("Shader won't compile on FXC"); return result; } - // This will be configurable. + */ let n_elements: u64 = config.size.choose(1 << 12, 1 << 24, 1 << 25); let data: Vec = (0..n_elements as u32).collect(); let data_buf = runner @@ -68,7 +70,6 @@ pub unsafe fn run_prefix_test(runner: &mut Runner, config: &Config) -> TestResul let code = PrefixCode::new(runner); let stage = PrefixStage::new(runner, &code, n_elements); let binding = stage.bind(runner, &code, &data_buf, &out_buf.dev_buf); - // Also will be configurable of course. let n_iter = config.n_iter; let mut total_elapsed = 0.0; for i in 0..n_iter { diff --git a/tests/src/test_result.rs b/tests/src/test_result.rs index a223ff0..e582c63 100644 --- a/tests/src/test_result.rs +++ b/tests/src/test_result.rs @@ -27,10 +27,11 @@ pub struct TestResult { pub enum Status { Pass, Fail(String), + #[allow(unused)] Skipped(String), } -#[derive(Clone, Copy)] +#[derive(Clone, Copy, PartialEq, Eq)] pub enum ReportStyle { Short, Verbose, @@ -84,6 +85,7 @@ impl TestResult { self.status = Status::Fail(explanation.into()); } + #[allow(unused)] pub fn skip(&mut self, explanation: impl Into) { self.status = Status::Skipped(explanation.into()); }