reflect(wgsl): only analyze active ubo members

This commit is contained in:
chyyran 2024-02-15 18:57:51 -05:00 committed by Ronny Chan
parent 350508a7aa
commit cbac011969
12 changed files with 352 additions and 93 deletions

50
Cargo.lock generated
View file

@ -1009,6 +1009,15 @@ version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c"
[[package]]
name = "fxhash"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c31b6d751ae2c7f11320402d34e41349dd1016f8d5d45e48c4312bc8625af50c"
dependencies = [
"byteorder",
]
[[package]] [[package]]
name = "generic-array" name = "generic-array"
version = "0.14.7" version = "0.14.7"
@ -1643,7 +1652,8 @@ dependencies = [
"rspirv", "rspirv",
"rustc-hash", "rustc-hash",
"serde", "serde",
"spirv", "spirv 0.2.0+1.5.4",
"spirv-linker",
"spirv-to-dxil", "spirv-to-dxil",
"thiserror", "thiserror",
] ]
@ -1924,7 +1934,7 @@ dependencies = [
"num-traits", "num-traits",
"petgraph", "petgraph",
"rustc-hash", "rustc-hash",
"spirv", "spirv 0.3.0+sdk-1.3.268.0",
"termcolor", "termcolor",
"thiserror", "thiserror",
"unicode-xid", "unicode-xid",
@ -2612,12 +2622,13 @@ dependencies = [
[[package]] [[package]]
name = "rspirv" name = "rspirv"
version = "0.12.0+sdk-1.3.268.0" version = "0.11.0+1.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69cf3a93856b6e5946537278df0d3075596371b1950ccff012f02b0f7eafec8d" checksum = "1503993b59ca9ae4127365c3293517576d7ce56be9f3d8abb1625c85ddc583ba"
dependencies = [ dependencies = [
"rustc-hash", "fxhash",
"spirv", "num-traits",
"spirv 0.2.0+1.5.4",
] ]
[[package]] [[package]]
@ -2816,6 +2827,16 @@ dependencies = [
"lock_api", "lock_api",
] ]
[[package]]
name = "spirv"
version = "0.2.0+1.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "246bfa38fe3db3f1dfc8ca5a2cdeb7348c78be2112740cc0ec8ef18b6d94f830"
dependencies = [
"bitflags 1.3.2",
"num-traits",
]
[[package]] [[package]]
name = "spirv" name = "spirv"
version = "0.3.0+sdk-1.3.268.0" version = "0.3.0+sdk-1.3.268.0"
@ -2825,6 +2846,17 @@ dependencies = [
"bitflags 2.4.2", "bitflags 2.4.2",
] ]
[[package]]
name = "spirv-linker"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d236255ec7387809e18d57221d0fa428d6f909abc88524944cdfb7ebab01acb"
dependencies = [
"rspirv",
"thiserror",
"topological-sort",
]
[[package]] [[package]]
name = "spirv-to-dxil" name = "spirv-to-dxil"
version = "0.4.6" version = "0.4.6"
@ -3023,6 +3055,12 @@ dependencies = [
"winnow", "winnow",
] ]
[[package]]
name = "topological-sort"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aa7c7f42dea4b1b99439786f5633aeb9c14c1b53f75e282803c2ec2ad545873c"
[[package]] [[package]]
name = "tracing" name = "tracing"
version = "0.1.40" version = "0.1.40"

View file

