#pragma clang diagnostic ignored "-Wmissing-prototypes" #pragma clang diagnostic ignored "-Wmissing-braces" #pragma clang diagnostic ignored "-Wunused-variable" #include #include #include using namespace metal; template struct spvUnsafeArray { T elements[Num ? Num : 1]; thread T& operator [] (size_t pos) thread { return elements[pos]; } constexpr const thread T& operator [] (size_t pos) const thread { return elements[pos]; } device T& operator [] (size_t pos) device { return elements[pos]; } constexpr const device T& operator [] (size_t pos) const device { return elements[pos]; } constexpr const constant T& operator [] (size_t pos) const constant { return elements[pos]; } threadgroup T& operator [] (size_t pos) threadgroup { return elements[pos]; } constexpr const threadgroup T& operator [] (size_t pos) const threadgroup { return elements[pos]; } }; struct Monoid { uint element; }; struct Monoid_1 { uint element; }; struct State { uint flag; Monoid_1 aggregate; Monoid_1 prefix; }; struct StateBuf { uint part_counter; State state[1]; }; struct InBuf { Monoid_1 inbuf[1]; }; struct OutBuf { Monoid_1 outbuf[1]; }; constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u); static inline __attribute__((always_inline)) Monoid combine_monoid(thread const Monoid& a, thread const Monoid& b) { return Monoid{ a.element + b.element }; } kernel void main0(const device InBuf& _67 [[buffer(0)]], device OutBuf& _372 [[buffer(1)]], volatile device StateBuf& _43 [[buffer(2)]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]]) { threadgroup uint sh_part_ix; threadgroup Monoid sh_scratch[512]; threadgroup uint sh_flag; threadgroup Monoid sh_prefix; if (gl_LocalInvocationID.x == 0u) { uint _47 = atomic_fetch_add_explicit((volatile device atomic_uint*)&_43.part_counter, 1u, memory_order_relaxed); sh_part_ix = _47; } threadgroup_barrier(mem_flags::mem_threadgroup); uint part_ix = sh_part_ix; uint ix = (part_ix * 8192u) + (gl_LocalInvocationID.x * 16u); spvUnsafeArray local; local[0].element = _67.inbuf[ix].element; Monoid param_1; for (uint i = 1u; i < 16u; i++) { Monoid param = local[i - 1u]; param_1.element = _67.inbuf[ix + i].element; local[i] = combine_monoid(param, param_1); } Monoid agg = local[15]; sh_scratch[gl_LocalInvocationID.x] = agg; for (uint i_1 = 0u; i_1 < 9u; i_1++) { threadgroup_barrier(mem_flags::mem_threadgroup); if (gl_LocalInvocationID.x >= (1u << i_1)) { Monoid other = sh_scratch[gl_LocalInvocationID.x - (1u << i_1)]; Monoid param_2 = other; Monoid param_3 = agg; agg = combine_monoid(param_2, param_3); } threadgroup_barrier(mem_flags::mem_threadgroup); sh_scratch[gl_LocalInvocationID.x] = agg; } if (gl_LocalInvocationID.x == 511u) { _43.state[part_ix].aggregate.element = agg.element; if (part_ix == 0u) { _43.state[0].prefix.element = agg.element; } } threadgroup_barrier(mem_flags::mem_device); if (gl_LocalInvocationID.x == 511u) { uint flag = 1u; if (part_ix == 0u) { flag = 2u; } _43.state[part_ix].flag = flag; } Monoid exclusive = Monoid{ 0u }; if (part_ix != 0u) { uint look_back_ix = part_ix - 1u; uint their_ix = 0u; Monoid their_prefix; Monoid their_agg; Monoid m; while (true) { if (gl_LocalInvocationID.x == 511u) { sh_flag = _43.state[look_back_ix].flag; } threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_device); uint flag_1 = sh_flag; if (flag_1 == 2u) { if (gl_LocalInvocationID.x == 511u) { their_prefix.element = _43.state[look_back_ix].prefix.element; Monoid param_4 = their_prefix; Monoid param_5 = exclusive; exclusive = combine_monoid(param_4, param_5); } break; } else { if (flag_1 == 1u) { if (gl_LocalInvocationID.x == 511u) { their_agg.element = _43.state[look_back_ix].aggregate.element; Monoid param_6 = their_agg; Monoid param_7 = exclusive; exclusive = combine_monoid(param_6, param_7); } look_back_ix--; their_ix = 0u; continue; } } if (gl_LocalInvocationID.x == 511u) { m.element = _67.inbuf[(look_back_ix * 8192u) + their_ix].element; if (their_ix == 0u) { their_agg = m; } else { Monoid param_8 = their_agg; Monoid param_9 = m; their_agg = combine_monoid(param_8, param_9); } their_ix++; if (their_ix == 8192u) { Monoid param_10 = their_agg; Monoid param_11 = exclusive; exclusive = combine_monoid(param_10, param_11); if (look_back_ix == 0u) { sh_flag = 2u; } else { look_back_ix--; their_ix = 0u; } } } threadgroup_barrier(mem_flags::mem_threadgroup); flag_1 = sh_flag; if (flag_1 == 2u) { break; } } if (gl_LocalInvocationID.x == 511u) { Monoid param_12 = exclusive; Monoid param_13 = agg; Monoid inclusive_prefix = combine_monoid(param_12, param_13); sh_prefix = exclusive; _43.state[part_ix].prefix.element = inclusive_prefix.element; } threadgroup_barrier(mem_flags::mem_device); if (gl_LocalInvocationID.x == 511u) { _43.state[part_ix].flag = 2u; } } threadgroup_barrier(mem_flags::mem_threadgroup); if (part_ix != 0u) { exclusive = sh_prefix; } Monoid row = exclusive; if (gl_LocalInvocationID.x > 0u) { Monoid other_1 = sh_scratch[gl_LocalInvocationID.x - 1u]; Monoid param_14 = row; Monoid param_15 = other_1; row = combine_monoid(param_14, param_15); } for (uint i_2 = 0u; i_2 < 16u; i_2++) { Monoid param_16 = row; Monoid param_17 = local[i_2]; Monoid m_1 = combine_monoid(param_16, param_17); _372.outbuf[ix + i_2].element = m_1.element; } }