reflect(wgsl): link spirv to remove unused input/outputs

This commit is contained in:
chyyran 2024-02-15 19:22:28 -05:00 committed by Ronny Chan
parent cbac011969
commit c0ecae844c
8 changed files with 94 additions and 53 deletions

16
Cargo.lock generated
View file

@ -1649,10 +1649,10 @@ dependencies = [
"librashader-spirv-cross", "librashader-spirv-cross",
"matches", "matches",
"naga", "naga",
"rspirv", "rspirv 0.12.0+sdk-1.3.268.0",
"rustc-hash", "rustc-hash",
"serde", "serde",
"spirv 0.2.0+1.5.4", "spirv 0.3.0+sdk-1.3.268.0",
"spirv-linker", "spirv-linker",
"spirv-to-dxil", "spirv-to-dxil",
"thiserror", "thiserror",
@ -2631,6 +2631,16 @@ dependencies = [
"spirv 0.2.0+1.5.4", "spirv 0.2.0+1.5.4",
] ]
[[package]]
name = "rspirv"
version = "0.12.0+sdk-1.3.268.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69cf3a93856b6e5946537278df0d3075596371b1950ccff012f02b0f7eafec8d"
dependencies = [
"rustc-hash",
"spirv 0.3.0+sdk-1.3.268.0",
]
[[package]] [[package]]
name = "rust-ini" name = "rust-ini"
version = "0.18.0" version = "0.18.0"
@ -2852,7 +2862,7 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d236255ec7387809e18d57221d0fa428d6f909abc88524944cdfb7ebab01acb" checksum = "5d236255ec7387809e18d57221d0fa428d6f909abc88524944cdfb7ebab01acb"
dependencies = [ dependencies = [
"rspirv", "rspirv 0.11.0+1.5.4",
"thiserror", "thiserror",
"topological-sort", "topological-sort",
] ]

View file

