reflect(wgsl): link spirv to remove unused input/outputs

This commit is contained in:
chyyran 2024-02-15 19:22:28 -05:00 committed by Ronny Chan
parent cbac011969
commit c0ecae844c
8 changed files with 94 additions and 53 deletions

16
Cargo.lock generated
View file

@ -1649,10 +1649,10 @@ dependencies = [
"librashader-spirv-cross",
"matches",
"naga",
"rspirv",
"rspirv 0.12.0+sdk-1.3.268.0",
"rustc-hash",
"serde",
"spirv 0.2.0+1.5.4",
"spirv 0.3.0+sdk-1.3.268.0",
"spirv-linker",
"spirv-to-dxil",
"thiserror",
@ -2631,6 +2631,16 @@ dependencies = [
"spirv 0.2.0+1.5.4",
]
[[package]]
name = "rspirv"
version = "0.12.0+sdk-1.3.268.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69cf3a93856b6e5946537278df0d3075596371b1950ccff012f02b0f7eafec8d"
dependencies = [
"rustc-hash",
"spirv 0.3.0+sdk-1.3.268.0",
]
[[package]]
name = "rust-ini"
version = "0.18.0"
@ -2852,7 +2862,7 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d236255ec7387809e18d57221d0fa428d6f909abc88524944cdfb7ebab01acb"
dependencies = [
"rspirv",
"rspirv 0.11.0+1.5.4",
"thiserror",
"topological-sort",
]

View file

@ -26,8 +26,8 @@ spirv_cross = { package = "librashader-spirv-cross", version = "0.25.1", optiona
spirv-linker = "0.1.0"
naga = { version = "0.19.0", optional = true }
rspirv = { version = "0.11.0", optional = true }
spirv = { version = "0.2.0", optional = true}
rspirv = { version = "0.12.0", optional = true }
spirv = { version = "0.3.0", optional = true}
serde = { version = "1.0", features = ["derive"], optional = true }

View file

@ -1,23 +1,22 @@
mod lower_samplers;
pub mod msl;
pub mod spirv;
mod trim_unused_inputs;
mod spirv_passes;
pub mod wgsl;
use crate::error::{SemanticsErrorKind, ShaderReflectError};
use bitflags::Flags;
use crate::front::SpirvCompilation;
use naga::valid::{Capabilities, ModuleInfo, ValidationFlags, Validator};
use naga::{
AddressSpace, Binding, Expression, GlobalVariable, Handle, ImageClass, Module, ResourceBinding,
Scalar, ScalarKind, StructMember, TypeInner, VectorSize,
};
use rspirv::binary::Assemble;
use rspirv::dr::Builder;
use rustc_hash::{FxHashMap, FxHashSet};
use rustc_hash::FxHashSet;
use crate::reflect::helper::{SemanticErrorBlame, TextureData, UboData};
use crate::reflect::naga::spirv_passes::{link_input_outputs, lower_samplers};
use crate::reflect::semantics::{
BindingMeta, BindingStage, BufferReflection, MemberOffset, ShaderSemantics, TextureBinding,
TextureSemanticMap, TextureSemantics, TextureSizeMeta, TypeInfo, UniformMemberBlock,
@ -117,20 +116,17 @@ impl TryFrom<&SpirvCompilation> for NagaReflect {
type Error = ShaderReflectError;
fn try_from(compile: &SpirvCompilation) -> Result<Self, Self::Error> {
fn lower_fragment_shader(words: &[u32]) -> Vec<u32> {
fn load_module(words: &[u32]) -> rspirv::dr::Module {
let mut loader = rspirv::dr::Loader::new();
rspirv::binary::parse_words(words, &mut loader).unwrap();
let module = loader.module();
let mut builder = Builder::new_from_module(module);
let mut pass = lower_samplers::LowerCombinedImageSamplerPass::new(&mut builder);
module
}
fn lower_fragment_shader(builder: &mut Builder) {
let mut pass = lower_samplers::LowerCombinedImageSamplerPass::new(builder);
pass.ensure_op_type_sampler();
pass.do_pass();
let module = builder.module();
module.assemble()
}
let options = naga::front::spv::Options {
@ -139,10 +135,20 @@ impl TryFrom<&SpirvCompilation> for NagaReflect {
block_ctx_dump_prefix: None,
};
let vertex =
naga::front::spv::parse_u8_slice(bytemuck::cast_slice(&compile.vertex), &options)?;
let vertex = load_module(&compile.vertex);
let fragment = load_module(&compile.fragment);
let mut fragment = Builder::new_from_module(fragment);
lower_fragment_shader(&mut fragment);
let mut pass = link_input_outputs::LinkInputs::new(&vertex, &mut fragment);
pass.do_pass();
let vertex = vertex.assemble();
let fragment = fragment.module().assemble();
let vertex = naga::front::spv::parse_u8_slice(bytemuck::cast_slice(&vertex), &options)?;
let fragment = lower_fragment_shader(&compile.fragment);
let fragment = naga::front::spv::parse_u8_slice(bytemuck::cast_slice(&fragment), &options)?;
Ok(NagaReflect { vertex, fragment })

View file

@ -1,12 +1,10 @@
use rspirv::dr::{Builder, Module, Operand};
use rustc_hash::{FxHashMap, FxHashSet};
use spirv::{Op, StorageClass};
use spirv::{Decoration, Op, StorageClass};
pub struct LinkInputs<'a> {
pub frag_builder: &'a mut Builder,
pub vert: &'a Module,
pub inputs: FxHashMap<spirv::Word, spirv::Word>,
pub inputs: FxHashSet<spirv::Word>,
}
impl<'a> LinkInputs<'a> {
@ -16,37 +14,62 @@ impl<'a> LinkInputs<'a> {
return None;
}
eprintln!("{:?}", op);
let Some(Operand::Decoration(Decoration::Location)) = op.operands.get(1) else {
return None;
};
let Some(&Operand::IdRef(target)) = op.operands.get(0) else {
return None;
};
if target != id {
return None;
}
let Some(&Operand::LiteralBit32(binding)) = op.operands.get(2) else {
return None;
};
return Some(binding);
})
}
pub fn new(vert: &'a Module, frag: &'a mut Builder) -> Self {
let mut inputs = FxHashMap::default();
let mut inputs = FxHashSet::default();
let mut bindings = FxHashMap::default();
for global in frag.module_ref().types_global_values.iter() {
if global.class.opcode == spirv::Op::Variable
&& global.operands[0] == Operand::StorageClass(StorageClass::Input)
{
if let Some(id) = global.result_id {
inputs.insert(id, 0);
let Some(location) = Self::find_location(frag.module_ref(), id) else {
continue;
};
inputs.insert(id);
bindings.insert(location, id);
}
}
}
// for global in vert.types_global_values.iter() {
// if global.class.opcode == spirv::Op::Variable
// && global.operands[0] == Operand::StorageClass(StorageClass::Output)
// {
// if let Some(id) = global.result_id {
// inputs.insert(id, 0);
// }
// }
// }
for global in vert.types_global_values.iter() {
if global.class.opcode == spirv::Op::Variable
&& global.operands[0] == Operand::StorageClass(StorageClass::Output)
{
if let Some(id) = global.result_id {
let Some(location) = Self::find_location(vert, id) else {
continue;
};
if let Some(frag_ref) = bindings.get(&location) {
// if something is bound to the same location in the vertex shader,
// we're good.
inputs.remove(&frag_ref);
}
}
}
}
let mut val = Self {
frag_builder: frag,
vert,
inputs,
};
val
@ -60,7 +83,7 @@ impl<'a> LinkInputs<'a> {
for param in &function.parameters {
for op in &param.operands {
if let Some(word) = op.id_ref_any() {
if self.inputs.contains_key(&word) {
if self.inputs.contains(&word) {
self.inputs.remove(&word);
}
}
@ -71,7 +94,7 @@ impl<'a> LinkInputs<'a> {
for inst in &block.instructions {
for op in &inst.operands {
if let Some(word) = op.id_ref_any() {
if self.inputs.contains_key(&word) {
if self.inputs.contains(&word) {
self.inputs.remove(&word);
}
}
@ -85,7 +108,7 @@ impl<'a> LinkInputs<'a> {
self.frag_builder.module_mut().debug_names.retain(|instr| {
for op in &instr.operands {
if let Some(word) = op.id_ref_any() {
if self.inputs.contains_key(&word) {
if self.inputs.contains(&word) {
return false;
}
}
@ -96,7 +119,7 @@ impl<'a> LinkInputs<'a> {
self.frag_builder.module_mut().annotations.retain(|instr| {
for op in &instr.operands {
if let Some(word) = op.id_ref_any() {
if self.inputs.contains_key(&word) {
if self.inputs.contains(&word) {
return false;
}
}
@ -107,7 +130,7 @@ impl<'a> LinkInputs<'a> {
for entry_point in self.frag_builder.module_mut().entry_points.iter_mut() {
entry_point.operands.retain(|op| {
if let Some(word) = op.id_ref_any() {
if self.inputs.contains_key(&word) {
if self.inputs.contains(&word) {
return false;
}
}
@ -123,7 +146,7 @@ impl<'a> LinkInputs<'a> {
return true;
};
!self.inputs.contains_key(&id)
!self.inputs.contains(&id)
});
}
}

View file

@ -0,0 +1,2 @@
pub mod link_input_outputs;
pub mod lower_samplers;

View file

@ -5,7 +5,7 @@ use crate::error::ShaderCompileError;
use crate::reflect::naga::{NagaLoweringOptions, NagaReflect};
use naga::back::wgsl::WriterFlags;
use naga::valid::{Capabilities, ModuleInfo, ValidationFlags, Validator};
use naga::{Expression, Module, Statement};
use naga::Module;
impl CompileShader<WGSL> for NagaReflect {
type Options = NagaLoweringOptions;

View file

@ -122,14 +122,14 @@ impl<'a> State<'a> {
// "../test/basic.slangp",
// )
// .unwrap();
//
// let preset =
// ShaderPreset::try_parse("../test/shaders_slang/crt/crt-royale.slangp").unwrap();
let preset =
ShaderPreset::try_parse("../test/shaders_slang/crt/crt-royale.slangp").unwrap();
// let preset = ShaderPreset::try_parse(
// "../test/shaders_slang/bezel/Mega_Bezel/Presets/MBZ__0__SMOOTH-ADV.slangp",
// )
// .unwrap();
let preset = ShaderPreset::try_parse(
"../test/shaders_slang/bezel/Mega_Bezel/Presets/MBZ__0__SMOOTH-ADV.slangp",
)
.unwrap();
let chain = FilterChainWgpu::load_from_preset(
preset,