// SPDX-License-Identifier: Apache-2.0 OR MIT OR Unlicense // A scan for a tree reduction prefix scan (either root or not, by ifdef). #version 450 #define N_ROWS 8 #define LG_WG_SIZE 9 #define WG_SIZE (1 << LG_WG_SIZE) #define PARTITION_SIZE (WG_SIZE * N_ROWS) layout(local_size_x = WG_SIZE, local_size_y = 1) in; struct Monoid { uint element; }; layout(set = 0, binding = 0) buffer DataBuf { Monoid[] data; }; #ifndef ROOT layout(set = 0, binding = 1) buffer ParentBuf { Monoid[] parent; }; #endif shared Monoid sh_scratch[WG_SIZE]; Monoid combine_monoid(Monoid a, Monoid b) { return Monoid(a.element + b.element); } void main() { Monoid local[N_ROWS]; uint ix = gl_GlobalInvocationID.x * N_ROWS; // TODO: gate buffer read local[0] = data[ix]; for (uint i = 1; i < N_ROWS; i++) { local[i] = combine_monoid(local[i - 1], data[ix + i]); } Monoid agg = local[N_ROWS - 1]; sh_scratch[gl_LocalInvocationID.x] = agg; for (uint i = 0; i < LG_WG_SIZE; i++) { barrier(); if (gl_LocalInvocationID.x >= (1 << i)) { Monoid other = sh_scratch[gl_LocalInvocationID.x - (1 << i)]; agg = combine_monoid(other, agg); } barrier(); sh_scratch[gl_LocalInvocationID.x] = agg; } barrier(); // This could be a semigroup instead of a monoid if we reworked the // conditional logic, but that might impact performance. Monoid row = Monoid(0); #ifdef ROOT if (gl_LocalInvocationID.x > 0) { row = sh_scratch[gl_LocalInvocationID.x - 1]; } #else if (gl_WorkGroupID.x > 0) { row = parent[gl_WorkGroupID.x - 1]; } if (gl_LocalInvocationID.x > 0) { row = combine_monoid(row, sh_scratch[gl_LocalInvocationID.x - 1]); } #endif for (uint i = 0; i < N_ROWS; i++) { Monoid m = combine_monoid(row, local[i]); // TODO: gate buffer write data[ix + i] = m; } }