Playing with shader permutations and AOT compilation

This commit is contained in:
Chad Brokaw 2023-01-25 14:27:14 -05:00 committed by Arman Uguray
parent 1a6add81b1
commit 0db71153ad
10 changed files with 588 additions and 0 deletions

View file

@ -9,6 +9,7 @@ members = [
"examples/with_bevy", "examples/with_bevy",
"examples/run_wasm", "examples/run_wasm",
"examples/scenes", "examples/scenes",
"vello_shaders",
] ]
[workspace.package] [workspace.package]

3
shader/permutations Normal file
View file

@ -0,0 +1,3 @@
pathtag_scan
+ pathtag_scan_large
+ pathtag_scan_small: small

17
vello_shaders/Cargo.toml Normal file
View file

@ -0,0 +1,17 @@
[package]
name = "vello_shaders"
version = "0.1.0"
edition = "2021"
[features]
default = ["compile", "wgsl", "msl"]
compile = ["naga"]
wgsl = []
msl = []
[dependencies]
naga = { git = "https://github.com/gfx-rs/naga", features = ["wgsl-in", "msl-out", "validate"], optional = true }
[build-dependencies]
naga = { git = "https://github.com/gfx-rs/naga", features = ["wgsl-in", "msl-out", "validate"] }

99
vello_shaders/build.rs Normal file
View file

@ -0,0 +1,99 @@
#[path = "src/compile/mod.rs"]
mod compile;
#[path = "src/types.rs"]
mod types;
use std::collections::HashMap;
use std::env;
use std::fmt::Write;
use std::path::Path;
use compile::ShaderInfo;
fn main() {
let out_dir = env::var_os("OUT_DIR").unwrap();
let dest_path = Path::new(&out_dir).join("shaders.rs");
let shaders = compile::ShaderInfo::from_dir("../shader");
let mut buf = String::default();
write_types(&mut buf, &shaders).unwrap();
if cfg!(feature = "wgsl") {
write_shaders(&mut buf, "wgsl", &shaders, |info| {
info.source.as_bytes().to_owned()
})
.unwrap();
}
if cfg!(feature = "msl") {
write_shaders(&mut buf, "msl", &shaders, |info| {
compile::msl::translate(info).unwrap().as_bytes().to_owned()
})
.unwrap();
}
std::fs::write(&dest_path, &buf).unwrap();
println!("cargo:rerun-if-changed=../shaders");
}
fn write_types(
buf: &mut String,
shaders: &HashMap<String, ShaderInfo>,
) -> Result<(), std::fmt::Error> {
writeln!(buf, "pub struct Shaders<'a> {{")?;
for (name, _) in shaders {
writeln!(buf, " pub {name}: ComputeShader<'a>,")?;
}
writeln!(buf, "}}")?;
writeln!(buf, "pub struct Pipelines<T> {{")?;
for (name, _) in shaders {
writeln!(buf, " pub {name}: T,")?;
}
writeln!(buf, "}}")?;
writeln!(buf, "impl<T> Pipelines<T> {{")?;
writeln!(buf, " pub fn from_shaders<H: PipelineHost<ComputePipeline = T>>(shaders: &Shaders, device: &H::Device, host: &mut H) -> Result<Self, H::Error> {{")?;
writeln!(buf, " Ok(Self {{")?;
for (name, _) in shaders {
writeln!(
buf,
" {name}: host.new_compute_pipeline(device, &shaders.{name})?,"
)?;
}
writeln!(buf, " }})")?;
writeln!(buf, " }}")?;
writeln!(buf, "}}")?;
Ok(())
}
fn write_shaders(
buf: &mut String,
mod_name: &str,
shaders: &HashMap<String, ShaderInfo>,
translate: impl Fn(&ShaderInfo) -> Vec<u8>,
) -> Result<(), std::fmt::Error> {
writeln!(buf, "pub mod {mod_name} {{")?;
writeln!(buf, " use super::*;")?;
writeln!(buf, " use BindType::*;")?;
writeln!(buf, " pub const SHADERS: Shaders<'static> = Shaders {{")?;
for (name, info) in shaders {
let bind_tys = info
.bindings
.iter()
.map(|binding| binding.ty)
.collect::<Vec<_>>();
let source = translate(info);
writeln!(buf, " {name}: ComputeShader {{")?;
writeln!(buf, " name: Cow::Borrowed({:?}),", name)?;
writeln!(
buf,
" code: Cow::Borrowed(&{:?}),",
source.as_slice()
)?;
writeln!(
buf,
" workgroup_size: {:?},",
info.workgroup_size
)?;
writeln!(buf, " bindings: Cow::Borrowed(&{:?}),", bind_tys)?;
writeln!(buf, " }},")?;
}
writeln!(buf, " }};")?;
writeln!(buf, "}}")?;
Ok(())
}