@ -26,8 +26,8 @@ spirv_cross = { package = "librashader-spirv-cross", version = "0.25.1", optiona
spirv-linker = "0.1.0" spirv-linker = "0.1.0"
naga = { version = "0.19.0", optional = true } naga = { version = "0.19.0", optional = true }
rspirv = { version = "0.11.0", optional = true } rspirv = { version = "0.12.0", optional = true }
spirv = { version = "0.2.0", optional = true} spirv = { version = "0.3.0", optional = true}
serde = { version = "1.0", features = ["derive"], optional = true } serde = { version = "1.0", features = ["derive"], optional = true }

View file

@ -1,23 +1,22 @@
mod lower_samplers;
pub mod msl; pub mod msl;
pub mod spirv; pub mod spirv;
mod trim_unused_inputs; mod spirv_passes;
pub mod wgsl; pub mod wgsl;
use crate::error::{SemanticsErrorKind, ShaderReflectError}; use crate::error::{SemanticsErrorKind, ShaderReflectError};
use bitflags::Flags; use bitflags::Flags;
use crate::front::SpirvCompilation; use crate::front::SpirvCompilation;
use naga::valid::{Capabilities, ModuleInfo, ValidationFlags, Validator};
use naga::{ use naga::{
AddressSpace, Binding, Expression, GlobalVariable, Handle, ImageClass, Module, ResourceBinding, AddressSpace, Binding, Expression, GlobalVariable, Handle, ImageClass, Module, ResourceBinding,
Scalar, ScalarKind, StructMember, TypeInner, VectorSize, Scalar, ScalarKind, StructMember, TypeInner, VectorSize,
}; };
use rspirv::binary::Assemble; use rspirv::binary::Assemble;
use rspirv::dr::Builder; use rspirv::dr::Builder;
use rustc_hash::{FxHashMap, FxHashSet}; use rustc_hash::FxHashSet;
use crate::reflect::helper::{SemanticErrorBlame, TextureData, UboData}; use crate::reflect::helper::{SemanticErrorBlame, TextureData, UboData};
use crate::reflect::naga::spirv_passes::{link_input_outputs, lower_samplers};
use crate::reflect::semantics::{ use crate::reflect::semantics::{
BindingMeta, BindingStage, BufferReflection, MemberOffset, ShaderSemantics, TextureBinding, BindingMeta, BindingStage, BufferReflection, MemberOffset, ShaderSemantics, TextureBinding,
TextureSemanticMap, TextureSemantics, TextureSizeMeta, TypeInfo, UniformMemberBlock, TextureSemanticMap, TextureSemantics, TextureSizeMeta, TypeInfo, UniformMemberBlock,
@ -117,20 +116,17 @@ impl TryFrom<&SpirvCompilation> for NagaReflect {
type Error = ShaderReflectError; type Error = ShaderReflectError;
fn try_from(compile: &SpirvCompilation) -> Result<Self, Self::Error> { fn try_from(compile: &SpirvCompilation) -> Result<Self, Self::Error> {
fn lower_fragment_shader(words: &[u32]) -> Vec<u32> { fn load_module(words: &[u32]) -> rspirv::dr::Module {
let mut loader = rspirv::dr::Loader::new(); let mut loader = rspirv::dr::Loader::new();
rspirv::binary::parse_words(words, &mut loader).unwrap(); rspirv::binary::parse_words(words, &mut loader).unwrap();
let module = loader.module(); let module = loader.module();
let mut builder = Builder::new_from_module(module); module
}
let mut pass = lower_samplers::LowerCombinedImageSamplerPass::new(&mut builder);
fn lower_fragment_shader(builder: &mut Builder) {
let mut pass = lower_samplers::LowerCombinedImageSamplerPass::new(builder);
pass.ensure_op_type_sampler(); pass.ensure_op_type_sampler();
pass.do_pass(); pass.do_pass();
let module = builder.module();
module.assemble()
} }
let options = naga::front::spv::Options { let options = naga::front::spv::Options {
@ -139,10 +135,20 @@ impl TryFrom<&SpirvCompilation> for NagaReflect {
block_ctx_dump_prefix: None, block_ctx_dump_prefix: None,
}; };
let vertex = let vertex = load_module(&compile.vertex);
naga::front::spv::parse_u8_slice(bytemuck::cast_slice(&compile.vertex), &options)?; let fragment = load_module(&compile.fragment);
let mut fragment = Builder::new_from_module(fragment);
lower_fragment_shader(&mut fragment);
let mut pass = link_input_outputs::LinkInputs::new(&vertex, &mut fragment);
pass.do_pass();
let vertex = vertex.assemble();
let fragment = fragment.module().assemble();
let vertex = naga::front::spv::parse_u8_slice(bytemuck::cast_slice(&vertex), &options)?;
let fragment = lower_fragment_shader(&compile.fragment);
let fragment = naga::front::spv::parse_u8_slice(bytemuck::cast_slice(&fragment), &options)?; let fragment = naga::front::spv::parse_u8_slice(bytemuck::cast_slice(&fragment), &options)?;
Ok(NagaReflect { vertex, fragment }) Ok(NagaReflect { vertex, fragment })

View file

@ -1,12 +1,10 @@
use rspirv::dr::{Builder, Module, Operand}; use rspirv::dr::{Builder, Module, Operand};
use rustc_hash::{FxHashMap, FxHashSet}; use rustc_hash::{FxHashMap, FxHashSet};
use spirv::{Op, StorageClass}; use spirv::{Decoration, Op, StorageClass};
pub struct LinkInputs<'a> { pub struct LinkInputs<'a> {
pub frag_builder: &'a mut Builder, pub frag_builder: &'a mut Builder,
pub vert: &'a Module, pub inputs: FxHashSet<spirv::Word>,
pub inputs: FxHashMap<spirv::Word, spirv::Word>,
} }
impl<'a> LinkInputs<'a> { impl<'a> LinkInputs<'a> {
@ -16,37 +14,62 @@ impl<'a> LinkInputs<'a> {
return None; return None;
} }
eprintln!("{:?}", op); let Some(Operand::Decoration(Decoration::Location)) = op.operands.get(1) else {
return None; return None;
};
let Some(&Operand::IdRef(target)) = op.operands.get(0) else {
return None;
};
if target != id {
return None;
}
let Some(&Operand::LiteralBit32(binding)) = op.operands.get(2) else {
return None;
};
return Some(binding);
}) })
} }
pub fn new(vert: &'a Module, frag: &'a mut Builder) -> Self { pub fn new(vert: &'a Module, frag: &'a mut Builder) -> Self {
let mut inputs = FxHashMap::default(); let mut inputs = FxHashSet::default();
let mut bindings = FxHashMap::default();
for global in frag.module_ref().types_global_values.iter() { for global in frag.module_ref().types_global_values.iter() {
if global.class.opcode == spirv::Op::Variable if global.class.opcode == spirv::Op::Variable
&& global.operands[0] == Operand::StorageClass(StorageClass::Input) && global.operands[0] == Operand::StorageClass(StorageClass::Input)
{ {
if let Some(id) = global.result_id { if let Some(id) = global.result_id {
inputs.insert(id, 0); let Some(location) = Self::find_location(frag.module_ref(), id) else {
continue;
};
inputs.insert(id);
bindings.insert(location, id);
} }
} }
} }
// for global in vert.types_global_values.iter() { for global in vert.types_global_values.iter() {
// if global.class.opcode == spirv::Op::Variable if global.class.opcode == spirv::Op::Variable
// && global.operands[0] == Operand::StorageClass(StorageClass::Output) && global.operands[0] == Operand::StorageClass(StorageClass::Output)
// { {
// if let Some(id) = global.result_id { if let Some(id) = global.result_id {
// inputs.insert(id, 0); let Some(location) = Self::find_location(vert, id) else {
// } continue;
// } };
// } if let Some(frag_ref) = bindings.get(&location) {
// if something is bound to the same location in the vertex shader,
// we're good.
inputs.remove(&frag_ref);
}
}
}
}
let mut val = Self { let mut val = Self {
frag_builder: frag, frag_builder: frag,
vert,
inputs, inputs,
}; };
val val
@ -60,7 +83,7 @@ impl<'a> LinkInputs<'a> {
for param in &function.parameters { for param in &function.parameters {
for op in &param.operands { for op in &param.operands {
if let Some(word) = op.id_ref_any() { if let Some(word) = op.id_ref_any() {
if self.inputs.contains_key(&word) { if self.inputs.contains(&word) {
self.inputs.remove(&word); self.inputs.remove(&word);
} }
} }
@ -71,7 +94,7 @@ impl<'a> LinkInputs<'a> {
for inst in &block.instructions { for inst in &block.instructions {
for op in &inst.operands { for op in &inst.operands {
if let Some(word) = op.id_ref_any() { if let Some(word) = op.id_ref_any() {
if self.inputs.contains_key(&word) { if self.inputs.contains(&word) {
self.inputs.remove(&word); self.inputs.remove(&word);
} }
} }
@ -85,7 +108,7 @@ impl<'a> LinkInputs<'a> {
self.frag_builder.module_mut().debug_names.retain(|instr| { self.frag_builder.module_mut().debug_names.retain(|instr| {
for op in &instr.operands { for op in &instr.operands {
if let Some(word) = op.id_ref_any() { if let Some(word) = op.id_ref_any() {
if self.inputs.contains_key(&word) { if self.inputs.contains(&word) {
return false; return false;
} }
} }
@ -96,7 +119,7 @@ impl<'a> LinkInputs<'a> {
self.frag_builder.module_mut().annotations.retain(|instr| { self.frag_builder.module_mut().annotations.retain(|instr| {
for op in &instr.operands { for op in &instr.operands {
if let Some(word) = op.id_ref_any() { if let Some(word) = op.id_ref_any() {
if self.inputs.contains_key(&word) { if self.inputs.contains(&word) {
return false; return false;
} }
} }
@ -107,7 +130,7 @@ impl<'a> LinkInputs<'a> {
for entry_point in self.frag_builder.module_mut().entry_points.iter_mut() { for entry_point in self.frag_builder.module_mut().entry_points.iter_mut() {
entry_point.operands.retain(|op| { entry_point.operands.retain(|op| {
if let Some(word) = op.id_ref_any() { if let Some(word) = op.id_ref_any() {
if self.inputs.contains_key(&word) { if self.inputs.contains(&word) {
return false; return false;
} }
} }
@ -123,7 +146,7 @@ impl<'a> LinkInputs<'a> {
return true; return true;
}; };
!self.inputs.contains_key(&id) !self.inputs.contains(&id)
}); });
} }
} }

View file

@ -0,0 +1,2 @@
pub mod link_input_outputs;
pub mod lower_samplers;

View file

@ -5,7 +5,7 @@ use crate::error::ShaderCompileError;
use crate::reflect::naga::{NagaLoweringOptions, NagaReflect}; use crate::reflect::naga::{NagaLoweringOptions, NagaReflect};
use naga::back::wgsl::WriterFlags; use naga::back::wgsl::WriterFlags;
use naga::valid::{Capabilities, ModuleInfo, ValidationFlags, Validator}; use naga::valid::{Capabilities, ModuleInfo, ValidationFlags, Validator};
use naga::{Expression, Module, Statement}; use naga::Module;
impl CompileShader<WGSL> for NagaReflect { impl CompileShader<WGSL> for NagaReflect {
type Options = NagaLoweringOptions; type Options = NagaLoweringOptions;

View file

@ -122,14 +122,14 @@ impl<'a> State<'a> {
// "../test/basic.slangp", // "../test/basic.slangp",
// ) // )
// .unwrap(); // .unwrap();
//
// let preset =
// ShaderPreset::try_parse("../test/shaders_slang/crt/crt-royale.slangp").unwrap();
let preset = let preset = ShaderPreset::try_parse(
ShaderPreset::try_parse("../test/shaders_slang/crt/crt-royale.slangp").unwrap(); "../test/shaders_slang/bezel/Mega_Bezel/Presets/MBZ__0__SMOOTH-ADV.slangp",
)
// let preset = ShaderPreset::try_parse( .unwrap();
// "../test/shaders_slang/bezel/Mega_Bezel/Presets/MBZ__0__SMOOTH-ADV.slangp",
// )
// .unwrap();
let chain = FilterChainWgpu::load_from_preset( let chain = FilterChainWgpu::load_from_preset(
preset, preset,