mirror of
https://github.com/italicsjenga/vello.git
synced 2025-01-08 20:01:30 +11:00
[vello_shaders] Provide workgroup shared memory allocation sizes
Naga traslates workgroup variable declarations to threadgroup address-space entry-point parameters when generating MSL. Metal API validation requires that the memory sizes for these parameters be set explicitly by calling setThreadgroupMemoryLength:index on the MTLComputeCommandEncoder. The crate now calculates the required memory size for global workgroup variables that are accessed by the entry point and provides them alongside the binding list. This is abstracted separately from the binding list. While the current usage that we're aware of is limited to Metal, this information is being provided as part of the generic ComputeShader type instead of a MSL-specific type, as the information itself is computed from the parsed WGSL IR and not specific to Metal.
This commit is contained in:
parent
b52ef32c90
commit
4f445c2e0a
|
@ -76,6 +76,7 @@ fn write_shaders(
|
|||
.iter()
|
||||
.map(|binding| binding.ty)
|
||||
.collect::<Vec<_>>();
|
||||
let wg_bufs = &info.workgroup_buffers;
|
||||
let source = translate(info);
|
||||
writeln!(buf, " {name}: ComputeShader {{")?;
|
||||
writeln!(buf, " name: Cow::Borrowed({:?}),", name)?;
|
||||
|
@ -90,6 +91,11 @@ fn write_shaders(
|
|||
info.workgroup_size
|
||||
)?;
|
||||
writeln!(buf, " bindings: Cow::Borrowed(&{:?}),", bind_tys)?;
|
||||
writeln!(
|
||||
buf,
|
||||
" workgroup_buffers: Cow::Borrowed(&{:?}),",
|
||||
wg_bufs
|
||||
)?;
|
||||
writeln!(buf, " }},")?;
|
||||
}
|
||||
writeln!(buf, " }};")?;
|
||||
|
|
|
@ -2,7 +2,8 @@ use {
|
|||
naga::{
|
||||
front::wgsl,
|
||||
valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags},
|
||||
AddressSpace, ImageClass, Module, StorageAccess, WithSpan,
|
||||
AddressSpace, ArraySize, ConstantInner, ImageClass, Module, ScalarValue, StorageAccess,
|
||||
WithSpan,
|
||||
},
|
||||
std::{
|
||||
collections::{HashMap, HashSet},
|
||||
|
@ -16,7 +17,7 @@ pub mod preprocess;
|
|||
|
||||
pub mod msl;
|
||||
|
||||
use crate::types::{BindType, BindingInfo};
|
||||
use crate::types::{BindType, BindingInfo, WorkgroupBufferInfo};
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum Error {
|
||||
|
@ -37,6 +38,7 @@ pub struct ShaderInfo {
|
|||
pub module_info: ModuleInfo,
|
||||
pub workgroup_size: [u32; 3],
|
||||
pub bindings: Vec<BindingInfo>,
|
||||
pub workgroup_buffers: Vec<WorkgroupBufferInfo>,
|
||||
}
|
||||
|
||||
impl ShaderInfo {
|
||||
|
@ -54,12 +56,53 @@ impl ShaderInfo {
|
|||
.find(|(_, entry)| entry.name.as_str() == entry_point)
|
||||
.ok_or(Error::EntryPointNotFound)?;
|
||||
let mut bindings = vec![];
|
||||
let mut workgroup_buffers = vec![];
|
||||
let mut wg_buffer_idx = 0;
|
||||
let entry_info = module_info.get_entry_point(entry_index);
|
||||
for (var_handle, var) in module.global_variables.iter() {
|
||||
if entry_info[var_handle].is_empty() {
|
||||
continue;
|
||||
}
|
||||
let binding_ty = match module.types[var.ty].inner {
|
||||
naga::TypeInner::BindingArray { base, .. } => &module.types[base].inner,
|
||||
ref ty => ty,
|
||||
};
|
||||
let Some(binding) = &var.binding else {
|
||||
if var.space == AddressSpace::WorkGroup {
|
||||
let index = wg_buffer_idx;
|
||||
wg_buffer_idx += 1;
|
||||
let size_in_bytes = match binding_ty {
|
||||
naga::TypeInner::Array {
|
||||
size: ArraySize::Constant(const_handle),
|
||||
stride,
|
||||
..
|
||||
} => {
|
||||
let size: u32 = match module.constants[*const_handle].inner {
|
||||
ConstantInner::Scalar { value, width: _ } => match value {
|
||||
ScalarValue::Uint(value) => value.try_into().unwrap(),
|
||||
ScalarValue::Sint(value) => value.try_into().unwrap(),
|
||||
_ => continue,
|
||||
},
|
||||
ConstantInner::Composite { .. } => continue,
|
||||
};
|
||||
size * stride
|
||||
},
|
||||
naga::TypeInner::Struct { span, .. } => *span,
|
||||
naga::TypeInner::Scalar { width, ..} => *width as u32,
|
||||
naga::TypeInner::Vector { width, ..} => *width as u32,
|
||||
naga::TypeInner::Matrix { width, ..} => *width as u32,
|
||||
naga::TypeInner::Atomic { width, ..} => *width as u32,
|
||||
_ => {
|
||||
// Not a valid workgroup variable type. At least not one that is used
|
||||
// in our shaders.
|
||||
continue;
|
||||
}
|
||||
};
|
||||
workgroup_buffers.push(WorkgroupBufferInfo {
|
||||
size_in_bytes,
|
||||
index,
|
||||
});
|
||||
}
|
||||
continue;
|
||||
};
|
||||
let mut resource = BindingInfo {
|
||||
|
@ -67,10 +110,6 @@ impl ShaderInfo {
|
|||
location: (binding.group, binding.binding),
|
||||
ty: BindType::Buffer,
|
||||
};
|
||||
let binding_ty = match module.types[var.ty].inner {
|
||||
naga::TypeInner::BindingArray { base, .. } => &module.types[base].inner,
|
||||
ref ty => ty,
|
||||
};
|
||||
if let naga::TypeInner::Image { class, .. } = &binding_ty {
|
||||
resource.ty = BindType::ImageRead;
|
||||
if let ImageClass::Storage { access, .. } = class {
|
||||
|
@ -102,6 +141,7 @@ impl ShaderInfo {
|
|||
module_info,
|
||||
workgroup_size,
|
||||
bindings,
|
||||
workgroup_buffers,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ mod types;
|
|||
#[cfg(feature = "compile")]
|
||||
pub mod compile;
|
||||
|
||||
pub use types::{BindType, BindingInfo};
|
||||
pub use types::{BindType, BindingInfo, WorkgroupBufferInfo};
|
||||
|
||||
use std::borrow::Cow;
|
||||
|
||||
|
@ -13,6 +13,7 @@ pub struct ComputeShader<'a> {
|
|||
pub code: Cow<'a, [u8]>,
|
||||
pub workgroup_size: [u32; 3],
|
||||
pub bindings: Cow<'a, [BindType]>,
|
||||
pub workgroup_buffers: Cow<'a, [WorkgroupBufferInfo]>,
|
||||
}
|
||||
|
||||
pub trait PipelineHost {
|
||||
|
|
|
@ -28,3 +28,10 @@ pub struct BindingInfo {
|
|||
pub location: (u32, u32),
|
||||
pub ty: BindType,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct WorkgroupBufferInfo {
|
||||
pub size_in_bytes: u32,
|
||||
/// The order in which the workgroup variable is declared in the shader module.
|
||||
pub index: u32,
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue