reflect: improve inout link algorithm
This commit is contained in:
parent
d1e49b7eb4
commit
1ac78695c6
2 changed files with 172 additions and 24 deletions
|
@ -41,11 +41,12 @@ pub(crate) fn compile_spirv(source: &ShaderSource) -> Result<SpirvCompilation, S
|
||||||
let vertex = load_module(&vertex);
|
let vertex = load_module(&vertex);
|
||||||
let fragment = load_module(&fragment);
|
let fragment = load_module(&fragment);
|
||||||
let mut fragment = Builder::new_from_module(fragment);
|
let mut fragment = Builder::new_from_module(fragment);
|
||||||
|
let mut vertex = Builder::new_from_module(vertex);
|
||||||
|
|
||||||
let mut pass = link_input_outputs::LinkInputs::new(&vertex, &mut fragment);
|
let mut pass = link_input_outputs::LinkInputs::new(&mut vertex, &mut fragment, false);
|
||||||
pass.do_pass();
|
pass.do_pass();
|
||||||
|
|
||||||
let vertex = vertex.assemble();
|
let vertex = vertex.module().assemble();
|
||||||
let fragment = fragment.module().assemble();
|
let fragment = fragment.module().assemble();
|
||||||
|
|
||||||
Ok(SpirvCompilation { vertex, fragment })
|
Ok(SpirvCompilation { vertex, fragment })
|
||||||
|
|
|
@ -1,13 +1,20 @@
|
||||||
use rspirv::dr::{Builder, Module, Operand};
|
use rspirv::dr::{Builder, Module, Operand};
|
||||||
use rustc_hash::{FxHashMap, FxHashSet};
|
use rustc_hash::{FxHashMap};
|
||||||
use spirv::{Decoration, Op, StorageClass};
|
use spirv::{Decoration, Op, StorageClass};
|
||||||
|
|
||||||
|
/// Do DCE on inputs and link
|
||||||
pub struct LinkInputs<'a> {
|
pub struct LinkInputs<'a> {
|
||||||
pub frag_builder: &'a mut Builder,
|
pub frag_builder: &'a mut Builder,
|
||||||
pub inputs: FxHashSet<spirv::Word>,
|
pub vert_builder: &'a mut Builder,
|
||||||
|
|
||||||
|
pub outputs: Vec<(u32, spirv::Word)>,
|
||||||
|
// pub inputs: Vec<(u32, spirv::Word)>,
|
||||||
|
pub inputs_to_remove: FxHashMap<spirv::Word, u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> LinkInputs<'a> {
|
impl<'a> LinkInputs<'a> {
|
||||||
|
|
||||||
|
/// Get the value of the location of the inout in the module
|
||||||
fn find_location(module: &Module, id: spirv::Word) -> Option<u32> {
|
fn find_location(module: &Module, id: spirv::Word) -> Option<u32> {
|
||||||
module.annotations.iter().find_map(|op| {
|
module.annotations.iter().find_map(|op| {
|
||||||
if op.class.opcode != Op::Decorate {
|
if op.class.opcode != Op::Decorate {
|
||||||
|
@ -33,9 +40,37 @@ impl<'a> LinkInputs<'a> {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new(vert: &'a Module, frag: &'a mut Builder) -> Self {
|
/// Get a mutable reference to the inout in the module
|
||||||
let mut inputs = FxHashSet::default();
|
fn find_location_operand(module: &mut Module, id: spirv::Word) -> Option<&mut u32> {
|
||||||
let mut bindings = FxHashMap::default();
|
module.annotations.iter_mut().find_map(|op| {
|
||||||
|
if op.class.opcode != Op::Decorate {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(Operand::Decoration(Decoration::Location)) = op.operands.get(1) else {
|
||||||
|
return None;
|
||||||
|
};
|
||||||
|
|
||||||
|
let Some(&Operand::IdRef(target)) = op.operands.get(0) else {
|
||||||
|
return None;
|
||||||
|
};
|
||||||
|
|
||||||
|
if target != id {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(Operand::LiteralBit32(binding)) = op.operands.get_mut(2) else {
|
||||||
|
return None;
|
||||||
|
};
|
||||||
|
return Some(binding);
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new(vert: &'a mut Builder, frag: &'a mut Builder, keep_if_bound: bool) -> Self {
|
||||||
|
let mut outputs = FxHashMap::default();
|
||||||
|
let mut inputs_to_remove = FxHashMap::default();
|
||||||
|
// let mut inputs = FxHashMap::default();
|
||||||
|
|
||||||
for global in frag.module_ref().types_global_values.iter() {
|
for global in frag.module_ref().types_global_values.iter() {
|
||||||
if global.class.opcode == spirv::Op::Variable
|
if global.class.opcode == spirv::Op::Variable
|
||||||
&& global.operands[0] == Operand::StorageClass(StorageClass::Input)
|
&& global.operands[0] == Operand::StorageClass(StorageClass::Input)
|
||||||
|
@ -45,36 +80,59 @@ impl<'a> LinkInputs<'a> {
|
||||||
continue;
|
continue;
|
||||||
};
|
};
|
||||||
|
|
||||||
inputs.insert(id);
|
inputs_to_remove.insert(id, location);
|
||||||
bindings.insert(location, id);
|
// inputs.insert(location, id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for global in vert.types_global_values.iter() {
|
for global in vert.module_ref().types_global_values.iter() {
|
||||||
if global.class.opcode == spirv::Op::Variable
|
if global.class.opcode == spirv::Op::Variable
|
||||||
&& global.operands[0] == Operand::StorageClass(StorageClass::Output)
|
&& global.operands[0] == Operand::StorageClass(StorageClass::Output)
|
||||||
{
|
{
|
||||||
if let Some(id) = global.result_id {
|
if let Some(id) = global.result_id {
|
||||||
let Some(location) = Self::find_location(vert, id) else {
|
let Some(location) = Self::find_location(vert.module_ref(), id) else {
|
||||||
continue;
|
continue;
|
||||||
};
|
};
|
||||||
if let Some(frag_ref) = bindings.get(&location) {
|
|
||||||
|
// Add to list of outputs
|
||||||
|
outputs.insert(location, id);
|
||||||
|
|
||||||
|
// Keep the input, if it is bound to both stages.
|
||||||
|
// Otherwise, do DCE analysis on the input, and remove it
|
||||||
|
// regardless if bound.
|
||||||
|
if keep_if_bound {
|
||||||
|
if let Some(&frag_ref) = inputs_to_remove.get(&location) {
|
||||||
// if something is bound to the same location in the vertex shader,
|
// if something is bound to the same location in the vertex shader,
|
||||||
// we're good.
|
// we're good.
|
||||||
inputs.remove(&frag_ref);
|
inputs_to_remove.remove(&frag_ref);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut outputs: Vec<(u32, spirv::Word)> = outputs.into_iter().collect();
|
||||||
|
// let mut inputs: Vec<(u32, spirv::Word)> = inputs.into_iter().collect();
|
||||||
|
|
||||||
|
outputs.sort_by(|&(a, _), &(b, _)| a.cmp(&b));
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
frag_builder: frag,
|
frag_builder: frag,
|
||||||
inputs,
|
vert_builder: vert,
|
||||||
|
outputs,
|
||||||
|
// inputs,
|
||||||
|
inputs_to_remove,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
pub fn do_pass(&mut self) {
|
pub fn do_pass(&mut self) {
|
||||||
|
self.trim_inputs();
|
||||||
|
self.reorder_inputs();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn trim_inputs(&mut self) {
|
||||||
let functions = &self.frag_builder.module_ref().functions;
|
let functions = &self.frag_builder.module_ref().functions;
|
||||||
|
|
||||||
// literally if it has any reference at all we can keep it
|
// literally if it has any reference at all we can keep it
|
||||||
|
@ -82,8 +140,8 @@ impl<'a> LinkInputs<'a> {
|
||||||
for param in &function.parameters {
|
for param in &function.parameters {
|
||||||
for op in ¶m.operands {
|
for op in ¶m.operands {
|
||||||
if let Some(word) = op.id_ref_any() {
|
if let Some(word) = op.id_ref_any() {
|
||||||
if self.inputs.contains(&word) {
|
if self.inputs_to_remove.contains_key(&word) {
|
||||||
self.inputs.remove(&word);
|
self.inputs_to_remove.remove(&word);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -93,8 +151,8 @@ impl<'a> LinkInputs<'a> {
|
||||||
for inst in &block.instructions {
|
for inst in &block.instructions {
|
||||||
for op in &inst.operands {
|
for op in &inst.operands {
|
||||||
if let Some(word) = op.id_ref_any() {
|
if let Some(word) = op.id_ref_any() {
|
||||||
if self.inputs.contains(&word) {
|
if self.inputs_to_remove.contains_key(&word) {
|
||||||
self.inputs.remove(&word);
|
self.inputs_to_remove.remove(&word);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -107,7 +165,7 @@ impl<'a> LinkInputs<'a> {
|
||||||
self.frag_builder.module_mut().debug_names.retain(|instr| {
|
self.frag_builder.module_mut().debug_names.retain(|instr| {
|
||||||
for op in &instr.operands {
|
for op in &instr.operands {
|
||||||
if let Some(word) = op.id_ref_any() {
|
if let Some(word) = op.id_ref_any() {
|
||||||
if self.inputs.contains(&word) {
|
if self.inputs_to_remove.contains_key(&word) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -118,7 +176,7 @@ impl<'a> LinkInputs<'a> {
|
||||||
self.frag_builder.module_mut().annotations.retain(|instr| {
|
self.frag_builder.module_mut().annotations.retain(|instr| {
|
||||||
for op in &instr.operands {
|
for op in &instr.operands {
|
||||||
if let Some(word) = op.id_ref_any() {
|
if let Some(word) = op.id_ref_any() {
|
||||||
if self.inputs.contains(&word) {
|
if self.inputs_to_remove.contains_key(&word) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -129,7 +187,7 @@ impl<'a> LinkInputs<'a> {
|
||||||
for entry_point in self.frag_builder.module_mut().entry_points.iter_mut() {
|
for entry_point in self.frag_builder.module_mut().entry_points.iter_mut() {
|
||||||
entry_point.operands.retain(|op| {
|
entry_point.operands.retain(|op| {
|
||||||
if let Some(word) = op.id_ref_any() {
|
if let Some(word) = op.id_ref_any() {
|
||||||
if self.inputs.contains(&word) {
|
if self.inputs_to_remove.contains_key(&word) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -145,7 +203,96 @@ impl<'a> LinkInputs<'a> {
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
!self.inputs.contains(&id)
|
!self.inputs_to_remove.contains_key(&id)
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn reorder_inputs(&mut self) {
|
||||||
|
// Preconditions:
|
||||||
|
// - trim_inputs is called, so all dead inputs are gone from the frag builder
|
||||||
|
|
||||||
|
|
||||||
|
// We want to have all the dead inputs get ordered last, but otherwise ensure
|
||||||
|
// that all locations are ordered.
|
||||||
|
|
||||||
|
let mut dead_inputs = self.inputs_to_remove.values().collect::<Vec<_>>();
|
||||||
|
dead_inputs.sort();
|
||||||
|
|
||||||
|
let Some(&&first) = dead_inputs.first() else {
|
||||||
|
// If there are no dead inputs then things are contiguous
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Mapping of old bindings -> new bindings
|
||||||
|
let mut remapping = FxHashMap::default();
|
||||||
|
|
||||||
|
// Start at the first dead input
|
||||||
|
let mut alloc = first;
|
||||||
|
|
||||||
|
for (binding, _) in &self.outputs {
|
||||||
|
if *binding < alloc {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !dead_inputs.contains(&binding) {
|
||||||
|
remapping.insert(*binding, alloc);
|
||||||
|
alloc += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now assign dead inputs the end
|
||||||
|
for binding in &dead_inputs {
|
||||||
|
remapping.insert(**binding, alloc);
|
||||||
|
alloc += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// eprintln!("dead: {:#?}", dead_inputs);
|
||||||
|
|
||||||
|
// eprintln!("remapping: {:#?}", remapping);
|
||||||
|
|
||||||
|
let frag_clone = self.frag_builder.module_ref().clone();
|
||||||
|
let frag_mut = self.frag_builder.module_mut();
|
||||||
|
|
||||||
|
for global in frag_clone.types_global_values {
|
||||||
|
if global.class.opcode == spirv::Op::Variable
|
||||||
|
&& global.operands[0] == Operand::StorageClass(StorageClass::Input)
|
||||||
|
{
|
||||||
|
if let Some(id) = global.result_id {
|
||||||
|
let Some(location) = Self::find_location_operand(frag_mut, id) else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
let Some(&remapping) = remapping.get(&location) else {
|
||||||
|
continue
|
||||||
|
};
|
||||||
|
|
||||||
|
// eprintln!("frag: remapped {} to {}", *location, remapping);
|
||||||
|
|
||||||
|
*location = remapping;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let vert_clone = self.vert_builder.module_ref().clone();
|
||||||
|
let vert_mut = self.vert_builder.module_mut();
|
||||||
|
|
||||||
|
for global in vert_clone.types_global_values {
|
||||||
|
if global.class.opcode == spirv::Op::Variable
|
||||||
|
&& global.operands[0] == Operand::StorageClass(StorageClass::Output)
|
||||||
|
{
|
||||||
|
if let Some(id) = global.result_id {
|
||||||
|
let Some(location) = Self::find_location_operand(vert_mut, id) else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
let Some(&remapping) = remapping.get(&location) else {
|
||||||
|
continue
|
||||||
|
};
|
||||||
|
|
||||||
|
// eprintln!("vert: remapped {} to {}", *location, remapping);
|
||||||
|
*location = remapping;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue