#pragma clang diagnostic ignored "-Wmissing-prototypes" #pragma clang diagnostic ignored "-Wmissing-braces" #include #include using namespace metal; template struct spvUnsafeArray { T elements[Num ? Num : 1]; thread T& operator [] (size_t pos) thread { return elements[pos]; } constexpr const thread T& operator [] (size_t pos) const thread { return elements[pos]; } device T& operator [] (size_t pos) device { return elements[pos]; } constexpr const device T& operator [] (size_t pos) const device { return elements[pos]; } constexpr const constant T& operator [] (size_t pos) const constant { return elements[pos]; } threadgroup T& operator [] (size_t pos) threadgroup { return elements[pos]; } constexpr const threadgroup T& operator [] (size_t pos) const threadgroup { return elements[pos]; } }; struct Monoid { uint element; }; struct Monoid_1 { uint element; }; struct DataBuf { Monoid_1 data[1]; }; constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u); static inline __attribute__((always_inline)) Monoid combine_monoid(thread const Monoid& a, thread const Monoid& b) { return Monoid{ a.element + b.element }; } kernel void main0(device DataBuf& _42 [[buffer(0)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]]) { threadgroup Monoid sh_scratch[512]; uint ix = gl_GlobalInvocationID.x * 8u; spvUnsafeArray local; local[0].element = _42.data[ix].element; Monoid param_1; for (uint i = 1u; i < 8u; i++) { Monoid param = local[i - 1u]; param_1.element = _42.data[ix + i].element; local[i] = combine_monoid(param, param_1); } Monoid agg = local[7]; sh_scratch[gl_LocalInvocationID.x] = agg; for (uint i_1 = 0u; i_1 < 9u; i_1++) { threadgroup_barrier(mem_flags::mem_threadgroup); if (gl_LocalInvocationID.x >= (1u << i_1)) { Monoid other = sh_scratch[gl_LocalInvocationID.x - (1u << i_1)]; Monoid param_2 = other; Monoid param_3 = agg; agg = combine_monoid(param_2, param_3); } threadgroup_barrier(mem_flags::mem_threadgroup); sh_scratch[gl_LocalInvocationID.x] = agg; } threadgroup_barrier(mem_flags::mem_threadgroup); Monoid row = Monoid{ 0u }; if (gl_LocalInvocationID.x > 0u) { row = sh_scratch[gl_LocalInvocationID.x - 1u]; } for (uint i_2 = 0u; i_2 < 8u; i_2++) { Monoid param_4 = row; Monoid param_5 = local[i_2]; Monoid m = combine_monoid(param_4, param_5); _42.data[ix + i_2].element = m.element; } }