diff --git a/librashader-reflect/src/front/spirv_passes/link_input_outputs.rs b/librashader-reflect/src/front/spirv_passes/link_input_outputs.rs index d35f579..5ec6a24 100644 --- a/librashader-reflect/src/front/spirv_passes/link_input_outputs.rs +++ b/librashader-reflect/src/front/spirv_passes/link_input_outputs.rs @@ -108,8 +108,74 @@ impl<'a> LinkInputs<'a> { .filter_map(|i| self.outputs.get(i).cloned()) .collect::>(); - let module = self.vert_builder.module_mut(); - for global in module.types_global_values.iter_mut() { + let mut pointer_types_to_downgrade = FxHashSet::default(); + + // Map from Pointer type to pointee + let mut pointer_type_pointee = Vec::new(); + + // Map from StorageClass Output to StorageClass Private + let mut downgraded_pointer_types = FxHashMap::default(); + + // First collect all the pointer types that are needed for dead outputs. + for global in self.vert_builder.module_ref().types_global_values.iter() { + if global.class.opcode != spirv::Op::Variable + || global.operands[0] != Operand::StorageClass(StorageClass::Output) + { + continue; + } + + if let Some(id) = global.result_id { + if !dead_outputs.contains(&id) { + continue; + } + + if let Some(result_type) = global.result_type { + pointer_types_to_downgrade.insert(result_type); + } + } + } + + // Collect all the pointee types of pointer types to downgrade + for global in self.vert_builder.module_ref().types_global_values.iter() { + if global.class.opcode != spirv::Op::TypePointer + || global.operands[0] != Operand::StorageClass(StorageClass::Output) + { + continue; + } + + if let Some(id) = global.result_id { + if !pointer_types_to_downgrade.contains(&id) { + continue; + } + + let Some(pointee_type) = global.operands[1].id_ref_any() else { + continue; + }; + + pointer_type_pointee.push((id, pointee_type)); + } + } + + // Create pointer types for everything we saw above with Private storage class. + // We don't have to deal with OpTypeForwardPointer, because PhysicalStorageBuffer + // is not valid in slang shaders, and we're only working with Vulkan inputs. + for (pointer_type, pointee_type) in pointer_type_pointee.into_iter() { + // Create a new private type + let private_pointer_type = + self.vert_builder + .type_pointer(None, StorageClass::Private, pointee_type); + + // Add it to the mapping + downgraded_pointer_types.insert(pointer_type, private_pointer_type); + } + + // Downgrade the OpVariable storage class and reassign the types. + for global in self + .vert_builder + .module_mut() + .types_global_values + .iter_mut() + { if global.class.opcode != spirv::Op::Variable || global.operands[0] != Operand::StorageClass(StorageClass::Output) { @@ -123,10 +189,25 @@ impl<'a> LinkInputs<'a> { // downgrade the OpVariable if it's in dead_outputs global.operands[0] = Operand::StorageClass(StorageClass::Private); + + // Get the result type. If there's no result type it's invalid anyways + // so it doesn't matter that we downgraded early (better downgraded than unmatched) + let Some(result_type) = &mut global.result_type else { + continue; + }; + + let Some(new_type) = downgraded_pointer_types.get(&result_type) else { + // We should have created one above. + continue; + }; + + // Set the type of the OpVariable to the same type with Private storageclass. + *result_type = *new_type; } } - module.annotations.retain_mut(|op| { + // Strip decorations of downgraded variables. + self.vert_builder.module_mut().annotations.retain_mut(|op| { if op.class.opcode != Op::Decorate { return true; } @@ -142,6 +223,27 @@ impl<'a> LinkInputs<'a> { // If target is in dead outputs, then don't keep it. !dead_outputs.contains(&target) }); + + for entry_point in self.vert_builder.module_mut().entry_points.iter_mut() { + let mut index = 0; + entry_point.operands.retain(|s| { + // Skip the execution mode, entry point reference, and name. + if index < 3 { + index += 1; + return true; + } + + index += 1; + + // Ignore any non-IdRef + let Operand::IdRef(id_ref) = s else { + return true; + }; + + // If the entry point contains a dead outputs, remove it from the interface. + !dead_outputs.contains(id_ref) + }); + } } // Trim unused fragment shader inputs @@ -218,4 +320,4 @@ impl<'a> LinkInputs<'a> { !self.inputs_to_remove.contains_key(&id) }); } -} \ No newline at end of file +}