Merge pull request #265 from linebender/vello_shaders

`vello_shaders` crate for AOT compilation and lightweight integration with external renderer projects
This commit is contained in:
Chad Brokaw 2023-03-29 16:13:15 -04:00 committed by GitHub
commit 0e71aa1d0f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 671 additions and 0 deletions

View file

@ -2,6 +2,8 @@
resolver = "2"
members = [
"crates/shaders",
"integrations/vello_svg",
"examples/headless",

19
crates/shaders/Cargo.toml Normal file
View file

@ -0,0 +1,19 @@
[package]
name = "vello_shaders"
version = "0.1.0"
edition = "2021"
[features]
default = ["compile", "wgsl", "msl"]
compile = ["naga", "thiserror"]
wgsl = []
msl = []
[dependencies]
naga = { git = "https://github.com/gfx-rs/naga", rev = "53d62b9", features = ["wgsl-in", "msl-out", "validate"], optional = true }
thiserror = { version = "1.0.40", optional = true }
[build-dependencies]
naga = { git = "https://github.com/gfx-rs/naga", rev = "53d62b9", features = ["wgsl-in", "msl-out", "validate"] }
thiserror = "1.0.40"

7
crates/shaders/README.md Normal file
View file

@ -0,0 +1,7 @@
The `vello_shaders` crate provides a utility library to integrate the Vello shader modules into any
renderer project. The crate provides the necessary metadata to construct the individual compute
pipelines on any GPU API while leaving the responsibility of all API interactions (such as
resource management and command encoding) up to the client.
The shaders can be pre-compiled to any target shading language at build time based on feature flags.
Currently only WGSL and Metal Shading Language are supported.

107
crates/shaders/build.rs Normal file
View file

@ -0,0 +1,107 @@
// Copyright 2023 The Vello authors
// SPDX-License-Identifier: Apache-2.0 OR MIT
#[path = "src/compile/mod.rs"]
mod compile;
#[path = "src/types.rs"]
mod types;
use std::env;
use std::fmt::Write;
use std::path::Path;
use compile::ShaderInfo;
fn main() {
let out_dir = env::var_os("OUT_DIR").unwrap();
let dest_path = Path::new(&out_dir).join("shaders.rs");
let mut shaders = compile::ShaderInfo::from_dir("../../shader");
// Drop the HashMap and sort by name so that we get deterministic order.
let mut shaders = shaders.drain().collect::<Vec<_>>();
shaders.sort_by(|x, y| x.0.cmp(&y.0));
let mut buf = String::default();
write_types(&mut buf, &shaders).unwrap();
if cfg!(feature = "wgsl") {
write_shaders(&mut buf, "wgsl", &shaders, |info| {
info.source.as_bytes().to_owned()
})
.unwrap();
}
if cfg!(feature = "msl") {
write_shaders(&mut buf, "msl", &shaders, |info| {
compile::msl::translate(info).unwrap().as_bytes().to_owned()
})
.unwrap();
}
std::fs::write(&dest_path, &buf).unwrap();
println!("cargo:rerun-if-changed=../shader");
}
fn write_types(buf: &mut String, shaders: &[(String, ShaderInfo)]) -> Result<(), std::fmt::Error> {
writeln!(buf, "pub struct Shaders<'a> {{")?;
for (name, _) in shaders {
writeln!(buf, " pub {name}: ComputeShader<'a>,")?;
}
writeln!(buf, "}}")?;
writeln!(buf, "pub struct Pipelines<T> {{")?;
for (name, _) in shaders {
writeln!(buf, " pub {name}: T,")?;
}
writeln!(buf, "}}")?;
writeln!(buf, "impl<T> Pipelines<T> {{")?;
writeln!(buf, " pub fn from_shaders<H: PipelineHost<ComputePipeline = T>>(shaders: &Shaders, device: &H::Device, host: &mut H) -> Result<Self, H::Error> {{")?;
writeln!(buf, " Ok(Self {{")?;
for (name, _) in shaders {
writeln!(
buf,
" {name}: host.new_compute_pipeline(device, &shaders.{name})?,"
)?;
}
writeln!(buf, " }})")?;
writeln!(buf, " }}")?;
writeln!(buf, "}}")?;
Ok(())
}
fn write_shaders(
buf: &mut String,
mod_name: &str,
shaders: &[(String, ShaderInfo)],
translate: impl Fn(&ShaderInfo) -> Vec<u8>,
) -> Result<(), std::fmt::Error> {
writeln!(buf, "pub mod {mod_name} {{")?;
writeln!(buf, " use super::*;")?;
writeln!(buf, " use BindType::*;")?;
writeln!(buf, " pub const SHADERS: Shaders<'static> = Shaders {{")?;
for (name, info) in shaders {
let bind_tys = info
.bindings
.iter()
.map(|binding| binding.ty)
.collect::<Vec<_>>();
let wg_bufs = &info.workgroup_buffers;
let source = translate(info);
writeln!(buf, " {name}: ComputeShader {{")?;
writeln!(buf, " name: Cow::Borrowed({:?}),", name)?;
writeln!(
buf,
" code: Cow::Borrowed(&{:?}),",
source.as_slice()
)?;
writeln!(
buf,
" workgroup_size: {:?},",
info.workgroup_size
)?;
writeln!(buf, " bindings: Cow::Borrowed(&{:?}),", bind_tys)?;
writeln!(
buf,
" workgroup_buffers: Cow::Borrowed(&{:?}),",
wg_bufs
)?;
writeln!(buf, " }},")?;
}
writeln!(buf, " }};")?;
writeln!(buf, "}}")?;
Ok(())
}

View file

@ -0,0 +1,197 @@
// Copyright 2023 The Vello authors
// SPDX-License-Identifier: Apache-2.0 OR MIT
use {
naga::{
front::wgsl,
valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags},
AddressSpace, ArraySize, ConstantInner, ImageClass, Module, ScalarValue, StorageAccess,
WithSpan,
},
std::{
collections::{HashMap, HashSet},
path::Path,
},
thiserror::Error,
};
pub mod permutations;
pub mod preprocess;
pub mod msl;
use crate::types::{BindType, BindingInfo, WorkgroupBufferInfo};
#[derive(Error, Debug)]
pub enum Error {
#[error("failed to parse shader: {0}")]
Parse(#[from] wgsl::ParseError),
#[error("failed to validate shader: {0}")]
Validate(#[from] WithSpan<ValidationError>),
#[error("missing entry point function")]
EntryPointNotFound,
}
#[derive(Debug)]
pub struct ShaderInfo {
pub source: String,
pub module: Module,
pub module_info: ModuleInfo,
pub workgroup_size: [u32; 3],
pub bindings: Vec<BindingInfo>,
pub workgroup_buffers: Vec<WorkgroupBufferInfo>,
}
impl ShaderInfo {
pub fn new(source: String, entry_point: &str) -> Result<ShaderInfo, Error> {
let module = wgsl::parse_str(&source)?;
let module_info = naga::valid::Validator::new(
ValidationFlags::all() & !ValidationFlags::CONTROL_FLOW_UNIFORMITY,
Capabilities::all(),
)
.validate(&module)?;
let (entry_index, entry) = module
.entry_points
.iter()
.enumerate()
.find(|(_, entry)| entry.name.as_str() == entry_point)
.ok_or(Error::EntryPointNotFound)?;
let mut bindings = vec![];
let mut workgroup_buffers = vec![];
let mut wg_buffer_idx = 0;
let entry_info = module_info.get_entry_point(entry_index);
for (var_handle, var) in module.global_variables.iter() {
if entry_info[var_handle].is_empty() {
continue;
}
let binding_ty = match module.types[var.ty].inner {
naga::TypeInner::BindingArray { base, .. } => &module.types[base].inner,
ref ty => ty,
};
let Some(binding) = &var.binding else {
if var.space == AddressSpace::WorkGroup {
let index = wg_buffer_idx;
wg_buffer_idx += 1;
let size_in_bytes = match binding_ty {
naga::TypeInner::Array {
size: ArraySize::Constant(const_handle),
stride,
..
} => {
let size: u32 = match module.constants[*const_handle].inner {
ConstantInner::Scalar { value, width: _ } => match value {
ScalarValue::Uint(value) => value.try_into().unwrap(),
ScalarValue::Sint(value) => value.try_into().unwrap(),
_ => continue,
},
ConstantInner::Composite { .. } => continue,
};
size * stride
},
naga::TypeInner::Struct { span, .. } => *span,
naga::TypeInner::Scalar { width, ..} => *width as u32,
naga::TypeInner::Vector { width, ..} => *width as u32,
naga::TypeInner::Matrix { width, ..} => *width as u32,
naga::TypeInner::Atomic { width, ..} => *width as u32,
_ => {
// Not a valid workgroup variable type. At least not one that is used
// in our shaders.
continue;
}
};
workgroup_buffers.push(WorkgroupBufferInfo {
size_in_bytes,
index,
});
}
continue;
};
let mut resource = BindingInfo {
name: var.name.clone(),
location: (binding.group, binding.binding),
ty: BindType::Buffer,
};
if let naga::TypeInner::Image { class, .. } = &binding_ty {
resource.ty = BindType::ImageRead;
if let ImageClass::Storage { access, .. } = class {
if access.contains(StorageAccess::STORE) {
resource.ty = BindType::Image;
}
}
} else {
resource.ty = BindType::BufReadOnly;
match var.space {
AddressSpace::Storage { access } => {
if access.contains(StorageAccess::STORE) {
resource.ty = BindType::Buffer;
}
}
AddressSpace::Uniform => {
resource.ty = BindType::Uniform;
}
_ => {}
}
}
bindings.push(resource);
}
bindings.sort_by_key(|res| res.location);
let workgroup_size = entry.workgroup_size;
Ok(ShaderInfo {
source,
module,
module_info,
workgroup_size,
bindings,
workgroup_buffers,
})
}
pub fn from_dir(shader_dir: impl AsRef<Path>) -> HashMap<String, Self> {
use std::fs;
let shader_dir = shader_dir.as_ref();
let permutation_map = if let Ok(permutations_source) =
std::fs::read_to_string(shader_dir.join("permutations"))
{
permutations::parse(&permutations_source)
} else {
Default::default()
};
println!("{:?}", permutation_map);
let imports = preprocess::get_imports(shader_dir);
let mut info = HashMap::default();
let mut defines = HashSet::default();
defines.insert("full".to_string());
for entry in shader_dir
.read_dir()
.expect("Can read shader import directory")
{
let entry = entry.expect("Can continue reading shader import directory");
if entry.file_type().unwrap().is_file() {
let file_name = entry.file_name();
if let Some(name) = file_name.to_str() {
let suffix = ".wgsl";
if let Some(shader_name) = name.strip_suffix(suffix) {
let contents = fs::read_to_string(shader_dir.join(&file_name))
.expect("Could read shader {shader_name} contents");
if let Some(permutations) = permutation_map.get(shader_name) {
for permutation in permutations {
let mut defines = defines.clone();
defines.extend(permutation.defines.iter().cloned());
let source = preprocess::preprocess(&contents, &defines, &imports);
let shader_info = Self::new(source.clone(), "main").unwrap();
info.insert(permutation.name.clone(), shader_info);
}
} else {
let source = preprocess::preprocess(&contents, &defines, &imports);
let shader_info = Self::new(source.clone(), "main").unwrap();
info.insert(shader_name.to_string(), shader_info);
}
}
}
}
}
info
}
}

View file

@ -0,0 +1,56 @@
// Copyright 2023 The Vello authors
// SPDX-License-Identifier: Apache-2.0 OR MIT
use naga::back::msl;
use super::{BindType, ShaderInfo};
pub fn translate(shader: &ShaderInfo) -> Result<String, msl::Error> {
let mut map = msl::EntryPointResourceMap::default();
let mut buffer_index = 0u8;
let mut image_index = 0u8;
let mut binding_map = msl::BindingMap::default();
for resource in &shader.bindings {
let binding = naga::ResourceBinding {
group: resource.location.0,
binding: resource.location.1,
};
let mut target = msl::BindTarget::default();
match resource.ty {
BindType::Buffer | BindType::BufReadOnly | BindType::Uniform => {
target.buffer = Some(buffer_index);
buffer_index += 1;
}
BindType::Image | BindType::ImageRead => {
target.texture = Some(image_index);
image_index += 1;
}
}
target.mutable = resource.ty.is_mutable();
binding_map.insert(binding, target);
}
map.insert(
"main".to_string(),
msl::EntryPointResources {
resources: binding_map,
push_constant_buffer: None,
sizes_buffer: Some(30),
},
);
let options = msl::Options {
lang_version: (2, 0),
per_entry_point_map: map,
inline_samplers: vec![],
spirv_cross_compatibility: false,
fake_missing_bindings: false,
bounds_check_policies: naga::proc::BoundsCheckPolicies::default(),
zero_initialize_workgroup_memory: false,
};
let (source, _) = msl::write_string(
&shader.module,
&shader.module_info,
&options,
&msl::PipelineOptions::default(),
)?;
Ok(source)
}

View file

@ -0,0 +1,44 @@
// Copyright 2023 The Vello authors
// SPDX-License-Identifier: Apache-2.0 OR MIT
use std::collections::HashMap;
#[derive(Debug)]
pub struct Permutation {
/// The new name for the permutation
pub name: String,
/// Set of defines to apply for the permutation
pub defines: Vec<String>,
}
pub fn parse(source: &str) -> HashMap<String, Vec<Permutation>> {
let mut map: HashMap<String, Vec<Permutation>> = Default::default();
let mut current_source: Option<String> = None;
for line in source.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
if let Some(line) = line.strip_prefix('+') {
if let Some(source) = &current_source {
let mut parts = line.split(':').map(|s| s.trim());
let Some(name) = parts.next() else {
continue;
};
let mut defines = vec![];
if let Some(define_list) = parts.next() {
defines.extend(define_list.split(' ').map(|s| s.trim().to_string()));
}
map.entry(source.to_string())
.or_default()
.push(Permutation {
name: name.to_string(),
defines,
});
}
} else {
current_source = Some(line.to_string());
}
}
map
}

View file

@ -0,0 +1,162 @@
// Copyright 2023 The Vello authors
// SPDX-License-Identifier: Apache-2.0 OR MIT
use std::{
collections::{HashMap, HashSet},
fs,
path::Path,
vec,
};
pub fn get_imports(shader_dir: &Path) -> HashMap<String, String> {
let mut imports = HashMap::new();
let imports_dir = shader_dir.join("shared");
for entry in imports_dir
.read_dir()
.expect("Can read shader import directory")
{
let entry = entry.expect("Can continue reading shader import directory");
if entry.file_type().unwrap().is_file() {
let file_name = entry.file_name();
if let Some(name) = file_name.to_str() {
let suffix = ".wgsl";
if let Some(import_name) = name.strip_suffix(suffix) {
let contents = fs::read_to_string(imports_dir.join(&file_name))
.expect("Could read shader {import_name} contents");
imports.insert(import_name.to_owned(), contents);
}
}
}
}
imports
}
pub struct StackItem {
active: bool,
else_passed: bool,
}
pub fn preprocess(
input: &str,
defines: &HashSet<String>,
imports: &HashMap<String, String>,
) -> String {
let mut output = String::with_capacity(input.len());
let mut stack = vec![];
'all_lines: for (line_number, mut line) in input.lines().enumerate() {
loop {
if line.is_empty() {
break;
}
let hash_index = line.find('#');
let comment_index = line.find("//");
let hash_index = match (hash_index, comment_index) {
(Some(hash_index), None) => hash_index,
(Some(hash_index), Some(comment_index)) if hash_index < comment_index => hash_index,
// Add this line to the output - all directives are commented out or there are no directives
_ => break,
};
let directive_start = &line[hash_index + '#'.len_utf8()..];
let directive_len = directive_start
// The first character which can't be part of the directive name marks the end of the directive
// In practise this should always be whitespace, but in theory a 'unit' directive
// could be added
.find(|c: char| !c.is_alphanumeric())
.unwrap_or(directive_start.len());
let directive = &directive_start[..directive_len];
let directive_is_at_start = line.trim_start().starts_with('#');
match directive {
if_item @ ("ifdef" | "ifndef" | "else" | "endif") if !directive_is_at_start => {
eprintln!("#{if_item} directives must be the first non_whitespace items on their line, ignoring (line {line_number})");
break;
}
def_test @ ("ifdef" | "ifndef") => {
let def = directive_start[directive_len..].trim();
let exists = defines.contains(def);
let mode = def_test == "ifdef";
stack.push(StackItem {
active: mode == exists,
else_passed: false,
});
// Don't add this line to the output; instead process the next line
continue 'all_lines;
}
"else" => {
let item = stack.last_mut();
if let Some(item) = item {
if item.else_passed {
eprintln!("Second else for same ifdef/ifndef (line {line_number}); ignoring second else")
} else {
item.else_passed = true;
item.active = !item.active;
}
}
let remainder = directive_start[directive_len..].trim();
if !remainder.is_empty() {
eprintln!("#else directives don't take an argument. `{remainder}` will not be in output (line {line_number})");
}
// Don't add this line to the output; it should be empty (see warning above)
continue 'all_lines;
}
"endif" => {
if stack.pop().is_none() {
eprintln!("Mismatched endif (line {line_number})");
}
let remainder = directive_start[directive_len..].trim();
if !remainder.is_empty() {
eprintln!("#endif directives don't take an argument. `{remainder}` will not be in output (line {line_number})");
}
// Don't add this line to the output; it should be empty (see warning above)
continue 'all_lines;
}
"import" => {
output.push_str(&line[..hash_index]);
let directive_end = &directive_start[directive_len..];
let import_name_start = if let Some(import_name_start) =
directive_end.find(|c: char| !c.is_whitespace())
{
import_name_start
} else {
eprintln!("#import needs a non_whitespace argument (line {line_number})");
continue 'all_lines;
};
let import_name_start = &directive_end[import_name_start..];
let import_name_end_index = import_name_start
// The first character which can't be part of the import name marks the end of the import
.find(|c: char| !(c == '_' || c.is_alphanumeric()))
.unwrap_or(import_name_start.len());
let import_name = &import_name_start[..import_name_end_index];
line = &import_name_start[import_name_end_index..];
let import = imports.get(import_name);
if let Some(import) = import {
// In theory, we can cache this until the top item of the stack changes
// However, in practise there will only ever be at most 2 stack items, so it's reasonable to just recompute it every time
if stack.iter().all(|item| item.active) {
output.push_str(&preprocess(import, defines, imports));
}
} else {
eprintln!("Unknown import `{import_name}` (line {line_number})");
}
continue;
}
val => {
eprintln!("Unknown preprocessor directive `{val}` (line {line_number})");
}
}
}
if stack.iter().all(|item| item.active) {
// Naga does not yet recognize `const` but web does not allow global `let`. We
// use `let` in our canonical sources to satisfy wgsl-analyzer but replace with
// `const` when targeting web.
if line.starts_with("let ") {
output.push_str("const");
output.push_str(&line[3..]);
} else {
output.push_str(line);
}
output.push('\n');
}
}
output
}

