reflect(wgsl): only analyze active ubo members

This commit is contained in:
chyyran 2024-02-15 18:57:51 -05:00 committed by Ronny Chan
parent 350508a7aa
commit cbac011969
12 changed files with 352 additions and 93 deletions

50
Cargo.lock generated
View file

@ -1009,6 +1009,15 @@ version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c"
[[package]]
name = "fxhash"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c31b6d751ae2c7f11320402d34e41349dd1016f8d5d45e48c4312bc8625af50c"
dependencies = [
"byteorder",
]
[[package]]
name = "generic-array"
version = "0.14.7"
@ -1643,7 +1652,8 @@ dependencies = [
"rspirv",
"rustc-hash",
"serde",
"spirv",
"spirv 0.2.0+1.5.4",
"spirv-linker",
"spirv-to-dxil",
"thiserror",
]
@ -1924,7 +1934,7 @@ dependencies = [
"num-traits",
"petgraph",
"rustc-hash",
"spirv",
"spirv 0.3.0+sdk-1.3.268.0",
"termcolor",
"thiserror",
"unicode-xid",
@ -2612,12 +2622,13 @@ dependencies = [
[[package]]
name = "rspirv"
version = "0.12.0+sdk-1.3.268.0"
version = "0.11.0+1.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69cf3a93856b6e5946537278df0d3075596371b1950ccff012f02b0f7eafec8d"
checksum = "1503993b59ca9ae4127365c3293517576d7ce56be9f3d8abb1625c85ddc583ba"
dependencies = [
"rustc-hash",
"spirv",
"fxhash",
"num-traits",
"spirv 0.2.0+1.5.4",
]
[[package]]
@ -2816,6 +2827,16 @@ dependencies = [
"lock_api",
]
[[package]]
name = "spirv"
version = "0.2.0+1.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "246bfa38fe3db3f1dfc8ca5a2cdeb7348c78be2112740cc0ec8ef18b6d94f830"
dependencies = [
"bitflags 1.3.2",
"num-traits",
]
[[package]]
name = "spirv"
version = "0.3.0+sdk-1.3.268.0"
@ -2825,6 +2846,17 @@ dependencies = [
"bitflags 2.4.2",
]
[[package]]
name = "spirv-linker"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d236255ec7387809e18d57221d0fa428d6f909abc88524944cdfb7ebab01acb"
dependencies = [
"rspirv",
"thiserror",
"topological-sort",
]
[[package]]
name = "spirv-to-dxil"
version = "0.4.6"
@ -3023,6 +3055,12 @@ dependencies = [
"winnow",
]
[[package]]
name = "topological-sort"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aa7c7f42dea4b1b99439786f5633aeb9c14c1b53f75e282803c2ec2ad545873c"
[[package]]
name = "tracing"
version = "0.1.40"

View file

@ -23,10 +23,11 @@ librashader-preprocess = { path = "../librashader-preprocess", version = "0.2.0-
librashader-presets = { path = "../librashader-presets", version = "0.2.0-rc.2" }
spirv_cross = { package = "librashader-spirv-cross", version = "0.25.1", optional = true }
spirv-linker = "0.1.0"
naga = { version = "0.19.0", optional = true }
rspirv = { version = "0.12.0", optional = true }
spirv = { version = "0.3.0", optional = true}
rspirv = { version = "0.11.0", optional = true }
spirv = { version = "0.2.0", optional = true}
serde = { version = "1.0", features = ["derive"], optional = true }

View file

@ -125,6 +125,10 @@ pub enum ShaderReflectError {
#[cfg(feature = "naga")]
#[error("naga-spv")]
NagaInputError(#[from] naga::front::spv::Error),
/// Error when transpiling from naga
#[cfg(feature = "naga")]
#[error("naga-spv")]
NagaReflectError(#[from] naga::WithSpan<naga::valid::ValidationError>),
}
#[cfg(feature = "unstable-naga")]

View file

@ -1,17 +1,21 @@
mod lower_samplers;
pub mod msl;
pub mod spirv;
mod trim_unused_inputs;
pub mod wgsl;
use crate::error::{SemanticsErrorKind, ShaderReflectError};
use bitflags::Flags;
use crate::front::SpirvCompilation;
use naga::valid::{Capabilities, ModuleInfo, ValidationFlags, Validator};
use naga::{
AddressSpace, Binding, GlobalVariable, Handle, ImageClass, Module, ResourceBinding, Scalar,
ScalarKind, TypeInner, VectorSize,
AddressSpace, Binding, Expression, GlobalVariable, Handle, ImageClass, Module, ResourceBinding,
Scalar, ScalarKind, StructMember, TypeInner, VectorSize,
};
use rspirv::binary::Assemble;
use rspirv::dr::Builder;
use rustc_hash::{FxHashMap, FxHashSet};
use crate::reflect::helper::{SemanticErrorBlame, TextureData, UboData};
use crate::reflect::semantics::{
@ -602,6 +606,41 @@ impl NagaReflect {
Ok(())
}
fn collect_uniform_names(
module: &Module,
buffer_handle: Handle<GlobalVariable>,
blame: SemanticErrorBlame,
) -> Result<FxHashSet<&StructMember>, ShaderReflectError> {
let mut names = FxHashSet::default();
let ubo = &module.global_variables[buffer_handle];
let TypeInner::Struct { members, .. } = &module.types[ubo.ty].inner else {
return Err(blame.error(SemanticsErrorKind::InvalidResourceType));
};
// struct access is AccessIndex
for (_, fun) in module.functions.iter() {
for (_, expr) in fun.expressions.iter() {
let &Expression::AccessIndex { base, index } = expr else {
continue;
};
let &Expression::GlobalVariable(base) = &fun.expressions[base] else {
continue;
};
if base == buffer_handle {
let member = members
.get(index as usize)
.ok_or(blame.error(SemanticsErrorKind::InvalidRange(index)))?;
names.insert(member);
}
}
}
Ok(names)
}
fn reflect_buffer_struct_members(
module: &Module,
resource: Handle<GlobalVariable>,
@ -611,7 +650,10 @@ impl NagaReflect {
offset_type: UniformMemberBlock,
blame: SemanticErrorBlame,
) -> Result<(), ShaderReflectError> {
let reachable = Self::collect_uniform_names(&module, resource, blame)?;
let resource = &module.global_variables[resource];
let TypeInner::Struct { members, .. } = &module.types[resource.ty].inner else {
return Err(blame.error(SemanticsErrorKind::InvalidResourceType));
};
@ -620,6 +662,11 @@ impl NagaReflect {
let Some(name) = member.name.clone() else {
return Err(blame.error(SemanticsErrorKind::InvalidRange(member.offset)));
};
if !reachable.contains(member) {
continue;
}
let member_type = &module.types[member.ty].inner;
if let Some(parameter) = semantics.uniform_semantics.get_unique_semantic(&name) {
@ -884,7 +931,6 @@ impl ReflectShader for NagaReflect {
});
let push_constant = self.reflect_push_constant_buffer(vertex_push, fragment_push)?;
let mut meta = BindingMeta::default();
if let Some(ubo) = vertex_ubo {
@ -965,6 +1011,10 @@ impl ReflectShader for NagaReflect {
#[cfg(test)]
mod test {
use crate::reflect::semantics::{Semantic, TextureSemantics, UniformSemantic};
use librashader_common::map::FastHashMap;
use librashader_preprocess::ShaderSource;
use librashader_presets::ShaderPreset;
// #[test]
// pub fn test_into() {
@ -983,4 +1033,19 @@ mod test {
//
// println!("{outputs:#?}");
// }
// #[test]
// pub fn mega_bezel_reflect() {
// let preset = ShaderPreset::try_parse(
// "../test/shaders_slang/bezel/Mega_Bezel/Presets/MBZ__0__SMOOTH-ADV.slangp",
// )
// .unwrap();
//
// let mut uniform_semantics: FastHashMap<String, UniformSemantic> = Default::default();
// let mut texture_semantics: FastHashMap<String, Semantic<TextureSemantics>> = Default::default();
//
//
//
//
// }
}

View file

@ -0,0 +1,129 @@
use rspirv::dr::{Builder, Module, Operand};
use rustc_hash::{FxHashMap, FxHashSet};
use spirv::{Op, StorageClass};
pub struct LinkInputs<'a> {
pub frag_builder: &'a mut Builder,
pub vert: &'a Module,
pub inputs: FxHashMap<spirv::Word, spirv::Word>,
}
impl<'a> LinkInputs<'a> {
fn find_location(module: &Module, id: spirv::Word) -> Option<u32> {
module.annotations.iter().find_map(|op| {
if op.class.opcode != Op::Decorate {
return None;
}
eprintln!("{:?}", op);
return None;
})
}
pub fn new(vert: &'a Module, frag: &'a mut Builder) -> Self {
let mut inputs = FxHashMap::default();
for global in frag.module_ref().types_global_values.iter() {
if global.class.opcode == spirv::Op::Variable
&& global.operands[0] == Operand::StorageClass(StorageClass::Input)
{
if let Some(id) = global.result_id {
inputs.insert(id, 0);
}
}
}
// for global in vert.types_global_values.iter() {
// if global.class.opcode == spirv::Op::Variable
// && global.operands[0] == Operand::StorageClass(StorageClass::Output)
// {
// if let Some(id) = global.result_id {
// inputs.insert(id, 0);
// }
// }
// }
let mut val = Self {
frag_builder: frag,
vert,
inputs,
};
val
}
pub fn do_pass(&mut self) {
let functions = &self.frag_builder.module_ref().functions;
// literally if it has any reference at all we can keep it
for function in functions {
for param in &function.parameters {
for op in &param.operands {
if let Some(word) = op.id_ref_any() {
if self.inputs.contains_key(&word) {
self.inputs.remove(&word);
}
}
}
}
for block in &function.blocks {
for inst in &block.instructions {
for op in &inst.operands {
if let Some(word) = op.id_ref_any() {
if self.inputs.contains_key(&word) {
self.inputs.remove(&word);
}
}
}
}
}
}
// ok well guess we dont
self.frag_builder.module_mut().debug_names.retain(|instr| {
for op in &instr.operands {
if let Some(word) = op.id_ref_any() {
if self.inputs.contains_key(&word) {
return false;
}
}
}
return true;
});
self.frag_builder.module_mut().annotations.retain(|instr| {
for op in &instr.operands {
if let Some(word) = op.id_ref_any() {
if self.inputs.contains_key(&word) {
return false;
}
}
}
return true;
});
for entry_point in self.frag_builder.module_mut().entry_points.iter_mut() {
entry_point.operands.retain(|op| {
if let Some(word) = op.id_ref_any() {
if self.inputs.contains_key(&word) {
return false;
}
}
return true;
})
}
self.frag_builder
.module_mut()
.types_global_values
.retain(|instr| {
let Some(id) = instr.result_id else {
return true;
};
!self.inputs.contains_key(&id)
});
}
}

View file

@ -4,8 +4,8 @@ use crate::back::{CompileShader, ShaderCompilerOutput};
use crate::error::ShaderCompileError;
use crate::reflect::naga::{NagaLoweringOptions, NagaReflect};
use naga::back::wgsl::WriterFlags;
use naga::valid::{Capabilities, ValidationFlags};
use naga::Module;
use naga::valid::{Capabilities, ModuleInfo, ValidationFlags, Validator};
use naga::{Expression, Module, Statement};
impl CompileShader<WGSL> for NagaReflect {
type Options = NagaLoweringOptions;
@ -15,19 +15,20 @@ impl CompileShader<WGSL> for NagaReflect {
mut self,
options: Self::Options,
) -> Result<ShaderCompilerOutput<String, Self::Context>, ShaderCompileError> {
fn write_wgsl(module: &Module) -> Result<String, ShaderCompileError> {
let mut valid =
naga::valid::Validator::new(ValidationFlags::all(), Capabilities::empty());
let info = valid.validate(&module)?;
let wgsl = naga::back::wgsl::write_string(&module, &info, WriterFlags::EXPLICIT_TYPES)?;
fn write_wgsl(module: &Module, info: &ModuleInfo) -> Result<String, ShaderCompileError> {
let wgsl = naga::back::wgsl::write_string(&module, &info, WriterFlags::empty())?;
Ok(wgsl)
}
self.do_lowering(&options);
let fragment = write_wgsl(&self.fragment)?;
let vertex = write_wgsl(&self.vertex)?;
let mut valid = Validator::new(ValidationFlags::all(), Capabilities::empty());
let vertex_info = valid.validate(&self.vertex)?;
let fragment_info = valid.validate(&self.fragment)?;
let fragment = write_wgsl(&self.fragment, &fragment_info)?;
let vertex = write_wgsl(&self.vertex, &vertex_info)?;
Ok(ShaderCompilerOutput {
vertex,
fragment,

View file

@ -26,7 +26,6 @@ thiserror = "1.0.50"
bytemuck = { version = "1.14.0", features = ["derive"] }
array-concat = "0.5.2"
[features]
# workaround for docsrs to not build metal-rs.
wgpu_dx12 = ["wgpu/dx12"]

View file

@ -8,45 +8,45 @@ use wgpu::{Buffer, Device, RenderPass};
// WGPU does vertex expansion
#[repr(C)]
#[derive(Debug, Copy, Clone, Default, Zeroable, Pod)]
struct WgpuVertex {
position: [f32; 2],
texcoord: [f32; 2],
pub struct WgpuVertex {
pub position: [f32; 4],
pub texcoord: [f32; 2],
}
const OFFSCREEN_VBO_DATA: [WgpuVertex; 4] = [
WgpuVertex {
position: [-1.0, -1.0],
position: [-1.0, -1.0, 0.0, 1.0],
texcoord: [0.0, 0.0],
},
WgpuVertex {
position: [-1.0, 1.0],
position: [-1.0, 1.0, 0.0, 1.0],
texcoord: [0.0, 1.0],
},
WgpuVertex {
position: [1.0, -1.0],
position: [1.0, -1.0, 0.0, 1.0],
texcoord: [1.0, 0.0],
},
WgpuVertex {
position: [1.0, 1.0],
position: [1.0, 1.0, 0.0, 1.0],
texcoord: [1.0, 1.0],
},
];
const FINAL_VBO_DATA: [WgpuVertex; 4] = [
WgpuVertex {
position: [0.0, 0.0],
position: [0.0, 0.0, 0.0, 1.0],
texcoord: [0.0, 0.0],
},
WgpuVertex {
position: [0.0, 1.0],
position: [0.0, 1.0, 0.0, 1.0],
texcoord: [0.0, 1.0],
},
WgpuVertex {
position: [1.0, 0.0],
position: [1.0, 0.0, 0.0, 1.0],
texcoord: [1.0, 0.0],
},
WgpuVertex {
position: [1.0, 1.0],
position: [1.0, 1.0, 0.0, 1.0],
texcoord: [1.0, 1.0],
},
];

View file

@ -15,6 +15,7 @@ use rayon::prelude::*;
use std::collections::VecDeque;
use std::path::Path;
use rayon::ThreadPoolBuilder;
use std::sync::Arc;
use crate::buffer::WgpuStagedBuffer;
@ -267,6 +268,12 @@ impl FilterChainWgpu {
#[cfg(target_arch = "wasm32")]
let passes_iter = passes.into_iter();
let thread_pool = ThreadPoolBuilder::new()
.stack_size(10 * 1048576)
.build()
.unwrap();
let filters = thread_pool.install(|| {
let filters: Vec<error::Result<FilterPass>> = passes_iter
.enumerate()
.map(|(index, (config, source, mut reflect))| {
@ -297,7 +304,8 @@ impl FilterChainWgpu {
),
);
let uniform_bindings = reflection.meta.create_binding_map(|param| param.offset());
let uniform_bindings =
reflection.meta.create_binding_map(|param| param.offset());
let render_pass_format: Option<TextureFormat> =
if let Some(format) = config.get_format_override() {
@ -324,6 +332,9 @@ impl FilterChainWgpu {
})
})
.collect();
filters
});
//
let filters: error::Result<Vec<FilterPass>> = filters.into_iter().collect();
let filters = filters?;

View file

@ -1,3 +1,4 @@
use crate::draw_quad::WgpuVertex;
use crate::framebuffer::WgpuOutputView;
use crate::util;
use librashader_reflect::back::wgsl::NagaWgslContext;
@ -152,17 +153,19 @@ impl PipelineLayoutObjects {
module: &self.vertex,
entry_point: &self.vertex_entry_name,
buffers: &[VertexBufferLayout {
array_stride: 4 * std::mem::size_of::<f32>() as wgpu::BufferAddress,
array_stride: std::mem::size_of::<WgpuVertex>() as wgpu::BufferAddress,
step_mode: wgpu::VertexStepMode::Vertex,
attributes: &[
wgpu::VertexAttribute {
format: wgpu::VertexFormat::Float32x2,
offset: 0,
format: wgpu::VertexFormat::Float32x4,
offset: bytemuck::offset_of!(WgpuVertex, position)
as wgpu::BufferAddress,
shader_location: 0,
},
wgpu::VertexAttribute {
format: wgpu::VertexFormat::Float32x2,
offset: (2 * std::mem::size_of::<f32>()) as wgpu::BufferAddress,
offset: bytemuck::offset_of!(WgpuVertex, texcoord)
as wgpu::BufferAddress,
shader_location: 1,
},
],
@ -173,7 +176,7 @@ impl PipelineLayoutObjects {
entry_point: &self.fragment_entry_name,
targets: &[Some(wgpu::ColorTargetState {
format: framebuffer_format,
blend: Some(wgpu::BlendState::REPLACE),
blend: None,
write_mask: wgpu::ColorWrites::ALL,
})],
}),

View file

@ -117,11 +117,19 @@ impl<'a> State<'a> {
let device = Arc::new(device);
let queue = Arc::new(queue);
//
// let preset = ShaderPreset::try_parse(
// "../test/basic.slangp",
// )
// .unwrap();
let preset = ShaderPreset::try_parse(
"../test/shaders_slang/bezel/Mega_Bezel/Presets/MBZ__0__SMOOTH-ADV.slangp",
)
.unwrap();
let preset =
ShaderPreset::try_parse("../test/shaders_slang/crt/crt-royale.slangp").unwrap();
// let preset = ShaderPreset::try_parse(
// "../test/shaders_slang/bezel/Mega_Bezel/Presets/MBZ__0__SMOOTH-ADV.slangp",
// )
// .unwrap();
let chain = FilterChainWgpu::load_from_preset(
preset,

View file

@ -32,5 +32,5 @@ layout(location = 0) out vec4 FragColor;
layout(binding = 1) uniform sampler2D Source;
void main()
{
FragColor = texture(Source, vTexCoord) * params.ColorMod * ColorMod2;
FragColor = texture(Source, vTexCoord) * params.ColorMod;
}