View file

@ -0,0 +1,159 @@
use naga::{
front::wgsl,
valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags},
AddressSpace, ImageClass, Module, StorageAccess, WithSpan,
};
use std::{
collections::{HashMap, HashSet},
path::Path,
};
pub mod permutations;
pub mod preprocess;
pub mod msl;
use crate::types::{BindType, BindingInfo};
#[derive(Debug)]
pub enum Error {
Parse(wgsl::ParseError),
Validate(WithSpan<ValidationError>),
EntryPointNotFound,
}
impl From<wgsl::ParseError> for Error {
fn from(e: wgsl::ParseError) -> Self {
Self::Parse(e)
}
}
impl From<WithSpan<ValidationError>> for Error {
fn from(e: WithSpan<ValidationError>) -> Self {
Self::Validate(e)
}
}
#[derive(Debug)]
pub struct ShaderInfo {
pub source: String,
pub module: Module,
pub module_info: ModuleInfo,
pub workgroup_size: [u32; 3],
pub bindings: Vec<BindingInfo>,
}
impl ShaderInfo {
pub fn new(source: String, entry_point: &str) -> Result<ShaderInfo, Error> {
let module = wgsl::parse_str(&source)?;
let module_info = naga::valid::Validator::new(
ValidationFlags::all() & !ValidationFlags::CONTROL_FLOW_UNIFORMITY,
Capabilities::all(),
)
.validate(&module)?;
let (entry_index, entry) = module
.entry_points
.iter()
.enumerate()
.find(|(_, entry)| entry.name.as_str() == entry_point)
.ok_or(Error::EntryPointNotFound)?;
let mut bindings = vec![];
let entry_info = module_info.get_entry_point(entry_index);
for (var_handle, var) in module.global_variables.iter() {
if entry_info[var_handle].is_empty() {
continue;
}
let Some(binding) = &var.binding else {
continue;
};
let mut resource = BindingInfo {
name: var.name.clone(),
location: (binding.group, binding.binding),
ty: BindType::Buffer,
};
let binding_ty = match module.types[var.ty].inner {
naga::TypeInner::BindingArray { base, .. } => &module.types[base].inner,
ref ty => ty,
};
if let naga::TypeInner::Image { class, .. } = &binding_ty {
resource.ty = BindType::ImageRead;
if let ImageClass::Storage { access, .. } = class {
if access.contains(StorageAccess::STORE) {
resource.ty = BindType::Image;
}
}
} else {
resource.ty = BindType::BufReadOnly;
match var.space {
AddressSpace::Storage { access } => {
if access.contains(StorageAccess::STORE) {
resource.ty = BindType::Buffer;
}
}
AddressSpace::Uniform => {
resource.ty = BindType::Uniform;
}
_ => {}
}
}
bindings.push(resource);
}
bindings.sort_by_key(|res| res.location);
let workgroup_size = entry.workgroup_size;
Ok(ShaderInfo {
source,
module,
module_info,
workgroup_size,
bindings,
})
}
pub fn from_dir(shader_dir: impl AsRef<Path>) -> HashMap<String, Self> {
use std::fs;
let shader_dir = shader_dir.as_ref();
let permutation_map = if let Ok(permutations_source) =
std::fs::read_to_string(shader_dir.join("permutations"))
{
permutations::parse(&permutations_source)
} else {
Default::default()
};
println!("{:?}", permutation_map);
let imports = preprocess::get_imports(shader_dir);
let mut info = HashMap::default();
let mut defines = HashSet::default();
defines.insert("full".to_string());
for entry in shader_dir
.read_dir()
.expect("Can read shader import directory")
{
let entry = entry.expect("Can continue reading shader import directory");
if entry.file_type().unwrap().is_file() {
let file_name = entry.file_name();
if let Some(name) = file_name.to_str() {
let suffix = ".wgsl";
if let Some(shader_name) = name.strip_suffix(suffix) {
let contents = fs::read_to_string(shader_dir.join(&file_name))
.expect("Could read shader {shader_name} contents");
if let Some(permutations) = permutation_map.get(shader_name) {
for permutation in permutations {
let mut defines = defines.clone();
defines.extend(permutation.defines.iter().cloned());
let source = preprocess::preprocess(&contents, &defines, &imports);
let shader_info = Self::new(source.clone(), "main").unwrap();
info.insert(permutation.name.clone(), shader_info);
}
} else {
let source = preprocess::preprocess(&contents, &defines, &imports);
let shader_info = Self::new(source.clone(), "main").unwrap();
info.insert(shader_name.to_string(), shader_info);
}
}
}
}
}
info
}
}

