From 3820e4b2f414955003c99240256bec054d3259df Mon Sep 17 00:00:00 2001 From: Raph Levien Date: Sat, 6 Nov 2021 21:43:09 -0700 Subject: [PATCH] Add missing file Also add finish_timestamps call, which is needed for DX12 (there are other issues but this is an easy fix for that one). --- piet-gpu-hal/examples/collatz.rs | 1 + tests/shader/prefix_scan.comp | 77 ++++++++++++++++++++++++++++++++ tests/src/runner.rs | 1 + 3 files changed, 79 insertions(+) create mode 100644 tests/shader/prefix_scan.comp diff --git a/piet-gpu-hal/examples/collatz.rs b/piet-gpu-hal/examples/collatz.rs index e974cde..cad508e 100644 --- a/piet-gpu-hal/examples/collatz.rs +++ b/piet-gpu-hal/examples/collatz.rs @@ -21,6 +21,7 @@ fn main() { cmd_buf.write_timestamp(&query_pool, 0); cmd_buf.dispatch(&pipeline, &descriptor_set, (256, 1, 1), (1, 1, 1)); cmd_buf.write_timestamp(&query_pool, 1); + cmd_buf.finish_timestamps(&query_pool); cmd_buf.host_barrier(); cmd_buf.finish(); let submitted = session.run_cmd_buf(cmd_buf, &[], &[]).unwrap(); diff --git a/tests/shader/prefix_scan.comp b/tests/shader/prefix_scan.comp new file mode 100644 index 0000000..59903ab --- /dev/null +++ b/tests/shader/prefix_scan.comp @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT OR Unlicense + +// A scan for a tree reduction prefix scan (either root or not, by ifdef). + +#version 450 + +#define N_ROWS 8 +#define LG_WG_SIZE 9 +#define WG_SIZE (1 << LG_WG_SIZE) +#define PARTITION_SIZE (WG_SIZE * N_ROWS) + +layout(local_size_x = WG_SIZE, local_size_y = 1) in; + +struct Monoid { + uint element; +}; + +layout(set = 0, binding = 0) buffer DataBuf { + Monoid[] data; +}; + +#ifndef ROOT +layout(set = 0, binding = 1) buffer ParentBuf { + Monoid[] parent; +}; +#endif + +shared Monoid sh_scratch[WG_SIZE]; + +Monoid combine_monoid(Monoid a, Monoid b) { + return Monoid(a.element + b.element); +} + +void main() { + Monoid local[N_ROWS]; + + uint ix = gl_GlobalInvocationID.x * N_ROWS; + + // TODO: gate buffer read + local[0] = data[ix]; + for (uint i = 1; i < N_ROWS; i++) { + local[i] = combine_monoid(local[i - 1], data[ix + i]); + } + Monoid agg = local[N_ROWS - 1]; + sh_scratch[gl_LocalInvocationID.x] = agg; + for (uint i = 0; i < LG_WG_SIZE; i++) { + barrier(); + if (gl_LocalInvocationID.x >= (1 << i)) { + Monoid other = sh_scratch[gl_LocalInvocationID.x - (1 << i)]; + agg = combine_monoid(other, agg); + } + barrier(); + sh_scratch[gl_LocalInvocationID.x] = agg; + } + + barrier(); + // This could be a semigroup instead of a monoid if we reworked the + // conditional logic, but that might impact performance. + Monoid row = Monoid(0); +#ifdef ROOT + if (gl_LocalInvocationID.x > 0) { + row = sh_scratch[gl_LocalInvocationID.x - 1]; + } +#else + if (gl_WorkGroupID.x > 0) { + row = parent[gl_WorkGroupID.x - 1]; + } + if (gl_LocalInvocationID.x > 0) { + row = combine_monoid(row, sh_scratch[gl_LocalInvocationID.x - 1]); + } +#endif + for (uint i = 0; i < N_ROWS; i++) { + Monoid m = combine_monoid(row, local[i]); + // TODO: gate buffer write + data[ix + i] = m; + } +} diff --git a/tests/src/runner.rs b/tests/src/runner.rs index ce89961..9bfde3b 100644 --- a/tests/src/runner.rs +++ b/tests/src/runner.rs @@ -76,6 +76,7 @@ impl Runner { pub unsafe fn submit(&mut self, commands: Commands) -> f64 { let mut cmd_buf = commands.cmd_buf; let query_pool = commands.query_pool; + cmd_buf.finish_timestamps(&query_pool); cmd_buf.host_barrier(); cmd_buf.finish(); let submitted = self.session.run_cmd_buf(cmd_buf, &[], &[]).unwrap();