From 4f445c2e0aaeac571fcdc56f997b29231f3f0b4a Mon Sep 17 00:00:00 2001 From: Arman Uguray Date: Wed, 29 Mar 2023 10:27:16 -0700 Subject: [PATCH] [vello_shaders] Provide workgroup shared memory allocation sizes Naga traslates workgroup variable declarations to threadgroup address-space entry-point parameters when generating MSL. Metal API validation requires that the memory sizes for these parameters be set explicitly by calling setThreadgroupMemoryLength:index on the MTLComputeCommandEncoder. The crate now calculates the required memory size for global workgroup variables that are accessed by the entry point and provides them alongside the binding list. This is abstracted separately from the binding list. While the current usage that we're aware of is limited to Metal, this information is being provided as part of the generic ComputeShader type instead of a MSL-specific type, as the information itself is computed from the parsed WGSL IR and not specific to Metal. --- vello_shaders/build.rs | 6 ++++ vello_shaders/src/compile/mod.rs | 52 ++++++++++++++++++++++++++++---- vello_shaders/src/lib.rs | 3 +- vello_shaders/src/types.rs | 7 +++++ 4 files changed, 61 insertions(+), 7 deletions(-) diff --git a/vello_shaders/build.rs b/vello_shaders/build.rs index aa13e97..65ea159 100644 --- a/vello_shaders/build.rs +++ b/vello_shaders/build.rs @@ -76,6 +76,7 @@ fn write_shaders( .iter() .map(|binding| binding.ty) .collect::>(); + let wg_bufs = &info.workgroup_buffers; let source = translate(info); writeln!(buf, " {name}: ComputeShader {{")?; writeln!(buf, " name: Cow::Borrowed({:?}),", name)?; @@ -90,6 +91,11 @@ fn write_shaders( info.workgroup_size )?; writeln!(buf, " bindings: Cow::Borrowed(&{:?}),", bind_tys)?; + writeln!( + buf, + " workgroup_buffers: Cow::Borrowed(&{:?}),", + wg_bufs + )?; writeln!(buf, " }},")?; } writeln!(buf, " }};")?; diff --git a/vello_shaders/src/compile/mod.rs b/vello_shaders/src/compile/mod.rs index 6d88667..c218be6 100644 --- a/vello_shaders/src/compile/mod.rs +++ b/vello_shaders/src/compile/mod.rs @@ -2,7 +2,8 @@ use { naga::{ front::wgsl, valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags}, - AddressSpace, ImageClass, Module, StorageAccess, WithSpan, + AddressSpace, ArraySize, ConstantInner, ImageClass, Module, ScalarValue, StorageAccess, + WithSpan, }, std::{ collections::{HashMap, HashSet}, @@ -16,7 +17,7 @@ pub mod preprocess; pub mod msl; -use crate::types::{BindType, BindingInfo}; +use crate::types::{BindType, BindingInfo, WorkgroupBufferInfo}; #[derive(Error, Debug)] pub enum Error { @@ -37,6 +38,7 @@ pub struct ShaderInfo { pub module_info: ModuleInfo, pub workgroup_size: [u32; 3], pub bindings: Vec, + pub workgroup_buffers: Vec, } impl ShaderInfo { @@ -54,12 +56,53 @@ impl ShaderInfo { .find(|(_, entry)| entry.name.as_str() == entry_point) .ok_or(Error::EntryPointNotFound)?; let mut bindings = vec![]; + let mut workgroup_buffers = vec![]; + let mut wg_buffer_idx = 0; let entry_info = module_info.get_entry_point(entry_index); for (var_handle, var) in module.global_variables.iter() { if entry_info[var_handle].is_empty() { continue; } + let binding_ty = match module.types[var.ty].inner { + naga::TypeInner::BindingArray { base, .. } => &module.types[base].inner, + ref ty => ty, + }; let Some(binding) = &var.binding else { + if var.space == AddressSpace::WorkGroup { + let index = wg_buffer_idx; + wg_buffer_idx += 1; + let size_in_bytes = match binding_ty { + naga::TypeInner::Array { + size: ArraySize::Constant(const_handle), + stride, + .. + } => { + let size: u32 = match module.constants[*const_handle].inner { + ConstantInner::Scalar { value, width: _ } => match value { + ScalarValue::Uint(value) => value.try_into().unwrap(), + ScalarValue::Sint(value) => value.try_into().unwrap(), + _ => continue, + }, + ConstantInner::Composite { .. } => continue, + }; + size * stride + }, + naga::TypeInner::Struct { span, .. } => *span, + naga::TypeInner::Scalar { width, ..} => *width as u32, + naga::TypeInner::Vector { width, ..} => *width as u32, + naga::TypeInner::Matrix { width, ..} => *width as u32, + naga::TypeInner::Atomic { width, ..} => *width as u32, + _ => { + // Not a valid workgroup variable type. At least not one that is used + // in our shaders. + continue; + } + }; + workgroup_buffers.push(WorkgroupBufferInfo { + size_in_bytes, + index, + }); + } continue; }; let mut resource = BindingInfo { @@ -67,10 +110,6 @@ impl ShaderInfo { location: (binding.group, binding.binding), ty: BindType::Buffer, }; - let binding_ty = match module.types[var.ty].inner { - naga::TypeInner::BindingArray { base, .. } => &module.types[base].inner, - ref ty => ty, - }; if let naga::TypeInner::Image { class, .. } = &binding_ty { resource.ty = BindType::ImageRead; if let ImageClass::Storage { access, .. } = class { @@ -102,6 +141,7 @@ impl ShaderInfo { module_info, workgroup_size, bindings, + workgroup_buffers, }) } diff --git a/vello_shaders/src/lib.rs b/vello_shaders/src/lib.rs index ad71439..ded1984 100644 --- a/vello_shaders/src/lib.rs +++ b/vello_shaders/src/lib.rs @@ -3,7 +3,7 @@ mod types; #[cfg(feature = "compile")] pub mod compile; -pub use types::{BindType, BindingInfo}; +pub use types::{BindType, BindingInfo, WorkgroupBufferInfo}; use std::borrow::Cow; @@ -13,6 +13,7 @@ pub struct ComputeShader<'a> { pub code: Cow<'a, [u8]>, pub workgroup_size: [u32; 3], pub bindings: Cow<'a, [BindType]>, + pub workgroup_buffers: Cow<'a, [WorkgroupBufferInfo]>, } pub trait PipelineHost { diff --git a/vello_shaders/src/types.rs b/vello_shaders/src/types.rs index 376ec01..a9db3fe 100644 --- a/vello_shaders/src/types.rs +++ b/vello_shaders/src/types.rs @@ -28,3 +28,10 @@ pub struct BindingInfo { pub location: (u32, u32), pub ty: BindType, } + +#[derive(Clone, Debug)] +pub struct WorkgroupBufferInfo { + pub size_in_bytes: u32, + /// The order in which the workgroup variable is declared in the shader module. + pub index: u32, +}