diff --git a/librashader-reflect/src/front/glslang.rs b/librashader-reflect/src/front/glslang.rs index f30abec..88b1170 100644 --- a/librashader-reflect/src/front/glslang.rs +++ b/librashader-reflect/src/front/glslang.rs @@ -41,11 +41,12 @@ pub(crate) fn compile_spirv(source: &ShaderSource) -> Result { pub frag_builder: &'a mut Builder, - pub inputs: FxHashSet, + pub vert_builder: &'a mut Builder, + + // binding -> ID + pub outputs: FxHashMap, + // id -> binding + pub inputs_to_remove: FxHashMap, } impl<'a> LinkInputs<'a> { + /// Get the value of the location of the inout in the module fn find_location(module: &Module, id: spirv::Word) -> Option { module.annotations.iter().find_map(|op| { if op.class.opcode != Op::Decorate { @@ -33,9 +42,11 @@ impl<'a> LinkInputs<'a> { }) } - pub fn new(vert: &'a Module, frag: &'a mut Builder) -> Self { - let mut inputs = FxHashSet::default(); - let mut bindings = FxHashMap::default(); + pub fn new(vert: &'a mut Builder, frag: &'a mut Builder, keep_if_bound: bool) -> Self { + let mut outputs = FxHashMap::default(); + let mut inputs_to_remove = FxHashMap::default(); + let mut inputs = FxHashMap::default(); + for global in frag.module_ref().types_global_values.iter() { if global.class.opcode == spirv::Op::Variable && global.operands[0] == Operand::StorageClass(StorageClass::Input) @@ -45,24 +56,32 @@ impl<'a> LinkInputs<'a> { continue; }; - inputs.insert(id); - bindings.insert(location, id); + inputs_to_remove.insert(id, location); + inputs.insert(location, id); } } } - for global in vert.types_global_values.iter() { + for global in vert.module_ref().types_global_values.iter() { if global.class.opcode == spirv::Op::Variable && global.operands[0] == Operand::StorageClass(StorageClass::Output) { if let Some(id) = global.result_id { - let Some(location) = Self::find_location(vert, id) else { + let Some(location) = Self::find_location(vert.module_ref(), id) else { continue; }; - if let Some(frag_ref) = bindings.get(&location) { - // if something is bound to the same location in the vertex shader, - // we're good. - inputs.remove(&frag_ref); + + // Add to list of outputs + outputs.insert(location, id); + + // Keep the input, if it is bound to both stages.Otherwise, do DCE analysis on + // the input, and remove it regardless if bound, if unused by the fragment stage. + if keep_if_bound { + if let Some(frag_ref) = inputs.get(&location) { + // if something is bound to the same location in the vertex shader, + // we're good. + inputs_to_remove.remove(&frag_ref); + } } } } @@ -70,11 +89,63 @@ impl<'a> LinkInputs<'a> { Self { frag_builder: frag, - inputs, + vert_builder: vert, + outputs, + inputs_to_remove, } } pub fn do_pass(&mut self) { + self.trim_inputs(); + self.downgrade_outputs(); + } + + /// Downgrade dead inputs corresponding to outputs to global variables, keeping existing mappings. + fn downgrade_outputs(&mut self) { + let dead_outputs = self + .inputs_to_remove + .values() + .filter_map(|i| self.outputs.get(i).cloned()) + .collect::>(); + + let module = self.vert_builder.module_mut(); + for global in module.types_global_values.iter_mut() { + if global.class.opcode != spirv::Op::Variable + || global.operands[0] != Operand::StorageClass(StorageClass::Output) + { + continue; + } + + if let Some(id) = global.result_id { + if !dead_outputs.contains(&id) { + continue; + } + + // downgrade the OpVariable if it's in dead_outputs + global.operands[0] = Operand::StorageClass(StorageClass::Private); + } + } + + module.annotations.retain_mut(|op| { + if op.class.opcode != Op::Decorate { + return true; + } + + let Some(Operand::Decoration(Decoration::Location)) = op.operands.get(1) else { + return true; + }; + + let Some(&Operand::IdRef(target)) = op.operands.get(0) else { + return true; + }; + + // If target is in dead outputs, then don't keep it. + !dead_outputs.contains(&target) + }); + } + + // Trim unused fragment shader inputs + fn trim_inputs(&mut self) { let functions = &self.frag_builder.module_ref().functions; // literally if it has any reference at all we can keep it @@ -82,8 +153,8 @@ impl<'a> LinkInputs<'a> { for param in &function.parameters { for op in ¶m.operands { if let Some(word) = op.id_ref_any() { - if self.inputs.contains(&word) { - self.inputs.remove(&word); + if self.inputs_to_remove.contains_key(&word) { + self.inputs_to_remove.remove(&word); } } } @@ -93,8 +164,8 @@ impl<'a> LinkInputs<'a> { for inst in &block.instructions { for op in &inst.operands { if let Some(word) = op.id_ref_any() { - if self.inputs.contains(&word) { - self.inputs.remove(&word); + if self.inputs_to_remove.contains_key(&word) { + self.inputs_to_remove.remove(&word); } } } @@ -103,11 +174,10 @@ impl<'a> LinkInputs<'a> { } // ok well guess we dont - self.frag_builder.module_mut().debug_names.retain(|instr| { for op in &instr.operands { if let Some(word) = op.id_ref_any() { - if self.inputs.contains(&word) { + if self.inputs_to_remove.contains_key(&word) { return false; } } @@ -118,7 +188,7 @@ impl<'a> LinkInputs<'a> { self.frag_builder.module_mut().annotations.retain(|instr| { for op in &instr.operands { if let Some(word) = op.id_ref_any() { - if self.inputs.contains(&word) { + if self.inputs_to_remove.contains_key(&word) { return false; } } @@ -129,7 +199,7 @@ impl<'a> LinkInputs<'a> { for entry_point in self.frag_builder.module_mut().entry_points.iter_mut() { entry_point.operands.retain(|op| { if let Some(word) = op.id_ref_any() { - if self.inputs.contains(&word) { + if self.inputs_to_remove.contains_key(&word) { return false; } } @@ -145,7 +215,7 @@ impl<'a> LinkInputs<'a> { return true; }; - !self.inputs.contains(&id) + !self.inputs_to_remove.contains_key(&id) }); } -} +} \ No newline at end of file