View file

@ -0,0 +1,49 @@
use naga::back::msl;
use super::{BindType, ShaderInfo};
pub fn translate(shader: &ShaderInfo) -> Result<String, msl::Error> {
let mut map = msl::PerStageMap::default();
let mut buffer_index = 0u8;
let mut image_index = 0u8;
let mut binding_map = msl::BindingMap::default();
for resource in &shader.bindings {
let binding = naga::ResourceBinding {
group: resource.location.0,
binding: resource.location.1,
};
let mut target = msl::BindTarget::default();
match resource.ty {
BindType::Buffer | BindType::BufReadOnly | BindType::Uniform => {
target.buffer = Some(buffer_index);
buffer_index += 1;
}
BindType::Image | BindType::ImageRead => {
target.texture = Some(image_index);
image_index += 1;
}
}
target.mutable = resource.ty.is_mutable();
binding_map.insert(binding, target);
}
map.cs = msl::PerStageResources {
resources: binding_map,
push_constant_buffer: None,
sizes_buffer: Some(30),
};
let options = msl::Options {
lang_version: (2, 0),
per_stage_map: map,
inline_samplers: vec![],
spirv_cross_compatibility: false,
fake_missing_bindings: false,
bounds_check_policies: naga::proc::BoundsCheckPolicies::default(),
};
let (source, _) = msl::write_string(
&shader.module,
&shader.module_info,
&options,
&msl::PipelineOptions::default(),
)?;
Ok(source)
}

View file

@ -0,0 +1,41 @@
use std::collections::HashMap;
#[derive(Debug)]
pub struct Permutation {
/// The new name for the permutation
pub name: String,
/// Set of defines to apply for the permutation
pub defines: Vec<String>,
}
pub fn parse(source: &str) -> HashMap<String, Vec<Permutation>> {
let mut map: HashMap<String, Vec<Permutation>> = Default::default();
let mut current_source: Option<String> = None;
for line in source.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
if let Some(line) = line.strip_prefix('+') {
if let Some(source) = &current_source {
let mut parts = line.split(':').map(|s| s.trim());
let Some(name) = parts.next() else {
continue;
};
let mut defines = vec![];
if let Some(define_list) = parts.next() {
defines.extend(define_list.split(' ').map(|s| s.trim().to_string()));
}
map.entry(source.to_string())
.or_default()
.push(Permutation {
name: name.to_string(),
defines,
});
}
} else {
current_source = Some(line.to_string());
}
}
map
}

View file

