diff --git a/Cargo.lock b/Cargo.lock index 64fc6fe..e5b2eaa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -85,6 +85,12 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0d8c1fef690941d3e7788d328517591fecc684c084084702d6ff1641e993699a" +[[package]] +name = "bytemuck" +version = "1.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72957246c41db82b8ef88a5486143830adeb8227ef9837740bdec67724cf2c5b" + [[package]] name = "byteorder" version = "1.3.4" @@ -897,6 +903,15 @@ dependencies = [ "wio", ] +[[package]] +name = "piet-gpu-tests" +version = "0.1.0" +dependencies = [ + "bytemuck", + "clap", + "piet-gpu-hal", +] + [[package]] name = "piet-gpu-types" version = "0.0.0" diff --git a/Cargo.toml b/Cargo.toml index f71f2de..bfa0030 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,5 +4,6 @@ members = [ "piet-gpu", "piet-gpu-derive", "piet-gpu-hal", - "piet-gpu-types" + "piet-gpu-types", + "tests" ] diff --git a/piet-gpu-hal/examples/collatz.rs b/piet-gpu-hal/examples/collatz.rs index e974cde..cad508e 100644 --- a/piet-gpu-hal/examples/collatz.rs +++ b/piet-gpu-hal/examples/collatz.rs @@ -21,6 +21,7 @@ fn main() { cmd_buf.write_timestamp(&query_pool, 0); cmd_buf.dispatch(&pipeline, &descriptor_set, (256, 1, 1), (1, 1, 1)); cmd_buf.write_timestamp(&query_pool, 1); + cmd_buf.finish_timestamps(&query_pool); cmd_buf.host_barrier(); cmd_buf.finish(); let submitted = session.run_cmd_buf(cmd_buf, &[], &[]).unwrap(); diff --git a/tests/Cargo.toml b/tests/Cargo.toml new file mode 100644 index 0000000..a987c9e --- /dev/null +++ b/tests/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "piet-gpu-tests" +version = "0.1.0" +authors = ["Raph Levien "] +description = "Tests for piet-gpu shaders and generic GPU capabilities." +license = "MIT/Apache-2.0" +edition = "2021" + +[dependencies] +clap = "2.33" +bytemuck = "1.7.2" + +[dependencies.piet-gpu-hal] +path = "../piet-gpu-hal" diff --git a/tests/shader/build.ninja b/tests/shader/build.ninja new file mode 100644 index 0000000..93a0b66 --- /dev/null +++ b/tests/shader/build.ninja @@ -0,0 +1,32 @@ +# Build file for shaders. + +# You must have Vulkan tools in your path, or patch here. + +glslang_validator = glslangValidator +spirv_cross = spirv-cross + +rule glsl + command = $glslang_validator $flags -V -o $out $in + +rule hlsl + command = $spirv_cross --hlsl $in --output $out + +rule msl + command = $spirv_cross --msl $in --output $out + +build gen/prefix.spv: glsl prefix.comp +build gen/prefix.hlsl: hlsl gen/prefix.spv +build gen/prefix.msl: msl gen/prefix.spv + +build gen/prefix_reduce.spv: glsl prefix_reduce.comp +build gen/prefix_reduce.hlsl: hlsl gen/prefix_reduce.spv +build gen/prefix_reduce.msl: msl gen/prefix_reduce.spv + +build gen/prefix_root.spv: glsl prefix_scan.comp + flags = -DROOT +build gen/prefix_root.hlsl: hlsl gen/prefix_root.spv +build gen/prefix_root.msl: msl gen/prefix_root.spv + +build gen/prefix_scan.spv: glsl prefix_scan.comp +build gen/prefix_scan.hlsl: hlsl gen/prefix_scan.spv +build gen/prefix_scan.msl: msl gen/prefix_scan.spv diff --git a/tests/shader/gen/collatz.hlsl b/tests/shader/gen/collatz.hlsl new file mode 100644 index 0000000..2f4861f --- /dev/null +++ b/tests/shader/gen/collatz.hlsl @@ -0,0 +1,43 @@ +static const uint3 gl_WorkGroupSize = uint3(1u, 1u, 1u); + +RWByteAddressBuffer _53 : register(u1); +ByteAddressBuffer _59 : register(t0); + +static uint3 gl_GlobalInvocationID; +struct SPIRV_Cross_Input +{ + uint3 gl_GlobalInvocationID : SV_DispatchThreadID; +}; + +uint collatz_iterations(inout uint n) +{ + uint i = 0u; + while (n != 1u) + { + if ((n % 2u) == 0u) + { + n /= 2u; + } + else + { + n = (3u * n) + 1u; + } + i++; + } + return i; +} + +void comp_main() +{ + uint index = gl_GlobalInvocationID.x; + uint param = _59.Load(index * 4 + 0); + uint _65 = collatz_iterations(param); + _53.Store(index * 4 + 0, _65); +} + +[numthreads(1, 1, 1)] +void main(SPIRV_Cross_Input stage_input) +{ + gl_GlobalInvocationID = stage_input.gl_GlobalInvocationID; + comp_main(); +} diff --git a/tests/shader/gen/collatz.msl b/tests/shader/gen/collatz.msl new file mode 100644 index 0000000..87cc7b5 --- /dev/null +++ b/tests/shader/gen/collatz.msl @@ -0,0 +1,46 @@ +#pragma clang diagnostic ignored "-Wmissing-prototypes" + +#include +#include + +using namespace metal; + +struct OutBuf +{ + uint out_buf[1]; +}; + +struct InBuf +{ + uint in_buf[1]; +}; + +constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1u); + +static inline __attribute__((always_inline)) +uint collatz_iterations(thread uint& n) +{ + uint i = 0u; + while (n != 1u) + { + if ((n % 2u) == 0u) + { + n /= 2u; + } + else + { + n = (3u * n) + 1u; + } + i++; + } + return i; +} + +kernel void main0(device OutBuf& _53 [[buffer(0)]], const device InBuf& _59 [[buffer(1)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]]) +{ + uint index = gl_GlobalInvocationID.x; + uint param = _59.in_buf[index]; + uint _65 = collatz_iterations(param); + _53.out_buf[index] = _65; +} + diff --git a/tests/shader/gen/collatz.spv b/tests/shader/gen/collatz.spv new file mode 100644 index 0000000..638894c Binary files /dev/null and b/tests/shader/gen/collatz.spv differ diff --git a/tests/shader/gen/prefix.hlsl b/tests/shader/gen/prefix.hlsl new file mode 100644 index 0000000..c0600e2 --- /dev/null +++ b/tests/shader/gen/prefix.hlsl @@ -0,0 +1,223 @@ +struct Monoid +{ + uint element; +}; + +struct State +{ + uint flag; + Monoid aggregate; + Monoid prefix; +}; + +static const uint3 gl_WorkGroupSize = uint3(512u, 1u, 1u); + +static const Monoid _187 = { 0u }; + +globallycoherent RWByteAddressBuffer _43 : register(u2); +ByteAddressBuffer _67 : register(t0); +RWByteAddressBuffer _374 : register(u1); + +static uint3 gl_LocalInvocationID; +struct SPIRV_Cross_Input +{ + uint3 gl_LocalInvocationID : SV_GroupThreadID; +}; + +groupshared uint sh_part_ix; +groupshared Monoid sh_scratch[512]; +groupshared uint sh_flag; +groupshared Monoid sh_prefix; + +Monoid combine_monoid(Monoid a, Monoid b) +{ + Monoid _22 = { a.element + b.element }; + return _22; +} + +void comp_main() +{ + if (gl_LocalInvocationID.x == 0u) + { + uint _47; + _43.InterlockedAdd(0, 1u, _47); + sh_part_ix = _47; + } + GroupMemoryBarrierWithGroupSync(); + uint part_ix = sh_part_ix; + uint ix = (part_ix * 8192u) + (gl_LocalInvocationID.x * 16u); + Monoid _71; + _71.element = _67.Load(ix * 4 + 0); + Monoid local[16]; + local[0].element = _71.element; + Monoid param_1; + for (uint i = 1u; i < 16u; i++) + { + Monoid param = local[i - 1u]; + Monoid _94; + _94.element = _67.Load((ix + i) * 4 + 0); + param_1.element = _94.element; + local[i] = combine_monoid(param, param_1); + } + Monoid agg = local[15]; + sh_scratch[gl_LocalInvocationID.x] = agg; + for (uint i_1 = 0u; i_1 < 9u; i_1++) + { + GroupMemoryBarrierWithGroupSync(); + if (gl_LocalInvocationID.x >= uint(1 << int(i_1))) + { + Monoid other = sh_scratch[gl_LocalInvocationID.x - uint(1 << int(i_1))]; + Monoid param_2 = other; + Monoid param_3 = agg; + agg = combine_monoid(param_2, param_3); + } + GroupMemoryBarrierWithGroupSync(); + sh_scratch[gl_LocalInvocationID.x] = agg; + } + if (gl_LocalInvocationID.x == 511u) + { + _43.Store(part_ix * 12 + 8, agg.element); + if (part_ix == 0u) + { + _43.Store(12, agg.element); + } + } + DeviceMemoryBarrier(); + if (gl_LocalInvocationID.x == 511u) + { + uint flag = 1u; + if (part_ix == 0u) + { + flag = 2u; + } + _43.Store(part_ix * 12 + 4, flag); + } + Monoid exclusive = _187; + if (part_ix != 0u) + { + uint look_back_ix = part_ix - 1u; + uint their_ix = 0u; + Monoid their_prefix; + Monoid their_agg; + Monoid m; + while (true) + { + if (gl_LocalInvocationID.x == 511u) + { + sh_flag = _43.Load(look_back_ix * 12 + 4); + } + GroupMemoryBarrierWithGroupSync(); + DeviceMemoryBarrier(); + uint flag_1 = sh_flag; + if (flag_1 == 2u) + { + if (gl_LocalInvocationID.x == 511u) + { + Monoid _225; + _225.element = _43.Load(look_back_ix * 12 + 12); + their_prefix.element = _225.element; + Monoid param_4 = their_prefix; + Monoid param_5 = exclusive; + exclusive = combine_monoid(param_4, param_5); + } + break; + } + else + { + if (flag_1 == 1u) + { + if (gl_LocalInvocationID.x == 511u) + { + Monoid _247; + _247.element = _43.Load(look_back_ix * 12 + 8); + their_agg.element = _247.element; + Monoid param_6 = their_agg; + Monoid param_7 = exclusive; + exclusive = combine_monoid(param_6, param_7); + } + look_back_ix--; + their_ix = 0u; + continue; + } + } + if (gl_LocalInvocationID.x == 511u) + { + Monoid _269; + _269.element = _67.Load(((look_back_ix * 8192u) + their_ix) * 4 + 0); + m.element = _269.element; + if (their_ix == 0u) + { + their_agg = m; + } + else + { + Monoid param_8 = their_agg; + Monoid param_9 = m; + their_agg = combine_monoid(param_8, param_9); + } + their_ix++; + if (their_ix == 8192u) + { + Monoid param_10 = their_agg; + Monoid param_11 = exclusive; + exclusive = combine_monoid(param_10, param_11); + if (look_back_ix == 0u) + { + sh_flag = 2u; + } + else + { + look_back_ix--; + their_ix = 0u; + } + } + } + GroupMemoryBarrierWithGroupSync(); + flag_1 = sh_flag; + if (flag_1 == 2u) + { + break; + } + } + if (gl_LocalInvocationID.x == 511u) + { + Monoid param_12 = exclusive; + Monoid param_13 = agg; + Monoid inclusive_prefix = combine_monoid(param_12, param_13); + sh_prefix = exclusive; + _43.Store(part_ix * 12 + 12, inclusive_prefix.element); + } + DeviceMemoryBarrier(); + if (gl_LocalInvocationID.x == 511u) + { + _43.Store(part_ix * 12 + 4, 2u); + } + } + GroupMemoryBarrierWithGroupSync(); + if (part_ix != 0u) + { + exclusive = sh_prefix; + } + Monoid row = exclusive; + if (gl_LocalInvocationID.x > 0u) + { + Monoid other_1 = sh_scratch[gl_LocalInvocationID.x - 1u]; + Monoid param_14 = row; + Monoid param_15 = other_1; + row = combine_monoid(param_14, param_15); + } + for (uint i_2 = 0u; i_2 < 16u; i_2++) + { + Monoid param_16 = row; + Monoid param_17 = local[i_2]; + Monoid m_1 = combine_monoid(param_16, param_17); + _374.Store((ix + i_2) * 4 + 0, m_1.element); + } +} + +[numthreads(512, 1, 1)] +void main(SPIRV_Cross_Input stage_input) +{ + gl_LocalInvocationID = stage_input.gl_LocalInvocationID; + comp_main(); +} diff --git a/tests/shader/gen/prefix.msl b/tests/shader/gen/prefix.msl new file mode 100644 index 0000000..ecdf8bd --- /dev/null +++ b/tests/shader/gen/prefix.msl @@ -0,0 +1,262 @@ +#pragma clang diagnostic ignored "-Wmissing-prototypes" +#pragma clang diagnostic ignored "-Wmissing-braces" +#pragma clang diagnostic ignored "-Wunused-variable" + +#include +#include +#include + +using namespace metal; + +template +struct spvUnsafeArray +{ + T elements[Num ? Num : 1]; + + thread T& operator [] (size_t pos) thread + { + return elements[pos]; + } + constexpr const thread T& operator [] (size_t pos) const thread + { + return elements[pos]; + } + + device T& operator [] (size_t pos) device + { + return elements[pos]; + } + constexpr const device T& operator [] (size_t pos) const device + { + return elements[pos]; + } + + constexpr const constant T& operator [] (size_t pos) const constant + { + return elements[pos]; + } + + threadgroup T& operator [] (size_t pos) threadgroup + { + return elements[pos]; + } + constexpr const threadgroup T& operator [] (size_t pos) const threadgroup + { + return elements[pos]; + } +}; + +struct Monoid +{ + uint element; +}; + +struct Monoid_1 +{ + uint element; +}; + +struct State +{ + uint flag; + Monoid_1 aggregate; + Monoid_1 prefix; +}; + +struct StateBuf +{ + uint part_counter; + State state[1]; +}; + +struct InBuf +{ + Monoid_1 inbuf[1]; +}; + +struct OutBuf +{ + Monoid_1 outbuf[1]; +}; + +constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u); + +static inline __attribute__((always_inline)) +Monoid combine_monoid(thread const Monoid& a, thread const Monoid& b) +{ + return Monoid{ a.element + b.element }; +} + +kernel void main0(volatile device StateBuf& _43 [[buffer(0)]], const device InBuf& _67 [[buffer(1)]], device OutBuf& _374 [[buffer(2)]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]]) +{ + threadgroup uint sh_part_ix; + threadgroup Monoid sh_scratch[512]; + threadgroup uint sh_flag; + threadgroup Monoid sh_prefix; + if (gl_LocalInvocationID.x == 0u) + { + uint _47 = atomic_fetch_add_explicit((volatile device atomic_uint*)&_43.part_counter, 1u, memory_order_relaxed); + sh_part_ix = _47; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + uint part_ix = sh_part_ix; + uint ix = (part_ix * 8192u) + (gl_LocalInvocationID.x * 16u); + spvUnsafeArray local; + local[0].element = _67.inbuf[ix].element; + Monoid param_1; + for (uint i = 1u; i < 16u; i++) + { + Monoid param = local[i - 1u]; + param_1.element = _67.inbuf[ix + i].element; + local[i] = combine_monoid(param, param_1); + } + Monoid agg = local[15]; + sh_scratch[gl_LocalInvocationID.x] = agg; + for (uint i_1 = 0u; i_1 < 9u; i_1++) + { + threadgroup_barrier(mem_flags::mem_threadgroup); + if (gl_LocalInvocationID.x >= uint(1 << int(i_1))) + { + Monoid other = sh_scratch[gl_LocalInvocationID.x - uint(1 << int(i_1))]; + Monoid param_2 = other; + Monoid param_3 = agg; + agg = combine_monoid(param_2, param_3); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + sh_scratch[gl_LocalInvocationID.x] = agg; + } + if (gl_LocalInvocationID.x == 511u) + { + _43.state[part_ix].aggregate.element = agg.element; + if (part_ix == 0u) + { + _43.state[0].prefix.element = agg.element; + } + } + threadgroup_barrier(mem_flags::mem_device); + if (gl_LocalInvocationID.x == 511u) + { + uint flag = 1u; + if (part_ix == 0u) + { + flag = 2u; + } + _43.state[part_ix].flag = flag; + } + Monoid exclusive = Monoid{ 0u }; + if (part_ix != 0u) + { + uint look_back_ix = part_ix - 1u; + uint their_ix = 0u; + Monoid their_prefix; + Monoid their_agg; + Monoid m; + while (true) + { + if (gl_LocalInvocationID.x == 511u) + { + sh_flag = _43.state[look_back_ix].flag; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_barrier(mem_flags::mem_device); + uint flag_1 = sh_flag; + if (flag_1 == 2u) + { + if (gl_LocalInvocationID.x == 511u) + { + their_prefix.element = _43.state[look_back_ix].prefix.element; + Monoid param_4 = their_prefix; + Monoid param_5 = exclusive; + exclusive = combine_monoid(param_4, param_5); + } + break; + } + else + { + if (flag_1 == 1u) + { + if (gl_LocalInvocationID.x == 511u) + { + their_agg.element = _43.state[look_back_ix].aggregate.element; + Monoid param_6 = their_agg; + Monoid param_7 = exclusive; + exclusive = combine_monoid(param_6, param_7); + } + look_back_ix--; + their_ix = 0u; + continue; + } + } + if (gl_LocalInvocationID.x == 511u) + { + m.element = _67.inbuf[(look_back_ix * 8192u) + their_ix].element; + if (their_ix == 0u) + { + their_agg = m; + } + else + { + Monoid param_8 = their_agg; + Monoid param_9 = m; + their_agg = combine_monoid(param_8, param_9); + } + their_ix++; + if (their_ix == 8192u) + { + Monoid param_10 = their_agg; + Monoid param_11 = exclusive; + exclusive = combine_monoid(param_10, param_11); + if (look_back_ix == 0u) + { + sh_flag = 2u; + } + else + { + look_back_ix--; + their_ix = 0u; + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + flag_1 = sh_flag; + if (flag_1 == 2u) + { + break; + } + } + if (gl_LocalInvocationID.x == 511u) + { + Monoid param_12 = exclusive; + Monoid param_13 = agg; + Monoid inclusive_prefix = combine_monoid(param_12, param_13); + sh_prefix = exclusive; + _43.state[part_ix].prefix.element = inclusive_prefix.element; + } + threadgroup_barrier(mem_flags::mem_device); + if (gl_LocalInvocationID.x == 511u) + { + _43.state[part_ix].flag = 2u; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (part_ix != 0u) + { + exclusive = sh_prefix; + } + Monoid row = exclusive; + if (gl_LocalInvocationID.x > 0u) + { + Monoid other_1 = sh_scratch[gl_LocalInvocationID.x - 1u]; + Monoid param_14 = row; + Monoid param_15 = other_1; + row = combine_monoid(param_14, param_15); + } + for (uint i_2 = 0u; i_2 < 16u; i_2++) + { + Monoid param_16 = row; + Monoid param_17 = local[i_2]; + Monoid m_1 = combine_monoid(param_16, param_17); + _374.outbuf[ix + i_2].element = m_1.element; + } +} + diff --git a/tests/shader/gen/prefix.spv b/tests/shader/gen/prefix.spv new file mode 100644 index 0000000..170a569 Binary files /dev/null and b/tests/shader/gen/prefix.spv differ diff --git a/tests/shader/gen/prefix_reduce.hlsl b/tests/shader/gen/prefix_reduce.hlsl new file mode 100644 index 0000000..837a75a --- /dev/null +++ b/tests/shader/gen/prefix_reduce.hlsl @@ -0,0 +1,72 @@ +struct Monoid +{ + uint element; +}; + +static const uint3 gl_WorkGroupSize = uint3(512u, 1u, 1u); + +ByteAddressBuffer _40 : register(t0); +RWByteAddressBuffer _129 : register(u1); + +static uint3 gl_WorkGroupID; +static uint3 gl_LocalInvocationID; +static uint3 gl_GlobalInvocationID; +struct SPIRV_Cross_Input +{ + uint3 gl_WorkGroupID : SV_GroupID; + uint3 gl_LocalInvocationID : SV_GroupThreadID; + uint3 gl_GlobalInvocationID : SV_DispatchThreadID; +}; + +groupshared Monoid sh_scratch[512]; + +Monoid combine_monoid(Monoid a, Monoid b) +{ + Monoid _22 = { a.element + b.element }; + return _22; +} + +void comp_main() +{ + uint ix = gl_GlobalInvocationID.x * 8u; + Monoid _44; + _44.element = _40.Load(ix * 4 + 0); + Monoid agg; + agg.element = _44.element; + Monoid param_1; + for (uint i = 1u; i < 8u; i++) + { + Monoid param = agg; + Monoid _64; + _64.element = _40.Load((ix + i) * 4 + 0); + param_1.element = _64.element; + agg = combine_monoid(param, param_1); + } + sh_scratch[gl_LocalInvocationID.x] = agg; + for (uint i_1 = 0u; i_1 < 9u; i_1++) + { + GroupMemoryBarrierWithGroupSync(); + if ((gl_LocalInvocationID.x + uint(1 << int(i_1))) < 512u) + { + Monoid other = sh_scratch[gl_LocalInvocationID.x + uint(1 << int(i_1))]; + Monoid param_2 = agg; + Monoid param_3 = other; + agg = combine_monoid(param_2, param_3); + } + GroupMemoryBarrierWithGroupSync(); + sh_scratch[gl_LocalInvocationID.x] = agg; + } + if (gl_LocalInvocationID.x == 0u) + { + _129.Store(gl_WorkGroupID.x * 4 + 0, agg.element); + } +} + +[numthreads(512, 1, 1)] +void main(SPIRV_Cross_Input stage_input) +{ + gl_WorkGroupID = stage_input.gl_WorkGroupID; + gl_LocalInvocationID = stage_input.gl_LocalInvocationID; + gl_GlobalInvocationID = stage_input.gl_GlobalInvocationID; + comp_main(); +} diff --git a/tests/shader/gen/prefix_reduce.msl b/tests/shader/gen/prefix_reduce.msl new file mode 100644 index 0000000..e1ed0ce --- /dev/null +++ b/tests/shader/gen/prefix_reduce.msl @@ -0,0 +1,68 @@ +#pragma clang diagnostic ignored "-Wmissing-prototypes" + +#include +#include + +using namespace metal; + +struct Monoid +{ + uint element; +}; + +struct Monoid_1 +{ + uint element; +}; + +struct InBuf +{ + Monoid_1 inbuf[1]; +}; + +struct OutBuf +{ + Monoid_1 outbuf[1]; +}; + +constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u); + +static inline __attribute__((always_inline)) +Monoid combine_monoid(thread const Monoid& a, thread const Monoid& b) +{ + return Monoid{ a.element + b.element }; +} + +kernel void main0(const device InBuf& _40 [[buffer(0)]], device OutBuf& _129 [[buffer(1)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]]) +{ + threadgroup Monoid sh_scratch[512]; + uint ix = gl_GlobalInvocationID.x * 8u; + Monoid agg; + agg.element = _40.inbuf[ix].element; + Monoid param_1; + for (uint i = 1u; i < 8u; i++) + { + Monoid param = agg; + param_1.element = _40.inbuf[ix + i].element; + agg = combine_monoid(param, param_1); + } + sh_scratch[gl_LocalInvocationID.x] = agg; + for (uint i_1 = 0u; i_1 < 9u; i_1++) + { + threadgroup_barrier(mem_flags::mem_threadgroup); + if ((gl_LocalInvocationID.x + uint(1 << int(i_1))) < 512u) + { + Monoid other = sh_scratch[gl_LocalInvocationID.x + uint(1 << int(i_1))]; + Monoid param_2 = agg; + Monoid param_3 = other; + agg = combine_monoid(param_2, param_3); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + sh_scratch[gl_LocalInvocationID.x] = agg; + } + if (gl_LocalInvocationID.x == 0u) + { + _129.outbuf[gl_WorkGroupID.x].element = agg.element; + } +} + diff --git a/tests/shader/gen/prefix_reduce.spv b/tests/shader/gen/prefix_reduce.spv new file mode 100644 index 0000000..d1db3aa Binary files /dev/null and b/tests/shader/gen/prefix_reduce.spv differ diff --git a/tests/shader/gen/prefix_root.hlsl b/tests/shader/gen/prefix_root.hlsl new file mode 100644 index 0000000..2ad617c --- /dev/null +++ b/tests/shader/gen/prefix_root.hlsl @@ -0,0 +1,80 @@ +struct Monoid +{ + uint element; +}; + +static const uint3 gl_WorkGroupSize = uint3(512u, 1u, 1u); + +static const Monoid _133 = { 0u }; + +RWByteAddressBuffer _42 : register(u0); + +static uint3 gl_LocalInvocationID; +static uint3 gl_GlobalInvocationID; +struct SPIRV_Cross_Input +{ + uint3 gl_LocalInvocationID : SV_GroupThreadID; + uint3 gl_GlobalInvocationID : SV_DispatchThreadID; +}; + +groupshared Monoid sh_scratch[512]; + +Monoid combine_monoid(Monoid a, Monoid b) +{ + Monoid _22 = { a.element + b.element }; + return _22; +} + +void comp_main() +{ + uint ix = gl_GlobalInvocationID.x * 8u; + Monoid _46; + _46.element = _42.Load(ix * 4 + 0); + Monoid local[8]; + local[0].element = _46.element; + Monoid param_1; + for (uint i = 1u; i < 8u; i++) + { + Monoid param = local[i - 1u]; + Monoid _71; + _71.element = _42.Load((ix + i) * 4 + 0); + param_1.element = _71.element; + local[i] = combine_monoid(param, param_1); + } + Monoid agg = local[7]; + sh_scratch[gl_LocalInvocationID.x] = agg; + for (uint i_1 = 0u; i_1 < 9u; i_1++) + { + GroupMemoryBarrierWithGroupSync(); + if (gl_LocalInvocationID.x >= uint(1 << int(i_1))) + { + Monoid other = sh_scratch[gl_LocalInvocationID.x - uint(1 << int(i_1))]; + Monoid param_2 = other; + Monoid param_3 = agg; + agg = combine_monoid(param_2, param_3); + } + GroupMemoryBarrierWithGroupSync(); + sh_scratch[gl_LocalInvocationID.x] = agg; + } + GroupMemoryBarrierWithGroupSync(); + Monoid row = _133; + if (gl_LocalInvocationID.x > 0u) + { + row = sh_scratch[gl_LocalInvocationID.x - 1u]; + } + for (uint i_2 = 0u; i_2 < 8u; i_2++) + { + Monoid param_4 = row; + Monoid param_5 = local[i_2]; + Monoid m = combine_monoid(param_4, param_5); + _42.Store((ix + i_2) * 4 + 0, m.element); + } +} + +[numthreads(512, 1, 1)] +void main(SPIRV_Cross_Input stage_input) +{ + gl_LocalInvocationID = stage_input.gl_LocalInvocationID; + gl_GlobalInvocationID = stage_input.gl_GlobalInvocationID; + comp_main(); +} diff --git a/tests/shader/gen/prefix_root.msl b/tests/shader/gen/prefix_root.msl new file mode 100644 index 0000000..ff02287 --- /dev/null +++ b/tests/shader/gen/prefix_root.msl @@ -0,0 +1,112 @@ +#pragma clang diagnostic ignored "-Wmissing-prototypes" +#pragma clang diagnostic ignored "-Wmissing-braces" + +#include +#include + +using namespace metal; + +template +struct spvUnsafeArray +{ + T elements[Num ? Num : 1]; + + thread T& operator [] (size_t pos) thread + { + return elements[pos]; + } + constexpr const thread T& operator [] (size_t pos) const thread + { + return elements[pos]; + } + + device T& operator [] (size_t pos) device + { + return elements[pos]; + } + constexpr const device T& operator [] (size_t pos) const device + { + return elements[pos]; + } + + constexpr const constant T& operator [] (size_t pos) const constant + { + return elements[pos]; + } + + threadgroup T& operator [] (size_t pos) threadgroup + { + return elements[pos]; + } + constexpr const threadgroup T& operator [] (size_t pos) const threadgroup + { + return elements[pos]; + } +}; + +struct Monoid +{ + uint element; +}; + +struct Monoid_1 +{ + uint element; +}; + +struct DataBuf +{ + Monoid_1 data[1]; +}; + +constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u); + +static inline __attribute__((always_inline)) +Monoid combine_monoid(thread const Monoid& a, thread const Monoid& b) +{ + return Monoid{ a.element + b.element }; +} + +kernel void main0(device DataBuf& _42 [[buffer(0)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]]) +{ + threadgroup Monoid sh_scratch[512]; + uint ix = gl_GlobalInvocationID.x * 8u; + spvUnsafeArray local; + local[0].element = _42.data[ix].element; + Monoid param_1; + for (uint i = 1u; i < 8u; i++) + { + Monoid param = local[i - 1u]; + param_1.element = _42.data[ix + i].element; + local[i] = combine_monoid(param, param_1); + } + Monoid agg = local[7]; + sh_scratch[gl_LocalInvocationID.x] = agg; + for (uint i_1 = 0u; i_1 < 9u; i_1++) + { + threadgroup_barrier(mem_flags::mem_threadgroup); + if (gl_LocalInvocationID.x >= uint(1 << int(i_1))) + { + Monoid other = sh_scratch[gl_LocalInvocationID.x - uint(1 << int(i_1))]; + Monoid param_2 = other; + Monoid param_3 = agg; + agg = combine_monoid(param_2, param_3); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + sh_scratch[gl_LocalInvocationID.x] = agg; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + Monoid row = Monoid{ 0u }; + if (gl_LocalInvocationID.x > 0u) + { + row = sh_scratch[gl_LocalInvocationID.x - 1u]; + } + for (uint i_2 = 0u; i_2 < 8u; i_2++) + { + Monoid param_4 = row; + Monoid param_5 = local[i_2]; + Monoid m = combine_monoid(param_4, param_5); + _42.data[ix + i_2].element = m.element; + } +} + diff --git a/tests/shader/gen/prefix_root.spv b/tests/shader/gen/prefix_root.spv new file mode 100644 index 0000000..70ba31c Binary files /dev/null and b/tests/shader/gen/prefix_root.spv differ diff --git a/tests/shader/gen/prefix_scan.hlsl b/tests/shader/gen/prefix_scan.hlsl new file mode 100644 index 0000000..feeff2e --- /dev/null +++ b/tests/shader/gen/prefix_scan.hlsl @@ -0,0 +1,92 @@ +struct Monoid +{ + uint element; +}; + +static const uint3 gl_WorkGroupSize = uint3(512u, 1u, 1u); + +static const Monoid _133 = { 0u }; + +RWByteAddressBuffer _42 : register(u0); +RWByteAddressBuffer _143 : register(u1); + +static uint3 gl_WorkGroupID; +static uint3 gl_LocalInvocationID; +static uint3 gl_GlobalInvocationID; +struct SPIRV_Cross_Input +{ + uint3 gl_WorkGroupID : SV_GroupID; + uint3 gl_LocalInvocationID : SV_GroupThreadID; + uint3 gl_GlobalInvocationID : SV_DispatchThreadID; +}; + +groupshared Monoid sh_scratch[512]; + +Monoid combine_monoid(Monoid a, Monoid b) +{ + Monoid _22 = { a.element + b.element }; + return _22; +} + +void comp_main() +{ + uint ix = gl_GlobalInvocationID.x * 8u; + Monoid _46; + _46.element = _42.Load(ix * 4 + 0); + Monoid local[8]; + local[0].element = _46.element; + Monoid param_1; + for (uint i = 1u; i < 8u; i++) + { + Monoid param = local[i - 1u]; + Monoid _71; + _71.element = _42.Load((ix + i) * 4 + 0); + param_1.element = _71.element; + local[i] = combine_monoid(param, param_1); + } + Monoid agg = local[7]; + sh_scratch[gl_LocalInvocationID.x] = agg; + for (uint i_1 = 0u; i_1 < 9u; i_1++) + { + GroupMemoryBarrierWithGroupSync(); + if (gl_LocalInvocationID.x >= uint(1 << int(i_1))) + { + Monoid other = sh_scratch[gl_LocalInvocationID.x - uint(1 << int(i_1))]; + Monoid param_2 = other; + Monoid param_3 = agg; + agg = combine_monoid(param_2, param_3); + } + GroupMemoryBarrierWithGroupSync(); + sh_scratch[gl_LocalInvocationID.x] = agg; + } + GroupMemoryBarrierWithGroupSync(); + Monoid row = _133; + if (gl_WorkGroupID.x > 0u) + { + Monoid _148; + _148.element = _143.Load((gl_WorkGroupID.x - 1u) * 4 + 0); + row.element = _148.element; + } + if (gl_LocalInvocationID.x > 0u) + { + Monoid param_4 = row; + Monoid param_5 = sh_scratch[gl_LocalInvocationID.x - 1u]; + row = combine_monoid(param_4, param_5); + } + for (uint i_2 = 0u; i_2 < 8u; i_2++) + { + Monoid param_6 = row; + Monoid param_7 = local[i_2]; + Monoid m = combine_monoid(param_6, param_7); + _42.Store((ix + i_2) * 4 + 0, m.element); + } +} + +[numthreads(512, 1, 1)] +void main(SPIRV_Cross_Input stage_input) +{ + gl_WorkGroupID = stage_input.gl_WorkGroupID; + gl_LocalInvocationID = stage_input.gl_LocalInvocationID; + gl_GlobalInvocationID = stage_input.gl_GlobalInvocationID; + comp_main(); +} diff --git a/tests/shader/gen/prefix_scan.msl b/tests/shader/gen/prefix_scan.msl new file mode 100644 index 0000000..c1efb22 --- /dev/null +++ b/tests/shader/gen/prefix_scan.msl @@ -0,0 +1,123 @@ +#pragma clang diagnostic ignored "-Wmissing-prototypes" +#pragma clang diagnostic ignored "-Wmissing-braces" + +#include +#include + +using namespace metal; + +template +struct spvUnsafeArray +{ + T elements[Num ? Num : 1]; + + thread T& operator [] (size_t pos) thread + { + return elements[pos]; + } + constexpr const thread T& operator [] (size_t pos) const thread + { + return elements[pos]; + } + + device T& operator [] (size_t pos) device + { + return elements[pos]; + } + constexpr const device T& operator [] (size_t pos) const device + { + return elements[pos]; + } + + constexpr const constant T& operator [] (size_t pos) const constant + { + return elements[pos]; + } + + threadgroup T& operator [] (size_t pos) threadgroup + { + return elements[pos]; + } + constexpr const threadgroup T& operator [] (size_t pos) const threadgroup + { + return elements[pos]; + } +}; + +struct Monoid +{ + uint element; +}; + +struct Monoid_1 +{ + uint element; +}; + +struct DataBuf +{ + Monoid_1 data[1]; +}; + +struct ParentBuf +{ + Monoid_1 parent[1]; +}; + +constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u); + +static inline __attribute__((always_inline)) +Monoid combine_monoid(thread const Monoid& a, thread const Monoid& b) +{ + return Monoid{ a.element + b.element }; +} + +kernel void main0(device DataBuf& _42 [[buffer(0)]], device ParentBuf& _143 [[buffer(1)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]]) +{ + threadgroup Monoid sh_scratch[512]; + uint ix = gl_GlobalInvocationID.x * 8u; + spvUnsafeArray local; + local[0].element = _42.data[ix].element; + Monoid param_1; + for (uint i = 1u; i < 8u; i++) + { + Monoid param = local[i - 1u]; + param_1.element = _42.data[ix + i].element; + local[i] = combine_monoid(param, param_1); + } + Monoid agg = local[7]; + sh_scratch[gl_LocalInvocationID.x] = agg; + for (uint i_1 = 0u; i_1 < 9u; i_1++) + { + threadgroup_barrier(mem_flags::mem_threadgroup); + if (gl_LocalInvocationID.x >= uint(1 << int(i_1))) + { + Monoid other = sh_scratch[gl_LocalInvocationID.x - uint(1 << int(i_1))]; + Monoid param_2 = other; + Monoid param_3 = agg; + agg = combine_monoid(param_2, param_3); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + sh_scratch[gl_LocalInvocationID.x] = agg; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + Monoid row = Monoid{ 0u }; + if (gl_WorkGroupID.x > 0u) + { + row.element = _143.parent[gl_WorkGroupID.x - 1u].element; + } + if (gl_LocalInvocationID.x > 0u) + { + Monoid param_4 = row; + Monoid param_5 = sh_scratch[gl_LocalInvocationID.x - 1u]; + row = combine_monoid(param_4, param_5); + } + for (uint i_2 = 0u; i_2 < 8u; i_2++) + { + Monoid param_6 = row; + Monoid param_7 = local[i_2]; + Monoid m = combine_monoid(param_6, param_7); + _42.data[ix + i_2].element = m.element; + } +} + diff --git a/tests/shader/gen/prefix_scan.spv b/tests/shader/gen/prefix_scan.spv new file mode 100644 index 0000000..d4216e9 Binary files /dev/null and b/tests/shader/gen/prefix_scan.spv differ diff --git a/tests/shader/prefix.comp b/tests/shader/prefix.comp new file mode 100644 index 0000000..ed5bcbc --- /dev/null +++ b/tests/shader/prefix.comp @@ -0,0 +1,188 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT OR Unlicense + +// A prefix sum. + +#version 450 + +#define N_ROWS 16 +#define LG_WG_SIZE 9 +#define WG_SIZE (1 << LG_WG_SIZE) +#define PARTITION_SIZE (WG_SIZE * N_ROWS) + +layout(local_size_x = WG_SIZE, local_size_y = 1) in; + +struct Monoid { + uint element; +}; + +layout(set = 0, binding = 0) readonly buffer InBuf { + Monoid[] inbuf; +}; + +layout(set = 0, binding = 1) buffer OutBuf { + Monoid[] outbuf; +}; + +// These correspond to X, A, P respectively in the prefix sum paper. +#define FLAG_NOT_READY 0 +#define FLAG_AGGREGATE_READY 1 +#define FLAG_PREFIX_READY 2 + +struct State { + uint flag; + Monoid aggregate; + Monoid prefix; +}; + +layout(set = 0, binding = 2) volatile buffer StateBuf { + uint part_counter; + State[] state; +}; + +shared Monoid sh_scratch[WG_SIZE]; + +Monoid combine_monoid(Monoid a, Monoid b) { + return Monoid(a.element + b.element); +} + +shared uint sh_part_ix; +shared Monoid sh_prefix; +shared uint sh_flag; + +void main() { + Monoid local[N_ROWS]; + // Determine partition to process by atomic counter (described in Section + // 4.4 of prefix sum paper). + if (gl_LocalInvocationID.x == 0) { + sh_part_ix = atomicAdd(part_counter, 1); + } + barrier(); + uint part_ix = sh_part_ix; + + uint ix = part_ix * PARTITION_SIZE + gl_LocalInvocationID.x * N_ROWS; + + // TODO: gate buffer read? (evaluate whether shader check or + // CPU-side padding is better) + local[0] = inbuf[ix]; + for (uint i = 1; i < N_ROWS; i++) { + local[i] = combine_monoid(local[i - 1], inbuf[ix + i]); + } + Monoid agg = local[N_ROWS - 1]; + sh_scratch[gl_LocalInvocationID.x] = agg; + for (uint i = 0; i < LG_WG_SIZE; i++) { + barrier(); + if (gl_LocalInvocationID.x >= (1 << i)) { + Monoid other = sh_scratch[gl_LocalInvocationID.x - (1 << i)]; + agg = combine_monoid(other, agg); + } + barrier(); + sh_scratch[gl_LocalInvocationID.x] = agg; + } + + // Publish aggregate for this partition + if (gl_LocalInvocationID.x == WG_SIZE - 1) { + state[part_ix].aggregate = agg; + if (part_ix == 0) { + state[0].prefix = agg; + } + } + // Write flag with release semantics; this is done portably with a barrier. + memoryBarrierBuffer(); + if (gl_LocalInvocationID.x == WG_SIZE - 1) { + uint flag = FLAG_AGGREGATE_READY; + if (part_ix == 0) { + flag = FLAG_PREFIX_READY; + } + state[part_ix].flag = flag; + } + + Monoid exclusive = Monoid(0); + if (part_ix != 0) { + // step 4 of paper: decoupled lookback + uint look_back_ix = part_ix - 1; + + Monoid their_agg; + uint their_ix = 0; + while (true) { + // Read flag with acquire semantics. + if (gl_LocalInvocationID.x == WG_SIZE - 1) { + sh_flag = state[look_back_ix].flag; + } + // The flag load is done only in the last thread. However, because the + // translation of memoryBarrierBuffer to Metal requires uniform control + // flow, we broadcast it to all threads. + barrier(); + memoryBarrierBuffer(); + uint flag = sh_flag; + + if (flag == FLAG_PREFIX_READY) { + if (gl_LocalInvocationID.x == WG_SIZE - 1) { + Monoid their_prefix = state[look_back_ix].prefix; + exclusive = combine_monoid(their_prefix, exclusive); + } + break; + } else if (flag == FLAG_AGGREGATE_READY) { + if (gl_LocalInvocationID.x == WG_SIZE - 1) { + their_agg = state[look_back_ix].aggregate; + exclusive = combine_monoid(their_agg, exclusive); + } + look_back_ix--; + their_ix = 0; + continue; + } + // else spin + + if (gl_LocalInvocationID.x == WG_SIZE - 1) { + // Unfortunately there's no guarantee of forward progress of other + // workgroups, so compute a bit of the aggregate before trying again. + // In the worst case, spinning stops when the aggregate is complete. + Monoid m = inbuf[look_back_ix * PARTITION_SIZE + their_ix]; + if (their_ix == 0) { + their_agg = m; + } else { + their_agg = combine_monoid(their_agg, m); + } + their_ix++; + if (their_ix == PARTITION_SIZE) { + exclusive = combine_monoid(their_agg, exclusive); + if (look_back_ix == 0) { + sh_flag = FLAG_PREFIX_READY; + } else { + look_back_ix--; + their_ix = 0; + } + } + } + barrier(); + flag = sh_flag; + if (flag == FLAG_PREFIX_READY) { + break; + } + } + // step 5 of paper: compute inclusive prefix + if (gl_LocalInvocationID.x == WG_SIZE - 1) { + Monoid inclusive_prefix = combine_monoid(exclusive, agg); + sh_prefix = exclusive; + state[part_ix].prefix = inclusive_prefix; + } + memoryBarrierBuffer(); + if (gl_LocalInvocationID.x == WG_SIZE - 1) { + state[part_ix].flag = FLAG_PREFIX_READY; + } + } + barrier(); + if (part_ix != 0) { + exclusive = sh_prefix; + } + + Monoid row = exclusive; + if (gl_LocalInvocationID.x > 0) { + Monoid other = sh_scratch[gl_LocalInvocationID.x - 1]; + row = combine_monoid(row, other); + } + for (uint i = 0; i < N_ROWS; i++) { + Monoid m = combine_monoid(row, local[i]); + // Make sure buffer allocation is padded appropriately. + outbuf[ix + i] = m; + } +} diff --git a/tests/shader/prefix_reduce.comp b/tests/shader/prefix_reduce.comp new file mode 100644 index 0000000..378da88 --- /dev/null +++ b/tests/shader/prefix_reduce.comp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT OR Unlicense + +// The reduction phase for prefix sum implemented as a tree reduction. + +#version 450 + +#define N_ROWS 8 +#define LG_WG_SIZE 9 +#define WG_SIZE (1 << LG_WG_SIZE) +#define PARTITION_SIZE (WG_SIZE * N_ROWS) + +layout(local_size_x = WG_SIZE, local_size_y = 1) in; + +struct Monoid { + uint element; +}; + +layout(set = 0, binding = 0) readonly buffer InBuf { + Monoid[] inbuf; +}; + +layout(set = 0, binding = 1) buffer OutBuf { + Monoid[] outbuf; +}; + +shared Monoid sh_scratch[WG_SIZE]; + +Monoid combine_monoid(Monoid a, Monoid b) { + return Monoid(a.element + b.element); +} + +void main() { + uint ix = gl_GlobalInvocationID.x * N_ROWS; + // TODO: gate buffer read + Monoid agg = inbuf[ix]; + for (uint i = 1; i < N_ROWS; i++) { + agg = combine_monoid(agg, inbuf[ix + i]); + } + sh_scratch[gl_LocalInvocationID.x] = agg; + for (uint i = 0; i < LG_WG_SIZE; i++) { + barrier(); + // We could make this predicate tighter, but would it help? + if (gl_LocalInvocationID.x + (1 << i) < WG_SIZE) { + Monoid other = sh_scratch[gl_LocalInvocationID.x + (1 << i)]; + agg = combine_monoid(agg, other); + } + barrier(); + sh_scratch[gl_LocalInvocationID.x] = agg; + } + if (gl_LocalInvocationID.x == 0) { + outbuf[gl_WorkGroupID.x] = agg; + } +} diff --git a/tests/shader/prefix_scan.comp b/tests/shader/prefix_scan.comp new file mode 100644 index 0000000..59903ab --- /dev/null +++ b/tests/shader/prefix_scan.comp @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT OR Unlicense + +// A scan for a tree reduction prefix scan (either root or not, by ifdef). + +#version 450 + +#define N_ROWS 8 +#define LG_WG_SIZE 9 +#define WG_SIZE (1 << LG_WG_SIZE) +#define PARTITION_SIZE (WG_SIZE * N_ROWS) + +layout(local_size_x = WG_SIZE, local_size_y = 1) in; + +struct Monoid { + uint element; +}; + +layout(set = 0, binding = 0) buffer DataBuf { + Monoid[] data; +}; + +#ifndef ROOT +layout(set = 0, binding = 1) buffer ParentBuf { + Monoid[] parent; +}; +#endif + +shared Monoid sh_scratch[WG_SIZE]; + +Monoid combine_monoid(Monoid a, Monoid b) { + return Monoid(a.element + b.element); +} + +void main() { + Monoid local[N_ROWS]; + + uint ix = gl_GlobalInvocationID.x * N_ROWS; + + // TODO: gate buffer read + local[0] = data[ix]; + for (uint i = 1; i < N_ROWS; i++) { + local[i] = combine_monoid(local[i - 1], data[ix + i]); + } + Monoid agg = local[N_ROWS - 1]; + sh_scratch[gl_LocalInvocationID.x] = agg; + for (uint i = 0; i < LG_WG_SIZE; i++) { + barrier(); + if (gl_LocalInvocationID.x >= (1 << i)) { + Monoid other = sh_scratch[gl_LocalInvocationID.x - (1 << i)]; + agg = combine_monoid(other, agg); + } + barrier(); + sh_scratch[gl_LocalInvocationID.x] = agg; + } + + barrier(); + // This could be a semigroup instead of a monoid if we reworked the + // conditional logic, but that might impact performance. + Monoid row = Monoid(0); +#ifdef ROOT + if (gl_LocalInvocationID.x > 0) { + row = sh_scratch[gl_LocalInvocationID.x - 1]; + } +#else + if (gl_WorkGroupID.x > 0) { + row = parent[gl_WorkGroupID.x - 1]; + } + if (gl_LocalInvocationID.x > 0) { + row = combine_monoid(row, sh_scratch[gl_LocalInvocationID.x - 1]); + } +#endif + for (uint i = 0; i < N_ROWS; i++) { + Monoid m = combine_monoid(row, local[i]); + // TODO: gate buffer write + data[ix + i] = m; + } +} diff --git a/tests/src/config.rs b/tests/src/config.rs new file mode 100644 index 0000000..50bd3be --- /dev/null +++ b/tests/src/config.rs @@ -0,0 +1,72 @@ +// Copyright 2021 The piet-gpu authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Also licensed under MIT license, at your choice. + +//! Test config parameters. + +use clap::ArgMatches; + +pub struct Config { + pub groups: Groups, + pub size: Size, +} + +pub struct Groups(String); + +pub enum Size { + Small, + Medium, + Large, +} + +impl Config { + pub fn from_matches(matches: &ArgMatches) -> Config { + let groups = Groups::from_str(matches.value_of("groups").unwrap_or("all")); + let size = Size::from_str(matches.value_of("size").unwrap_or("m")); + Config { + groups, size + } + } +} + +impl Groups { + pub fn from_str(s: &str) -> Groups { + Groups(s.to_string()) + } + + pub fn matches(&self, group_name: &str) -> bool { + self.0 == "all" || self.0 == group_name + } +} + +impl Size { + fn from_str(s: &str) -> Size { + if s == "small" || s == "s" { + Size::Small + } else if s == "large" || s == "l" { + Size::Large + } else { + Size::Medium + } + } + + pub fn choose(&self, small: T, medium: T, large: T) -> T { + match self { + Size::Small => small, + Size::Medium => medium, + Size::Large => large, + } + } +} diff --git a/tests/src/main.rs b/tests/src/main.rs new file mode 100644 index 0000000..b7bc1d9 --- /dev/null +++ b/tests/src/main.rs @@ -0,0 +1,77 @@ +// Copyright 2021 The piet-gpu authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Also licensed under MIT license, at your choice. + +//! Tests for piet-gpu shaders and GPU capabilities. + +mod config; +mod prefix; +mod prefix_tree; +mod runner; +mod test_result; + +use clap::{App, Arg}; + +use crate::config::Config; +use crate::runner::Runner; +use crate::test_result::{ReportStyle, TestResult}; + +fn main() { + let matches = App::new("piet-gpu-tests") + .arg( + Arg::with_name("verbose") + .short("v") + .long("verbose") + .help("Verbose reporting of results"), + ) + .arg( + Arg::with_name("groups") + .short("g") + .long("groups") + .help("Groups to run") + .takes_value(true) + ) + .arg( + Arg::with_name("size") + .short("s") + .long("size") + .help("Size of tests") + .takes_value(true) + ) + .arg( + Arg::with_name("n_iter") + .short("n") + .long("n_iter") + .help("Number of iterations") + .takes_value(true) + ) + .get_matches(); + let style = if matches.is_present("verbose") { + ReportStyle::Verbose + } else { + ReportStyle::Short + }; + let config = Config::from_matches(&matches); + unsafe { + let report = |test_result: &TestResult| { + test_result.report(style); + }; + let mut runner = Runner::new(); + if config.groups.matches("prefix") { + report(&prefix::run_prefix_test(&mut runner, &config)); + report(&prefix_tree::run_prefix_test(&mut runner, &config)); + } + } +} diff --git a/tests/src/prefix.rs b/tests/src/prefix.rs new file mode 100644 index 0000000..adc58b4 --- /dev/null +++ b/tests/src/prefix.rs @@ -0,0 +1,149 @@ +// Copyright 2021 The piet-gpu authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Also licensed under MIT license, at your choice. + +use piet_gpu_hal::{include_shader, BufferUsage, DescriptorSet}; +use piet_gpu_hal::{Buffer, Pipeline}; + +use crate::config::Config; +use crate::runner::{Commands, Runner}; +use crate::test_result::TestResult; + +const WG_SIZE: u64 = 512; +const N_ROWS: u64 = 16; +const ELEMENTS_PER_WG: u64 = WG_SIZE * N_ROWS; + +/// The shader code for the prefix sum example. +/// +/// A code struct can be created once and reused any number of times. +struct PrefixCode { + pipeline: Pipeline, +} + +/// The stage resources for the prefix sum example. +/// +/// A stage resources struct is specific to a particular problem size +/// and queue. +struct PrefixStage { + // This is the actual problem size but perhaps it would be better to + // treat it as a capacity. + n_elements: u64, + state_buf: Buffer, +} + +/// The binding for the prefix sum example. +struct PrefixBinding { + descriptor_set: DescriptorSet, +} + +pub unsafe fn run_prefix_test(runner: &mut Runner, config: &Config) -> TestResult { + let mut result = TestResult::new("prefix sum, decoupled look-back"); + // This will be configurable. + let n_elements: u64 = config.size.choose(1 << 12, 1 << 24, 1 << 25); + let data: Vec = (0..n_elements as u32).collect(); + let data_buf = runner + .session + .create_buffer_init(&data, BufferUsage::STORAGE) + .unwrap(); + let out_buf = runner.buf_down(data_buf.size()); + let code = PrefixCode::new(runner); + let stage = PrefixStage::new(runner, n_elements); + let binding = stage.bind(runner, &code, &data_buf, &out_buf.dev_buf); + // Also will be configurable of course. + let n_iter = 1000; + let mut total_elapsed = 0.0; + for i in 0..n_iter { + let mut commands = runner.commands(); + commands.write_timestamp(0); + stage.record(&mut commands, &code, &binding); + commands.write_timestamp(1); + if i == 0 { + commands.cmd_buf.memory_barrier(); + commands.download(&out_buf); + } + total_elapsed += runner.submit(commands); + if i == 0 { + let mut dst: Vec = Default::default(); + out_buf.read(&mut dst); + if let Some(failure) = verify(&dst) { + result.fail(format!("failure at {}", failure)); + } + } + } + result.timing(total_elapsed, n_elements * n_iter); + result +} + +impl PrefixCode { + unsafe fn new(runner: &mut Runner) -> PrefixCode { + let code = include_shader!(&runner.session, "../shader/gen/prefix"); + let pipeline = runner + .session + .create_simple_compute_pipeline(code, 3) + .unwrap(); + PrefixCode { pipeline } + } +} + +impl PrefixStage { + unsafe fn new(runner: &mut Runner, n_elements: u64) -> PrefixStage { + let n_workgroups = (n_elements + ELEMENTS_PER_WG - 1) / ELEMENTS_PER_WG; + let state_buf_size = 4 + 12 * n_workgroups; + let state_buf = runner + .session + .create_buffer(state_buf_size, BufferUsage::STORAGE | BufferUsage::COPY_DST) + .unwrap(); + PrefixStage { + n_elements, + state_buf, + } + } + + unsafe fn bind( + &self, + runner: &mut Runner, + code: &PrefixCode, + in_buf: &Buffer, + out_buf: &Buffer, + ) -> PrefixBinding { + let descriptor_set = runner + .session + .create_simple_descriptor_set(&code.pipeline, &[in_buf, out_buf, &self.state_buf]) + .unwrap(); + PrefixBinding { descriptor_set } + } + + unsafe fn record(&self, commands: &mut Commands, code: &PrefixCode, bindings: &PrefixBinding) { + let n_workgroups = (self.n_elements + ELEMENTS_PER_WG - 1) / ELEMENTS_PER_WG; + commands.cmd_buf.clear_buffer(&self.state_buf, None); + commands.cmd_buf.memory_barrier(); + commands.cmd_buf.dispatch( + &code.pipeline, + &bindings.descriptor_set, + (n_workgroups as u32, 1, 1), + (WG_SIZE as u32, 1, 1), + ); + // One thing that's missing here is registering the buffers so + // they can be safely dropped by Rust code before the execution + // of the command buffer completes. + } +} + +// Verify that the data is OEIS A000217 +fn verify(data: &[u32]) -> Option { + data.iter() + .enumerate() + .position(|(i, val)| ((i * (i + 1)) / 2) as u32 != *val) +} diff --git a/tests/src/prefix_tree.rs b/tests/src/prefix_tree.rs new file mode 100644 index 0000000..1f78202 --- /dev/null +++ b/tests/src/prefix_tree.rs @@ -0,0 +1,211 @@ +// Copyright 2021 The piet-gpu authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Also licensed under MIT license, at your choice. + +use piet_gpu_hal::{include_shader, BufferUsage, DescriptorSet}; +use piet_gpu_hal::{Buffer, Pipeline}; + +use crate::config::Config; +use crate::runner::{Commands, Runner}; +use crate::test_result::TestResult; + +const WG_SIZE: u64 = 512; +const N_ROWS: u64 = 8; +const ELEMENTS_PER_WG: u64 = WG_SIZE * N_ROWS; + +struct PrefixTreeCode { + reduce_pipeline: Pipeline, + scan_pipeline: Pipeline, + root_pipeline: Pipeline, +} + +struct PrefixTreeStage { + sizes: Vec, + tmp_bufs: Vec, +} + +struct PrefixTreeBinding { + // All but the first and last can be moved to stage. + descriptor_sets: Vec, +} + +pub unsafe fn run_prefix_test(runner: &mut Runner, config: &Config) -> TestResult { + let mut result = TestResult::new("prefix sum, tree reduction"); + // This will be configurable. Note though that the current code is + // prone to reading and writing past the end of buffers if this is + // not a power of the number of elements processed in a workgroup. + let n_elements: u64 = config.size.choose(1 << 12, 1 << 24, 1 << 24); + let data: Vec = (0..n_elements as u32).collect(); + let data_buf = runner + .session + .create_buffer_init(&data, BufferUsage::STORAGE) + .unwrap(); + let out_buf = runner.buf_down(data_buf.size()); + let code = PrefixTreeCode::new(runner); + let stage = PrefixTreeStage::new(runner, n_elements); + let binding = stage.bind(runner, &code, &out_buf.dev_buf); + // Also will be configurable of course. + let n_iter = 1000; + let mut total_elapsed = 0.0; + for i in 0..n_iter { + let mut commands = runner.commands(); + commands.cmd_buf.copy_buffer(&data_buf, &out_buf.dev_buf); + commands.cmd_buf.memory_barrier(); + commands.write_timestamp(0); + stage.record(&mut commands, &code, &binding); + commands.write_timestamp(1); + if i == 0 { + commands.cmd_buf.memory_barrier(); + commands.download(&out_buf); + } + total_elapsed += runner.submit(commands); + if i == 0 { + let mut dst: Vec = Default::default(); + out_buf.read(&mut dst); + if let Some(failure) = verify(&dst) { + result.fail(format!("failure at {}", failure)); + } + } + } + result.timing(total_elapsed, n_elements * n_iter); + result +} + +impl PrefixTreeCode { + unsafe fn new(runner: &mut Runner) -> PrefixTreeCode { + let reduce_code = include_shader!(&runner.session, "../shader/gen/prefix_reduce"); + let reduce_pipeline = runner + .session + .create_simple_compute_pipeline(reduce_code, 2) + .unwrap(); + let scan_code = include_shader!(&runner.session, "../shader/gen/prefix_scan"); + let scan_pipeline = runner + .session + .create_simple_compute_pipeline(scan_code, 2) + .unwrap(); + let root_code = include_shader!(&runner.session, "../shader/gen/prefix_root"); + let root_pipeline = runner + .session + .create_simple_compute_pipeline(root_code, 1) + .unwrap(); + PrefixTreeCode { + reduce_pipeline, + scan_pipeline, + root_pipeline, + } + } +} + +impl PrefixTreeStage { + unsafe fn new(runner: &mut Runner, n_elements: u64) -> PrefixTreeStage { + let mut size = n_elements; + let mut sizes = vec![size]; + let mut tmp_bufs = Vec::new(); + while size > ELEMENTS_PER_WG { + size = (size + ELEMENTS_PER_WG - 1) / ELEMENTS_PER_WG; + sizes.push(size); + let buf = runner + .session + .create_buffer(4 * size, BufferUsage::STORAGE) + .unwrap(); + tmp_bufs.push(buf); + } + PrefixTreeStage { sizes, tmp_bufs } + } + + unsafe fn bind( + &self, + runner: &mut Runner, + code: &PrefixTreeCode, + data_buf: &Buffer, + ) -> PrefixTreeBinding { + let mut descriptor_sets = Vec::with_capacity(2 * self.tmp_bufs.len() + 1); + for i in 0..self.tmp_bufs.len() { + let buf0 = if i == 0 { + data_buf + } else { + &self.tmp_bufs[i - 1] + }; + let buf1 = &self.tmp_bufs[i]; + let descriptor_set = runner + .session + .create_simple_descriptor_set(&code.reduce_pipeline, &[buf0, buf1]) + .unwrap(); + descriptor_sets.push(descriptor_set); + } + let buf0 = self.tmp_bufs.last().unwrap_or(data_buf); + let descriptor_set = runner + .session + .create_simple_descriptor_set(&code.root_pipeline, &[buf0]) + .unwrap(); + descriptor_sets.push(descriptor_set); + for i in (0..self.tmp_bufs.len()).rev() { + let buf0 = if i == 0 { + data_buf + } else { + &self.tmp_bufs[i - 1] + }; + let buf1 = &self.tmp_bufs[i]; + let descriptor_set = runner + .session + .create_simple_descriptor_set(&code.scan_pipeline, &[buf0, buf1]) + .unwrap(); + descriptor_sets.push(descriptor_set); + } + PrefixTreeBinding { descriptor_sets } + } + + unsafe fn record( + &self, + commands: &mut Commands, + code: &PrefixTreeCode, + bindings: &PrefixTreeBinding, + ) { + let n = self.tmp_bufs.len(); + for i in 0..n { + let n_workgroups = self.sizes[i + 1]; + commands.cmd_buf.dispatch( + &code.reduce_pipeline, + &bindings.descriptor_sets[i], + (n_workgroups as u32, 1, 1), + (WG_SIZE as u32, 1, 1), + ); + commands.cmd_buf.memory_barrier(); + } + commands.cmd_buf.dispatch( + &code.root_pipeline, + &bindings.descriptor_sets[n], + (1, 1, 1), + (WG_SIZE as u32, 1, 1), + ); + for i in (0..n).rev() { + commands.cmd_buf.memory_barrier(); + let n_workgroups = self.sizes[i + 1]; + commands.cmd_buf.dispatch( + &code.scan_pipeline, + &bindings.descriptor_sets[2 * n - i], + (n_workgroups as u32, 1, 1), + (WG_SIZE as u32, 1, 1), + ); + } + } +} + +// Verify that the data is OEIS A000217 +fn verify(data: &[u32]) -> Option { + data.iter() + .enumerate() + .position(|(i, val)| ((i * (i + 1)) / 2) as u32 != *val) +} diff --git a/tests/src/runner.rs b/tests/src/runner.rs new file mode 100644 index 0000000..9bfde3b --- /dev/null +++ b/tests/src/runner.rs @@ -0,0 +1,138 @@ +// Copyright 2021 The piet-gpu authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Also licensed under MIT license, at your choice. + +//! Test runner intended to make it easy to write tests. + +use piet_gpu_hal::{Buffer, BufferUsage, CmdBuf, Instance, PlainData, QueryPool, Session}; + +pub struct Runner { + #[allow(unused)] + instance: Instance, + pub session: Session, + cmd_buf_pool: Vec, +} + +/// A wrapper around command buffers. +pub struct Commands { + pub cmd_buf: CmdBuf, + query_pool: QueryPool, +} + +/// Buffer for uploading data to GPU. +#[allow(unused)] +pub struct BufUp { + pub stage_buf: Buffer, + pub dev_buf: Buffer, +} + +/// Buffer for downloading data from GPU. +pub struct BufDown { + pub stage_buf: Buffer, + pub dev_buf: Buffer, +} + +impl Runner { + pub unsafe fn new() -> Runner { + let (instance, _) = Instance::new(None).unwrap(); + let device = instance.device(None).unwrap(); + let session = Session::new(device); + let cmd_buf_pool = Vec::new(); + Runner { + instance, + session, + cmd_buf_pool, + } + } + + pub unsafe fn commands(&mut self) -> Commands { + let mut cmd_buf = self + .cmd_buf_pool + .pop() + .unwrap_or_else(|| self.session.cmd_buf().unwrap()); + cmd_buf.begin(); + // TODO: also pool these. But we might sort by size, as + // we might not always be doing two. + let query_pool = self.session.create_query_pool(2).unwrap(); + cmd_buf.reset_query_pool(&query_pool); + Commands { + cmd_buf, + query_pool, + } + } + + pub unsafe fn submit(&mut self, commands: Commands) -> f64 { + let mut cmd_buf = commands.cmd_buf; + let query_pool = commands.query_pool; + cmd_buf.finish_timestamps(&query_pool); + cmd_buf.host_barrier(); + cmd_buf.finish(); + let submitted = self.session.run_cmd_buf(cmd_buf, &[], &[]).unwrap(); + self.cmd_buf_pool.extend(submitted.wait().unwrap()); + let timestamps = self.session.fetch_query_pool(&query_pool).unwrap(); + timestamps[0] + } + + #[allow(unused)] + pub fn buf_up(&self, size: u64) -> BufUp { + let stage_buf = self + .session + .create_buffer(size, BufferUsage::MAP_WRITE | BufferUsage::COPY_SRC) + .unwrap(); + let dev_buf = self + .session + .create_buffer(size, BufferUsage::COPY_DST | BufferUsage::STORAGE) + .unwrap(); + BufUp { stage_buf, dev_buf } + } + + pub fn buf_down(&self, size: u64) -> BufDown { + let stage_buf = self + .session + .create_buffer(size, BufferUsage::MAP_READ | BufferUsage::COPY_DST) + .unwrap(); + // Note: the COPY_DST isn't needed in all use cases, but I don't think + // making this tighter would help. + let dev_buf = self + .session + .create_buffer( + size, + BufferUsage::COPY_SRC | BufferUsage::COPY_DST | BufferUsage::STORAGE, + ) + .unwrap(); + BufDown { stage_buf, dev_buf } + } +} + +impl Commands { + pub unsafe fn write_timestamp(&mut self, query: u32) { + self.cmd_buf.write_timestamp(&self.query_pool, query); + } + + #[allow(unused)] + pub unsafe fn upload(&mut self, buf: &BufUp) { + self.cmd_buf.copy_buffer(&buf.stage_buf, &buf.dev_buf); + } + + pub unsafe fn download(&mut self, buf: &BufDown) { + self.cmd_buf.copy_buffer(&buf.dev_buf, &buf.stage_buf); + } +} + +impl BufDown { + pub unsafe fn read(&self, dst: &mut Vec) { + self.stage_buf.read(dst).unwrap() + } +} diff --git a/tests/src/test_result.rs b/tests/src/test_result.rs new file mode 100644 index 0000000..84bbc85 --- /dev/null +++ b/tests/src/test_result.rs @@ -0,0 +1,110 @@ +// Copyright 2021 The piet-gpu authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Also licensed under MIT license, at your choice. + +//! Recording of results from tests. + +pub struct TestResult { + name: String, + // TODO: statistics. We're lean and mean for now. + total_time: f64, + n_elements: u64, + failure: Option, +} + +#[derive(Clone, Copy)] +pub enum ReportStyle { + Short, + Verbose, +} + +impl TestResult { + pub fn new(name: &str) -> TestResult { + TestResult { + name: name.to_string(), + total_time: 0.0, + n_elements: 0, + failure: None, + } + } + + pub fn report(&self, style: ReportStyle) { + let fail_string = match &self.failure { + None => "pass".into(), + Some(s) => format!("fail ({})", s), + }; + match style { + ReportStyle::Short => { + let mut timing_string = String::new(); + if self.total_time > 0.0 { + if self.n_elements > 0 { + let throughput = self.n_elements as f64 / self.total_time; + timing_string = format!(" {} elements/s", format_nice(throughput, 1)); + } else { + timing_string = format!(" {}s", format_nice(self.total_time, 1)); + } + } + println!("{}: {}{}", self.name, fail_string, timing_string) + } + ReportStyle::Verbose => { + println!("test {}", self.name); + println!(" {}", fail_string); + if self.total_time > 0.0 { + println!(" {}s total time", format_nice(self.total_time, 1)); + if self.n_elements > 0 { + println!(" {} elements", self.n_elements); + let throughput = self.n_elements as f64 / self.total_time; + println!(" {} elements/s", format_nice(throughput, 1)); + } + } + } + } + } + + pub fn fail(&mut self, explanation: String) { + self.failure = Some(explanation); + } + + pub fn timing(&mut self, total_time: f64, n_elements: u64) { + self.total_time = total_time; + self.n_elements = n_elements; + } +} + +fn format_nice(x: f64, precision: usize) -> String { + // Precision should probably scale; later + let (scale, suffix) = if x >= 1e12 && x < 1e15 { + (1e-12, "T") + } else if x >= 1e9 { + (1e-9, "G") + } else if x >= 1e6 { + (1e-6, "M") + } else if x >= 1e3 { + (1e-3, "k") + } else if x >= 1.0 { + (1.0, "") + } else if x >= 1e-3 { + (1e3, "m") + } else if x >= 1e-6 { + (1e6, "\u{00b5}") + } else if x >= 1e-9 { + (1e9, "n") + } else if x >= 1e-12 { + (1e12, "p") + } else { + return format!("{:.*e}", precision, x); + }; + format!("{:.*}{}", precision, scale * x, suffix) +}