@ -23,10 +23,11 @@ librashader-preprocess = { path = "../librashader-preprocess", version = "0.2.0-
librashader-presets = { path = "../librashader-presets", version = "0.2.0-rc.2" } librashader-presets = { path = "../librashader-presets", version = "0.2.0-rc.2" }
spirv_cross = { package = "librashader-spirv-cross", version = "0.25.1", optional = true } spirv_cross = { package = "librashader-spirv-cross", version = "0.25.1", optional = true }
spirv-linker = "0.1.0"
naga = { version = "0.19.0", optional = true } naga = { version = "0.19.0", optional = true }
rspirv = { version = "0.12.0", optional = true } rspirv = { version = "0.11.0", optional = true }
spirv = { version = "0.3.0", optional = true} spirv = { version = "0.2.0", optional = true}
serde = { version = "1.0", features = ["derive"], optional = true } serde = { version = "1.0", features = ["derive"], optional = true }

View file

@ -125,6 +125,10 @@ pub enum ShaderReflectError {
#[cfg(feature = "naga")] #[cfg(feature = "naga")]
#[error("naga-spv")] #[error("naga-spv")]
NagaInputError(#[from] naga::front::spv::Error), NagaInputError(#[from] naga::front::spv::Error),
/// Error when transpiling from naga
#[cfg(feature = "naga")]
#[error("naga-spv")]
NagaReflectError(#[from] naga::WithSpan<naga::valid::ValidationError>),
} }
#[cfg(feature = "unstable-naga")] #[cfg(feature = "unstable-naga")]

View file

@ -1,17 +1,21 @@
mod lower_samplers; mod lower_samplers;
pub mod msl; pub mod msl;
pub mod spirv; pub mod spirv;
mod trim_unused_inputs;
pub mod wgsl; pub mod wgsl;
use crate::error::{SemanticsErrorKind, ShaderReflectError}; use crate::error::{SemanticsErrorKind, ShaderReflectError};
use bitflags::Flags;
use crate::front::SpirvCompilation; use crate::front::SpirvCompilation;
use naga::valid::{Capabilities, ModuleInfo, ValidationFlags, Validator};
use naga::{ use naga::{
AddressSpace, Binding, GlobalVariable, Handle, ImageClass, Module, ResourceBinding, Scalar, AddressSpace, Binding, Expression, GlobalVariable, Handle, ImageClass, Module, ResourceBinding,
ScalarKind, TypeInner, VectorSize, Scalar, ScalarKind, StructMember, TypeInner, VectorSize,
}; };
use rspirv::binary::Assemble; use rspirv::binary::Assemble;
use rspirv::dr::Builder; use rspirv::dr::Builder;
use rustc_hash::{FxHashMap, FxHashSet};
use crate::reflect::helper::{SemanticErrorBlame, TextureData, UboData}; use crate::reflect::helper::{SemanticErrorBlame, TextureData, UboData};
use crate::reflect::semantics::{ use crate::reflect::semantics::{
@ -602,6 +606,41 @@ impl NagaReflect {
Ok(()) Ok(())
} }
fn collect_uniform_names(
module: &Module,
buffer_handle: Handle<GlobalVariable>,
blame: SemanticErrorBlame,
) -> Result<FxHashSet<&StructMember>, ShaderReflectError> {
let mut names = FxHashSet::default();
let ubo = &module.global_variables[buffer_handle];
let TypeInner::Struct { members, .. } = &module.types[ubo.ty].inner else {
return Err(blame.error(SemanticsErrorKind::InvalidResourceType));
};
// struct access is AccessIndex
for (_, fun) in module.functions.iter() {
for (_, expr) in fun.expressions.iter() {
let &Expression::AccessIndex { base, index } = expr else {
continue;
};
let &Expression::GlobalVariable(base) = &fun.expressions[base] else {
continue;
};
if base == buffer_handle {
let member = members
.get(index as usize)
.ok_or(blame.error(SemanticsErrorKind::InvalidRange(index)))?;
names.insert(member);
}
}
}
Ok(names)
}
fn reflect_buffer_struct_members( fn reflect_buffer_struct_members(
module: &Module, module: &Module,
resource: Handle<GlobalVariable>, resource: Handle<GlobalVariable>,
@ -611,7 +650,10 @@ impl NagaReflect {
offset_type: UniformMemberBlock, offset_type: UniformMemberBlock,
blame: SemanticErrorBlame, blame: SemanticErrorBlame,
) -> Result<(), ShaderReflectError> { ) -> Result<(), ShaderReflectError> {
let reachable = Self::collect_uniform_names(&module, resource, blame)?;
let resource = &module.global_variables[resource]; let resource = &module.global_variables[resource];
let TypeInner::Struct { members, .. } = &module.types[resource.ty].inner else { let TypeInner::Struct { members, .. } = &module.types[resource.ty].inner else {
return Err(blame.error(SemanticsErrorKind::InvalidResourceType)); return Err(blame.error(SemanticsErrorKind::InvalidResourceType));
}; };
@ -620,6 +662,11 @@ impl NagaReflect {
let Some(name) = member.name.clone() else { let Some(name) = member.name.clone() else {
return Err(blame.error(SemanticsErrorKind::InvalidRange(member.offset))); return Err(blame.error(SemanticsErrorKind::InvalidRange(member.offset)));
}; };
if !reachable.contains(member) {
continue;
}
let member_type = &module.types[member.ty].inner; let member_type = &module.types[member.ty].inner;
if let Some(parameter) = semantics.uniform_semantics.get_unique_semantic(&name) { if let Some(parameter) = semantics.uniform_semantics.get_unique_semantic(&name) {
@ -884,7 +931,6 @@ impl ReflectShader for NagaReflect {
}); });
let push_constant = self.reflect_push_constant_buffer(vertex_push, fragment_push)?; let push_constant = self.reflect_push_constant_buffer(vertex_push, fragment_push)?;
let mut meta = BindingMeta::default(); let mut meta = BindingMeta::default();
if let Some(ubo) = vertex_ubo { if let Some(ubo) = vertex_ubo {
@ -965,6 +1011,10 @@ impl ReflectShader for NagaReflect {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use crate::reflect::semantics::{Semantic, TextureSemantics, UniformSemantic};
use librashader_common::map::FastHashMap;
use librashader_preprocess::ShaderSource;
use librashader_presets::ShaderPreset;
// #[test] // #[test]
// pub fn test_into() { // pub fn test_into() {
@ -983,4 +1033,19 @@ mod test {
// //
// println!("{outputs:#?}"); // println!("{outputs:#?}");
// } // }
// #[test]
// pub fn mega_bezel_reflect() {
// let preset = ShaderPreset::try_parse(
// "../test/shaders_slang/bezel/Mega_Bezel/Presets/MBZ__0__SMOOTH-ADV.slangp",
// )
// .unwrap();
//
// let mut uniform_semantics: FastHashMap<String, UniformSemantic> = Default::default();
// let mut texture_semantics: FastHashMap<String, Semantic<TextureSemantics>> = Default::default();
//
//
//
//
// }
} }

View file

@ -0,0 +1,129 @@
use rspirv::dr::{Builder, Module, Operand};
use rustc_hash::{FxHashMap, FxHashSet};
use spirv::{Op, StorageClass};
pub struct LinkInputs<'a> {
pub frag_builder: &'a mut Builder,
pub vert: &'a Module,
pub inputs: FxHashMap<spirv::Word, spirv::Word>,
}
impl<'a> LinkInputs<'a> {
fn find_location(module: &Module, id: spirv::Word) -> Option<u32> {
module.annotations.iter().find_map(|op| {
if op.class.opcode != Op::Decorate {
return None;
}
eprintln!("{:?}", op);
return None;
})
}
pub fn new(vert: &'a Module, frag: &'a mut Builder) -> Self {
let mut inputs = FxHashMap::default();
for global in frag.module_ref().types_global_values.iter() {
if global.class.opcode == spirv::Op::Variable
&& global.operands[0] == Operand::StorageClass(StorageClass::Input)
{
if let Some(id) = global.result_id {
inputs.insert(id, 0);
}
}
}
// for global in vert.types_global_values.iter() {
// if global.class.opcode == spirv::Op::Variable
// && global.operands[0] == Operand::StorageClass(StorageClass::Output)
// {
// if let Some(id) = global.result_id {
// inputs.insert(id, 0);
// }
// }
// }
let mut val = Self {
frag_builder: frag,
vert,
inputs,
};
val
}
pub fn do_pass(&mut self) {
let functions = &self.frag_builder.module_ref().functions;
// literally if it has any reference at all we can keep it
for function in functions {
for param in &function.parameters {
for op in &param.operands {
if let Some(word) = op.id_ref_any() {
if self.inputs.contains_key(&word) {
self.inputs.remove(&word);
}
}
}
}
for block in &function.blocks {
for inst in &block.instructions {
for op in &inst.operands {
if let Some(word) = op.id_ref_any() {
if self.inputs.contains_key(&word) {
self.inputs.remove(&word);
}
}
}
}
}
}
// ok well guess we dont
self.frag_builder.module_mut().debug_names.retain(|instr| {
for op in &instr.operands {
if let Some(word) = op.id_ref_any() {
if self.inputs.contains_key(&word) {
return false;
}
}
}
return true;
});
self.frag_builder.module_mut().annotations.retain(|instr| {
for op in &instr.operands {
if let Some(word) = op.id_ref_any() {
if self.inputs.contains_key(&word) {
return false;
}
}
}
return true;
});
for entry_point in self.frag_builder.module_mut().entry_points.iter_mut() {
entry_point.operands.retain(|op| {
if let Some(word) = op.id_ref_any() {
if self.inputs.contains_key(&word) {
return false;
}
}
return true;
})
}
self.frag_builder
.module_mut()
.types_global_values
.retain(|instr| {
let Some(id) = instr.result_id else {
return true;
};
!self.inputs.contains_key(&id)
});
}
}

View file

@ -4,8 +4,8 @@ use crate::back::{CompileShader, ShaderCompilerOutput};
use crate::error::ShaderCompileError; use crate::error::ShaderCompileError;
use crate::reflect::naga::{NagaLoweringOptions, NagaReflect}; use crate::reflect::naga::{NagaLoweringOptions, NagaReflect};
use naga::back::wgsl::WriterFlags; use naga::back::wgsl::WriterFlags;
use naga::valid::{Capabilities, ValidationFlags}; use naga::valid::{Capabilities, ModuleInfo, ValidationFlags, Validator};
use naga::Module; use naga::{Expression, Module, Statement};
impl CompileShader<WGSL> for NagaReflect { impl CompileShader<WGSL> for NagaReflect {
type Options = NagaLoweringOptions; type Options = NagaLoweringOptions;
@ -15,19 +15,20 @@ impl CompileShader<WGSL> for NagaReflect {
mut self, mut self,
options: Self::Options, options: Self::Options,
) -> Result<ShaderCompilerOutput<String, Self::Context>, ShaderCompileError> { ) -> Result<ShaderCompilerOutput<String, Self::Context>, ShaderCompileError> {
fn write_wgsl(module: &Module) -> Result<String, ShaderCompileError> { fn write_wgsl(module: &Module, info: &ModuleInfo) -> Result<String, ShaderCompileError> {
let mut valid = let wgsl = naga::back::wgsl::write_string(&module, &info, WriterFlags::empty())?;
naga::valid::Validator::new(ValidationFlags::all(), Capabilities::empty());
let info = valid.validate(&module)?;
let wgsl = naga::back::wgsl::write_string(&module, &info, WriterFlags::EXPLICIT_TYPES)?;
Ok(wgsl) Ok(wgsl)
} }
self.do_lowering(&options); self.do_lowering(&options);
let fragment = write_wgsl(&self.fragment)?; let mut valid = Validator::new(ValidationFlags::all(), Capabilities::empty());
let vertex = write_wgsl(&self.vertex)?;
let vertex_info = valid.validate(&self.vertex)?;
let fragment_info = valid.validate(&self.fragment)?;
let fragment = write_wgsl(&self.fragment, &fragment_info)?;
let vertex = write_wgsl(&self.vertex, &vertex_info)?;
Ok(ShaderCompilerOutput { Ok(ShaderCompilerOutput {
vertex, vertex,
fragment, fragment,

View file

@ -26,7 +26,6 @@ thiserror = "1.0.50"
bytemuck = { version = "1.14.0", features = ["derive"] } bytemuck = { version = "1.14.0", features = ["derive"] }
array-concat = "0.5.2" array-concat = "0.5.2"
[features] [features]
# workaround for docsrs to not build metal-rs. # workaround for docsrs to not build metal-rs.
wgpu_dx12 = ["wgpu/dx12"] wgpu_dx12 = ["wgpu/dx12"]

View file

@ -8,45 +8,45 @@ use wgpu::{Buffer, Device, RenderPass};
// WGPU does vertex expansion // WGPU does vertex expansion
#[repr(C)] #[repr(C)]
#[derive(Debug, Copy, Clone, Default, Zeroable, Pod)] #[derive(Debug, Copy, Clone, Default, Zeroable, Pod)]
struct WgpuVertex { pub struct WgpuVertex {
position: [f32; 2], pub position: [f32; 4],
texcoord: [f32; 2], pub texcoord: [f32; 2],
} }
const OFFSCREEN_VBO_DATA: [WgpuVertex; 4] = [ const OFFSCREEN_VBO_DATA: [WgpuVertex; 4] = [
WgpuVertex { WgpuVertex {
position: [-1.0, -1.0], position: [-1.0, -1.0, 0.0, 1.0],
texcoord: [0.0, 0.0], texcoord: [0.0, 0.0],
}, },
WgpuVertex { WgpuVertex {
position: [-1.0, 1.0], position: [-1.0, 1.0, 0.0, 1.0],
texcoord: [0.0, 1.0], texcoord: [0.0, 1.0],
}, },
WgpuVertex { WgpuVertex {
position: [1.0, -1.0], position: [1.0, -1.0, 0.0, 1.0],
texcoord: [1.0, 0.0], texcoord: [1.0, 0.0],
}, },
WgpuVertex { WgpuVertex {
position: [1.0, 1.0], position: [1.0, 1.0, 0.0, 1.0],
texcoord: [1.0, 1.0], texcoord: [1.0, 1.0],
}, },
]; ];
const FINAL_VBO_DATA: [WgpuVertex; 4] = [ const FINAL_VBO_DATA: [WgpuVertex; 4] = [
WgpuVertex { WgpuVertex {
position: [0.0, 0.0], position: [0.0, 0.0, 0.0, 1.0],
texcoord: [0.0, 0.0], texcoord: [0.0, 0.0],
}, },
WgpuVertex { WgpuVertex {
position: [0.0, 1.0], position: [0.0, 1.0, 0.0, 1.0],
texcoord: [0.0, 1.0], texcoord: [0.0, 1.0],
}, },
WgpuVertex { WgpuVertex {
position: [1.0, 0.0], position: [1.0, 0.0, 0.0, 1.0],
texcoord: [1.0, 0.0], texcoord: [1.0, 0.0],
}, },
WgpuVertex { WgpuVertex {
position: [1.0, 1.0], position: [1.0, 1.0, 0.0, 1.0],
texcoord: [1.0, 1.0], texcoord: [1.0, 1.0],
}, },
]; ];

View file

@ -15,6 +15,7 @@ use rayon::prelude::*;
use std::collections::VecDeque; use std::collections::VecDeque;
use std::path::Path; use std::path::Path;
use rayon::ThreadPoolBuilder;
use std::sync::Arc; use std::sync::Arc;
use crate::buffer::WgpuStagedBuffer; use crate::buffer::WgpuStagedBuffer;
@ -267,6 +268,12 @@ impl FilterChainWgpu {
#[cfg(target_arch = "wasm32")] #[cfg(target_arch = "wasm32")]
let passes_iter = passes.into_iter(); let passes_iter = passes.into_iter();
let thread_pool = ThreadPoolBuilder::new()
.stack_size(10 * 1048576)
.build()
.unwrap();
let filters = thread_pool.install(|| {
let filters: Vec<error::Result<FilterPass>> = passes_iter let filters: Vec<error::Result<FilterPass>> = passes_iter
.enumerate() .enumerate()
.map(|(index, (config, source, mut reflect))| { .map(|(index, (config, source, mut reflect))| {
@ -297,7 +304,8 @@ impl FilterChainWgpu {
), ),
); );
let uniform_bindings = reflection.meta.create_binding_map(|param| param.offset()); let uniform_bindings =
reflection.meta.create_binding_map(|param| param.offset());
let render_pass_format: Option<TextureFormat> = let render_pass_format: Option<TextureFormat> =
if let Some(format) = config.get_format_override() { if let Some(format) = config.get_format_override() {
@ -324,6 +332,9 @@ impl FilterChainWgpu {
}) })
}) })
.collect(); .collect();
filters
});
// //
let filters: error::Result<Vec<FilterPass>> = filters.into_iter().collect(); let filters: error::Result<Vec<FilterPass>> = filters.into_iter().collect();
let filters = filters?; let filters = filters?;

View file

@ -1,3 +1,4 @@
use crate::draw_quad::WgpuVertex;
use crate::framebuffer::WgpuOutputView; use crate::framebuffer::WgpuOutputView;
use crate::util; use crate::util;
use librashader_reflect::back::wgsl::NagaWgslContext; use librashader_reflect::back::wgsl::NagaWgslContext;
@ -152,17 +153,19 @@ impl PipelineLayoutObjects {
module: &self.vertex, module: &self.vertex,
entry_point: &self.vertex_entry_name, entry_point: &self.vertex_entry_name,
buffers: &[VertexBufferLayout { buffers: &[VertexBufferLayout {
array_stride: 4 * std::mem::size_of::<f32>() as wgpu::BufferAddress, array_stride: std::mem::size_of::<WgpuVertex>() as wgpu::BufferAddress,
step_mode: wgpu::VertexStepMode::Vertex, step_mode: wgpu::VertexStepMode::Vertex,
attributes: &[ attributes: &[
wgpu::VertexAttribute { wgpu::VertexAttribute {
format: wgpu::VertexFormat::Float32x2, format: wgpu::VertexFormat::Float32x4,
offset: 0, offset: bytemuck::offset_of!(WgpuVertex, position)
as wgpu::BufferAddress,
shader_location: 0, shader_location: 0,
}, },
wgpu::VertexAttribute { wgpu::VertexAttribute {
format: wgpu::VertexFormat::Float32x2, format: wgpu::VertexFormat::Float32x2,
offset: (2 * std::mem::size_of::<f32>()) as wgpu::BufferAddress, offset: bytemuck::offset_of!(WgpuVertex, texcoord)
as wgpu::BufferAddress,
shader_location: 1, shader_location: 1,
}, },
], ],
@ -173,7 +176,7 @@ impl PipelineLayoutObjects {
entry_point: &self.fragment_entry_name, entry_point: &self.fragment_entry_name,
targets: &[Some(wgpu::ColorTargetState { targets: &[Some(wgpu::ColorTargetState {
format: framebuffer_format, format: framebuffer_format,
blend: Some(wgpu::BlendState::REPLACE), blend: None,
write_mask: wgpu::ColorWrites::ALL, write_mask: wgpu::ColorWrites::ALL,
})], })],
}), }),

View file

@ -117,11 +117,19 @@ impl<'a> State<'a> {
let device = Arc::new(device); let device = Arc::new(device);
let queue = Arc::new(queue); let queue = Arc::new(queue);
//
// let preset = ShaderPreset::try_parse(
// "../test/basic.slangp",
// )
// .unwrap();
let preset = ShaderPreset::try_parse( let preset =
"../test/shaders_slang/bezel/Mega_Bezel/Presets/MBZ__0__SMOOTH-ADV.slangp", ShaderPreset::try_parse("../test/shaders_slang/crt/crt-royale.slangp").unwrap();
)
.unwrap(); // let preset = ShaderPreset::try_parse(
// "../test/shaders_slang/bezel/Mega_Bezel/Presets/MBZ__0__SMOOTH-ADV.slangp",
// )
// .unwrap();
let chain = FilterChainWgpu::load_from_preset( let chain = FilterChainWgpu::load_from_preset(
preset, preset,

View file

@ -32,5 +32,5 @@ layout(location = 0) out vec4 FragColor;
layout(binding = 1) uniform sampler2D Source; layout(binding = 1) uniform sampler2D Source;
void main() void main()
{ {
FragColor = texture(Source, vTexCoord) * params.ColorMod * ColorMod2; FragColor = texture(Source, vTexCoord) * params.ColorMod;
} }