Add memory barrier to elements shader

The flag read needs acquire semantics. There are a number of ways that
could be expressed, but a generally portable way is to have a barrier
after. However, in the translation to Metal, that barrier needs to be in
uniform control flow. This patch does some workarounds to ensure that.
This commit is contained in:
Raph Levien 2021-11-02 10:15:52 -07:00
parent c648038967
commit e50d5c1f58
2 changed files with 46 additions and 23 deletions

View file

@ -175,6 +175,7 @@ shared State sh_state[WG_SIZE];
shared uint sh_part_ix; shared uint sh_part_ix;
shared State sh_prefix; shared State sh_prefix;
shared uint sh_flag;
void main() { void main() {
State th_state[N_ROWS]; State th_state[N_ROWS];
@ -230,27 +231,42 @@ void main() {
flag = FLAG_PREFIX_READY; flag = FLAG_PREFIX_READY;
} }
state[state_flag_index(part_ix)] = flag; state[state_flag_index(part_ix)] = flag;
if (part_ix != 0) { }
// step 4 of paper: decoupled lookback if (part_ix != 0) {
uint look_back_ix = part_ix - 1; // step 4 of paper: decoupled lookback
uint look_back_ix = part_ix - 1;
State their_agg; State their_agg;
uint their_ix = 0; uint their_ix = 0;
while (true) { while (true) {
flag = state[state_flag_index(look_back_ix)]; if (gl_LocalInvocationID.x == WG_SIZE - 1) {
if (flag == FLAG_PREFIX_READY) { sh_flag = state[state_flag_index(look_back_ix)];
}
// The flag load is done only in the last thread. However, because the
// translation of memoryBarrierBuffer to Metal requires uniform control
// flow, we broadcast it to all threads.
barrier();
memoryBarrierBuffer();
uint flag = sh_flag;
if (flag == FLAG_PREFIX_READY) {
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
State their_prefix = State_read(state_prefix_ref(look_back_ix)); State their_prefix = State_read(state_prefix_ref(look_back_ix));
exclusive = combine_state(their_prefix, exclusive); exclusive = combine_state(their_prefix, exclusive);
break; }
} else if (flag == FLAG_AGGREGATE_READY) { break;
} else if (flag == FLAG_AGGREGATE_READY) {
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
their_agg = State_read(state_aggregate_ref(look_back_ix)); their_agg = State_read(state_aggregate_ref(look_back_ix));
exclusive = combine_state(their_agg, exclusive); exclusive = combine_state(their_agg, exclusive);
look_back_ix--;
their_ix = 0;
continue;
} }
// else spin look_back_ix--;
their_ix = 0;
continue;
}
// else spin
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
// Unfortunately there's no guarantee of forward progress of other // Unfortunately there's no guarantee of forward progress of other
// workgroups, so compute a bit of the aggregate before trying again. // workgroups, so compute a bit of the aggregate before trying again.
// In the worst case, spinning stops when the aggregate is complete. // In the worst case, spinning stops when the aggregate is complete.
@ -265,22 +281,29 @@ void main() {
if (their_ix == PARTITION_SIZE) { if (their_ix == PARTITION_SIZE) {
exclusive = combine_state(their_agg, exclusive); exclusive = combine_state(their_agg, exclusive);
if (look_back_ix == 0) { if (look_back_ix == 0) {
break; sh_flag = FLAG_PREFIX_READY;
} else {
look_back_ix--;
their_ix = 0;
} }
look_back_ix--;
their_ix = 0;
} }
} }
barrier();
// step 5 of paper: compute inclusive prefix flag = sh_flag;
if (flag == FLAG_PREFIX_READY) {
break;
}
}
// step 5 of paper: compute inclusive prefix
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
State inclusive_prefix = combine_state(exclusive, agg); State inclusive_prefix = combine_state(exclusive, agg);
sh_prefix = exclusive; sh_prefix = exclusive;
State_write(state_prefix_ref(part_ix), inclusive_prefix); State_write(state_prefix_ref(part_ix), inclusive_prefix);
} }
} memoryBarrierBuffer();
memoryBarrierBuffer(); if (gl_LocalInvocationID.x == WG_SIZE - 1) {
if (gl_LocalInvocationID.x == WG_SIZE - 1 && part_ix != 0) { state[state_flag_index(part_ix)] = FLAG_PREFIX_READY;
state[state_flag_index(part_ix)] = FLAG_PREFIX_READY; }
} }
barrier(); barrier();
if (part_ix != 0) { if (part_ix != 0) {

Binary file not shown.