From c0ecae844c041460669fcd617a2b4cf6b63977b2 Mon Sep 17 00:00:00 2001 From: chyyran Date: Thu, 15 Feb 2024 19:22:28 -0500 Subject: [PATCH] reflect(wgsl): link spirv to remove unused input/outputs --- Cargo.lock | 16 +++- librashader-reflect/Cargo.toml | 4 +- librashader-reflect/src/reflect/naga/mod.rs | 36 +++++---- .../link_input_outputs.rs} | 73 ++++++++++++------- .../naga/{ => spirv_passes}/lower_samplers.rs | 0 .../src/reflect/naga/spirv_passes/mod.rs | 2 + librashader-reflect/src/reflect/naga/wgsl.rs | 2 +- .../tests/hello_triangle.rs | 14 ++-- 8 files changed, 94 insertions(+), 53 deletions(-) rename librashader-reflect/src/reflect/naga/{trim_unused_inputs.rs => spirv_passes/link_input_outputs.rs} (60%) rename librashader-reflect/src/reflect/naga/{ => spirv_passes}/lower_samplers.rs (100%) create mode 100644 librashader-reflect/src/reflect/naga/spirv_passes/mod.rs diff --git a/Cargo.lock b/Cargo.lock index a3fc38a..b7b5e03 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1649,10 +1649,10 @@ dependencies = [ "librashader-spirv-cross", "matches", "naga", - "rspirv", + "rspirv 0.12.0+sdk-1.3.268.0", "rustc-hash", "serde", - "spirv 0.2.0+1.5.4", + "spirv 0.3.0+sdk-1.3.268.0", "spirv-linker", "spirv-to-dxil", "thiserror", @@ -2631,6 +2631,16 @@ dependencies = [ "spirv 0.2.0+1.5.4", ] +[[package]] +name = "rspirv" +version = "0.12.0+sdk-1.3.268.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cf3a93856b6e5946537278df0d3075596371b1950ccff012f02b0f7eafec8d" +dependencies = [ + "rustc-hash", + "spirv 0.3.0+sdk-1.3.268.0", +] + [[package]] name = "rust-ini" version = "0.18.0" @@ -2852,7 +2862,7 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d236255ec7387809e18d57221d0fa428d6f909abc88524944cdfb7ebab01acb" dependencies = [ - "rspirv", + "rspirv 0.11.0+1.5.4", "thiserror", "topological-sort", ] diff --git a/librashader-reflect/Cargo.toml b/librashader-reflect/Cargo.toml index a1a7a21..5167111 100644 --- a/librashader-reflect/Cargo.toml +++ b/librashader-reflect/Cargo.toml @@ -26,8 +26,8 @@ spirv_cross = { package = "librashader-spirv-cross", version = "0.25.1", optiona spirv-linker = "0.1.0" naga = { version = "0.19.0", optional = true } -rspirv = { version = "0.11.0", optional = true } -spirv = { version = "0.2.0", optional = true} +rspirv = { version = "0.12.0", optional = true } +spirv = { version = "0.3.0", optional = true} serde = { version = "1.0", features = ["derive"], optional = true } diff --git a/librashader-reflect/src/reflect/naga/mod.rs b/librashader-reflect/src/reflect/naga/mod.rs index 77691c7..f4e725b 100644 --- a/librashader-reflect/src/reflect/naga/mod.rs +++ b/librashader-reflect/src/reflect/naga/mod.rs @@ -1,23 +1,22 @@ -mod lower_samplers; pub mod msl; pub mod spirv; -mod trim_unused_inputs; +mod spirv_passes; pub mod wgsl; use crate::error::{SemanticsErrorKind, ShaderReflectError}; use bitflags::Flags; use crate::front::SpirvCompilation; -use naga::valid::{Capabilities, ModuleInfo, ValidationFlags, Validator}; use naga::{ AddressSpace, Binding, Expression, GlobalVariable, Handle, ImageClass, Module, ResourceBinding, Scalar, ScalarKind, StructMember, TypeInner, VectorSize, }; use rspirv::binary::Assemble; use rspirv::dr::Builder; -use rustc_hash::{FxHashMap, FxHashSet}; +use rustc_hash::FxHashSet; use crate::reflect::helper::{SemanticErrorBlame, TextureData, UboData}; +use crate::reflect::naga::spirv_passes::{link_input_outputs, lower_samplers}; use crate::reflect::semantics::{ BindingMeta, BindingStage, BufferReflection, MemberOffset, ShaderSemantics, TextureBinding, TextureSemanticMap, TextureSemantics, TextureSizeMeta, TypeInfo, UniformMemberBlock, @@ -117,20 +116,17 @@ impl TryFrom<&SpirvCompilation> for NagaReflect { type Error = ShaderReflectError; fn try_from(compile: &SpirvCompilation) -> Result { - fn lower_fragment_shader(words: &[u32]) -> Vec { + fn load_module(words: &[u32]) -> rspirv::dr::Module { let mut loader = rspirv::dr::Loader::new(); rspirv::binary::parse_words(words, &mut loader).unwrap(); let module = loader.module(); - let mut builder = Builder::new_from_module(module); - - let mut pass = lower_samplers::LowerCombinedImageSamplerPass::new(&mut builder); + module + } + fn lower_fragment_shader(builder: &mut Builder) { + let mut pass = lower_samplers::LowerCombinedImageSamplerPass::new(builder); pass.ensure_op_type_sampler(); pass.do_pass(); - - let module = builder.module(); - - module.assemble() } let options = naga::front::spv::Options { @@ -139,10 +135,20 @@ impl TryFrom<&SpirvCompilation> for NagaReflect { block_ctx_dump_prefix: None, }; - let vertex = - naga::front::spv::parse_u8_slice(bytemuck::cast_slice(&compile.vertex), &options)?; + let vertex = load_module(&compile.vertex); + let fragment = load_module(&compile.fragment); + + let mut fragment = Builder::new_from_module(fragment); + lower_fragment_shader(&mut fragment); + + let mut pass = link_input_outputs::LinkInputs::new(&vertex, &mut fragment); + pass.do_pass(); + + let vertex = vertex.assemble(); + let fragment = fragment.module().assemble(); + + let vertex = naga::front::spv::parse_u8_slice(bytemuck::cast_slice(&vertex), &options)?; - let fragment = lower_fragment_shader(&compile.fragment); let fragment = naga::front::spv::parse_u8_slice(bytemuck::cast_slice(&fragment), &options)?; Ok(NagaReflect { vertex, fragment }) diff --git a/librashader-reflect/src/reflect/naga/trim_unused_inputs.rs b/librashader-reflect/src/reflect/naga/spirv_passes/link_input_outputs.rs similarity index 60% rename from librashader-reflect/src/reflect/naga/trim_unused_inputs.rs rename to librashader-reflect/src/reflect/naga/spirv_passes/link_input_outputs.rs index 6d431dc..f1a91fa 100644 --- a/librashader-reflect/src/reflect/naga/trim_unused_inputs.rs +++ b/librashader-reflect/src/reflect/naga/spirv_passes/link_input_outputs.rs @@ -1,12 +1,10 @@ use rspirv::dr::{Builder, Module, Operand}; use rustc_hash::{FxHashMap, FxHashSet}; -use spirv::{Op, StorageClass}; +use spirv::{Decoration, Op, StorageClass}; pub struct LinkInputs<'a> { pub frag_builder: &'a mut Builder, - pub vert: &'a Module, - - pub inputs: FxHashMap, + pub inputs: FxHashSet, } impl<'a> LinkInputs<'a> { @@ -16,37 +14,62 @@ impl<'a> LinkInputs<'a> { return None; } - eprintln!("{:?}", op); - return None; + let Some(Operand::Decoration(Decoration::Location)) = op.operands.get(1) else { + return None; + }; + + let Some(&Operand::IdRef(target)) = op.operands.get(0) else { + return None; + }; + + if target != id { + return None; + } + + let Some(&Operand::LiteralBit32(binding)) = op.operands.get(2) else { + return None; + }; + return Some(binding); }) } pub fn new(vert: &'a Module, frag: &'a mut Builder) -> Self { - let mut inputs = FxHashMap::default(); - + let mut inputs = FxHashSet::default(); + let mut bindings = FxHashMap::default(); for global in frag.module_ref().types_global_values.iter() { if global.class.opcode == spirv::Op::Variable && global.operands[0] == Operand::StorageClass(StorageClass::Input) { if let Some(id) = global.result_id { - inputs.insert(id, 0); + let Some(location) = Self::find_location(frag.module_ref(), id) else { + continue; + }; + + inputs.insert(id); + bindings.insert(location, id); } } } - // for global in vert.types_global_values.iter() { - // if global.class.opcode == spirv::Op::Variable - // && global.operands[0] == Operand::StorageClass(StorageClass::Output) - // { - // if let Some(id) = global.result_id { - // inputs.insert(id, 0); - // } - // } - // } + for global in vert.types_global_values.iter() { + if global.class.opcode == spirv::Op::Variable + && global.operands[0] == Operand::StorageClass(StorageClass::Output) + { + if let Some(id) = global.result_id { + let Some(location) = Self::find_location(vert, id) else { + continue; + }; + if let Some(frag_ref) = bindings.get(&location) { + // if something is bound to the same location in the vertex shader, + // we're good. + inputs.remove(&frag_ref); + } + } + } + } let mut val = Self { frag_builder: frag, - vert, inputs, }; val @@ -60,7 +83,7 @@ impl<'a> LinkInputs<'a> { for param in &function.parameters { for op in ¶m.operands { if let Some(word) = op.id_ref_any() { - if self.inputs.contains_key(&word) { + if self.inputs.contains(&word) { self.inputs.remove(&word); } } @@ -71,7 +94,7 @@ impl<'a> LinkInputs<'a> { for inst in &block.instructions { for op in &inst.operands { if let Some(word) = op.id_ref_any() { - if self.inputs.contains_key(&word) { + if self.inputs.contains(&word) { self.inputs.remove(&word); } } @@ -85,7 +108,7 @@ impl<'a> LinkInputs<'a> { self.frag_builder.module_mut().debug_names.retain(|instr| { for op in &instr.operands { if let Some(word) = op.id_ref_any() { - if self.inputs.contains_key(&word) { + if self.inputs.contains(&word) { return false; } } @@ -96,7 +119,7 @@ impl<'a> LinkInputs<'a> { self.frag_builder.module_mut().annotations.retain(|instr| { for op in &instr.operands { if let Some(word) = op.id_ref_any() { - if self.inputs.contains_key(&word) { + if self.inputs.contains(&word) { return false; } } @@ -107,7 +130,7 @@ impl<'a> LinkInputs<'a> { for entry_point in self.frag_builder.module_mut().entry_points.iter_mut() { entry_point.operands.retain(|op| { if let Some(word) = op.id_ref_any() { - if self.inputs.contains_key(&word) { + if self.inputs.contains(&word) { return false; } } @@ -123,7 +146,7 @@ impl<'a> LinkInputs<'a> { return true; }; - !self.inputs.contains_key(&id) + !self.inputs.contains(&id) }); } } diff --git a/librashader-reflect/src/reflect/naga/lower_samplers.rs b/librashader-reflect/src/reflect/naga/spirv_passes/lower_samplers.rs similarity index 100% rename from librashader-reflect/src/reflect/naga/lower_samplers.rs rename to librashader-reflect/src/reflect/naga/spirv_passes/lower_samplers.rs diff --git a/librashader-reflect/src/reflect/naga/spirv_passes/mod.rs b/librashader-reflect/src/reflect/naga/spirv_passes/mod.rs new file mode 100644 index 0000000..2af9848 --- /dev/null +++ b/librashader-reflect/src/reflect/naga/spirv_passes/mod.rs @@ -0,0 +1,2 @@ +pub mod link_input_outputs; +pub mod lower_samplers; diff --git a/librashader-reflect/src/reflect/naga/wgsl.rs b/librashader-reflect/src/reflect/naga/wgsl.rs index 90389da..388ba3f 100644 --- a/librashader-reflect/src/reflect/naga/wgsl.rs +++ b/librashader-reflect/src/reflect/naga/wgsl.rs @@ -5,7 +5,7 @@ use crate::error::ShaderCompileError; use crate::reflect::naga::{NagaLoweringOptions, NagaReflect}; use naga::back::wgsl::WriterFlags; use naga::valid::{Capabilities, ModuleInfo, ValidationFlags, Validator}; -use naga::{Expression, Module, Statement}; +use naga::Module; impl CompileShader for NagaReflect { type Options = NagaLoweringOptions; diff --git a/librashader-runtime-wgpu/tests/hello_triangle.rs b/librashader-runtime-wgpu/tests/hello_triangle.rs index cf6c9c4..1a60a23 100644 --- a/librashader-runtime-wgpu/tests/hello_triangle.rs +++ b/librashader-runtime-wgpu/tests/hello_triangle.rs @@ -122,14 +122,14 @@ impl<'a> State<'a> { // "../test/basic.slangp", // ) // .unwrap(); + // + // let preset = + // ShaderPreset::try_parse("../test/shaders_slang/crt/crt-royale.slangp").unwrap(); - let preset = - ShaderPreset::try_parse("../test/shaders_slang/crt/crt-royale.slangp").unwrap(); - - // let preset = ShaderPreset::try_parse( - // "../test/shaders_slang/bezel/Mega_Bezel/Presets/MBZ__0__SMOOTH-ADV.slangp", - // ) - // .unwrap(); + let preset = ShaderPreset::try_parse( + "../test/shaders_slang/bezel/Mega_Bezel/Presets/MBZ__0__SMOOTH-ADV.slangp", + ) + .unwrap(); let chain = FilterChainWgpu::load_from_preset( preset,