vello/tests/shader/prefix.comp
Raph Levien a0648a2153 Portability fixes
The MSL translation of the prefix example had its bindings permuted; a
flag prevents this (but, as is typical for shader translation,
potentially creates other problems).

Also use explicit unsigned literal to avoid DXC warnings.
2021-11-11 07:08:39 -08:00

189 lines
5.7 KiB
GLSL

// SPDX-License-Identifier: Apache-2.0 OR MIT OR Unlicense
// A prefix sum.
#version 450
#define N_ROWS 16
#define LG_WG_SIZE 9
#define WG_SIZE (1 << LG_WG_SIZE)
#define PARTITION_SIZE (WG_SIZE * N_ROWS)
layout(local_size_x = WG_SIZE, local_size_y = 1) in;
struct Monoid {
uint element;
};
layout(set = 0, binding = 0) readonly buffer InBuf {
Monoid[] inbuf;
};
layout(set = 0, binding = 1) buffer OutBuf {
Monoid[] outbuf;
};
// These correspond to X, A, P respectively in the prefix sum paper.
#define FLAG_NOT_READY 0
#define FLAG_AGGREGATE_READY 1
#define FLAG_PREFIX_READY 2
struct State {
uint flag;
Monoid aggregate;
Monoid prefix;
};
layout(set = 0, binding = 2) volatile buffer StateBuf {
uint part_counter;
State[] state;
};
shared Monoid sh_scratch[WG_SIZE];
Monoid combine_monoid(Monoid a, Monoid b) {
return Monoid(a.element + b.element);
}
shared uint sh_part_ix;
shared Monoid sh_prefix;
shared uint sh_flag;
void main() {
Monoid local[N_ROWS];
// Determine partition to process by atomic counter (described in Section
// 4.4 of prefix sum paper).
if (gl_LocalInvocationID.x == 0) {
sh_part_ix = atomicAdd(part_counter, 1);
}
barrier();
uint part_ix = sh_part_ix;
uint ix = part_ix * PARTITION_SIZE + gl_LocalInvocationID.x * N_ROWS;
// TODO: gate buffer read? (evaluate whether shader check or
// CPU-side padding is better)
local[0] = inbuf[ix];
for (uint i = 1; i < N_ROWS; i++) {
local[i] = combine_monoid(local[i - 1], inbuf[ix + i]);
}
Monoid agg = local[N_ROWS - 1];
sh_scratch[gl_LocalInvocationID.x] = agg;
for (uint i = 0; i < LG_WG_SIZE; i++) {
barrier();
if (gl_LocalInvocationID.x >= (1u << i)) {
Monoid other = sh_scratch[gl_LocalInvocationID.x - (1u << i)];
agg = combine_monoid(other, agg);
}
barrier();
sh_scratch[gl_LocalInvocationID.x] = agg;
}
// Publish aggregate for this partition
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
state[part_ix].aggregate = agg;
if (part_ix == 0) {
state[0].prefix = agg;
}
}
// Write flag with release semantics; this is done portably with a barrier.
memoryBarrierBuffer();
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
uint flag = FLAG_AGGREGATE_READY;
if (part_ix == 0) {
flag = FLAG_PREFIX_READY;
}
state[part_ix].flag = flag;
}
Monoid exclusive = Monoid(0);
if (part_ix != 0) {
// step 4 of paper: decoupled lookback
uint look_back_ix = part_ix - 1;
Monoid their_agg;
uint their_ix = 0;
while (true) {
// Read flag with acquire semantics.
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
sh_flag = state[look_back_ix].flag;
}
// The flag load is done only in the last thread. However, because the
// translation of memoryBarrierBuffer to Metal requires uniform control
// flow, we broadcast it to all threads.
barrier();
memoryBarrierBuffer();
uint flag = sh_flag;
if (flag == FLAG_PREFIX_READY) {
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
Monoid their_prefix = state[look_back_ix].prefix;
exclusive = combine_monoid(their_prefix, exclusive);
}
break;
} else if (flag == FLAG_AGGREGATE_READY) {
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
their_agg = state[look_back_ix].aggregate;
exclusive = combine_monoid(their_agg, exclusive);
}
look_back_ix--;
their_ix = 0;
continue;
}
// else spin
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
// Unfortunately there's no guarantee of forward progress of other
// workgroups, so compute a bit of the aggregate before trying again.
// In the worst case, spinning stops when the aggregate is complete.
Monoid m = inbuf[look_back_ix * PARTITION_SIZE + their_ix];
if (their_ix == 0) {
their_agg = m;
} else {
their_agg = combine_monoid(their_agg, m);
}
their_ix++;
if (their_ix == PARTITION_SIZE) {
exclusive = combine_monoid(their_agg, exclusive);
if (look_back_ix == 0) {
sh_flag = FLAG_PREFIX_READY;
} else {
look_back_ix--;
their_ix = 0;
}
}
}
barrier();
flag = sh_flag;
if (flag == FLAG_PREFIX_READY) {
break;
}
}
// step 5 of paper: compute inclusive prefix
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
Monoid inclusive_prefix = combine_monoid(exclusive, agg);
sh_prefix = exclusive;
state[part_ix].prefix = inclusive_prefix;
}
memoryBarrierBuffer();
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
state[part_ix].flag = FLAG_PREFIX_READY;
}
}
barrier();
if (part_ix != 0) {
exclusive = sh_prefix;
}
Monoid row = exclusive;
if (gl_LocalInvocationID.x > 0) {
Monoid other = sh_scratch[gl_LocalInvocationID.x - 1];
row = combine_monoid(row, other);
}
for (uint i = 0; i < N_ROWS; i++) {
Monoid m = combine_monoid(row, local[i]);
// Make sure buffer allocation is padded appropriately.
outbuf[ix + i] = m;
}
}