Use DXIL shader compilation

Integrate DXC for translating HLSL for use in DX12. This will work
around FXC limitations and unlock the use of more advanced HLSL features
such as subgroups.

This hardcodes the use of DXIL, but it could be adapted (with a bit of
effort) to choose between DXIL and HLSL at runtime.
This commit is contained in:
Raph Levien 2021-11-11 11:48:58 -08:00
parent 7a021793ee
commit f9d0aa078b
11 changed files with 80 additions and 21 deletions

View file

@ -4,6 +4,7 @@
glslang_validator = glslangValidator glslang_validator = glslangValidator
spirv_cross = spirv-cross spirv_cross = spirv-cross
dxc = dxc
rule glsl rule glsl
command = $glslang_validator -V -o $out $in command = $glslang_validator -V -o $out $in
@ -11,9 +12,13 @@ rule glsl
rule hlsl rule hlsl
command = $spirv_cross --hlsl $in --output $out command = $spirv_cross --hlsl $in --output $out
rule dxil
command = $dxc -T cs_6_0 $in -Fo $out
rule msl rule msl
command = $spirv_cross --msl $in --output $out command = $spirv_cross --msl $in --output $out
build gen/collatz.spv: glsl collatz.comp build gen/collatz.spv: glsl collatz.comp
build gen/collatz.hlsl: hlsl gen/collatz.spv build gen/collatz.hlsl: hlsl gen/collatz.spv
build gen/collatz.dxil: dxil gen/collatz.hlsl
build gen/collatz.msl: msl gen/collatz.spv build gen/collatz.msl: msl gen/collatz.spv

View file

