reflect: improve inout link algorithm
This commit is contained in:
parent
cd14bca23a
commit
9b40c10466
2 changed files with 96 additions and 25 deletions
|
@ -41,11 +41,12 @@ pub(crate) fn compile_spirv(source: &ShaderSource) -> Result<SpirvCompilation, S
|
|||
let vertex = load_module(&vertex);
|
||||
let fragment = load_module(&fragment);
|
||||
let mut fragment = Builder::new_from_module(fragment);
|
||||
let mut vertex = Builder::new_from_module(vertex);
|
||||
|
||||
let mut pass = link_input_outputs::LinkInputs::new(&vertex, &mut fragment);
|
||||
let mut pass = link_input_outputs::LinkInputs::new(&mut vertex, &mut fragment, false);
|
||||
pass.do_pass();
|
||||
|
||||
let vertex = vertex.assemble();
|
||||
let vertex = vertex.module().assemble();
|
||||
let fragment = fragment.module().assemble();
|
||||
|
||||
Ok(SpirvCompilation { vertex, fragment })
|
||||
|
|
|
@ -2,12 +2,21 @@ use rspirv::dr::{Builder, Module, Operand};
|
|||
use rustc_hash::{FxHashMap, FxHashSet};
|
||||
use spirv::{Decoration, Op, StorageClass};
|
||||
|
||||
/// Do DCE on inputs of the fragment shader, then
|
||||
/// link by downgrading outputs of unused fragment inputs
|
||||
/// to global variables on the vertex shader.
|
||||
pub struct LinkInputs<'a> {
|
||||
pub frag_builder: &'a mut Builder,
|
||||
pub inputs: FxHashSet<spirv::Word>,
|
||||
pub vert_builder: &'a mut Builder,
|
||||
|
||||
// binding -> ID
|
||||
pub outputs: FxHashMap<u32, spirv::Word>,
|
||||
// id -> binding
|
||||
pub inputs_to_remove: FxHashMap<spirv::Word, u32>,
|
||||
}
|
||||
|
||||
impl<'a> LinkInputs<'a> {
|
||||
/// Get the value of the location of the inout in the module
|
||||
fn find_location(module: &Module, id: spirv::Word) -> Option<u32> {
|
||||
module.annotations.iter().find_map(|op| {
|
||||
if op.class.opcode != Op::Decorate {
|
||||
|
@ -33,9 +42,11 @@ impl<'a> LinkInputs<'a> {
|
|||
})
|
||||
}
|
||||
|
||||
pub fn new(vert: &'a Module, frag: &'a mut Builder) -> Self {
|
||||
let mut inputs = FxHashSet::default();
|
||||
let mut bindings = FxHashMap::default();
|
||||
pub fn new(vert: &'a mut Builder, frag: &'a mut Builder, keep_if_bound: bool) -> Self {
|
||||
let mut outputs = FxHashMap::default();
|
||||
let mut inputs_to_remove = FxHashMap::default();
|
||||
let mut inputs = FxHashMap::default();
|
||||
|
||||
for global in frag.module_ref().types_global_values.iter() {
|
||||
if global.class.opcode == spirv::Op::Variable
|
||||
&& global.operands[0] == Operand::StorageClass(StorageClass::Input)
|
||||
|
@ -45,24 +56,32 @@ impl<'a> LinkInputs<'a> {
|
|||
continue;
|
||||
};
|
||||
|
||||
inputs.insert(id);
|
||||
bindings.insert(location, id);
|
||||
inputs_to_remove.insert(id, location);
|
||||
inputs.insert(location, id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for global in vert.types_global_values.iter() {
|
||||
for global in vert.module_ref().types_global_values.iter() {
|
||||
if global.class.opcode == spirv::Op::Variable
|
||||
&& global.operands[0] == Operand::StorageClass(StorageClass::Output)
|
||||
{
|
||||
if let Some(id) = global.result_id {
|
||||
let Some(location) = Self::find_location(vert, id) else {
|
||||
let Some(location) = Self::find_location(vert.module_ref(), id) else {
|
||||
continue;
|
||||
};
|
||||
if let Some(frag_ref) = bindings.get(&location) {
|
||||
// if something is bound to the same location in the vertex shader,
|
||||
// we're good.
|
||||
inputs.remove(&frag_ref);
|
||||
|
||||
// Add to list of outputs
|
||||
outputs.insert(location, id);
|
||||
|
||||
// Keep the input, if it is bound to both stages.Otherwise, do DCE analysis on
|
||||
// the input, and remove it regardless if bound, if unused by the fragment stage.
|
||||
if keep_if_bound {
|
||||
if let Some(frag_ref) = inputs.get(&location) {
|
||||
// if something is bound to the same location in the vertex shader,
|
||||
// we're good.
|
||||
inputs_to_remove.remove(&frag_ref);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -70,11 +89,63 @@ impl<'a> LinkInputs<'a> {
|
|||
|
||||
Self {
|
||||
frag_builder: frag,
|
||||
inputs,
|
||||
vert_builder: vert,
|
||||
outputs,
|
||||
inputs_to_remove,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn do_pass(&mut self) {
|
||||
self.trim_inputs();
|
||||
self.downgrade_outputs();
|
||||
}
|
||||
|
||||
/// Downgrade dead inputs corresponding to outputs to global variables, keeping existing mappings.
|
||||
fn downgrade_outputs(&mut self) {
|
||||
let dead_outputs = self
|
||||
.inputs_to_remove
|
||||
.values()
|
||||
.filter_map(|i| self.outputs.get(i).cloned())
|
||||
.collect::<FxHashSet<spirv::Word>>();
|
||||
|
||||
let module = self.vert_builder.module_mut();
|
||||
for global in module.types_global_values.iter_mut() {
|
||||
if global.class.opcode != spirv::Op::Variable
|
||||
|| global.operands[0] != Operand::StorageClass(StorageClass::Output)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(id) = global.result_id {
|
||||
if !dead_outputs.contains(&id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// downgrade the OpVariable if it's in dead_outputs
|
||||
global.operands[0] = Operand::StorageClass(StorageClass::Private);
|
||||
}
|
||||
}
|
||||
|
||||
module.annotations.retain_mut(|op| {
|
||||
if op.class.opcode != Op::Decorate {
|
||||
return true;
|
||||
}
|
||||
|
||||
let Some(Operand::Decoration(Decoration::Location)) = op.operands.get(1) else {
|
||||
return true;
|
||||
};
|
||||
|
||||
let Some(&Operand::IdRef(target)) = op.operands.get(0) else {
|
||||
return true;
|
||||
};
|
||||
|
||||
// If target is in dead outputs, then don't keep it.
|
||||
!dead_outputs.contains(&target)
|
||||
});
|
||||
}
|
||||
|
||||
// Trim unused fragment shader inputs
|
||||
fn trim_inputs(&mut self) {
|
||||
let functions = &self.frag_builder.module_ref().functions;
|
||||
|
||||
// literally if it has any reference at all we can keep it
|
||||
|
@ -82,8 +153,8 @@ impl<'a> LinkInputs<'a> {
|
|||
for param in &function.parameters {
|
||||
for op in ¶m.operands {
|
||||
if let Some(word) = op.id_ref_any() {
|
||||
if self.inputs.contains(&word) {
|
||||
self.inputs.remove(&word);
|
||||
if self.inputs_to_remove.contains_key(&word) {
|
||||
self.inputs_to_remove.remove(&word);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -93,8 +164,8 @@ impl<'a> LinkInputs<'a> {
|
|||
for inst in &block.instructions {
|
||||
for op in &inst.operands {
|
||||
if let Some(word) = op.id_ref_any() {
|
||||
if self.inputs.contains(&word) {
|
||||
self.inputs.remove(&word);
|
||||
if self.inputs_to_remove.contains_key(&word) {
|
||||
self.inputs_to_remove.remove(&word);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -103,11 +174,10 @@ impl<'a> LinkInputs<'a> {
|
|||
}
|
||||
|
||||
// ok well guess we dont
|
||||
|
||||
self.frag_builder.module_mut().debug_names.retain(|instr| {
|
||||
for op in &instr.operands {
|
||||
if let Some(word) = op.id_ref_any() {
|
||||
if self.inputs.contains(&word) {
|
||||
if self.inputs_to_remove.contains_key(&word) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -118,7 +188,7 @@ impl<'a> LinkInputs<'a> {
|
|||
self.frag_builder.module_mut().annotations.retain(|instr| {
|
||||
for op in &instr.operands {
|
||||
if let Some(word) = op.id_ref_any() {
|
||||
if self.inputs.contains(&word) {
|
||||
if self.inputs_to_remove.contains_key(&word) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -129,7 +199,7 @@ impl<'a> LinkInputs<'a> {
|
|||
for entry_point in self.frag_builder.module_mut().entry_points.iter_mut() {
|
||||
entry_point.operands.retain(|op| {
|
||||
if let Some(word) = op.id_ref_any() {
|
||||
if self.inputs.contains(&word) {
|
||||
if self.inputs_to_remove.contains_key(&word) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -145,7 +215,7 @@ impl<'a> LinkInputs<'a> {
|
|||
return true;
|
||||
};
|
||||
|
||||
!self.inputs.contains(&id)
|
||||
!self.inputs_to_remove.contains_key(&id)
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Add table
Reference in a new issue