@ -0,0 +1,159 @@
use std::{
collections::{HashMap, HashSet},
fs,
path::Path,
vec,
};
pub fn get_imports(shader_dir: &Path) -> HashMap<String, String> {
let mut imports = HashMap::new();
let imports_dir = shader_dir.join("shared");
for entry in imports_dir
.read_dir()
.expect("Can read shader import directory")
{
let entry = entry.expect("Can continue reading shader import directory");
if entry.file_type().unwrap().is_file() {
let file_name = entry.file_name();
if let Some(name) = file_name.to_str() {
let suffix = ".wgsl";
if let Some(import_name) = name.strip_suffix(suffix) {
let contents = fs::read_to_string(imports_dir.join(&file_name))
.expect("Could read shader {import_name} contents");
imports.insert(import_name.to_owned(), contents);
}
}
}
}
imports
}
pub struct StackItem {
active: bool,
else_passed: bool,
}
pub fn preprocess(
input: &str,
defines: &HashSet<String>,
imports: &HashMap<String, String>,
) -> String {
let mut output = String::with_capacity(input.len());
let mut stack = vec![];
'all_lines: for (line_number, mut line) in input.lines().enumerate() {
loop {
if line.is_empty() {
break;
}
let hash_index = line.find('#');
let comment_index = line.find("//");
let hash_index = match (hash_index, comment_index) {
(Some(hash_index), None) => hash_index,
(Some(hash_index), Some(comment_index)) if hash_index < comment_index => hash_index,
// Add this line to the output - all directives are commented out or there are no directives
_ => break,
};
let directive_start = &line[hash_index + '#'.len_utf8()..];
let directive_len = directive_start
// The first character which can't be part of the directive name marks the end of the directive
// In practise this should always be whitespace, but in theory a 'unit' directive
// could be added
.find(|c: char| !c.is_alphanumeric())
.unwrap_or(directive_start.len());
let directive = &directive_start[..directive_len];
let directive_is_at_start = line.trim_start().starts_with('#');
match directive {
if_item @ ("ifdef" | "ifndef" | "else" | "endif") if !directive_is_at_start => {
eprintln!("#{if_item} directives must be the first non_whitespace items on their line, ignoring (line {line_number})");
break;
}
def_test @ ("ifdef" | "ifndef") => {
let def = directive_start[directive_len..].trim();
let exists = defines.contains(def);
let mode = def_test == "ifdef";
stack.push(StackItem {
active: mode == exists,
else_passed: false,
});
// Don't add this line to the output; instead process the next line
continue 'all_lines;
}
"else" => {
let item = stack.last_mut();
if let Some(item) = item {
if item.else_passed {
eprintln!("Second else for same ifdef/ifndef (line {line_number}); ignoring second else")
} else {
item.else_passed = true;
item.active = !item.active;
}
}
let remainder = directive_start[directive_len..].trim();
if !remainder.is_empty() {
eprintln!("#else directives don't take an argument. `{remainder}` will not be in output (line {line_number})");
}
// Don't add this line to the output; it should be empty (see warning above)
continue 'all_lines;
}
"endif" => {
if stack.pop().is_none() {
eprintln!("Mismatched endif (line {line_number})");
}
let remainder = directive_start[directive_len..].trim();
if !remainder.is_empty() {
eprintln!("#endif directives don't take an argument. `{remainder}` will not be in output (line {line_number})");
}
// Don't add this line to the output; it should be empty (see warning above)
continue 'all_lines;
}
"import" => {
output.push_str(&line[..hash_index]);
let directive_end = &directive_start[directive_len..];
let import_name_start = if let Some(import_name_start) =
directive_end.find(|c: char| !c.is_whitespace())
{
import_name_start
} else {
eprintln!("#import needs a non_whitespace argument (line {line_number})");
continue 'all_lines;
};
let import_name_start = &directive_end[import_name_start..];
let import_name_end_index = import_name_start
// The first character which can't be part of the import name marks the end of the import
.find(|c: char| !(c == '_' || c.is_alphanumeric()))
.unwrap_or(import_name_start.len());
let import_name = &import_name_start[..import_name_end_index];
line = &import_name_start[import_name_end_index..];
let import = imports.get(import_name);
if let Some(import) = import {
// In theory, we can cache this until the top item of the stack changes
// However, in practise there will only ever be at most 2 stack items, so it's reasonable to just recompute it every time
if stack.iter().all(|item| item.active) {
output.push_str(&preprocess(import, defines, imports));
}
} else {
eprintln!("Unknown import `{import_name}` (line {line_number})");
}
continue;
}
val => {
eprintln!("Unknown preprocessor directive `{val}` (line {line_number})");
}
}
}
if stack.iter().all(|item| item.active) {
// Naga does not yet recognize `const` but web does not allow global `let`. We
// use `let` in our canonical sources to satisfy wgsl-analyzer but replace with
// `const` when targeting web.
if line.starts_with("let ") {
output.push_str("const");
output.push_str(&line[3..]);
} else {
output.push_str(line);
}
output.push('\n');
}
}
output
}

30
vello_shaders/src/lib.rs Normal file
View file

@ -0,0 +1,30 @@
mod types;
#[cfg(feature = "compile")]
pub mod compile;
pub use types::{BindType, BindingInfo};
use std::borrow::Cow;
#[derive(Clone, Debug)]
pub struct ComputeShader<'a> {
pub name: Cow<'a, str>,
pub code: Cow<'a, [u8]>,
pub workgroup_size: [u32; 3],
pub bindings: Cow<'a, [BindType]>,
}
pub trait PipelineHost {
type Device;
type ComputePipeline;
type Error;
fn new_compute_pipeline(
&mut self,
device: &Self::Device,
shader: &ComputeShader,
) -> Result<Self::ComputePipeline, Self::Error>;
}
include!(concat!(env!("OUT_DIR"), "/shaders.rs"));

View file

@ -0,0 +1,30 @@
//! Types that are shared between the main crate and build.
/// The type of resource that will be bound to a slot in a shader.
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub enum BindType {
/// A storage buffer with read/write access.
Buffer,
/// A storage buffer with read only access.
BufReadOnly,
/// A small storage buffer to be used as uniforms.
Uniform,
/// A storage image.
Image,
/// A storage image with read only access.
ImageRead,
// TODO: Sampler, maybe others
}
impl BindType {
pub fn is_mutable(self) -> bool {
matches!(self, Self::Buffer | Self::Image)
}
}
#[derive(Clone, Debug)]
pub struct BindingInfo {
pub name: Option<String>,
pub location: (u32, u32),
pub ty: BindType,
}