34
crates/shaders/src/lib.rs Normal file
View file

@ -0,0 +1,34 @@
// Copyright 2023 The Vello authors
// SPDX-License-Identifier: Apache-2.0 OR MIT
mod types;
#[cfg(feature = "compile")]
pub mod compile;
pub use types::{BindType, BindingInfo, WorkgroupBufferInfo};
use std::borrow::Cow;
#[derive(Clone, Debug)]
pub struct ComputeShader<'a> {
pub name: Cow<'a, str>,
pub code: Cow<'a, [u8]>,
pub workgroup_size: [u32; 3],
pub bindings: Cow<'a, [BindType]>,
pub workgroup_buffers: Cow<'a, [WorkgroupBufferInfo]>,
}
pub trait PipelineHost {
type Device;
type ComputePipeline;
type Error;
fn new_compute_pipeline(
&mut self,
device: &Self::Device,
shader: &ComputeShader,
) -> Result<Self::ComputePipeline, Self::Error>;
}
include!(concat!(env!("OUT_DIR"), "/shaders.rs"));

View file

@ -0,0 +1,40 @@
// Copyright 2023 The Vello authors
// SPDX-License-Identifier: Apache-2.0 OR MIT
//! Types that are shared between the main crate and build.
/// The type of resource that will be bound to a slot in a shader.
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub enum BindType {
/// A storage buffer with read/write access.
Buffer,
/// A storage buffer with read only access.
BufReadOnly,
/// A small storage buffer to be used as uniforms.
Uniform,
/// A storage image.
Image,
/// A storage image with read only access.
ImageRead,
// TODO: Sampler, maybe others
}
impl BindType {
pub fn is_mutable(self) -> bool {
matches!(self, Self::Buffer | Self::Image)
}
}
#[derive(Clone, Debug)]
pub struct BindingInfo {
pub name: Option<String>,
pub location: (u32, u32),
pub ty: BindType,
}
#[derive(Clone, Debug)]
pub struct WorkgroupBufferInfo {
pub size_in_bytes: u32,
/// The order in which the workgroup variable is declared in the shader module.
pub index: u32,
}

3
shader/permutations Normal file
View file

@ -0,0 +1,3 @@
pathtag_scan
+ pathtag_scan_large
+ pathtag_scan_small: small