mirror of
https://github.com/italicsjenga/vello.git
synced 2025-01-10 20:51:29 +11:00
69b6632085
Thanks to Jeff Bolz for spotting the write-after-read hazard on the sh_flag accesses. This fixes observed failures on Nvidia Turing and Ampere on DX12.
265 lines
7.4 KiB
Plaintext
265 lines
7.4 KiB
Plaintext
#pragma clang diagnostic ignored "-Wmissing-prototypes"
|
|
#pragma clang diagnostic ignored "-Wmissing-braces"
|
|
#pragma clang diagnostic ignored "-Wunused-variable"
|
|
|
|
#include <metal_stdlib>
|
|
#include <simd/simd.h>
|
|
#include <metal_atomic>
|
|
|
|
using namespace metal;
|
|
|
|
template<typename T, size_t Num>
|
|
struct spvUnsafeArray
|
|
{
|
|
T elements[Num ? Num : 1];
|
|
|
|
thread T& operator [] (size_t pos) thread
|
|
{
|
|
return elements[pos];
|
|
}
|
|
constexpr const thread T& operator [] (size_t pos) const thread
|
|
{
|
|
return elements[pos];
|
|
}
|
|
|
|
device T& operator [] (size_t pos) device
|
|
{
|
|
return elements[pos];
|
|
}
|
|
constexpr const device T& operator [] (size_t pos) const device
|
|
{
|
|
return elements[pos];
|
|
}
|
|
|
|
constexpr const constant T& operator [] (size_t pos) const constant
|
|
{
|
|
return elements[pos];
|
|
}
|
|
|
|
threadgroup T& operator [] (size_t pos) threadgroup
|
|
{
|
|
return elements[pos];
|
|
}
|
|
constexpr const threadgroup T& operator [] (size_t pos) const threadgroup
|
|
{
|
|
return elements[pos];
|
|
}
|
|
};
|
|
|
|
struct Monoid
|
|
{
|
|
uint element;
|
|
};
|
|
|
|
struct Monoid_1
|
|
{
|
|
uint element;
|
|
};
|
|
|
|
struct State
|
|
{
|
|
uint flag;
|
|
Monoid_1 aggregate;
|
|
Monoid_1 prefix;
|
|
};
|
|
|
|
struct StateBuf
|
|
{
|
|
uint part_counter;
|
|
State state[1];
|
|
};
|
|
|
|
struct InBuf
|
|
{
|
|
Monoid_1 inbuf[1];
|
|
};
|
|
|
|
struct OutBuf
|
|
{
|
|
Monoid_1 outbuf[1];
|
|
};
|
|
|
|
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u);
|
|
|
|
static inline __attribute__((always_inline))
|
|
Monoid combine_monoid(thread const Monoid& a, thread const Monoid& b)
|
|
{
|
|
return Monoid{ a.element + b.element };
|
|
}
|
|
|
|
kernel void main0(const device InBuf& _67 [[buffer(0)]], device OutBuf& _372 [[buffer(1)]], volatile device StateBuf& _43 [[buffer(2)]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]])
|
|
{
|
|
threadgroup uint sh_part_ix;
|
|
threadgroup Monoid sh_scratch[512];
|
|
threadgroup uint sh_flag;
|
|
threadgroup Monoid sh_prefix;
|
|
if (gl_LocalInvocationID.x == 0u)
|
|
{
|
|
uint _47 = atomic_fetch_add_explicit((volatile device atomic_uint*)&_43.part_counter, 1u, memory_order_relaxed);
|
|
sh_part_ix = _47;
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
uint part_ix = sh_part_ix;
|
|
uint ix = (part_ix * 8192u) + (gl_LocalInvocationID.x * 16u);
|
|
spvUnsafeArray<Monoid, 16> local;
|
|
local[0].element = _67.inbuf[ix].element;
|
|
Monoid param_1;
|
|
for (uint i = 1u; i < 16u; i++)
|
|
{
|
|
Monoid param = local[i - 1u];
|
|
param_1.element = _67.inbuf[ix + i].element;
|
|
local[i] = combine_monoid(param, param_1);
|
|
}
|
|
Monoid agg = local[15];
|
|
sh_scratch[gl_LocalInvocationID.x] = agg;
|
|
for (uint i_1 = 0u; i_1 < 9u; i_1++)
|
|
{
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
if (gl_LocalInvocationID.x >= (1u << i_1))
|
|
{
|
|
Monoid other = sh_scratch[gl_LocalInvocationID.x - (1u << i_1)];
|
|
Monoid param_2 = other;
|
|
Monoid param_3 = agg;
|
|
agg = combine_monoid(param_2, param_3);
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
sh_scratch[gl_LocalInvocationID.x] = agg;
|
|
}
|
|
if (gl_LocalInvocationID.x == 511u)
|
|
{
|
|
_43.state[part_ix].aggregate.element = agg.element;
|
|
if (part_ix == 0u)
|
|
{
|
|
_43.state[0].prefix.element = agg.element;
|
|
}
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_device);
|
|
if (gl_LocalInvocationID.x == 511u)
|
|
{
|
|
uint flag = 1u;
|
|
if (part_ix == 0u)
|
|
{
|
|
flag = 2u;
|
|
}
|
|
_43.state[part_ix].flag = flag;
|
|
}
|
|
Monoid exclusive = Monoid{ 0u };
|
|
if (part_ix != 0u)
|
|
{
|
|
uint look_back_ix = part_ix - 1u;
|
|
uint their_ix = 0u;
|
|
Monoid their_prefix;
|
|
Monoid their_agg;
|
|
Monoid m;
|
|
while (true)
|
|
{
|
|
if (gl_LocalInvocationID.x == 511u)
|
|
{
|
|
sh_flag = _43.state[look_back_ix].flag;
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
threadgroup_barrier(mem_flags::mem_device);
|
|
uint flag_1 = sh_flag;
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
if (flag_1 == 2u)
|
|
{
|
|
if (gl_LocalInvocationID.x == 511u)
|
|
{
|
|
their_prefix.element = _43.state[look_back_ix].prefix.element;
|
|
Monoid param_4 = their_prefix;
|
|
Monoid param_5 = exclusive;
|
|
exclusive = combine_monoid(param_4, param_5);
|
|
}
|
|
break;
|
|
}
|
|
else
|
|
{
|
|
if (flag_1 == 1u)
|
|
{
|
|
if (gl_LocalInvocationID.x == 511u)
|
|
{
|
|
their_agg.element = _43.state[look_back_ix].aggregate.element;
|
|
Monoid param_6 = their_agg;
|
|
Monoid param_7 = exclusive;
|
|
exclusive = combine_monoid(param_6, param_7);
|
|
}
|
|
look_back_ix--;
|
|
their_ix = 0u;
|
|
continue;
|
|
}
|
|
}
|
|
if (gl_LocalInvocationID.x == 511u)
|
|
{
|
|
m.element = _67.inbuf[(look_back_ix * 8192u) + their_ix].element;
|
|
if (their_ix == 0u)
|
|
{
|
|
their_agg = m;
|
|
}
|
|
else
|
|
{
|
|
Monoid param_8 = their_agg;
|
|
Monoid param_9 = m;
|
|
their_agg = combine_monoid(param_8, param_9);
|
|
}
|
|
their_ix++;
|
|
if (their_ix == 8192u)
|
|
{
|
|
Monoid param_10 = their_agg;
|
|
Monoid param_11 = exclusive;
|
|
exclusive = combine_monoid(param_10, param_11);
|
|
if (look_back_ix == 0u)
|
|
{
|
|
sh_flag = 2u;
|
|
}
|
|
else
|
|
{
|
|
look_back_ix--;
|
|
their_ix = 0u;
|
|
}
|
|
}
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
flag_1 = sh_flag;
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
if (flag_1 == 2u)
|
|
{
|
|
break;
|
|
}
|
|
}
|
|
if (gl_LocalInvocationID.x == 511u)
|
|
{
|
|
Monoid param_12 = exclusive;
|
|
Monoid param_13 = agg;
|
|
Monoid inclusive_prefix = combine_monoid(param_12, param_13);
|
|
sh_prefix = exclusive;
|
|
_43.state[part_ix].prefix.element = inclusive_prefix.element;
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_device);
|
|
if (gl_LocalInvocationID.x == 511u)
|
|
{
|
|
_43.state[part_ix].flag = 2u;
|
|
}
|
|
}
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
if (part_ix != 0u)
|
|
{
|
|
exclusive = sh_prefix;
|
|
}
|
|
Monoid row = exclusive;
|
|
if (gl_LocalInvocationID.x > 0u)
|
|
{
|
|
Monoid other_1 = sh_scratch[gl_LocalInvocationID.x - 1u];
|
|
Monoid param_14 = row;
|
|
Monoid param_15 = other_1;
|
|
row = combine_monoid(param_14, param_15);
|
|
}
|
|
for (uint i_2 = 0u; i_2 < 16u; i_2++)
|
|
{
|
|
Monoid param_16 = row;
|
|
Monoid param_17 = local[i_2];
|
|
Monoid m_1 = combine_monoid(param_16, param_17);
|
|
_372.outbuf[ix + i_2].element = m_1.element;
|
|
}
|
|
}
|
|
|