reflect(wgsl): link spirv to remove unused input/outputs
This commit is contained in:
parent
cbac011969
commit
c0ecae844c
16
Cargo.lock
generated
16
Cargo.lock
generated
|
@ -1649,10 +1649,10 @@ dependencies = [
|
|||
"librashader-spirv-cross",
|
||||
"matches",
|
||||
"naga",
|
||||
"rspirv",
|
||||
"rspirv 0.12.0+sdk-1.3.268.0",
|
||||
"rustc-hash",
|
||||
"serde",
|
||||
"spirv 0.2.0+1.5.4",
|
||||
"spirv 0.3.0+sdk-1.3.268.0",
|
||||
"spirv-linker",
|
||||
"spirv-to-dxil",
|
||||
"thiserror",
|
||||
|
@ -2631,6 +2631,16 @@ dependencies = [
|
|||
"spirv 0.2.0+1.5.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rspirv"
|
||||
version = "0.12.0+sdk-1.3.268.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "69cf3a93856b6e5946537278df0d3075596371b1950ccff012f02b0f7eafec8d"
|
||||
dependencies = [
|
||||
"rustc-hash",
|
||||
"spirv 0.3.0+sdk-1.3.268.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rust-ini"
|
||||
version = "0.18.0"
|
||||
|
@ -2852,7 +2862,7 @@ version = "0.1.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5d236255ec7387809e18d57221d0fa428d6f909abc88524944cdfb7ebab01acb"
|
||||
dependencies = [
|
||||
"rspirv",
|
||||
"rspirv 0.11.0+1.5.4",
|
||||
"thiserror",
|
||||
"topological-sort",
|
||||
]
|
||||
|
|
|
@ -26,8 +26,8 @@ spirv_cross = { package = "librashader-spirv-cross", version = "0.25.1", optiona
|
|||
spirv-linker = "0.1.0"
|
||||
|
||||
naga = { version = "0.19.0", optional = true }
|
||||
rspirv = { version = "0.11.0", optional = true }
|
||||
spirv = { version = "0.2.0", optional = true}
|
||||
rspirv = { version = "0.12.0", optional = true }
|
||||
spirv = { version = "0.3.0", optional = true}
|
||||
|
||||
serde = { version = "1.0", features = ["derive"], optional = true }
|
||||
|
||||
|
|
|
@ -1,23 +1,22 @@
|
|||
mod lower_samplers;
|
||||
pub mod msl;
|
||||
pub mod spirv;
|
||||
mod trim_unused_inputs;
|
||||
mod spirv_passes;
|
||||
pub mod wgsl;
|
||||
|
||||
use crate::error::{SemanticsErrorKind, ShaderReflectError};
|
||||
use bitflags::Flags;
|
||||
|
||||
use crate::front::SpirvCompilation;
|
||||
use naga::valid::{Capabilities, ModuleInfo, ValidationFlags, Validator};
|
||||
use naga::{
|
||||
AddressSpace, Binding, Expression, GlobalVariable, Handle, ImageClass, Module, ResourceBinding,
|
||||
Scalar, ScalarKind, StructMember, TypeInner, VectorSize,
|
||||
};
|
||||
use rspirv::binary::Assemble;
|
||||
use rspirv::dr::Builder;
|
||||
use rustc_hash::{FxHashMap, FxHashSet};
|
||||
use rustc_hash::FxHashSet;
|
||||
|
||||
use crate::reflect::helper::{SemanticErrorBlame, TextureData, UboData};
|
||||
use crate::reflect::naga::spirv_passes::{link_input_outputs, lower_samplers};
|
||||
use crate::reflect::semantics::{
|
||||
BindingMeta, BindingStage, BufferReflection, MemberOffset, ShaderSemantics, TextureBinding,
|
||||
TextureSemanticMap, TextureSemantics, TextureSizeMeta, TypeInfo, UniformMemberBlock,
|
||||
|
@ -117,20 +116,17 @@ impl TryFrom<&SpirvCompilation> for NagaReflect {
|
|||
type Error = ShaderReflectError;
|
||||
|
||||
fn try_from(compile: &SpirvCompilation) -> Result<Self, Self::Error> {
|
||||
fn lower_fragment_shader(words: &[u32]) -> Vec<u32> {
|
||||
fn load_module(words: &[u32]) -> rspirv::dr::Module {
|
||||
let mut loader = rspirv::dr::Loader::new();
|
||||
rspirv::binary::parse_words(words, &mut loader).unwrap();
|
||||
let module = loader.module();
|
||||
let mut builder = Builder::new_from_module(module);
|
||||
|
||||
let mut pass = lower_samplers::LowerCombinedImageSamplerPass::new(&mut builder);
|
||||
module
|
||||
}
|
||||
|
||||
fn lower_fragment_shader(builder: &mut Builder) {
|
||||
let mut pass = lower_samplers::LowerCombinedImageSamplerPass::new(builder);
|
||||
pass.ensure_op_type_sampler();
|
||||
pass.do_pass();
|
||||
|
||||
let module = builder.module();
|
||||
|
||||
module.assemble()
|
||||
}
|
||||
|
||||
let options = naga::front::spv::Options {
|
||||
|
@ -139,10 +135,20 @@ impl TryFrom<&SpirvCompilation> for NagaReflect {
|
|||
block_ctx_dump_prefix: None,
|
||||
};
|
||||
|
||||
let vertex =
|
||||
naga::front::spv::parse_u8_slice(bytemuck::cast_slice(&compile.vertex), &options)?;
|
||||
let vertex = load_module(&compile.vertex);
|
||||
let fragment = load_module(&compile.fragment);
|
||||
|
||||
let mut fragment = Builder::new_from_module(fragment);
|
||||
lower_fragment_shader(&mut fragment);
|
||||
|
||||
let mut pass = link_input_outputs::LinkInputs::new(&vertex, &mut fragment);
|
||||
pass.do_pass();
|
||||
|
||||
let vertex = vertex.assemble();
|
||||
let fragment = fragment.module().assemble();
|
||||
|
||||
let vertex = naga::front::spv::parse_u8_slice(bytemuck::cast_slice(&vertex), &options)?;
|
||||
|
||||
let fragment = lower_fragment_shader(&compile.fragment);
|
||||
let fragment = naga::front::spv::parse_u8_slice(bytemuck::cast_slice(&fragment), &options)?;
|
||||
|
||||
Ok(NagaReflect { vertex, fragment })
|
||||
|
|
|
@ -1,12 +1,10 @@
|
|||
use rspirv::dr::{Builder, Module, Operand};
|
||||
use rustc_hash::{FxHashMap, FxHashSet};
|
||||
use spirv::{Op, StorageClass};
|
||||
use spirv::{Decoration, Op, StorageClass};
|
||||
|
||||
pub struct LinkInputs<'a> {
|
||||
pub frag_builder: &'a mut Builder,
|
||||
pub vert: &'a Module,
|
||||
|
||||
pub inputs: FxHashMap<spirv::Word, spirv::Word>,
|
||||
pub inputs: FxHashSet<spirv::Word>,
|
||||
}
|
||||
|
||||
impl<'a> LinkInputs<'a> {
|
||||
|
@ -16,37 +14,62 @@ impl<'a> LinkInputs<'a> {
|
|||
return None;
|
||||
}
|
||||
|
||||
eprintln!("{:?}", op);
|
||||
let Some(Operand::Decoration(Decoration::Location)) = op.operands.get(1) else {
|
||||
return None;
|
||||
};
|
||||
|
||||
let Some(&Operand::IdRef(target)) = op.operands.get(0) else {
|
||||
return None;
|
||||
};
|
||||
|
||||
if target != id {
|
||||
return None;
|
||||
}
|
||||
|
||||
let Some(&Operand::LiteralBit32(binding)) = op.operands.get(2) else {
|
||||
return None;
|
||||
};
|
||||
return Some(binding);
|
||||
})
|
||||
}
|
||||
|
||||
pub fn new(vert: &'a Module, frag: &'a mut Builder) -> Self {
|
||||
let mut inputs = FxHashMap::default();
|
||||
|
||||
let mut inputs = FxHashSet::default();
|
||||
let mut bindings = FxHashMap::default();
|
||||
for global in frag.module_ref().types_global_values.iter() {
|
||||
if global.class.opcode == spirv::Op::Variable
|
||||
&& global.operands[0] == Operand::StorageClass(StorageClass::Input)
|
||||
{
|
||||
if let Some(id) = global.result_id {
|
||||
inputs.insert(id, 0);
|
||||
let Some(location) = Self::find_location(frag.module_ref(), id) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
inputs.insert(id);
|
||||
bindings.insert(location, id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// for global in vert.types_global_values.iter() {
|
||||
// if global.class.opcode == spirv::Op::Variable
|
||||
// && global.operands[0] == Operand::StorageClass(StorageClass::Output)
|
||||
// {
|
||||
// if let Some(id) = global.result_id {
|
||||
// inputs.insert(id, 0);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
for global in vert.types_global_values.iter() {
|
||||
if global.class.opcode == spirv::Op::Variable
|
||||
&& global.operands[0] == Operand::StorageClass(StorageClass::Output)
|
||||
{
|
||||
if let Some(id) = global.result_id {
|
||||
let Some(location) = Self::find_location(vert, id) else {
|
||||
continue;
|
||||
};
|
||||
if let Some(frag_ref) = bindings.get(&location) {
|
||||
// if something is bound to the same location in the vertex shader,
|
||||
// we're good.
|
||||
inputs.remove(&frag_ref);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut val = Self {
|
||||
frag_builder: frag,
|
||||
vert,
|
||||
inputs,
|
||||
};
|
||||
val
|
||||
|
@ -60,7 +83,7 @@ impl<'a> LinkInputs<'a> {
|
|||
for param in &function.parameters {
|
||||
for op in ¶m.operands {
|
||||
if let Some(word) = op.id_ref_any() {
|
||||
if self.inputs.contains_key(&word) {
|
||||
if self.inputs.contains(&word) {
|
||||
self.inputs.remove(&word);
|
||||
}
|
||||
}
|
||||
|
@ -71,7 +94,7 @@ impl<'a> LinkInputs<'a> {
|
|||
for inst in &block.instructions {
|
||||
for op in &inst.operands {
|
||||
if let Some(word) = op.id_ref_any() {
|
||||
if self.inputs.contains_key(&word) {
|
||||
if self.inputs.contains(&word) {
|
||||
self.inputs.remove(&word);
|
||||
}
|
||||
}
|
||||
|
@ -85,7 +108,7 @@ impl<'a> LinkInputs<'a> {
|
|||
self.frag_builder.module_mut().debug_names.retain(|instr| {
|
||||
for op in &instr.operands {
|
||||
if let Some(word) = op.id_ref_any() {
|
||||
if self.inputs.contains_key(&word) {
|
||||
if self.inputs.contains(&word) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -96,7 +119,7 @@ impl<'a> LinkInputs<'a> {
|
|||
self.frag_builder.module_mut().annotations.retain(|instr| {
|
||||
for op in &instr.operands {
|
||||
if let Some(word) = op.id_ref_any() {
|
||||
if self.inputs.contains_key(&word) {
|
||||
if self.inputs.contains(&word) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -107,7 +130,7 @@ impl<'a> LinkInputs<'a> {
|
|||
for entry_point in self.frag_builder.module_mut().entry_points.iter_mut() {
|
||||
entry_point.operands.retain(|op| {
|
||||
if let Some(word) = op.id_ref_any() {
|
||||
if self.inputs.contains_key(&word) {
|
||||
if self.inputs.contains(&word) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -123,7 +146,7 @@ impl<'a> LinkInputs<'a> {
|
|||
return true;
|
||||
};
|
||||
|
||||
!self.inputs.contains_key(&id)
|
||||
!self.inputs.contains(&id)
|
||||
});
|
||||
}
|
||||
}
|
2
librashader-reflect/src/reflect/naga/spirv_passes/mod.rs
Normal file
2
librashader-reflect/src/reflect/naga/spirv_passes/mod.rs
Normal file
|
@ -0,0 +1,2 @@
|
|||
pub mod link_input_outputs;
|
||||
pub mod lower_samplers;
|
|
@ -5,7 +5,7 @@ use crate::error::ShaderCompileError;
|
|||
use crate::reflect::naga::{NagaLoweringOptions, NagaReflect};
|
||||
use naga::back::wgsl::WriterFlags;
|
||||
use naga::valid::{Capabilities, ModuleInfo, ValidationFlags, Validator};
|
||||
use naga::{Expression, Module, Statement};
|
||||
use naga::Module;
|
||||
|
||||
impl CompileShader<WGSL> for NagaReflect {
|
||||
type Options = NagaLoweringOptions;
|
||||
|
|
|
@ -122,14 +122,14 @@ impl<'a> State<'a> {
|
|||
// "../test/basic.slangp",
|
||||
// )
|
||||
// .unwrap();
|
||||
//
|
||||
// let preset =
|
||||
// ShaderPreset::try_parse("../test/shaders_slang/crt/crt-royale.slangp").unwrap();
|
||||
|
||||
let preset =
|
||||
ShaderPreset::try_parse("../test/shaders_slang/crt/crt-royale.slangp").unwrap();
|
||||
|
||||
// let preset = ShaderPreset::try_parse(
|
||||
// "../test/shaders_slang/bezel/Mega_Bezel/Presets/MBZ__0__SMOOTH-ADV.slangp",
|
||||
// )
|
||||
// .unwrap();
|
||||
let preset = ShaderPreset::try_parse(
|
||||
"../test/shaders_slang/bezel/Mega_Bezel/Presets/MBZ__0__SMOOTH-ADV.slangp",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let chain = FilterChainWgpu::load_from_preset(
|
||||
preset,
|
||||
|
|
Loading…
Reference in a new issue