diff --git a/tests/shader/gen/prefix_reduce.hlsl b/tests/shader/gen/prefix_reduce.hlsl new file mode 100644 index 0000000..837a75a --- /dev/null +++ b/tests/shader/gen/prefix_reduce.hlsl @@ -0,0 +1,72 @@ +struct Monoid +{ + uint element; +}; + +static const uint3 gl_WorkGroupSize = uint3(512u, 1u, 1u); + +ByteAddressBuffer _40 : register(t0); +RWByteAddressBuffer _129 : register(u1); + +static uint3 gl_WorkGroupID; +static uint3 gl_LocalInvocationID; +static uint3 gl_GlobalInvocationID; +struct SPIRV_Cross_Input +{ + uint3 gl_WorkGroupID : SV_GroupID; + uint3 gl_LocalInvocationID : SV_GroupThreadID; + uint3 gl_GlobalInvocationID : SV_DispatchThreadID; +}; + +groupshared Monoid sh_scratch[512]; + +Monoid combine_monoid(Monoid a, Monoid b) +{ + Monoid _22 = { a.element + b.element }; + return _22; +} + +void comp_main() +{ + uint ix = gl_GlobalInvocationID.x * 8u; + Monoid _44; + _44.element = _40.Load(ix * 4 + 0); + Monoid agg; + agg.element = _44.element; + Monoid param_1; + for (uint i = 1u; i < 8u; i++) + { + Monoid param = agg; + Monoid _64; + _64.element = _40.Load((ix + i) * 4 + 0); + param_1.element = _64.element; + agg = combine_monoid(param, param_1); + } + sh_scratch[gl_LocalInvocationID.x] = agg; + for (uint i_1 = 0u; i_1 < 9u; i_1++) + { + GroupMemoryBarrierWithGroupSync(); + if ((gl_LocalInvocationID.x + uint(1 << int(i_1))) < 512u) + { + Monoid other = sh_scratch[gl_LocalInvocationID.x + uint(1 << int(i_1))]; + Monoid param_2 = agg; + Monoid param_3 = other; + agg = combine_monoid(param_2, param_3); + } + GroupMemoryBarrierWithGroupSync(); + sh_scratch[gl_LocalInvocationID.x] = agg; + } + if (gl_LocalInvocationID.x == 0u) + { + _129.Store(gl_WorkGroupID.x * 4 + 0, agg.element); + } +} + +[numthreads(512, 1, 1)] +void main(SPIRV_Cross_Input stage_input) +{ + gl_WorkGroupID = stage_input.gl_WorkGroupID; + gl_LocalInvocationID = stage_input.gl_LocalInvocationID; + gl_GlobalInvocationID = stage_input.gl_GlobalInvocationID; + comp_main(); +} diff --git a/tests/shader/gen/prefix_reduce.msl b/tests/shader/gen/prefix_reduce.msl new file mode 100644 index 0000000..e1ed0ce --- /dev/null +++ b/tests/shader/gen/prefix_reduce.msl @@ -0,0 +1,68 @@ +#pragma clang diagnostic ignored "-Wmissing-prototypes" + +#include +#include + +using namespace metal; + +struct Monoid +{ + uint element; +}; + +struct Monoid_1 +{ + uint element; +}; + +struct InBuf +{ + Monoid_1 inbuf[1]; +}; + +struct OutBuf +{ + Monoid_1 outbuf[1]; +}; + +constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u); + +static inline __attribute__((always_inline)) +Monoid combine_monoid(thread const Monoid& a, thread const Monoid& b) +{ + return Monoid{ a.element + b.element }; +} + +kernel void main0(const device InBuf& _40 [[buffer(0)]], device OutBuf& _129 [[buffer(1)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]]) +{ + threadgroup Monoid sh_scratch[512]; + uint ix = gl_GlobalInvocationID.x * 8u; + Monoid agg; + agg.element = _40.inbuf[ix].element; + Monoid param_1; + for (uint i = 1u; i < 8u; i++) + { + Monoid param = agg; + param_1.element = _40.inbuf[ix + i].element; + agg = combine_monoid(param, param_1); + } + sh_scratch[gl_LocalInvocationID.x] = agg; + for (uint i_1 = 0u; i_1 < 9u; i_1++) + { + threadgroup_barrier(mem_flags::mem_threadgroup); + if ((gl_LocalInvocationID.x + uint(1 << int(i_1))) < 512u) + { + Monoid other = sh_scratch[gl_LocalInvocationID.x + uint(1 << int(i_1))]; + Monoid param_2 = agg; + Monoid param_3 = other; + agg = combine_monoid(param_2, param_3); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + sh_scratch[gl_LocalInvocationID.x] = agg; + } + if (gl_LocalInvocationID.x == 0u) + { + _129.outbuf[gl_WorkGroupID.x].element = agg.element; + } +} + diff --git a/tests/shader/gen/prefix_reduce.spv b/tests/shader/gen/prefix_reduce.spv new file mode 100644 index 0000000..d1db3aa Binary files /dev/null and b/tests/shader/gen/prefix_reduce.spv differ diff --git a/tests/shader/gen/prefix_root.hlsl b/tests/shader/gen/prefix_root.hlsl new file mode 100644 index 0000000..2ad617c --- /dev/null +++ b/tests/shader/gen/prefix_root.hlsl @@ -0,0 +1,80 @@ +struct Monoid +{ + uint element; +}; + +static const uint3 gl_WorkGroupSize = uint3(512u, 1u, 1u); + +static const Monoid _133 = { 0u }; + +RWByteAddressBuffer _42 : register(u0); + +static uint3 gl_LocalInvocationID; +static uint3 gl_GlobalInvocationID; +struct SPIRV_Cross_Input +{ + uint3 gl_LocalInvocationID : SV_GroupThreadID; + uint3 gl_GlobalInvocationID : SV_DispatchThreadID; +}; + +groupshared Monoid sh_scratch[512]; + +Monoid combine_monoid(Monoid a, Monoid b) +{ + Monoid _22 = { a.element + b.element }; + return _22; +} + +void comp_main() +{ + uint ix = gl_GlobalInvocationID.x * 8u; + Monoid _46; + _46.element = _42.Load(ix * 4 + 0); + Monoid local[8]; + local[0].element = _46.element; + Monoid param_1; + for (uint i = 1u; i < 8u; i++) + { + Monoid param = local[i - 1u]; + Monoid _71; + _71.element = _42.Load((ix + i) * 4 + 0); + param_1.element = _71.element; + local[i] = combine_monoid(param, param_1); + } + Monoid agg = local[7]; + sh_scratch[gl_LocalInvocationID.x] = agg; + for (uint i_1 = 0u; i_1 < 9u; i_1++) + { + GroupMemoryBarrierWithGroupSync(); + if (gl_LocalInvocationID.x >= uint(1 << int(i_1))) + { + Monoid other = sh_scratch[gl_LocalInvocationID.x - uint(1 << int(i_1))]; + Monoid param_2 = other; + Monoid param_3 = agg; + agg = combine_monoid(param_2, param_3); + } + GroupMemoryBarrierWithGroupSync(); + sh_scratch[gl_LocalInvocationID.x] = agg; + } + GroupMemoryBarrierWithGroupSync(); + Monoid row = _133; + if (gl_LocalInvocationID.x > 0u) + { + row = sh_scratch[gl_LocalInvocationID.x - 1u]; + } + for (uint i_2 = 0u; i_2 < 8u; i_2++) + { + Monoid param_4 = row; + Monoid param_5 = local[i_2]; + Monoid m = combine_monoid(param_4, param_5); + _42.Store((ix + i_2) * 4 + 0, m.element); + } +} + +[numthreads(512, 1, 1)] +void main(SPIRV_Cross_Input stage_input) +{ + gl_LocalInvocationID = stage_input.gl_LocalInvocationID; + gl_GlobalInvocationID = stage_input.gl_GlobalInvocationID; + comp_main(); +} diff --git a/tests/shader/gen/prefix_root.msl b/tests/shader/gen/prefix_root.msl new file mode 100644 index 0000000..ff02287 --- /dev/null +++ b/tests/shader/gen/prefix_root.msl @@ -0,0 +1,112 @@ +#pragma clang diagnostic ignored "-Wmissing-prototypes" +#pragma clang diagnostic ignored "-Wmissing-braces" + +#include +#include + +using namespace metal; + +template +struct spvUnsafeArray +{ + T elements[Num ? Num : 1]; + + thread T& operator [] (size_t pos) thread + { + return elements[pos]; + } + constexpr const thread T& operator [] (size_t pos) const thread + { + return elements[pos]; + } + + device T& operator [] (size_t pos) device + { + return elements[pos]; + } + constexpr const device T& operator [] (size_t pos) const device + { + return elements[pos]; + } + + constexpr const constant T& operator [] (size_t pos) const constant + { + return elements[pos]; + } + + threadgroup T& operator [] (size_t pos) threadgroup + { + return elements[pos]; + } + constexpr const threadgroup T& operator [] (size_t pos) const threadgroup + { + return elements[pos]; + } +}; + +struct Monoid +{ + uint element; +}; + +struct Monoid_1 +{ + uint element; +}; + +struct DataBuf +{ + Monoid_1 data[1]; +}; + +constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u); + +static inline __attribute__((always_inline)) +Monoid combine_monoid(thread const Monoid& a, thread const Monoid& b) +{ + return Monoid{ a.element + b.element }; +} + +kernel void main0(device DataBuf& _42 [[buffer(0)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]]) +{ + threadgroup Monoid sh_scratch[512]; + uint ix = gl_GlobalInvocationID.x * 8u; + spvUnsafeArray local; + local[0].element = _42.data[ix].element; + Monoid param_1; + for (uint i = 1u; i < 8u; i++) + { + Monoid param = local[i - 1u]; + param_1.element = _42.data[ix + i].element; + local[i] = combine_monoid(param, param_1); + } + Monoid agg = local[7]; + sh_scratch[gl_LocalInvocationID.x] = agg; + for (uint i_1 = 0u; i_1 < 9u; i_1++) + { + threadgroup_barrier(mem_flags::mem_threadgroup); + if (gl_LocalInvocationID.x >= uint(1 << int(i_1))) + { + Monoid other = sh_scratch[gl_LocalInvocationID.x - uint(1 << int(i_1))]; + Monoid param_2 = other; + Monoid param_3 = agg; + agg = combine_monoid(param_2, param_3); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + sh_scratch[gl_LocalInvocationID.x] = agg; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + Monoid row = Monoid{ 0u }; + if (gl_LocalInvocationID.x > 0u) + { + row = sh_scratch[gl_LocalInvocationID.x - 1u]; + } + for (uint i_2 = 0u; i_2 < 8u; i_2++) + { + Monoid param_4 = row; + Monoid param_5 = local[i_2]; + Monoid m = combine_monoid(param_4, param_5); + _42.data[ix + i_2].element = m.element; + } +} + diff --git a/tests/shader/gen/prefix_root.spv b/tests/shader/gen/prefix_root.spv new file mode 100644 index 0000000..70ba31c Binary files /dev/null and b/tests/shader/gen/prefix_root.spv differ diff --git a/tests/shader/gen/prefix_scan.hlsl b/tests/shader/gen/prefix_scan.hlsl new file mode 100644 index 0000000..feeff2e --- /dev/null +++ b/tests/shader/gen/prefix_scan.hlsl @@ -0,0 +1,92 @@ +struct Monoid +{ + uint element; +}; + +static const uint3 gl_WorkGroupSize = uint3(512u, 1u, 1u); + +static const Monoid _133 = { 0u }; + +RWByteAddressBuffer _42 : register(u0); +RWByteAddressBuffer _143 : register(u1); + +static uint3 gl_WorkGroupID; +static uint3 gl_LocalInvocationID; +static uint3 gl_GlobalInvocationID; +struct SPIRV_Cross_Input +{ + uint3 gl_WorkGroupID : SV_GroupID; + uint3 gl_LocalInvocationID : SV_GroupThreadID; + uint3 gl_GlobalInvocationID : SV_DispatchThreadID; +}; + +groupshared Monoid sh_scratch[512]; + +Monoid combine_monoid(Monoid a, Monoid b) +{ + Monoid _22 = { a.element + b.element }; + return _22; +} + +void comp_main() +{ + uint ix = gl_GlobalInvocationID.x * 8u; + Monoid _46; + _46.element = _42.Load(ix * 4 + 0); + Monoid local[8]; + local[0].element = _46.element; + Monoid param_1; + for (uint i = 1u; i < 8u; i++) + { + Monoid param = local[i - 1u]; + Monoid _71; + _71.element = _42.Load((ix + i) * 4 + 0); + param_1.element = _71.element; + local[i] = combine_monoid(param, param_1); + } + Monoid agg = local[7]; + sh_scratch[gl_LocalInvocationID.x] = agg; + for (uint i_1 = 0u; i_1 < 9u; i_1++) + { + GroupMemoryBarrierWithGroupSync(); + if (gl_LocalInvocationID.x >= uint(1 << int(i_1))) + { + Monoid other = sh_scratch[gl_LocalInvocationID.x - uint(1 << int(i_1))]; + Monoid param_2 = other; + Monoid param_3 = agg; + agg = combine_monoid(param_2, param_3); + } + GroupMemoryBarrierWithGroupSync(); + sh_scratch[gl_LocalInvocationID.x] = agg; + } + GroupMemoryBarrierWithGroupSync(); + Monoid row = _133; + if (gl_WorkGroupID.x > 0u) + { + Monoid _148; + _148.element = _143.Load((gl_WorkGroupID.x - 1u) * 4 + 0); + row.element = _148.element; + } + if (gl_LocalInvocationID.x > 0u) + { + Monoid param_4 = row; + Monoid param_5 = sh_scratch[gl_LocalInvocationID.x - 1u]; + row = combine_monoid(param_4, param_5); + } + for (uint i_2 = 0u; i_2 < 8u; i_2++) + { + Monoid param_6 = row; + Monoid param_7 = local[i_2]; + Monoid m = combine_monoid(param_6, param_7); + _42.Store((ix + i_2) * 4 + 0, m.element); + } +} + +[numthreads(512, 1, 1)] +void main(SPIRV_Cross_Input stage_input) +{ + gl_WorkGroupID = stage_input.gl_WorkGroupID; + gl_LocalInvocationID = stage_input.gl_LocalInvocationID; + gl_GlobalInvocationID = stage_input.gl_GlobalInvocationID; + comp_main(); +} diff --git a/tests/shader/gen/prefix_scan.msl b/tests/shader/gen/prefix_scan.msl new file mode 100644 index 0000000..c1efb22 --- /dev/null +++ b/tests/shader/gen/prefix_scan.msl @@ -0,0 +1,123 @@ +#pragma clang diagnostic ignored "-Wmissing-prototypes" +#pragma clang diagnostic ignored "-Wmissing-braces" + +#include +#include + +using namespace metal; + +template +struct spvUnsafeArray +{ + T elements[Num ? Num : 1]; + + thread T& operator [] (size_t pos) thread + { + return elements[pos]; + } + constexpr const thread T& operator [] (size_t pos) const thread + { + return elements[pos]; + } + + device T& operator [] (size_t pos) device + { + return elements[pos]; + } + constexpr const device T& operator [] (size_t pos) const device + { + return elements[pos]; + } + + constexpr const constant T& operator [] (size_t pos) const constant + { + return elements[pos]; + } + + threadgroup T& operator [] (size_t pos) threadgroup + { + return elements[pos]; + } + constexpr const threadgroup T& operator [] (size_t pos) const threadgroup + { + return elements[pos]; + } +}; + +struct Monoid +{ + uint element; +}; + +struct Monoid_1 +{ + uint element; +}; + +struct DataBuf +{ + Monoid_1 data[1]; +}; + +struct ParentBuf +{ + Monoid_1 parent[1]; +}; + +constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u); + +static inline __attribute__((always_inline)) +Monoid combine_monoid(thread const Monoid& a, thread const Monoid& b) +{ + return Monoid{ a.element + b.element }; +} + +kernel void main0(device DataBuf& _42 [[buffer(0)]], device ParentBuf& _143 [[buffer(1)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]]) +{ + threadgroup Monoid sh_scratch[512]; + uint ix = gl_GlobalInvocationID.x * 8u; + spvUnsafeArray local; + local[0].element = _42.data[ix].element; + Monoid param_1; + for (uint i = 1u; i < 8u; i++) + { + Monoid param = local[i - 1u]; + param_1.element = _42.data[ix + i].element; + local[i] = combine_monoid(param, param_1); + } + Monoid agg = local[7]; + sh_scratch[gl_LocalInvocationID.x] = agg; + for (uint i_1 = 0u; i_1 < 9u; i_1++) + { + threadgroup_barrier(mem_flags::mem_threadgroup); + if (gl_LocalInvocationID.x >= uint(1 << int(i_1))) + { + Monoid other = sh_scratch[gl_LocalInvocationID.x - uint(1 << int(i_1))]; + Monoid param_2 = other; + Monoid param_3 = agg; + agg = combine_monoid(param_2, param_3); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + sh_scratch[gl_LocalInvocationID.x] = agg; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + Monoid row = Monoid{ 0u }; + if (gl_WorkGroupID.x > 0u) + { + row.element = _143.parent[gl_WorkGroupID.x - 1u].element; + } + if (gl_LocalInvocationID.x > 0u) + { + Monoid param_4 = row; + Monoid param_5 = sh_scratch[gl_LocalInvocationID.x - 1u]; + row = combine_monoid(param_4, param_5); + } + for (uint i_2 = 0u; i_2 < 8u; i_2++) + { + Monoid param_6 = row; + Monoid param_7 = local[i_2]; + Monoid m = combine_monoid(param_6, param_7); + _42.data[ix + i_2].element = m.element; + } +} + diff --git a/tests/shader/gen/prefix_scan.spv b/tests/shader/gen/prefix_scan.spv new file mode 100644 index 0000000..d4216e9 Binary files /dev/null and b/tests/shader/gen/prefix_scan.spv differ