// SPDX-License-Identifier: Apache-2.0 OR MIT OR Unlicense // A prefix sum. #version 450 #define N_ROWS 16 #define LG_WG_SIZE 9 #define WG_SIZE (1 << LG_WG_SIZE) #define PARTITION_SIZE (WG_SIZE * N_ROWS) layout(local_size_x = WG_SIZE, local_size_y = 1) in; struct Monoid { uint element; }; layout(set = 0, binding = 0) readonly buffer InBuf { Monoid[] inbuf; }; layout(set = 0, binding = 1) buffer OutBuf { Monoid[] outbuf; }; // These correspond to X, A, P respectively in the prefix sum paper. #define FLAG_NOT_READY 0 #define FLAG_AGGREGATE_READY 1 #define FLAG_PREFIX_READY 2 struct State { uint flag; Monoid aggregate; Monoid prefix; }; layout(set = 0, binding = 2) volatile buffer StateBuf { uint part_counter; State[] state; }; shared Monoid sh_scratch[WG_SIZE]; Monoid combine_monoid(Monoid a, Monoid b) { return Monoid(a.element + b.element); } shared uint sh_part_ix; shared Monoid sh_prefix; shared uint sh_flag; void main() { Monoid local[N_ROWS]; // Determine partition to process by atomic counter (described in Section // 4.4 of prefix sum paper). if (gl_LocalInvocationID.x == 0) { sh_part_ix = atomicAdd(part_counter, 1); } barrier(); uint part_ix = sh_part_ix; uint ix = part_ix * PARTITION_SIZE + gl_LocalInvocationID.x * N_ROWS; // TODO: gate buffer read? (evaluate whether shader check or // CPU-side padding is better) local[0] = inbuf[ix]; for (uint i = 1; i < N_ROWS; i++) { local[i] = combine_monoid(local[i - 1], inbuf[ix + i]); } Monoid agg = local[N_ROWS - 1]; sh_scratch[gl_LocalInvocationID.x] = agg; for (uint i = 0; i < LG_WG_SIZE; i++) { barrier(); if (gl_LocalInvocationID.x >= (1 << i)) { Monoid other = sh_scratch[gl_LocalInvocationID.x - (1 << i)]; agg = combine_monoid(other, agg); } barrier(); sh_scratch[gl_LocalInvocationID.x] = agg; } // Publish aggregate for this partition if (gl_LocalInvocationID.x == WG_SIZE - 1) { state[part_ix].aggregate = agg; if (part_ix == 0) { state[0].prefix = agg; } } // Write flag with release semantics; this is done portably with a barrier. memoryBarrierBuffer(); if (gl_LocalInvocationID.x == WG_SIZE - 1) { uint flag = FLAG_AGGREGATE_READY; if (part_ix == 0) { flag = FLAG_PREFIX_READY; } state[part_ix].flag = flag; } Monoid exclusive = Monoid(0); if (part_ix != 0) { // step 4 of paper: decoupled lookback uint look_back_ix = part_ix - 1; Monoid their_agg; uint their_ix = 0; while (true) { // Read flag with acquire semantics. if (gl_LocalInvocationID.x == WG_SIZE - 1) { sh_flag = state[look_back_ix].flag; } // The flag load is done only in the last thread. However, because the // translation of memoryBarrierBuffer to Metal requires uniform control // flow, we broadcast it to all threads. barrier(); memoryBarrierBuffer(); uint flag = sh_flag; if (flag == FLAG_PREFIX_READY) { if (gl_LocalInvocationID.x == WG_SIZE - 1) { Monoid their_prefix = state[look_back_ix].prefix; exclusive = combine_monoid(their_prefix, exclusive); } break; } else if (flag == FLAG_AGGREGATE_READY) { if (gl_LocalInvocationID.x == WG_SIZE - 1) { their_agg = state[look_back_ix].aggregate; exclusive = combine_monoid(their_agg, exclusive); } look_back_ix--; their_ix = 0; continue; } // else spin if (gl_LocalInvocationID.x == WG_SIZE - 1) { // Unfortunately there's no guarantee of forward progress of other // workgroups, so compute a bit of the aggregate before trying again. // In the worst case, spinning stops when the aggregate is complete. Monoid m = inbuf[look_back_ix * PARTITION_SIZE + their_ix]; if (their_ix == 0) { their_agg = m; } else { their_agg = combine_monoid(their_agg, m); } their_ix++; if (their_ix == PARTITION_SIZE) { exclusive = combine_monoid(their_agg, exclusive); if (look_back_ix == 0) { sh_flag = FLAG_PREFIX_READY; } else { look_back_ix--; their_ix = 0; } } } barrier(); flag = sh_flag; if (flag == FLAG_PREFIX_READY) { break; } } // step 5 of paper: compute inclusive prefix if (gl_LocalInvocationID.x == WG_SIZE - 1) { Monoid inclusive_prefix = combine_monoid(exclusive, agg); sh_prefix = exclusive; state[part_ix].prefix = inclusive_prefix; } memoryBarrierBuffer(); if (gl_LocalInvocationID.x == WG_SIZE - 1) { state[part_ix].flag = FLAG_PREFIX_READY; } } barrier(); if (part_ix != 0) { exclusive = sh_prefix; } Monoid row = exclusive; if (gl_LocalInvocationID.x > 0) { Monoid other = sh_scratch[gl_LocalInvocationID.x - 1]; row = combine_monoid(row, other); } for (uint i = 0; i < N_ROWS; i++) { Monoid m = combine_monoid(row, local[i]); // Make sure buffer allocation is padded appropriately. outbuf[ix + i] = m; } }