diff --git a/vello_shaders/build.rs b/vello_shaders/build.rs index aa13e97..65ea159 100644 --- a/vello_shaders/build.rs +++ b/vello_shaders/build.rs @@ -76,6 +76,7 @@ fn write_shaders( .iter() .map(|binding| binding.ty) .collect::>(); + let wg_bufs = &info.workgroup_buffers; let source = translate(info); writeln!(buf, " {name}: ComputeShader {{")?; writeln!(buf, " name: Cow::Borrowed({:?}),", name)?; @@ -90,6 +91,11 @@ fn write_shaders( info.workgroup_size )?; writeln!(buf, " bindings: Cow::Borrowed(&{:?}),", bind_tys)?; + writeln!( + buf, + " workgroup_buffers: Cow::Borrowed(&{:?}),", + wg_bufs + )?; writeln!(buf, " }},")?; } writeln!(buf, " }};")?; diff --git a/vello_shaders/src/compile/mod.rs b/vello_shaders/src/compile/mod.rs index 6d88667..c218be6 100644 --- a/vello_shaders/src/compile/mod.rs +++ b/vello_shaders/src/compile/mod.rs @@ -2,7 +2,8 @@ use { naga::{ front::wgsl, valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags}, - AddressSpace, ImageClass, Module, StorageAccess, WithSpan, + AddressSpace, ArraySize, ConstantInner, ImageClass, Module, ScalarValue, StorageAccess, + WithSpan, }, std::{ collections::{HashMap, HashSet}, @@ -16,7 +17,7 @@ pub mod preprocess; pub mod msl; -use crate::types::{BindType, BindingInfo}; +use crate::types::{BindType, BindingInfo, WorkgroupBufferInfo}; #[derive(Error, Debug)] pub enum Error { @@ -37,6 +38,7 @@ pub struct ShaderInfo { pub module_info: ModuleInfo, pub workgroup_size: [u32; 3], pub bindings: Vec, + pub workgroup_buffers: Vec, } impl ShaderInfo { @@ -54,12 +56,53 @@ impl ShaderInfo { .find(|(_, entry)| entry.name.as_str() == entry_point) .ok_or(Error::EntryPointNotFound)?; let mut bindings = vec![]; + let mut workgroup_buffers = vec![]; + let mut wg_buffer_idx = 0; let entry_info = module_info.get_entry_point(entry_index); for (var_handle, var) in module.global_variables.iter() { if entry_info[var_handle].is_empty() { continue; } + let binding_ty = match module.types[var.ty].inner { + naga::TypeInner::BindingArray { base, .. } => &module.types[base].inner, + ref ty => ty, + }; let Some(binding) = &var.binding else { + if var.space == AddressSpace::WorkGroup { + let index = wg_buffer_idx; + wg_buffer_idx += 1; + let size_in_bytes = match binding_ty { + naga::TypeInner::Array { + size: ArraySize::Constant(const_handle), + stride, + .. + } => { + let size: u32 = match module.constants[*const_handle].inner { + ConstantInner::Scalar { value, width: _ } => match value { + ScalarValue::Uint(value) => value.try_into().unwrap(), + ScalarValue::Sint(value) => value.try_into().unwrap(), + _ => continue, + }, + ConstantInner::Composite { .. } => continue, + }; + size * stride + }, + naga::TypeInner::Struct { span, .. } => *span, + naga::TypeInner::Scalar { width, ..} => *width as u32, + naga::TypeInner::Vector { width, ..} => *width as u32, + naga::TypeInner::Matrix { width, ..} => *width as u32, + naga::TypeInner::Atomic { width, ..} => *width as u32, + _ => { + // Not a valid workgroup variable type. At least not one that is used + // in our shaders. + continue; + } + }; + workgroup_buffers.push(WorkgroupBufferInfo { + size_in_bytes, + index, + }); + } continue; }; let mut resource = BindingInfo { @@ -67,10 +110,6 @@ impl ShaderInfo { location: (binding.group, binding.binding), ty: BindType::Buffer, }; - let binding_ty = match module.types[var.ty].inner { - naga::TypeInner::BindingArray { base, .. } => &module.types[base].inner, - ref ty => ty, - }; if let naga::TypeInner::Image { class, .. } = &binding_ty { resource.ty = BindType::ImageRead; if let ImageClass::Storage { access, .. } = class { @@ -102,6 +141,7 @@ impl ShaderInfo { module_info, workgroup_size, bindings, + workgroup_buffers, }) } diff --git a/vello_shaders/src/lib.rs b/vello_shaders/src/lib.rs index ad71439..ded1984 100644 --- a/vello_shaders/src/lib.rs +++ b/vello_shaders/src/lib.rs @@ -3,7 +3,7 @@ mod types; #[cfg(feature = "compile")] pub mod compile; -pub use types::{BindType, BindingInfo}; +pub use types::{BindType, BindingInfo, WorkgroupBufferInfo}; use std::borrow::Cow; @@ -13,6 +13,7 @@ pub struct ComputeShader<'a> { pub code: Cow<'a, [u8]>, pub workgroup_size: [u32; 3], pub bindings: Cow<'a, [BindType]>, + pub workgroup_buffers: Cow<'a, [WorkgroupBufferInfo]>, } pub trait PipelineHost { diff --git a/vello_shaders/src/types.rs b/vello_shaders/src/types.rs index 376ec01..a9db3fe 100644 --- a/vello_shaders/src/types.rs +++ b/vello_shaders/src/types.rs @@ -28,3 +28,10 @@ pub struct BindingInfo { pub location: (u32, u32), pub ty: BindType, } + +#[derive(Clone, Debug)] +pub struct WorkgroupBufferInfo { + pub size_in_bytes: u32, + /// The order in which the workgroup variable is declared in the shader module. + pub index: u32, +}