mirror of
https://github.com/italicsjenga/vello.git
synced 2025-01-11 04:51:32 +11:00
a0648a2153
The MSL translation of the prefix example had its bindings permuted; a flag prevents this (but, as is typical for shader translation, potentially creates other problems). Also use explicit unsigned literal to avoid DXC warnings.
124 lines
3.2 KiB
Plaintext
124 lines
3.2 KiB
Plaintext
#pragma clang diagnostic ignored "-Wmissing-prototypes"
|
|
#pragma clang diagnostic ignored "-Wmissing-braces"
|
|
|
|
#include <metal_stdlib>
|
|
#include <simd/simd.h>
|
|
|
|
using namespace metal;
|
|
|
|
template<typename T, size_t Num>
|
|
struct spvUnsafeArray
|
|
{
|
|
T elements[Num ? Num : 1];
|
|
|
|
thread T& operator [] (size_t pos) thread
|
|
{
|
|
return elements[pos];
|
|
}
|
|
constexpr const thread T& operator [] (size_t pos) const thread
|
|
{
|
|
return elements[pos];
|
|
}
|
|
|
|
device T& operator [] (size_t pos) device
|
|
{
|
|
return elements[pos];
|
|
}
|
|
constexpr const device T& operator [] (size_t pos) const device
|
|
{
|
|
return elements[pos];
|
|
}
|
|
|
|
constexpr const constant T& operator [] (size_t pos) const constant
|
|
{
|
|
return elements[pos];
|
|
}
|
|
|
|
threadgroup T& operator [] (size_t pos) threadgroup
|
|
{
|
|
return elements[pos];
|
|
}
|
|
constexpr const threadgroup T& operator [] (size_t pos) const threadgroup
|
|
{
|
|
return elements[pos];
|
|
}
|
|
};
|
|
|
|
struct Monoid
|
|
{
|
|
uint element;
|
|
};
|
|
|
|
struct Monoid_1
|
|
{
|
|
uint element;
|
|
};
|
|
|
|
struct DataBuf
|
|
{
|
|
Monoid_1 data[1];
|
|
};
|
|
|
|
struct ParentBuf
|
|
{
|
|
Monoid_1 parent[1];
|
|
};
|
|
|
|
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u);
|
|
|
|
static inline __attribute__((always_inline))
|
|
Monoid combine_monoid(thread const Monoid& a, thread const Monoid& b)
|
|
{
|
|
return Monoid{ a.element + b.element };
|
|
}
|
|
|
|
kernel void main0(device DataBuf& _42 [[buffer(0)]], const device ParentBuf& _141 [[buffer(1)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]])
|
|
{
|
|
threadgroup Monoid sh_scratch[512];
|
|
uint ix = gl_GlobalInvocationID.x * 8u;
|
|
spvUnsafeArray<Monoid, 8> local;
|
|
local[0].element = _42.data[ix].element;
|
|
Monoid param_1;
|
|
for (uint i = 1u; i < 8u; i++)
|
|
{
|
|
Monoid param = local[i - 1u];
|
|
param_1.element = _42.data[ix + i].element;
|
|
local[i] = combine_monoid(param, param_1);
|
|
}
|
|
Monoid agg = local[7];
|
|
sh_scratch[gl_LocalInvocationID.x] = agg;
|
|
for (uint i_1 = 0u; i_1 < 9u; i_1++)
|
|
{
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
if (gl_LocalInvocationID.x >= (1u << i_1))
|
|
{
|
|
Monoid other = sh_scratch[gl_LocalInvocationID.x - (1u << i_1)];
|
|
Monoid param_2 = other;
|
|
Monoid param_3 = agg;
|
|
agg = combine_monoid(param_2, param_3);
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
sh_scratch[gl_LocalInvocationID.x] = agg;
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
Monoid row = Monoid{ 0u };
|
|
if (gl_WorkGroupID.x > 0u)
|
|
{
|
|
row.element = _141.parent[gl_WorkGroupID.x - 1u].element;
|
|
}
|
|
if (gl_LocalInvocationID.x > 0u)
|
|
{
|
|
Monoid param_4 = row;
|
|
Monoid param_5 = sh_scratch[gl_LocalInvocationID.x - 1u];
|
|
row = combine_monoid(param_4, param_5);
|
|
}
|
|
for (uint i_2 = 0u; i_2 < 8u; i_2++)
|
|
{
|
|
Monoid param_6 = row;
|
|
Monoid param_7 = local[i_2];
|
|
Monoid m = combine_monoid(param_6, param_7);
|
|
_42.data[ix + i_2].element = m.element;
|
|
}
|
|
}
|
|
|