[vello_shaders] Provide workgroup shared memory allocation sizes

Naga traslates workgroup variable declarations to threadgroup
address-space entry-point parameters when generating MSL. Metal API
validation requires that the memory sizes for these parameters be set
explicitly by calling setThreadgroupMemoryLength:index on the
MTLComputeCommandEncoder.

The crate now calculates the required memory size for global workgroup
variables that are accessed by the entry point and provides them
alongside the binding list. This is abstracted separately from the
binding list.

While the current usage that we're aware of is limited to Metal, this
information is being provided as part of the generic ComputeShader type
instead of a MSL-specific type, as the information itself is computed
from the parsed WGSL IR and not specific to Metal.
This commit is contained in:
Arman Uguray 2023-03-29 10:27:16 -07:00
parent b52ef32c90
commit 4f445c2e0a
4 changed files with 61 additions and 7 deletions

View file

@ -76,6 +76,7 @@ fn write_shaders(
.iter()
.map(|binding| binding.ty)
.collect::<Vec<_>>();
let wg_bufs = &info.workgroup_buffers;
let source = translate(info);
writeln!(buf, " {name}: ComputeShader {{")?;
writeln!(buf, " name: Cow::Borrowed({:?}),", name)?;
@ -90,6 +91,11 @@ fn write_shaders(
info.workgroup_size
)?;
writeln!(buf, " bindings: Cow::Borrowed(&{:?}),", bind_tys)?;
writeln!(
buf,
" workgroup_buffers: Cow::Borrowed(&{:?}),",
wg_bufs
)?;
writeln!(buf, " }},")?;
}
writeln!(buf, " }};")?;

View file

@ -2,7 +2,8 @@ use {
naga::{
front::wgsl,
valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags},
AddressSpace, ImageClass, Module, StorageAccess, WithSpan,
AddressSpace, ArraySize, ConstantInner, ImageClass, Module, ScalarValue, StorageAccess,
WithSpan,
},
std::{
collections::{HashMap, HashSet},
@ -16,7 +17,7 @@ pub mod preprocess;
pub mod msl;
use crate::types::{BindType, BindingInfo};
use crate::types::{BindType, BindingInfo, WorkgroupBufferInfo};
#[derive(Error, Debug)]
pub enum Error {
@ -37,6 +38,7 @@ pub struct ShaderInfo {
pub module_info: ModuleInfo,
pub workgroup_size: [u32; 3],
pub bindings: Vec<BindingInfo>,
pub workgroup_buffers: Vec<WorkgroupBufferInfo>,
}
impl ShaderInfo {
@ -54,12 +56,53 @@ impl ShaderInfo {
.find(|(_, entry)| entry.name.as_str() == entry_point)
.ok_or(Error::EntryPointNotFound)?;
let mut bindings = vec![];
let mut workgroup_buffers = vec![];
let mut wg_buffer_idx = 0;
let entry_info = module_info.get_entry_point(entry_index);
for (var_handle, var) in module.global_variables.iter() {
if entry_info[var_handle].is_empty() {
continue;
}
let binding_ty = match module.types[var.ty].inner {
naga::TypeInner::BindingArray { base, .. } => &module.types[base].inner,
ref ty => ty,
};
let Some(binding) = &var.binding else {
if var.space == AddressSpace::WorkGroup {
let index = wg_buffer_idx;
wg_buffer_idx += 1;
let size_in_bytes = match binding_ty {
naga::TypeInner::Array {
size: ArraySize::Constant(const_handle),
stride,
..
} => {
let size: u32 = match module.constants[*const_handle].inner {
ConstantInner::Scalar { value, width: _ } => match value {
ScalarValue::Uint(value) => value.try_into().unwrap(),
ScalarValue::Sint(value) => value.try_into().unwrap(),
_ => continue,
},
ConstantInner::Composite { .. } => continue,
};
size * stride
},
naga::TypeInner::Struct { span, .. } => *span,
naga::TypeInner::Scalar { width, ..} => *width as u32,
naga::TypeInner::Vector { width, ..} => *width as u32,
naga::TypeInner::Matrix { width, ..} => *width as u32,
naga::TypeInner::Atomic { width, ..} => *width as u32,
_ => {
// Not a valid workgroup variable type. At least not one that is used
// in our shaders.
continue;
}
};
workgroup_buffers.push(WorkgroupBufferInfo {
size_in_bytes,
index,
});
}
continue;
};
let mut resource = BindingInfo {
@ -67,10 +110,6 @@ impl ShaderInfo {
location: (binding.group, binding.binding),
ty: BindType::Buffer,
};
let binding_ty = match module.types[var.ty].inner {
naga::TypeInner::BindingArray { base, .. } => &module.types[base].inner,
ref ty => ty,
};
if let naga::TypeInner::Image { class, .. } = &binding_ty {
resource.ty = BindType::ImageRead;
if let ImageClass::Storage { access, .. } = class {
@ -102,6 +141,7 @@ impl ShaderInfo {
module_info,
workgroup_size,
bindings,
workgroup_buffers,
})
}

View file

@ -3,7 +3,7 @@ mod types;
#[cfg(feature = "compile")]
pub mod compile;
pub use types::{BindType, BindingInfo};
pub use types::{BindType, BindingInfo, WorkgroupBufferInfo};
use std::borrow::Cow;
@ -13,6 +13,7 @@ pub struct ComputeShader<'a> {
pub code: Cow<'a, [u8]>,
pub workgroup_size: [u32; 3],
pub bindings: Cow<'a, [BindType]>,
pub workgroup_buffers: Cow<'a, [WorkgroupBufferInfo]>,
}
pub trait PipelineHost {

View file

@ -28,3 +28,10 @@ pub struct BindingInfo {
pub location: (u32, u32),
pub ty: BindType,
}
#[derive(Clone, Debug)]
pub struct WorkgroupBufferInfo {
pub size_in_bytes: u32,
/// The order in which the workgroup variable is declared in the shader module.
pub index: u32,
}