@ -6,7 +6,9 @@ mod wrappers;
use std::{cell::Cell, convert::{TryFrom, TryInto}, mem, ptr}; use std::{cell::Cell, convert::{TryFrom, TryInto}, mem, ptr};
use winapi::shared::minwindef::TRUE; use winapi::shared::minwindef::TRUE;
use winapi::shared::{dxgi, dxgi1_2, dxgi1_3, dxgitype}; use winapi::shared::{dxgi, dxgi1_2, dxgitype};
#[allow(unused)]
use winapi::shared::dxgi1_3; // for error reporting in debug mode
use winapi::um::d3d12; use winapi::um::d3d12;
use raw_window_handle::{HasRawWindowHandle, RawWindowHandle}; use raw_window_handle::{HasRawWindowHandle, RawWindowHandle};
@ -236,8 +238,9 @@ impl crate::backend::Device for Dx12Device {
type Sampler = (); type Sampler = ();
// Currently this is HLSL source, but we'll probably change it to IR. // Currently due to type inflexibility this is hardcoded to either HLSL or
type ShaderSource = str; // DXIL, but it would be nice to be able to handle both at runtime.
type ShaderSource = [u8];
fn create_buffer(&self, size: u64, usage: BufferUsage) -> Result<Self::Buffer, Error> { fn create_buffer(&self, size: u64, usage: BufferUsage) -> Result<Self::Buffer, Error> {
// TODO: consider supporting BufferUsage::QUERY_RESOLVE here rather than // TODO: consider supporting BufferUsage::QUERY_RESOLVE here rather than
@ -411,7 +414,7 @@ impl crate::backend::Device for Dx12Device {
unsafe fn create_compute_pipeline( unsafe fn create_compute_pipeline(
&self, &self,
code: &str, code: &Self::ShaderSource,
bind_types: &[BindType], bind_types: &[BindType],
) -> Result<Pipeline, Error> { ) -> Result<Pipeline, Error> {
if u32::try_from(bind_types.len()).is_err() { if u32::try_from(bind_types.len()).is_err() {
@ -442,6 +445,11 @@ impl crate::backend::Device for Dx12Device {
i = end; i = end;
} }
// We could always have ShaderSource as [u8] even when it's HLSL, and use the
// magic number to distinguish. In any case, for now it's hardcoded as one or
// the other.
/*
// HLSL code path
#[cfg(debug_assertions)] #[cfg(debug_assertions)]
let flags = winapi::um::d3dcompiler::D3DCOMPILE_DEBUG let flags = winapi::um::d3dcompiler::D3DCOMPILE_DEBUG
| winapi::um::d3dcompiler::D3DCOMPILE_SKIP_OPTIMIZATION; | winapi::um::d3dcompiler::D3DCOMPILE_SKIP_OPTIMIZATION;
@ -449,6 +457,11 @@ impl crate::backend::Device for Dx12Device {
let flags = 0; let flags = 0;
let shader_blob = ShaderByteCode::compile(code, "cs_5_1", "main", flags)?; let shader_blob = ShaderByteCode::compile(code, "cs_5_1", "main", flags)?;
let shader = ShaderByteCode::from_blob(shader_blob); let shader = ShaderByteCode::from_blob(shader_blob);
*/
// DXIL code path
let shader = ShaderByteCode::from_slice(code);
let mut root_parameter = d3d12::D3D12_ROOT_PARAMETER { let mut root_parameter = d3d12::D3D12_ROOT_PARAMETER {
ParameterType: d3d12::D3D12_ROOT_PARAMETER_TYPE_DESCRIPTOR_TABLE, ParameterType: d3d12::D3D12_ROOT_PARAMETER_TYPE_DESCRIPTOR_TABLE,
ShaderVisibility: d3d12::D3D12_SHADER_VISIBILITY_ALL, ShaderVisibility: d3d12::D3D12_SHADER_VISIBILITY_ALL,

View file

@ -196,7 +196,7 @@ impl Factory4 {
error_if_failed_else_unit(self.0.EnumAdapters1(id, &mut adapter))?; error_if_failed_else_unit(self.0.EnumAdapters1(id, &mut adapter))?;
let mut desc = mem::zeroed(); let mut desc = mem::zeroed();
(*adapter).GetDesc(&mut desc); (*adapter).GetDesc(&mut desc);
println!("desc: {:?}", desc.Description); //println!("desc: {:?}", desc.Description);
Ok(Adapter1(ComPtr::from_raw(adapter))) Ok(Adapter1(ComPtr::from_raw(adapter)))
} }
@ -276,6 +276,7 @@ impl SwapChain3 {
} }
impl Blob { impl Blob {
#[allow(unused)]
pub unsafe fn print_to_console(blob: &Blob) { pub unsafe fn print_to_console(blob: &Blob) {
println!("==SHADER COMPILE MESSAGES=="); println!("==SHADER COMPILE MESSAGES==");
let message = { let message = {
@ -714,13 +715,13 @@ impl RootSignature {
let hresult = let hresult =
d3d12::D3D12SerializeRootSignature(desc, version, &mut blob, &mut error_blob_ptr); d3d12::D3D12SerializeRootSignature(desc, version, &mut blob, &mut error_blob_ptr);
#[cfg(debug_assertions)]
{
let error_blob = if error_blob_ptr.is_null() { let error_blob = if error_blob_ptr.is_null() {
None None
} else { } else {
Some(Blob(ComPtr::from_raw(error_blob_ptr))) Some(Blob(ComPtr::from_raw(error_blob_ptr)))
}; };
#[cfg(debug_assertions)]
{
if let Some(error_blob) = &error_blob { if let Some(error_blob) = &error_blob {
Blob::print_to_console(error_blob); Blob::print_to_console(error_blob);
} }
@ -736,6 +737,7 @@ impl ShaderByteCode {
// `blob` may not be null. // `blob` may not be null.
// TODO: this is not super elegant, maybe want to move the get // TODO: this is not super elegant, maybe want to move the get
// operations closer to where they're used. // operations closer to where they're used.
#[allow(unused)]
pub unsafe fn from_blob(blob: Blob) -> ShaderByteCode { pub unsafe fn from_blob(blob: Blob) -> ShaderByteCode {
ShaderByteCode { ShaderByteCode {
bytecode: d3d12::D3D12_SHADER_BYTECODE { bytecode: d3d12::D3D12_SHADER_BYTECODE {
@ -749,6 +751,7 @@ impl ShaderByteCode {
/// Compile a shader from raw HLSL. /// Compile a shader from raw HLSL.
/// ///
/// * `target`: example format: `ps_5_1`. /// * `target`: example format: `ps_5_1`.
#[allow(unused)]
pub unsafe fn compile( pub unsafe fn compile(
source: &str, source: &str,
target: &str, target: &str,
@ -795,6 +798,24 @@ impl ShaderByteCode {
Ok(Blob(ComPtr::from_raw(shader_blob_ptr))) Ok(Blob(ComPtr::from_raw(shader_blob_ptr)))
} }
/// Create bytecode from a slice.
///
/// # Safety
///
/// This call elides the lifetime from the slice. The caller is responsible
/// for making sure the reference remains valid for the lifetime of this
/// object.
#[allow(unused)]
pub unsafe fn from_slice(bytecode: &[u8]) -> ShaderByteCode {
ShaderByteCode {
bytecode: d3d12::D3D12_SHADER_BYTECODE {
BytecodeLength: bytecode.len(),
pShaderBytecode: bytecode.as_ptr() as *const _,
},
blob: None,
}
}
} }
impl Fence { impl Fence {
@ -1073,9 +1094,8 @@ pub unsafe fn create_transition_resource_barrier(
resource_barrier resource_barrier
} }
#[allow(unused)]
pub unsafe fn enable_debug_layer() -> Result<(), Error> { pub unsafe fn enable_debug_layer() -> Result<(), Error> {
println!("enabling debug layer.");
let mut debug_controller: *mut d3d12sdklayers::ID3D12Debug1 = ptr::null_mut(); let mut debug_controller: *mut d3d12sdklayers::ID3D12Debug1 = ptr::null_mut();
explain_error( explain_error(
d3d12::D3D12GetDebugInterface( d3d12::D3D12GetDebugInterface(

View file

@ -369,8 +369,8 @@ impl Session {
} }
/// Choose shader code from the available choices. /// Choose shader code from the available choices.
pub fn choose_shader<'a>(&self, spv: &'a [u8], hlsl: &'a str, msl: &'a str) -> ShaderCode<'a> { pub fn choose_shader<'a>(&self, spv: &'a [u8], hlsl: &'a str, dxil: &'a [u8], msl: &'a str) -> ShaderCode<'a> {
self.0.device.choose_shader(spv, hlsl, msl) self.0.device.choose_shader(spv, hlsl, dxil, msl)
} }
/// Report the backend type that was chosen. /// Report the backend type that was chosen.

View file

@ -51,7 +51,7 @@ bitflags! {
} }
/// The GPU backend that was selected. /// The GPU backend that was selected.
#[derive(Clone, Copy, PartialEq, Eq)] #[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum BackendType { pub enum BackendType {
Vulkan, Vulkan,
Dx12, Dx12,

View file

@ -198,6 +198,7 @@ macro_rules! include_shader {
$device.choose_shader( $device.choose_shader(
include_bytes!(concat!($path_base, ".spv")), include_bytes!(concat!($path_base, ".spv")),
include_str!(concat!($path_base, ".hlsl")), include_str!(concat!($path_base, ".hlsl")),
include_bytes!(concat!($path_base, ".dxil")),
include_str!(concat!($path_base, ".msl")), include_str!(concat!($path_base, ".msl")),
) )
}; };

View file

@ -104,6 +104,8 @@ pub enum ShaderCode<'a> {
Spv(&'a [u8]), Spv(&'a [u8]),
/// HLSL (source) /// HLSL (source)
Hlsl(&'a str), Hlsl(&'a str),
/// DXIL (DX12 intermediate language)
Dxil(&'a [u8]),
/// Metal Shading Language (source) /// Metal Shading Language (source)
Msl(&'a str), Msl(&'a str),
} }
@ -321,9 +323,10 @@ impl Device {
} }
Device::Dx12(d) => { Device::Dx12(d) => {
let shader_code = match code { let shader_code = match code {
ShaderCode::Hlsl(hlsl) => hlsl, //ShaderCode::Hlsl(hlsl) => hlsl,
ShaderCode::Dxil(dxil) => dxil,
// Panic or return "incompatible shader" error here? // Panic or return "incompatible shader" error here?
_ => panic!("DX12 backend requires shader code in HLSL format"), _ => panic!("DX12 backend requires shader code in DXIL format"),
}; };
d.create_compute_pipeline(shader_code, bind_types) d.create_compute_pipeline(shader_code, bind_types)
.map(Pipeline::Dx12) .map(Pipeline::Dx12)
@ -475,11 +478,12 @@ impl Device {
&self, &self,
_spv: &'a [u8], _spv: &'a [u8],
_hlsl: &'a str, _hlsl: &'a str,
_dxil: &'a [u8],
_msl: &'a str, _msl: &'a str,
) -> ShaderCode<'a> { ) -> ShaderCode<'a> {
mux_match! { self; mux_match! { self;
Device::Vk(_d) => ShaderCode::Spv(_spv), Device::Vk(_d) => ShaderCode::Spv(_spv),
Device::Dx12(_d) => ShaderCode::Hlsl(_hlsl), Device::Dx12(_d) => ShaderCode::Dxil(_dxil),
Device::Mtl(_d) => ShaderCode::Msl(_msl), Device::Mtl(_d) => ShaderCode::Msl(_msl),
} }
} }

View file

@ -4,6 +4,7 @@
glslang_validator = glslangValidator glslang_validator = glslangValidator
spirv_cross = spirv-cross spirv_cross = spirv-cross
dxc = dxc
# See https://github.com/KhronosGroup/SPIRV-Cross/issues/1248 for # See https://github.com/KhronosGroup/SPIRV-Cross/issues/1248 for
# why we set this. # why we set this.
@ -15,26 +16,34 @@ rule glsl
rule hlsl rule hlsl
command = $spirv_cross --hlsl $in --output $out command = $spirv_cross --hlsl $in --output $out
rule dxil
command = $dxc -T cs_6_0 $in -Fo $out
rule msl rule msl
command = $spirv_cross --msl $in --output $out $msl_flags command = $spirv_cross --msl $in --output $out $msl_flags
build gen/clear.spv: glsl clear.comp build gen/clear.spv: glsl clear.comp
build gen/clear.hlsl: hlsl gen/clear.spv build gen/clear.hlsl: hlsl gen/clear.spv
build gen/clear.dxil: dxil gen/clear.hlsl
build gen/clear.msl: msl gen/clear.spv build gen/clear.msl: msl gen/clear.spv
build gen/prefix.spv: glsl prefix.comp build gen/prefix.spv: glsl prefix.comp
build gen/prefix.hlsl: hlsl gen/prefix.spv build gen/prefix.hlsl: hlsl gen/prefix.spv
build gen/prefix.dxil: dxil gen/prefix.hlsl
build gen/prefix.msl: msl gen/prefix.spv build gen/prefix.msl: msl gen/prefix.spv
build gen/prefix_reduce.spv: glsl prefix_reduce.comp build gen/prefix_reduce.spv: glsl prefix_reduce.comp
build gen/prefix_reduce.hlsl: hlsl gen/prefix_reduce.spv build gen/prefix_reduce.hlsl: hlsl gen/prefix_reduce.spv
build gen/prefix_reduce.dxil: dxil gen/prefix_reduce.hlsl
build gen/prefix_reduce.msl: msl gen/prefix_reduce.spv build gen/prefix_reduce.msl: msl gen/prefix_reduce.spv
build gen/prefix_root.spv: glsl prefix_scan.comp build gen/prefix_root.spv: glsl prefix_scan.comp
flags = -DROOT flags = -DROOT
build gen/prefix_root.hlsl: hlsl gen/prefix_root.spv build gen/prefix_root.hlsl: hlsl gen/prefix_root.spv
build gen/prefix_root.dxil: dxil gen/prefix_root.hlsl
build gen/prefix_root.msl: msl gen/prefix_root.spv build gen/prefix_root.msl: msl gen/prefix_root.spv
build gen/prefix_scan.spv: glsl prefix_scan.comp build gen/prefix_scan.spv: glsl prefix_scan.comp
build gen/prefix_scan.hlsl: hlsl gen/prefix_scan.spv build gen/prefix_scan.hlsl: hlsl gen/prefix_scan.spv
build gen/prefix_scan.dxil: dxil gen/prefix_scan.hlsl
build gen/prefix_scan.msl: msl gen/prefix_scan.spv build gen/prefix_scan.msl: msl gen/prefix_scan.spv

View file

@ -80,6 +80,10 @@ fn main() {
flags |= InstanceFlags::DX12; flags |= InstanceFlags::DX12;
} }
let mut runner = Runner::new(flags); let mut runner = Runner::new(flags);
if style == ReportStyle::Verbose {
// TODO: get adapter name in here too
println!("Backend: {:?}", runner.backend_type());
}
report(&clear::run_clear_test(&mut runner, &config)); report(&clear::run_clear_test(&mut runner, &config));
if config.groups.matches("prefix") { if config.groups.matches("prefix") {
report(&prefix::run_prefix_test(&mut runner, &config)); report(&prefix::run_prefix_test(&mut runner, &config));

View file

@ -53,11 +53,13 @@ struct PrefixBinding {
pub unsafe fn run_prefix_test(runner: &mut Runner, config: &Config) -> TestResult { pub unsafe fn run_prefix_test(runner: &mut Runner, config: &Config) -> TestResult {
let mut result = TestResult::new("prefix sum, decoupled look-back"); let mut result = TestResult::new("prefix sum, decoupled look-back");
/*
// We're good if we're using DXC.
if runner.backend_type() == BackendType::Dx12 { if runner.backend_type() == BackendType::Dx12 {
result.skip("Shader won't compile on FXC"); result.skip("Shader won't compile on FXC");
return result; return result;
} }
// This will be configurable. */
let n_elements: u64 = config.size.choose(1 << 12, 1 << 24, 1 << 25); let n_elements: u64 = config.size.choose(1 << 12, 1 << 24, 1 << 25);
let data: Vec<u32> = (0..n_elements as u32).collect(); let data: Vec<u32> = (0..n_elements as u32).collect();
let data_buf = runner let data_buf = runner
@ -68,7 +70,6 @@ pub unsafe fn run_prefix_test(runner: &mut Runner, config: &Config) -> TestResul
let code = PrefixCode::new(runner); let code = PrefixCode::new(runner);
let stage = PrefixStage::new(runner, &code, n_elements); let stage = PrefixStage::new(runner, &code, n_elements);
let binding = stage.bind(runner, &code, &data_buf, &out_buf.dev_buf); let binding = stage.bind(runner, &code, &data_buf, &out_buf.dev_buf);
// Also will be configurable of course.
let n_iter = config.n_iter; let n_iter = config.n_iter;
let mut total_elapsed = 0.0; let mut total_elapsed = 0.0;
for i in 0..n_iter { for i in 0..n_iter {

View file

@ -27,10 +27,11 @@ pub struct TestResult {
pub enum Status { pub enum Status {
Pass, Pass,
Fail(String), Fail(String),
#[allow(unused)]
Skipped(String), Skipped(String),
} }
#[derive(Clone, Copy)] #[derive(Clone, Copy, PartialEq, Eq)]
pub enum ReportStyle { pub enum ReportStyle {
Short, Short,
Verbose, Verbose,
@ -84,6 +85,7 @@ impl TestResult {
self.status = Status::Fail(explanation.into()); self.status = Status::Fail(explanation.into());
} }
#[allow(unused)]
pub fn skip(&mut self, explanation: impl Into<String>) { pub fn skip(&mut self, explanation: impl Into<String>) {
self.status = Status::Skipped(explanation.into()); self.status = Status::Skipped(explanation.into());
} }