From 1ac78695c67996b6a20546937a8fa12c587738fe Mon Sep 17 00:00:00 2001 From: chyyran Date: Mon, 19 Aug 2024 01:31:26 -0400 Subject: [PATCH] reflect: improve inout link algorithm --- librashader-reflect/src/front/glslang.rs | 5 +- .../front/spirv_passes/link_input_outputs.rs | 191 ++++++++++++++++-- 2 files changed, 172 insertions(+), 24 deletions(-) diff --git a/librashader-reflect/src/front/glslang.rs b/librashader-reflect/src/front/glslang.rs index f30abec..88b1170 100644 --- a/librashader-reflect/src/front/glslang.rs +++ b/librashader-reflect/src/front/glslang.rs @@ -41,11 +41,12 @@ pub(crate) fn compile_spirv(source: &ShaderSource) -> Result { pub frag_builder: &'a mut Builder, - pub inputs: FxHashSet, + pub vert_builder: &'a mut Builder, + + pub outputs: Vec<(u32, spirv::Word)>, + // pub inputs: Vec<(u32, spirv::Word)>, + pub inputs_to_remove: FxHashMap, } impl<'a> LinkInputs<'a> { + + /// Get the value of the location of the inout in the module fn find_location(module: &Module, id: spirv::Word) -> Option { module.annotations.iter().find_map(|op| { if op.class.opcode != Op::Decorate { @@ -33,9 +40,37 @@ impl<'a> LinkInputs<'a> { }) } - pub fn new(vert: &'a Module, frag: &'a mut Builder) -> Self { - let mut inputs = FxHashSet::default(); - let mut bindings = FxHashMap::default(); + /// Get a mutable reference to the inout in the module + fn find_location_operand(module: &mut Module, id: spirv::Word) -> Option<&mut u32> { + module.annotations.iter_mut().find_map(|op| { + if op.class.opcode != Op::Decorate { + return None; + } + + let Some(Operand::Decoration(Decoration::Location)) = op.operands.get(1) else { + return None; + }; + + let Some(&Operand::IdRef(target)) = op.operands.get(0) else { + return None; + }; + + if target != id { + return None; + } + + let Some(Operand::LiteralBit32(binding)) = op.operands.get_mut(2) else { + return None; + }; + return Some(binding); + }) + } + + pub fn new(vert: &'a mut Builder, frag: &'a mut Builder, keep_if_bound: bool) -> Self { + let mut outputs = FxHashMap::default(); + let mut inputs_to_remove = FxHashMap::default(); + // let mut inputs = FxHashMap::default(); + for global in frag.module_ref().types_global_values.iter() { if global.class.opcode == spirv::Op::Variable && global.operands[0] == Operand::StorageClass(StorageClass::Input) @@ -45,36 +80,59 @@ impl<'a> LinkInputs<'a> { continue; }; - inputs.insert(id); - bindings.insert(location, id); + inputs_to_remove.insert(id, location); + // inputs.insert(location, id); } } } - for global in vert.types_global_values.iter() { + for global in vert.module_ref().types_global_values.iter() { if global.class.opcode == spirv::Op::Variable && global.operands[0] == Operand::StorageClass(StorageClass::Output) { if let Some(id) = global.result_id { - let Some(location) = Self::find_location(vert, id) else { + let Some(location) = Self::find_location(vert.module_ref(), id) else { continue; }; - if let Some(frag_ref) = bindings.get(&location) { - // if something is bound to the same location in the vertex shader, - // we're good. - inputs.remove(&frag_ref); + + // Add to list of outputs + outputs.insert(location, id); + + // Keep the input, if it is bound to both stages. + // Otherwise, do DCE analysis on the input, and remove it + // regardless if bound. + if keep_if_bound { + if let Some(&frag_ref) = inputs_to_remove.get(&location) { + // if something is bound to the same location in the vertex shader, + // we're good. + inputs_to_remove.remove(&frag_ref); + } } } } } + let mut outputs: Vec<(u32, spirv::Word)> = outputs.into_iter().collect(); + // let mut inputs: Vec<(u32, spirv::Word)> = inputs.into_iter().collect(); + + outputs.sort_by(|&(a, _), &(b, _)| a.cmp(&b)); + Self { frag_builder: frag, - inputs, + vert_builder: vert, + outputs, + // inputs, + inputs_to_remove, } } + pub fn do_pass(&mut self) { + self.trim_inputs(); + self.reorder_inputs(); + } + + fn trim_inputs(&mut self) { let functions = &self.frag_builder.module_ref().functions; // literally if it has any reference at all we can keep it @@ -82,8 +140,8 @@ impl<'a> LinkInputs<'a> { for param in &function.parameters { for op in ¶m.operands { if let Some(word) = op.id_ref_any() { - if self.inputs.contains(&word) { - self.inputs.remove(&word); + if self.inputs_to_remove.contains_key(&word) { + self.inputs_to_remove.remove(&word); } } } @@ -93,8 +151,8 @@ impl<'a> LinkInputs<'a> { for inst in &block.instructions { for op in &inst.operands { if let Some(word) = op.id_ref_any() { - if self.inputs.contains(&word) { - self.inputs.remove(&word); + if self.inputs_to_remove.contains_key(&word) { + self.inputs_to_remove.remove(&word); } } } @@ -107,7 +165,7 @@ impl<'a> LinkInputs<'a> { self.frag_builder.module_mut().debug_names.retain(|instr| { for op in &instr.operands { if let Some(word) = op.id_ref_any() { - if self.inputs.contains(&word) { + if self.inputs_to_remove.contains_key(&word) { return false; } } @@ -118,7 +176,7 @@ impl<'a> LinkInputs<'a> { self.frag_builder.module_mut().annotations.retain(|instr| { for op in &instr.operands { if let Some(word) = op.id_ref_any() { - if self.inputs.contains(&word) { + if self.inputs_to_remove.contains_key(&word) { return false; } } @@ -129,7 +187,7 @@ impl<'a> LinkInputs<'a> { for entry_point in self.frag_builder.module_mut().entry_points.iter_mut() { entry_point.operands.retain(|op| { if let Some(word) = op.id_ref_any() { - if self.inputs.contains(&word) { + if self.inputs_to_remove.contains_key(&word) { return false; } } @@ -145,7 +203,96 @@ impl<'a> LinkInputs<'a> { return true; }; - !self.inputs.contains(&id) + !self.inputs_to_remove.contains_key(&id) }); } + + fn reorder_inputs(&mut self) { + // Preconditions: + // - trim_inputs is called, so all dead inputs are gone from the frag builder + + + // We want to have all the dead inputs get ordered last, but otherwise ensure + // that all locations are ordered. + + let mut dead_inputs = self.inputs_to_remove.values().collect::>(); + dead_inputs.sort(); + + let Some(&&first) = dead_inputs.first() else { + // If there are no dead inputs then things are contiguous + return; + }; + + // Mapping of old bindings -> new bindings + let mut remapping = FxHashMap::default(); + + // Start at the first dead input + let mut alloc = first; + + for (binding, _) in &self.outputs { + if *binding < alloc { + continue + } + + if !dead_inputs.contains(&binding) { + remapping.insert(*binding, alloc); + alloc += 1; + } + } + + // Now assign dead inputs the end + for binding in &dead_inputs { + remapping.insert(**binding, alloc); + alloc += 1; + } + + // eprintln!("dead: {:#?}", dead_inputs); + + // eprintln!("remapping: {:#?}", remapping); + + let frag_clone = self.frag_builder.module_ref().clone(); + let frag_mut = self.frag_builder.module_mut(); + + for global in frag_clone.types_global_values { + if global.class.opcode == spirv::Op::Variable + && global.operands[0] == Operand::StorageClass(StorageClass::Input) + { + if let Some(id) = global.result_id { + let Some(location) = Self::find_location_operand(frag_mut, id) else { + continue; + }; + + let Some(&remapping) = remapping.get(&location) else { + continue + }; + + // eprintln!("frag: remapped {} to {}", *location, remapping); + + *location = remapping; + } + } + } + + let vert_clone = self.vert_builder.module_ref().clone(); + let vert_mut = self.vert_builder.module_mut(); + + for global in vert_clone.types_global_values { + if global.class.opcode == spirv::Op::Variable + && global.operands[0] == Operand::StorageClass(StorageClass::Output) + { + if let Some(id) = global.result_id { + let Some(location) = Self::find_location_operand(vert_mut, id) else { + continue; + }; + + let Some(&remapping) = remapping.get(&location) else { + continue + }; + + // eprintln!("vert: remapped {} to {}", *location, remapping); + *location = remapping; + } + } + } + } }