Use DXIL shader compilation

Integrate DXC for translating HLSL for use in DX12. This will work
around FXC limitations and unlock the use of more advanced HLSL features
such as subgroups.

This hardcodes the use of DXIL, but it could be adapted (with a bit of
effort) to choose between DXIL and HLSL at runtime.
This commit is contained in:
Raph Levien 2021-11-11 11:48:58 -08:00
parent 7a021793ee
commit f9d0aa078b
11 changed files with 80 additions and 21 deletions

View file

@ -4,6 +4,7 @@
glslang_validator = glslangValidator
spirv_cross = spirv-cross
dxc = dxc
rule glsl
command = $glslang_validator -V -o $out $in
@ -11,9 +12,13 @@ rule glsl
rule hlsl
command = $spirv_cross --hlsl $in --output $out
rule dxil
command = $dxc -T cs_6_0 $in -Fo $out
rule msl
command = $spirv_cross --msl $in --output $out
build gen/collatz.spv: glsl collatz.comp
build gen/collatz.hlsl: hlsl gen/collatz.spv
build gen/collatz.dxil: dxil gen/collatz.hlsl
build gen/collatz.msl: msl gen/collatz.spv

View file

@ -6,7 +6,9 @@ mod wrappers;
use std::{cell::Cell, convert::{TryFrom, TryInto}, mem, ptr};
use winapi::shared::minwindef::TRUE;
use winapi::shared::{dxgi, dxgi1_2, dxgi1_3, dxgitype};
use winapi::shared::{dxgi, dxgi1_2, dxgitype};
#[allow(unused)]
use winapi::shared::dxgi1_3; // for error reporting in debug mode
use winapi::um::d3d12;
use raw_window_handle::{HasRawWindowHandle, RawWindowHandle};
@ -236,8 +238,9 @@ impl crate::backend::Device for Dx12Device {
type Sampler = ();
// Currently this is HLSL source, but we'll probably change it to IR.
type ShaderSource = str;
// Currently due to type inflexibility this is hardcoded to either HLSL or
// DXIL, but it would be nice to be able to handle both at runtime.
type ShaderSource = [u8];
fn create_buffer(&self, size: u64, usage: BufferUsage) -> Result<Self::Buffer, Error> {
// TODO: consider supporting BufferUsage::QUERY_RESOLVE here rather than
@ -411,7 +414,7 @@ impl crate::backend::Device for Dx12Device {
unsafe fn create_compute_pipeline(
&self,
code: &str,
code: &Self::ShaderSource,
bind_types: &[BindType],
) -> Result<Pipeline, Error> {
if u32::try_from(bind_types.len()).is_err() {
@ -442,6 +445,11 @@ impl crate::backend::Device for Dx12Device {
i = end;
}
// We could always have ShaderSource as [u8] even when it's HLSL, and use the
// magic number to distinguish. In any case, for now it's hardcoded as one or
// the other.
/*
// HLSL code path
#[cfg(debug_assertions)]
let flags = winapi::um::d3dcompiler::D3DCOMPILE_DEBUG
| winapi::um::d3dcompiler::D3DCOMPILE_SKIP_OPTIMIZATION;
@ -449,6 +457,11 @@ impl crate::backend::Device for Dx12Device {
let flags = 0;
let shader_blob = ShaderByteCode::compile(code, "cs_5_1", "main", flags)?;
let shader = ShaderByteCode::from_blob(shader_blob);
*/
// DXIL code path
let shader = ShaderByteCode::from_slice(code);
let mut root_parameter = d3d12::D3D12_ROOT_PARAMETER {
ParameterType: d3d12::D3D12_ROOT_PARAMETER_TYPE_DESCRIPTOR_TABLE,
ShaderVisibility: d3d12::D3D12_SHADER_VISIBILITY_ALL,

View file

@ -196,7 +196,7 @@ impl Factory4 {
error_if_failed_else_unit(self.0.EnumAdapters1(id, &mut adapter))?;
let mut desc = mem::zeroed();
(*adapter).GetDesc(&mut desc);
println!("desc: {:?}", desc.Description);
//println!("desc: {:?}", desc.Description);
Ok(Adapter1(ComPtr::from_raw(adapter)))
}
@ -276,6 +276,7 @@ impl SwapChain3 {
}
impl Blob {
#[allow(unused)]
pub unsafe fn print_to_console(blob: &Blob) {
println!("==SHADER COMPILE MESSAGES==");
let message = {
@ -714,13 +715,13 @@ impl RootSignature {
let hresult =
d3d12::D3D12SerializeRootSignature(desc, version, &mut blob, &mut error_blob_ptr);
#[cfg(debug_assertions)]
{
let error_blob = if error_blob_ptr.is_null() {
None
} else {
Some(Blob(ComPtr::from_raw(error_blob_ptr)))
};
#[cfg(debug_assertions)]
{
if let Some(error_blob) = &error_blob {
Blob::print_to_console(error_blob);
}
@ -736,6 +737,7 @@ impl ShaderByteCode {
// `blob` may not be null.
// TODO: this is not super elegant, maybe want to move the get
// operations closer to where they're used.
#[allow(unused)]
pub unsafe fn from_blob(blob: Blob) -> ShaderByteCode {
ShaderByteCode {
bytecode: d3d12::D3D12_SHADER_BYTECODE {
@ -749,6 +751,7 @@ impl ShaderByteCode {
/// Compile a shader from raw HLSL.
///
/// * `target`: example format: `ps_5_1`.
#[allow(unused)]
pub unsafe fn compile(
source: &str,
target: &str,
@ -795,6 +798,24 @@ impl ShaderByteCode {
Ok(Blob(ComPtr::from_raw(shader_blob_ptr)))
}
/// Create bytecode from a slice.
///
/// # Safety
///
/// This call elides the lifetime from the slice. The caller is responsible
/// for making sure the reference remains valid for the lifetime of this
/// object.
#[allow(unused)]
pub unsafe fn from_slice(bytecode: &[u8]) -> ShaderByteCode {
ShaderByteCode {
bytecode: d3d12::D3D12_SHADER_BYTECODE {
BytecodeLength: bytecode.len(),
pShaderBytecode: bytecode.as_ptr() as *const _,
},
blob: None,
}
}
}
impl Fence {
@ -1073,9 +1094,8 @@ pub unsafe fn create_transition_resource_barrier(
resource_barrier
}
#[allow(unused)]
pub unsafe fn enable_debug_layer() -> Result<(), Error> {
println!("enabling debug layer.");
let mut debug_controller: *mut d3d12sdklayers::ID3D12Debug1 = ptr::null_mut();
explain_error(
d3d12::D3D12GetDebugInterface(

View file

@ -369,8 +369,8 @@ impl Session {
}
/// Choose shader code from the available choices.
pub fn choose_shader<'a>(&self, spv: &'a [u8], hlsl: &'a str, msl: &'a str) -> ShaderCode<'a> {
self.0.device.choose_shader(spv, hlsl, msl)
pub fn choose_shader<'a>(&self, spv: &'a [u8], hlsl: &'a str, dxil: &'a [u8], msl: &'a str) -> ShaderCode<'a> {
self.0.device.choose_shader(spv, hlsl, dxil, msl)
}
/// Report the backend type that was chosen.

View file

@ -51,7 +51,7 @@ bitflags! {
}
/// The GPU backend that was selected.
#[derive(Clone, Copy, PartialEq, Eq)]
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum BackendType {
Vulkan,
Dx12,

View file

@ -198,6 +198,7 @@ macro_rules! include_shader {
$device.choose_shader(
include_bytes!(concat!($path_base, ".spv")),
include_str!(concat!($path_base, ".hlsl")),
include_bytes!(concat!($path_base, ".dxil")),
include_str!(concat!($path_base, ".msl")),
)
};

View file

@ -104,6 +104,8 @@ pub enum ShaderCode<'a> {
Spv(&'a [u8]),
/// HLSL (source)
Hlsl(&'a str),
/// DXIL (DX12 intermediate language)
Dxil(&'a [u8]),
/// Metal Shading Language (source)
Msl(&'a str),
}
@ -321,9 +323,10 @@ impl Device {
}
Device::Dx12(d) => {
let shader_code = match code {
ShaderCode::Hlsl(hlsl) => hlsl,
//ShaderCode::Hlsl(hlsl) => hlsl,
ShaderCode::Dxil(dxil) => dxil,
// Panic or return "incompatible shader" error here?
_ => panic!("DX12 backend requires shader code in HLSL format"),
_ => panic!("DX12 backend requires shader code in DXIL format"),
};
d.create_compute_pipeline(shader_code, bind_types)
.map(Pipeline::Dx12)
@ -475,11 +478,12 @@ impl Device {
&self,
_spv: &'a [u8],
_hlsl: &'a str,
_dxil: &'a [u8],
_msl: &'a str,
) -> ShaderCode<'a> {
mux_match! { self;
Device::Vk(_d) => ShaderCode::Spv(_spv),
Device::Dx12(_d) => ShaderCode::Hlsl(_hlsl),
Device::Dx12(_d) => ShaderCode::Dxil(_dxil),
Device::Mtl(_d) => ShaderCode::Msl(_msl),
}
}

View file

@ -4,6 +4,7 @@
glslang_validator = glslangValidator
spirv_cross = spirv-cross
dxc = dxc
# See https://github.com/KhronosGroup/SPIRV-Cross/issues/1248 for
# why we set this.
@ -15,26 +16,34 @@ rule glsl
rule hlsl
command = $spirv_cross --hlsl $in --output $out
rule dxil
command = $dxc -T cs_6_0 $in -Fo $out
rule msl
command = $spirv_cross --msl $in --output $out $msl_flags
build gen/clear.spv: glsl clear.comp
build gen/clear.hlsl: hlsl gen/clear.spv
build gen/clear.dxil: dxil gen/clear.hlsl
build gen/clear.msl: msl gen/clear.spv
build gen/prefix.spv: glsl prefix.comp
build gen/prefix.hlsl: hlsl gen/prefix.spv
build gen/prefix.dxil: dxil gen/prefix.hlsl
build gen/prefix.msl: msl gen/prefix.spv
build gen/prefix_reduce.spv: glsl prefix_reduce.comp
build gen/prefix_reduce.hlsl: hlsl gen/prefix_reduce.spv
build gen/prefix_reduce.dxil: dxil gen/prefix_reduce.hlsl
build gen/prefix_reduce.msl: msl gen/prefix_reduce.spv
build gen/prefix_root.spv: glsl prefix_scan.comp
flags = -DROOT
build gen/prefix_root.hlsl: hlsl gen/prefix_root.spv
build gen/prefix_root.dxil: dxil gen/prefix_root.hlsl
build gen/prefix_root.msl: msl gen/prefix_root.spv
build gen/prefix_scan.spv: glsl prefix_scan.comp
build gen/prefix_scan.hlsl: hlsl gen/prefix_scan.spv
build gen/prefix_scan.dxil: dxil gen/prefix_scan.hlsl
build gen/prefix_scan.msl: msl gen/prefix_scan.spv

View file

@ -80,6 +80,10 @@ fn main() {
flags |= InstanceFlags::DX12;
}
let mut runner = Runner::new(flags);
if style == ReportStyle::Verbose {
// TODO: get adapter name in here too
println!("Backend: {:?}", runner.backend_type());
}
report(&clear::run_clear_test(&mut runner, &config));
if config.groups.matches("prefix") {
report(&prefix::run_prefix_test(&mut runner, &config));

View file

@ -53,11 +53,13 @@ struct PrefixBinding {
pub unsafe fn run_prefix_test(runner: &mut Runner, config: &Config) -> TestResult {
let mut result = TestResult::new("prefix sum, decoupled look-back");
/*
// We're good if we're using DXC.
if runner.backend_type() == BackendType::Dx12 {
result.skip("Shader won't compile on FXC");
return result;
}
// This will be configurable.
*/
let n_elements: u64 = config.size.choose(1 << 12, 1 << 24, 1 << 25);
let data: Vec<u32> = (0..n_elements as u32).collect();
let data_buf = runner
@ -68,7 +70,6 @@ pub unsafe fn run_prefix_test(runner: &mut Runner, config: &Config) -> TestResul
let code = PrefixCode::new(runner);
let stage = PrefixStage::new(runner, &code, n_elements);
let binding = stage.bind(runner, &code, &data_buf, &out_buf.dev_buf);
// Also will be configurable of course.
let n_iter = config.n_iter;
let mut total_elapsed = 0.0;
for i in 0..n_iter {

View file

@ -27,10 +27,11 @@ pub struct TestResult {
pub enum Status {
Pass,
Fail(String),
#[allow(unused)]
Skipped(String),
}
#[derive(Clone, Copy)]
#[derive(Clone, Copy, PartialEq, Eq)]
pub enum ReportStyle {
Short,
Verbose,
@ -84,6 +85,7 @@ impl TestResult {
self.status = Status::Fail(explanation.into());
}
#[allow(unused)]
pub fn skip(&mut self, explanation: impl Into<String>) {
self.status = Status::Skipped(explanation.into());
}