diff --git a/piet-gpu/shader/elements.comp b/piet-gpu/shader/elements.comp index d37a2c6..e4bbfec 100644 --- a/piet-gpu/shader/elements.comp +++ b/piet-gpu/shader/elements.comp @@ -175,6 +175,7 @@ shared State sh_state[WG_SIZE]; shared uint sh_part_ix; shared State sh_prefix; +shared uint sh_flag; void main() { State th_state[N_ROWS]; @@ -219,38 +220,56 @@ void main() { // Publish aggregate for this partition if (gl_LocalInvocationID.x == WG_SIZE - 1) { - // Note: with memory model, we'd want to generate the atomic store version of this. State_write(state_aggregate_ref(part_ix), agg); + if (part_ix == 0) { + State_write(state_prefix_ref(part_ix), agg); + } } + // Write flag with release semantics; this is done portably with a barrier. memoryBarrierBuffer(); if (gl_LocalInvocationID.x == WG_SIZE - 1) { uint flag = FLAG_AGGREGATE_READY; if (part_ix == 0) { - State_write(state_prefix_ref(part_ix), agg); flag = FLAG_PREFIX_READY; } state[state_flag_index(part_ix)] = flag; - if (part_ix != 0) { - // step 4 of paper: decoupled lookback - uint look_back_ix = part_ix - 1; + } + if (part_ix != 0) { + // step 4 of paper: decoupled lookback + uint look_back_ix = part_ix - 1; - State their_agg; - uint their_ix = 0; - while (true) { - flag = state[state_flag_index(look_back_ix)]; - if (flag == FLAG_PREFIX_READY) { + State their_agg; + uint their_ix = 0; + while (true) { + // Read flag with acquire semantics. + if (gl_LocalInvocationID.x == WG_SIZE - 1) { + sh_flag = state[state_flag_index(look_back_ix)]; + } + // The flag load is done only in the last thread. However, because the + // translation of memoryBarrierBuffer to Metal requires uniform control + // flow, we broadcast it to all threads. + barrier(); + memoryBarrierBuffer(); + uint flag = sh_flag; + + if (flag == FLAG_PREFIX_READY) { + if (gl_LocalInvocationID.x == WG_SIZE - 1) { State their_prefix = State_read(state_prefix_ref(look_back_ix)); exclusive = combine_state(their_prefix, exclusive); - break; - } else if (flag == FLAG_AGGREGATE_READY) { + } + break; + } else if (flag == FLAG_AGGREGATE_READY) { + if (gl_LocalInvocationID.x == WG_SIZE - 1) { their_agg = State_read(state_aggregate_ref(look_back_ix)); exclusive = combine_state(their_agg, exclusive); - look_back_ix--; - their_ix = 0; - continue; } - // else spin + look_back_ix--; + their_ix = 0; + continue; + } + // else spin + if (gl_LocalInvocationID.x == WG_SIZE - 1) { // Unfortunately there's no guarantee of forward progress of other // workgroups, so compute a bit of the aggregate before trying again. // In the worst case, spinning stops when the aggregate is complete. @@ -265,22 +284,29 @@ void main() { if (their_ix == PARTITION_SIZE) { exclusive = combine_state(their_agg, exclusive); if (look_back_ix == 0) { - break; + sh_flag = FLAG_PREFIX_READY; + } else { + look_back_ix--; + their_ix = 0; } - look_back_ix--; - their_ix = 0; } } - - // step 5 of paper: compute inclusive prefix + barrier(); + flag = sh_flag; + if (flag == FLAG_PREFIX_READY) { + break; + } + } + // step 5 of paper: compute inclusive prefix + if (gl_LocalInvocationID.x == WG_SIZE - 1) { State inclusive_prefix = combine_state(exclusive, agg); sh_prefix = exclusive; State_write(state_prefix_ref(part_ix), inclusive_prefix); } - } - memoryBarrierBuffer(); - if (gl_LocalInvocationID.x == WG_SIZE - 1 && part_ix != 0) { - state[state_flag_index(part_ix)] = FLAG_PREFIX_READY; + memoryBarrierBuffer(); + if (gl_LocalInvocationID.x == WG_SIZE - 1) { + state[state_flag_index(part_ix)] = FLAG_PREFIX_READY; + } } barrier(); if (part_ix != 0) { diff --git a/piet-gpu/shader/elements.spv b/piet-gpu/shader/elements.spv index d1fd39a..60517b0 100644 Binary files a/piet-gpu/shader/elements.spv and b/piet-gpu/shader/elements.spv differ