2021-11-06 16:25:56 -07:00
|
|
|
#pragma clang diagnostic ignored "-Wmissing-prototypes"
|
|
|
|
|
|
|
|
#include <metal_stdlib>
|
|
|
|
#include <simd/simd.h>
|
|
|
|
|
|
|
|
using namespace metal;
|
|
|
|
|
|
|
|
struct Monoid
|
|
|
|
{
|
|
|
|
uint element;
|
|
|
|
};
|
|
|
|
|
|
|
|
struct Monoid_1
|
|
|
|
{
|
|
|
|
uint element;
|
|
|
|
};
|
|
|
|
|
|
|
|
struct InBuf
|
|
|
|
{
|
|
|
|
Monoid_1 inbuf[1];
|
|
|
|
};
|
|
|
|
|
|
|
|
struct OutBuf
|
|
|
|
{
|
|
|
|
Monoid_1 outbuf[1];
|
|
|
|
};
|
|
|
|
|
|
|
|
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u);
|
|
|
|
|
|
|
|
static inline __attribute__((always_inline))
|
|
|
|
Monoid combine_monoid(thread const Monoid& a, thread const Monoid& b)
|
|
|
|
{
|
|
|
|
return Monoid{ a.element + b.element };
|
|
|
|
}
|
|
|
|
|
2021-11-11 06:59:27 -08:00
|
|
|
kernel void main0(const device InBuf& _40 [[buffer(0)]], device OutBuf& _127 [[buffer(1)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]])
|
2021-11-06 16:25:56 -07:00
|
|
|
{
|
|
|
|
threadgroup Monoid sh_scratch[512];
|
|
|
|
uint ix = gl_GlobalInvocationID.x * 8u;
|
|
|
|
Monoid agg;
|
|
|
|
agg.element = _40.inbuf[ix].element;
|
|
|
|
Monoid param_1;
|
|
|
|
for (uint i = 1u; i < 8u; i++)
|
|
|
|
{
|
|
|
|
Monoid param = agg;
|
|
|
|
param_1.element = _40.inbuf[ix + i].element;
|
|
|
|
agg = combine_monoid(param, param_1);
|
|
|
|
}
|
|
|
|
sh_scratch[gl_LocalInvocationID.x] = agg;
|
|
|
|
for (uint i_1 = 0u; i_1 < 9u; i_1++)
|
|
|
|
{
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
2021-11-11 06:59:27 -08:00
|
|
|
if ((gl_LocalInvocationID.x + (1u << i_1)) < 512u)
|
2021-11-06 16:25:56 -07:00
|
|
|
{
|
2021-11-11 06:59:27 -08:00
|
|
|
Monoid other = sh_scratch[gl_LocalInvocationID.x + (1u << i_1)];
|
2021-11-06 16:25:56 -07:00
|
|
|
Monoid param_2 = agg;
|
|
|
|
Monoid param_3 = other;
|
|
|
|
agg = combine_monoid(param_2, param_3);
|
|
|
|
}
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
sh_scratch[gl_LocalInvocationID.x] = agg;
|
|
|
|
}
|
|
|
|
if (gl_LocalInvocationID.x == 0u)
|
|
|
|
{
|
2021-11-11 06:59:27 -08:00
|
|
|
_127.outbuf[gl_WorkGroupID.x].element = agg.element;
|
2021-11-06 16:25:56 -07:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|