mirror of
https://github.com/italicsjenga/vello.git
synced 2025-01-25 18:56:35 +11:00
Use DXIL shader compilation
Integrate DXC for translating HLSL for use in DX12. This will work around FXC limitations and unlock the use of more advanced HLSL features such as subgroups. This hardcodes the use of DXIL, but it could be adapted (with a bit of effort) to choose between DXIL and HLSL at runtime.
This commit is contained in:
parent
7a021793ee
commit
f9d0aa078b
11 changed files with 80 additions and 21 deletions
|
@ -4,6 +4,7 @@
|
|||
|
||||
glslang_validator = glslangValidator
|
||||
spirv_cross = spirv-cross
|
||||
dxc = dxc
|
||||
|
||||
rule glsl
|
||||
command = $glslang_validator -V -o $out $in
|
||||
|
@ -11,9 +12,13 @@ rule glsl
|
|||
rule hlsl
|
||||
command = $spirv_cross --hlsl $in --output $out
|
||||
|
||||
rule dxil
|
||||
command = $dxc -T cs_6_0 $in -Fo $out
|
||||
|
||||
rule msl
|
||||
command = $spirv_cross --msl $in --output $out
|
||||
|
||||
build gen/collatz.spv: glsl collatz.comp
|
||||
build gen/collatz.hlsl: hlsl gen/collatz.spv
|
||||
build gen/collatz.dxil: dxil gen/collatz.hlsl
|
||||
build gen/collatz.msl: msl gen/collatz.spv
|
||||
|
|
|
@ -6,7 +6,9 @@ mod wrappers;
|
|||
use std::{cell::Cell, convert::{TryFrom, TryInto}, mem, ptr};
|
||||
|
||||
use winapi::shared::minwindef::TRUE;
|
||||
use winapi::shared::{dxgi, dxgi1_2, dxgi1_3, dxgitype};
|
||||
use winapi::shared::{dxgi, dxgi1_2, dxgitype};
|
||||
#[allow(unused)]
|
||||
use winapi::shared::dxgi1_3; // for error reporting in debug mode
|
||||
use winapi::um::d3d12;
|
||||
|
||||
use raw_window_handle::{HasRawWindowHandle, RawWindowHandle};
|
||||
|
@ -236,8 +238,9 @@ impl crate::backend::Device for Dx12Device {
|
|||
|
||||
type Sampler = ();
|
||||
|
||||
// Currently this is HLSL source, but we'll probably change it to IR.
|
||||
type ShaderSource = str;
|
||||
// Currently due to type inflexibility this is hardcoded to either HLSL or
|
||||
// DXIL, but it would be nice to be able to handle both at runtime.
|
||||
type ShaderSource = [u8];
|
||||
|
||||
fn create_buffer(&self, size: u64, usage: BufferUsage) -> Result<Self::Buffer, Error> {
|
||||
// TODO: consider supporting BufferUsage::QUERY_RESOLVE here rather than
|
||||
|
@ -411,7 +414,7 @@ impl crate::backend::Device for Dx12Device {
|
|||
|
||||
unsafe fn create_compute_pipeline(
|
||||
&self,
|
||||
code: &str,
|
||||
code: &Self::ShaderSource,
|
||||
bind_types: &[BindType],
|
||||
) -> Result<Pipeline, Error> {
|
||||
if u32::try_from(bind_types.len()).is_err() {
|
||||
|
@ -442,6 +445,11 @@ impl crate::backend::Device for Dx12Device {
|
|||
i = end;
|
||||
}
|
||||
|
||||
// We could always have ShaderSource as [u8] even when it's HLSL, and use the
|
||||
// magic number to distinguish. In any case, for now it's hardcoded as one or
|
||||
// the other.
|
||||
/*
|
||||
// HLSL code path
|
||||
#[cfg(debug_assertions)]
|
||||
let flags = winapi::um::d3dcompiler::D3DCOMPILE_DEBUG
|
||||
| winapi::um::d3dcompiler::D3DCOMPILE_SKIP_OPTIMIZATION;
|
||||
|
@ -449,6 +457,11 @@ impl crate::backend::Device for Dx12Device {
|
|||
let flags = 0;
|
||||
let shader_blob = ShaderByteCode::compile(code, "cs_5_1", "main", flags)?;
|
||||
let shader = ShaderByteCode::from_blob(shader_blob);
|
||||
*/
|
||||
|
||||
// DXIL code path
|
||||
let shader = ShaderByteCode::from_slice(code);
|
||||
|
||||
let mut root_parameter = d3d12::D3D12_ROOT_PARAMETER {
|
||||
ParameterType: d3d12::D3D12_ROOT_PARAMETER_TYPE_DESCRIPTOR_TABLE,
|
||||
ShaderVisibility: d3d12::D3D12_SHADER_VISIBILITY_ALL,
|
||||
|
|
|
@ -196,7 +196,7 @@ impl Factory4 {
|
|||
error_if_failed_else_unit(self.0.EnumAdapters1(id, &mut adapter))?;
|
||||
let mut desc = mem::zeroed();
|
||||
(*adapter).GetDesc(&mut desc);
|
||||
println!("desc: {:?}", desc.Description);
|
||||
//println!("desc: {:?}", desc.Description);
|
||||
Ok(Adapter1(ComPtr::from_raw(adapter)))
|
||||
}
|
||||
|
||||
|
@ -276,6 +276,7 @@ impl SwapChain3 {
|
|||
}
|
||||
|
||||
impl Blob {
|
||||
#[allow(unused)]
|
||||
pub unsafe fn print_to_console(blob: &Blob) {
|
||||
println!("==SHADER COMPILE MESSAGES==");
|
||||
let message = {
|
||||
|
@ -714,13 +715,13 @@ impl RootSignature {
|
|||
let hresult =
|
||||
d3d12::D3D12SerializeRootSignature(desc, version, &mut blob, &mut error_blob_ptr);
|
||||
|
||||
let error_blob = if error_blob_ptr.is_null() {
|
||||
None
|
||||
} else {
|
||||
Some(Blob(ComPtr::from_raw(error_blob_ptr)))
|
||||
};
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
let error_blob = if error_blob_ptr.is_null() {
|
||||
None
|
||||
} else {
|
||||
Some(Blob(ComPtr::from_raw(error_blob_ptr)))
|
||||
};
|
||||
if let Some(error_blob) = &error_blob {
|
||||
Blob::print_to_console(error_blob);
|
||||
}
|
||||
|
@ -736,6 +737,7 @@ impl ShaderByteCode {
|
|||
// `blob` may not be null.
|
||||
// TODO: this is not super elegant, maybe want to move the get
|
||||
// operations closer to where they're used.
|
||||
#[allow(unused)]
|
||||
pub unsafe fn from_blob(blob: Blob) -> ShaderByteCode {
|
||||
ShaderByteCode {
|
||||
bytecode: d3d12::D3D12_SHADER_BYTECODE {
|
||||
|
@ -749,6 +751,7 @@ impl ShaderByteCode {
|
|||
/// Compile a shader from raw HLSL.
|
||||
///
|
||||
/// * `target`: example format: `ps_5_1`.
|
||||
#[allow(unused)]
|
||||
pub unsafe fn compile(
|
||||
source: &str,
|
||||
target: &str,
|
||||
|
@ -795,6 +798,24 @@ impl ShaderByteCode {
|
|||
|
||||
Ok(Blob(ComPtr::from_raw(shader_blob_ptr)))
|
||||
}
|
||||
|
||||
/// Create bytecode from a slice.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// This call elides the lifetime from the slice. The caller is responsible
|
||||
/// for making sure the reference remains valid for the lifetime of this
|
||||
/// object.
|
||||
#[allow(unused)]
|
||||
pub unsafe fn from_slice(bytecode: &[u8]) -> ShaderByteCode {
|
||||
ShaderByteCode {
|
||||
bytecode: d3d12::D3D12_SHADER_BYTECODE {
|
||||
BytecodeLength: bytecode.len(),
|
||||
pShaderBytecode: bytecode.as_ptr() as *const _,
|
||||
},
|
||||
blob: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Fence {
|
||||
|
@ -1073,9 +1094,8 @@ pub unsafe fn create_transition_resource_barrier(
|
|||
resource_barrier
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub unsafe fn enable_debug_layer() -> Result<(), Error> {
|
||||
println!("enabling debug layer.");
|
||||
|
||||
let mut debug_controller: *mut d3d12sdklayers::ID3D12Debug1 = ptr::null_mut();
|
||||
explain_error(
|
||||
d3d12::D3D12GetDebugInterface(
|
||||
|
|
|
@ -369,8 +369,8 @@ impl Session {
|
|||
}
|
||||
|
||||
/// Choose shader code from the available choices.
|
||||
pub fn choose_shader<'a>(&self, spv: &'a [u8], hlsl: &'a str, msl: &'a str) -> ShaderCode<'a> {
|
||||
self.0.device.choose_shader(spv, hlsl, msl)
|
||||
pub fn choose_shader<'a>(&self, spv: &'a [u8], hlsl: &'a str, dxil: &'a [u8], msl: &'a str) -> ShaderCode<'a> {
|
||||
self.0.device.choose_shader(spv, hlsl, dxil, msl)
|
||||
}
|
||||
|
||||
/// Report the backend type that was chosen.
|
||||
|
|
|
@ -51,7 +51,7 @@ bitflags! {
|
|||
}
|
||||
|
||||
/// The GPU backend that was selected.
|
||||
#[derive(Clone, Copy, PartialEq, Eq)]
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
|
||||
pub enum BackendType {
|
||||
Vulkan,
|
||||
Dx12,
|
||||
|
|
|
@ -198,6 +198,7 @@ macro_rules! include_shader {
|
|||
$device.choose_shader(
|
||||
include_bytes!(concat!($path_base, ".spv")),
|
||||
include_str!(concat!($path_base, ".hlsl")),
|
||||
include_bytes!(concat!($path_base, ".dxil")),
|
||||
include_str!(concat!($path_base, ".msl")),
|
||||
)
|
||||
};
|
||||
|
|
|
@ -104,6 +104,8 @@ pub enum ShaderCode<'a> {
|
|||
Spv(&'a [u8]),
|
||||
/// HLSL (source)
|
||||
Hlsl(&'a str),
|
||||
/// DXIL (DX12 intermediate language)
|
||||
Dxil(&'a [u8]),
|
||||
/// Metal Shading Language (source)
|
||||
Msl(&'a str),
|
||||
}
|
||||
|
@ -321,9 +323,10 @@ impl Device {
|
|||
}
|
||||
Device::Dx12(d) => {
|
||||
let shader_code = match code {
|
||||
ShaderCode::Hlsl(hlsl) => hlsl,
|
||||
//ShaderCode::Hlsl(hlsl) => hlsl,
|
||||
ShaderCode::Dxil(dxil) => dxil,
|
||||
// Panic or return "incompatible shader" error here?
|
||||
_ => panic!("DX12 backend requires shader code in HLSL format"),
|
||||
_ => panic!("DX12 backend requires shader code in DXIL format"),
|
||||
};
|
||||
d.create_compute_pipeline(shader_code, bind_types)
|
||||
.map(Pipeline::Dx12)
|
||||
|
@ -475,11 +478,12 @@ impl Device {
|
|||
&self,
|
||||
_spv: &'a [u8],
|
||||
_hlsl: &'a str,
|
||||
_dxil: &'a [u8],
|
||||
_msl: &'a str,
|
||||
) -> ShaderCode<'a> {
|
||||
mux_match! { self;
|
||||
Device::Vk(_d) => ShaderCode::Spv(_spv),
|
||||
Device::Dx12(_d) => ShaderCode::Hlsl(_hlsl),
|
||||
Device::Dx12(_d) => ShaderCode::Dxil(_dxil),
|
||||
Device::Mtl(_d) => ShaderCode::Msl(_msl),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
|
||||
glslang_validator = glslangValidator
|
||||
spirv_cross = spirv-cross
|
||||
dxc = dxc
|
||||
|
||||
# See https://github.com/KhronosGroup/SPIRV-Cross/issues/1248 for
|
||||
# why we set this.
|
||||
|
@ -15,26 +16,34 @@ rule glsl
|
|||
rule hlsl
|
||||
command = $spirv_cross --hlsl $in --output $out
|
||||
|
||||
rule dxil
|
||||
command = $dxc -T cs_6_0 $in -Fo $out
|
||||
|
||||
rule msl
|
||||
command = $spirv_cross --msl $in --output $out $msl_flags
|
||||
|
||||
build gen/clear.spv: glsl clear.comp
|
||||
build gen/clear.hlsl: hlsl gen/clear.spv
|
||||
build gen/clear.dxil: dxil gen/clear.hlsl
|
||||
build gen/clear.msl: msl gen/clear.spv
|
||||
|
||||
build gen/prefix.spv: glsl prefix.comp
|
||||
build gen/prefix.hlsl: hlsl gen/prefix.spv
|
||||
build gen/prefix.dxil: dxil gen/prefix.hlsl
|
||||
build gen/prefix.msl: msl gen/prefix.spv
|
||||
|
||||
build gen/prefix_reduce.spv: glsl prefix_reduce.comp
|
||||
build gen/prefix_reduce.hlsl: hlsl gen/prefix_reduce.spv
|
||||
build gen/prefix_reduce.dxil: dxil gen/prefix_reduce.hlsl
|
||||
build gen/prefix_reduce.msl: msl gen/prefix_reduce.spv
|
||||
|
||||
build gen/prefix_root.spv: glsl prefix_scan.comp
|
||||
flags = -DROOT
|
||||
build gen/prefix_root.hlsl: hlsl gen/prefix_root.spv
|
||||
build gen/prefix_root.dxil: dxil gen/prefix_root.hlsl
|
||||
build gen/prefix_root.msl: msl gen/prefix_root.spv
|
||||
|
||||
build gen/prefix_scan.spv: glsl prefix_scan.comp
|
||||
build gen/prefix_scan.hlsl: hlsl gen/prefix_scan.spv
|
||||
build gen/prefix_scan.dxil: dxil gen/prefix_scan.hlsl
|
||||
build gen/prefix_scan.msl: msl gen/prefix_scan.spv
|
||||
|
|
|
@ -80,6 +80,10 @@ fn main() {
|
|||
flags |= InstanceFlags::DX12;
|
||||
}
|
||||
let mut runner = Runner::new(flags);
|
||||
if style == ReportStyle::Verbose {
|
||||
// TODO: get adapter name in here too
|
||||
println!("Backend: {:?}", runner.backend_type());
|
||||
}
|
||||
report(&clear::run_clear_test(&mut runner, &config));
|
||||
if config.groups.matches("prefix") {
|
||||
report(&prefix::run_prefix_test(&mut runner, &config));
|
||||
|
|
|
@ -53,11 +53,13 @@ struct PrefixBinding {
|
|||
|
||||
pub unsafe fn run_prefix_test(runner: &mut Runner, config: &Config) -> TestResult {
|
||||
let mut result = TestResult::new("prefix sum, decoupled look-back");
|
||||
/*
|
||||
// We're good if we're using DXC.
|
||||
if runner.backend_type() == BackendType::Dx12 {
|
||||
result.skip("Shader won't compile on FXC");
|
||||
return result;
|
||||
}
|
||||
// This will be configurable.
|
||||
*/
|
||||
let n_elements: u64 = config.size.choose(1 << 12, 1 << 24, 1 << 25);
|
||||
let data: Vec<u32> = (0..n_elements as u32).collect();
|
||||
let data_buf = runner
|
||||
|
@ -68,7 +70,6 @@ pub unsafe fn run_prefix_test(runner: &mut Runner, config: &Config) -> TestResul
|
|||
let code = PrefixCode::new(runner);
|
||||
let stage = PrefixStage::new(runner, &code, n_elements);
|
||||
let binding = stage.bind(runner, &code, &data_buf, &out_buf.dev_buf);
|
||||
// Also will be configurable of course.
|
||||
let n_iter = config.n_iter;
|
||||
let mut total_elapsed = 0.0;
|
||||
for i in 0..n_iter {
|
||||
|
|
|
@ -27,10 +27,11 @@ pub struct TestResult {
|
|||
pub enum Status {
|
||||
Pass,
|
||||
Fail(String),
|
||||
#[allow(unused)]
|
||||
Skipped(String),
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
#[derive(Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ReportStyle {
|
||||
Short,
|
||||
Verbose,
|
||||
|
@ -84,6 +85,7 @@ impl TestResult {
|
|||
self.status = Status::Fail(explanation.into());
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub fn skip(&mut self, explanation: impl Into<String>) {
|
||||
self.status = Status::Skipped(explanation.into());
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue