diff --git a/piet-gpu/shader/elements.comp b/piet-gpu/shader/elements.comp index d37a2c6..bca2e2f 100644 --- a/piet-gpu/shader/elements.comp +++ b/piet-gpu/shader/elements.comp @@ -175,6 +175,7 @@ shared State sh_state[WG_SIZE]; shared uint sh_part_ix; shared State sh_prefix; +shared uint sh_flag; void main() { State th_state[N_ROWS]; @@ -230,27 +231,42 @@ void main() { flag = FLAG_PREFIX_READY; } state[state_flag_index(part_ix)] = flag; - if (part_ix != 0) { - // step 4 of paper: decoupled lookback - uint look_back_ix = part_ix - 1; + } + if (part_ix != 0) { + // step 4 of paper: decoupled lookback + uint look_back_ix = part_ix - 1; - State their_agg; - uint their_ix = 0; - while (true) { - flag = state[state_flag_index(look_back_ix)]; - if (flag == FLAG_PREFIX_READY) { + State their_agg; + uint their_ix = 0; + while (true) { + if (gl_LocalInvocationID.x == WG_SIZE - 1) { + sh_flag = state[state_flag_index(look_back_ix)]; + } + // The flag load is done only in the last thread. However, because the + // translation of memoryBarrierBuffer to Metal requires uniform control + // flow, we broadcast it to all threads. + barrier(); + memoryBarrierBuffer(); + uint flag = sh_flag; + + if (flag == FLAG_PREFIX_READY) { + if (gl_LocalInvocationID.x == WG_SIZE - 1) { State their_prefix = State_read(state_prefix_ref(look_back_ix)); exclusive = combine_state(their_prefix, exclusive); - break; - } else if (flag == FLAG_AGGREGATE_READY) { + } + break; + } else if (flag == FLAG_AGGREGATE_READY) { + if (gl_LocalInvocationID.x == WG_SIZE - 1) { their_agg = State_read(state_aggregate_ref(look_back_ix)); exclusive = combine_state(their_agg, exclusive); - look_back_ix--; - their_ix = 0; - continue; } - // else spin + look_back_ix--; + their_ix = 0; + continue; + } + // else spin + if (gl_LocalInvocationID.x == WG_SIZE - 1) { // Unfortunately there's no guarantee of forward progress of other // workgroups, so compute a bit of the aggregate before trying again. // In the worst case, spinning stops when the aggregate is complete. @@ -265,22 +281,29 @@ void main() { if (their_ix == PARTITION_SIZE) { exclusive = combine_state(their_agg, exclusive); if (look_back_ix == 0) { - break; + sh_flag = FLAG_PREFIX_READY; + } else { + look_back_ix--; + their_ix = 0; } - look_back_ix--; - their_ix = 0; } } - - // step 5 of paper: compute inclusive prefix + barrier(); + flag = sh_flag; + if (flag == FLAG_PREFIX_READY) { + break; + } + } + // step 5 of paper: compute inclusive prefix + if (gl_LocalInvocationID.x == WG_SIZE - 1) { State inclusive_prefix = combine_state(exclusive, agg); sh_prefix = exclusive; State_write(state_prefix_ref(part_ix), inclusive_prefix); } - } - memoryBarrierBuffer(); - if (gl_LocalInvocationID.x == WG_SIZE - 1 && part_ix != 0) { - state[state_flag_index(part_ix)] = FLAG_PREFIX_READY; + memoryBarrierBuffer(); + if (gl_LocalInvocationID.x == WG_SIZE - 1) { + state[state_flag_index(part_ix)] = FLAG_PREFIX_READY; + } } barrier(); if (part_ix != 0) { diff --git a/piet-gpu/shader/elements.spv b/piet-gpu/shader/elements.spv index d1fd39a..4daf42d 100644 Binary files a/piet-gpu/shader/elements.spv and b/piet-gpu/shader/elements.spv differ