vello/tests/shader/gen/prefix_scan.msl

124 lines
3.2 KiB
Plaintext
Raw Normal View History

2021-11-07 10:25:56 +11:00
#pragma clang diagnostic ignored "-Wmissing-prototypes"
#pragma clang diagnostic ignored "-Wmissing-braces"
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
template<typename T, size_t Num>
struct spvUnsafeArray
{
T elements[Num ? Num : 1];
thread T& operator [] (size_t pos) thread
{
return elements[pos];
}
constexpr const thread T& operator [] (size_t pos) const thread
{
return elements[pos];
}
device T& operator [] (size_t pos) device
{
return elements[pos];
}
constexpr const device T& operator [] (size_t pos) const device
{
return elements[pos];
}
constexpr const constant T& operator [] (size_t pos) const constant
{
return elements[pos];
}
threadgroup T& operator [] (size_t pos) threadgroup
{
return elements[pos];
}
constexpr const threadgroup T& operator [] (size_t pos) const threadgroup
{
return elements[pos];
}
};
struct Monoid
{
uint element;
};
struct Monoid_1
{
uint element;
};
struct DataBuf
{
Monoid_1 data[1];
};
struct ParentBuf
{
Monoid_1 parent[1];
};
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u);
static inline __attribute__((always_inline))
Monoid combine_monoid(thread const Monoid& a, thread const Monoid& b)
{
return Monoid{ a.element + b.element };
}
kernel void main0(device DataBuf& _42 [[buffer(0)]], const device ParentBuf& _143 [[buffer(1)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]])
2021-11-07 10:25:56 +11:00
{
threadgroup Monoid sh_scratch[512];
uint ix = gl_GlobalInvocationID.x * 8u;
spvUnsafeArray<Monoid, 8> local;
local[0].element = _42.data[ix].element;
Monoid param_1;
for (uint i = 1u; i < 8u; i++)
{
Monoid param = local[i - 1u];
param_1.element = _42.data[ix + i].element;
local[i] = combine_monoid(param, param_1);
}
Monoid agg = local[7];
sh_scratch[gl_LocalInvocationID.x] = agg;
for (uint i_1 = 0u; i_1 < 9u; i_1++)
{
threadgroup_barrier(mem_flags::mem_threadgroup);
if (gl_LocalInvocationID.x >= uint(1 << int(i_1)))
{
Monoid other = sh_scratch[gl_LocalInvocationID.x - uint(1 << int(i_1))];
Monoid param_2 = other;
Monoid param_3 = agg;
agg = combine_monoid(param_2, param_3);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sh_scratch[gl_LocalInvocationID.x] = agg;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
Monoid row = Monoid{ 0u };
if (gl_WorkGroupID.x > 0u)
{
row.element = _143.parent[gl_WorkGroupID.x - 1u].element;
}
if (gl_LocalInvocationID.x > 0u)
{
Monoid param_4 = row;
Monoid param_5 = sh_scratch[gl_LocalInvocationID.x - 1u];
row = combine_monoid(param_4, param_5);
}
for (uint i_2 = 0u; i_2 < 8u; i_2++)
{
Monoid param_6 = row;
Monoid param_7 = local[i_2];
Monoid m = combine_monoid(param_6, param_7);
_42.data[ix + i_2].element = m.element;
}
}