mirror of
https://github.com/italicsjenga/vello.git
synced 2025-01-26 19:26:33 +11:00
a0648a2153
The MSL translation of the prefix example had its bindings permuted; a flag prevents this (but, as is typical for shader translation, potentially creates other problems). Also use explicit unsigned literal to avoid DXC warnings.
262 lines
7.3 KiB
Text
262 lines
7.3 KiB
Text
#pragma clang diagnostic ignored "-Wmissing-prototypes"
|
|
#pragma clang diagnostic ignored "-Wmissing-braces"
|
|
#pragma clang diagnostic ignored "-Wunused-variable"
|
|
|
|
#include <metal_stdlib>
|
|
#include <simd/simd.h>
|
|
#include <metal_atomic>
|
|
|
|
using namespace metal;
|
|
|
|
template<typename T, size_t Num>
|
|
struct spvUnsafeArray
|
|
{
|
|
T elements[Num ? Num : 1];
|
|
|
|
thread T& operator [] (size_t pos) thread
|
|
{
|
|
return elements[pos];
|
|
}
|
|
constexpr const thread T& operator [] (size_t pos) const thread
|
|
{
|
|
return elements[pos];
|
|
}
|
|
|
|
device T& operator [] (size_t pos) device
|
|
{
|
|
return elements[pos];
|
|
}
|
|
constexpr const device T& operator [] (size_t pos) const device
|
|
{
|
|
return elements[pos];
|
|
}
|
|
|
|
constexpr const constant T& operator [] (size_t pos) const constant
|
|
{
|
|
return elements[pos];
|
|
}
|
|
|
|
threadgroup T& operator [] (size_t pos) threadgroup
|
|
{
|
|
return elements[pos];
|
|
}
|
|
constexpr const threadgroup T& operator [] (size_t pos) const threadgroup
|
|
{
|
|
return elements[pos];
|
|
}
|
|
};
|
|
|
|
struct Monoid
|
|
{
|
|
uint element;
|
|
};
|
|
|
|
struct Monoid_1
|
|
{
|
|
uint element;
|
|
};
|
|
|
|
struct State
|
|
{
|
|
uint flag;
|
|
Monoid_1 aggregate;
|
|
Monoid_1 prefix;
|
|
};
|
|
|
|
struct StateBuf
|
|
{
|
|
uint part_counter;
|
|
State state[1];
|
|
};
|
|
|
|
struct InBuf
|
|
{
|
|
Monoid_1 inbuf[1];
|
|
};
|
|
|
|
struct OutBuf
|
|
{
|
|
Monoid_1 outbuf[1];
|
|
};
|
|
|
|
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u);
|
|
|
|
static inline __attribute__((always_inline))
|
|
Monoid combine_monoid(thread const Monoid& a, thread const Monoid& b)
|
|
{
|
|
return Monoid{ a.element + b.element };
|
|
}
|
|
|
|
kernel void main0(const device InBuf& _67 [[buffer(0)]], device OutBuf& _372 [[buffer(1)]], volatile device StateBuf& _43 [[buffer(2)]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]])
|
|
{
|
|
threadgroup uint sh_part_ix;
|
|
threadgroup Monoid sh_scratch[512];
|
|
threadgroup uint sh_flag;
|
|
threadgroup Monoid sh_prefix;
|
|
if (gl_LocalInvocationID.x == 0u)
|
|
{
|
|
uint _47 = atomic_fetch_add_explicit((volatile device atomic_uint*)&_43.part_counter, 1u, memory_order_relaxed);
|
|
sh_part_ix = _47;
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
uint part_ix = sh_part_ix;
|
|
uint ix = (part_ix * 8192u) + (gl_LocalInvocationID.x * 16u);
|
|
spvUnsafeArray<Monoid, 16> local;
|
|
local[0].element = _67.inbuf[ix].element;
|
|
Monoid param_1;
|
|
for (uint i = 1u; i < 16u; i++)
|
|
{
|
|
Monoid param = local[i - 1u];
|
|
param_1.element = _67.inbuf[ix + i].element;
|
|
local[i] = combine_monoid(param, param_1);
|
|
}
|
|
Monoid agg = local[15];
|
|
sh_scratch[gl_LocalInvocationID.x] = agg;
|
|
for (uint i_1 = 0u; i_1 < 9u; i_1++)
|
|
{
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
if (gl_LocalInvocationID.x >= (1u << i_1))
|
|
{
|
|
Monoid other = sh_scratch[gl_LocalInvocationID.x - (1u << i_1)];
|
|
Monoid param_2 = other;
|
|
Monoid param_3 = agg;
|
|
agg = combine_monoid(param_2, param_3);
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
sh_scratch[gl_LocalInvocationID.x] = agg;
|
|
}
|
|
if (gl_LocalInvocationID.x == 511u)
|
|
{
|
|
_43.state[part_ix].aggregate.element = agg.element;
|
|
if (part_ix == 0u)
|
|
{
|
|
_43.state[0].prefix.element = agg.element;
|
|
}
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_device);
|
|
if (gl_LocalInvocationID.x == 511u)
|
|
{
|
|
uint flag = 1u;
|
|
if (part_ix == 0u)
|
|
{
|
|
flag = 2u;
|
|
}
|
|
_43.state[part_ix].flag = flag;
|
|
}
|
|
Monoid exclusive = Monoid{ 0u };
|
|
if (part_ix != 0u)
|
|
{
|
|
uint look_back_ix = part_ix - 1u;
|
|
uint their_ix = 0u;
|
|
Monoid their_prefix;
|
|
Monoid their_agg;
|
|
Monoid m;
|
|
while (true)
|
|
{
|
|
if (gl_LocalInvocationID.x == 511u)
|
|
{
|
|
sh_flag = _43.state[look_back_ix].flag;
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
threadgroup_barrier(mem_flags::mem_device);
|
|
uint flag_1 = sh_flag;
|
|
if (flag_1 == 2u)
|
|
{
|
|
if (gl_LocalInvocationID.x == 511u)
|
|
{
|
|
their_prefix.element = _43.state[look_back_ix].prefix.element;
|
|
Monoid param_4 = their_prefix;
|
|
Monoid param_5 = exclusive;
|
|
exclusive = combine_monoid(param_4, param_5);
|
|
}
|
|
break;
|
|
}
|
|
else
|
|
{
|
|
if (flag_1 == 1u)
|
|
{
|
|
if (gl_LocalInvocationID.x == 511u)
|
|
{
|
|
their_agg.element = _43.state[look_back_ix].aggregate.element;
|
|
Monoid param_6 = their_agg;
|
|
Monoid param_7 = exclusive;
|
|
exclusive = combine_monoid(param_6, param_7);
|
|
}
|
|
look_back_ix--;
|
|
their_ix = 0u;
|
|
continue;
|
|
}
|
|
}
|
|
if (gl_LocalInvocationID.x == 511u)
|
|
{
|
|
m.element = _67.inbuf[(look_back_ix * 8192u) + their_ix].element;
|
|
if (their_ix == 0u)
|
|
{
|
|
their_agg = m;
|
|
}
|
|
else
|
|
{
|
|
Monoid param_8 = their_agg;
|
|
Monoid param_9 = m;
|
|
their_agg = combine_monoid(param_8, param_9);
|
|
}
|
|
their_ix++;
|
|
if (their_ix == 8192u)
|
|
{
|
|
Monoid param_10 = their_agg;
|
|
Monoid param_11 = exclusive;
|
|
exclusive = combine_monoid(param_10, param_11);
|
|
if (look_back_ix == 0u)
|
|
{
|
|
sh_flag = 2u;
|
|
}
|
|
else
|
|
{
|
|
look_back_ix--;
|
|
their_ix = 0u;
|
|
}
|
|
}
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
flag_1 = sh_flag;
|
|
if (flag_1 == 2u)
|
|
{
|
|
break;
|
|
}
|
|
}
|
|
if (gl_LocalInvocationID.x == 511u)
|
|
{
|
|
Monoid param_12 = exclusive;
|
|
Monoid param_13 = agg;
|
|
Monoid inclusive_prefix = combine_monoid(param_12, param_13);
|
|
sh_prefix = exclusive;
|
|
_43.state[part_ix].prefix.element = inclusive_prefix.element;
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_device);
|
|
if (gl_LocalInvocationID.x == 511u)
|
|
{
|
|
_43.state[part_ix].flag = 2u;
|
|
}
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
if (part_ix != 0u)
|
|
{
|
|
exclusive = sh_prefix;
|
|
}
|
|
Monoid row = exclusive;
|
|
if (gl_LocalInvocationID.x > 0u)
|
|
{
|
|
Monoid other_1 = sh_scratch[gl_LocalInvocationID.x - 1u];
|
|
Monoid param_14 = row;
|
|
Monoid param_15 = other_1;
|
|
row = combine_monoid(param_14, param_15);
|
|
}
|
|
for (uint i_2 = 0u; i_2 < 16u; i_2++)
|
|
{
|
|
Monoid param_16 = row;
|
|
Monoid param_17 = local[i_2];
|
|
Monoid m_1 = combine_monoid(param_16, param_17);
|
|
_372.outbuf[ix + i_2].element = m_1.element;
|
|
}
|
|
}
|
|
|