From 33d7b25a929fb1a97875d27eaa81cfaa84710b3b Mon Sep 17 00:00:00 2001 From: Raph Levien Date: Fri, 5 Nov 2021 16:52:07 -0700 Subject: [PATCH 1/5] Start testing framework This adds a prefix sum test. This patch is also trying to get a little more serious about structuring both the test runner (toward the goal of collecting proper statistics) and pipeline stages for the tests. Still WIP but giving good results. --- Cargo.lock | 15 ++ Cargo.toml | 3 +- tests/Cargo.toml | 14 ++ tests/shader/build.ninja | 19 +++ tests/shader/gen/collatz.hlsl | 43 ++++++ tests/shader/gen/collatz.msl | 46 ++++++ tests/shader/gen/collatz.spv | Bin 0 -> 1748 bytes tests/shader/gen/prefix.hlsl | 223 +++++++++++++++++++++++++++++ tests/shader/gen/prefix.msl | 262 ++++++++++++++++++++++++++++++++++ tests/shader/gen/prefix.spv | Bin 0 -> 9792 bytes tests/shader/prefix.comp | 188 ++++++++++++++++++++++++ tests/src/main.rs | 29 ++++ tests/src/prefix.rs | 142 ++++++++++++++++++ tests/src/runner.rs | 132 +++++++++++++++++ 14 files changed, 1115 insertions(+), 1 deletion(-) create mode 100644 tests/Cargo.toml create mode 100644 tests/shader/build.ninja create mode 100644 tests/shader/gen/collatz.hlsl create mode 100644 tests/shader/gen/collatz.msl create mode 100644 tests/shader/gen/collatz.spv create mode 100644 tests/shader/gen/prefix.hlsl create mode 100644 tests/shader/gen/prefix.msl create mode 100644 tests/shader/gen/prefix.spv create mode 100644 tests/shader/prefix.comp create mode 100644 tests/src/main.rs create mode 100644 tests/src/prefix.rs create mode 100644 tests/src/runner.rs diff --git a/Cargo.lock b/Cargo.lock index 59604da..b612c40 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -85,6 +85,12 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0d8c1fef690941d3e7788d328517591fecc684c084084702d6ff1641e993699a" +[[package]] +name = "bytemuck" +version = "1.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72957246c41db82b8ef88a5486143830adeb8227ef9837740bdec67724cf2c5b" + [[package]] name = "byteorder" version = "1.3.4" @@ -887,6 +893,15 @@ dependencies = [ "wio", ] +[[package]] +name = "piet-gpu-tests" +version = "0.1.0" +dependencies = [ + "bytemuck", + "clap", + "piet-gpu-hal", +] + [[package]] name = "piet-gpu-types" version = "0.0.0" diff --git a/Cargo.toml b/Cargo.toml index f71f2de..bfa0030 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,5 +4,6 @@ members = [ "piet-gpu", "piet-gpu-derive", "piet-gpu-hal", - "piet-gpu-types" + "piet-gpu-types", + "tests" ] diff --git a/tests/Cargo.toml b/tests/Cargo.toml new file mode 100644 index 0000000..a987c9e --- /dev/null +++ b/tests/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "piet-gpu-tests" +version = "0.1.0" +authors = ["Raph Levien "] +description = "Tests for piet-gpu shaders and generic GPU capabilities." +license = "MIT/Apache-2.0" +edition = "2021" + +[dependencies] +clap = "2.33" +bytemuck = "1.7.2" + +[dependencies.piet-gpu-hal] +path = "../piet-gpu-hal" diff --git a/tests/shader/build.ninja b/tests/shader/build.ninja new file mode 100644 index 0000000..c5ecb06 --- /dev/null +++ b/tests/shader/build.ninja @@ -0,0 +1,19 @@ +# Build file for shaders. + +# You must have Vulkan tools in your path, or patch here. + +glslang_validator = glslangValidator +spirv_cross = spirv-cross + +rule glsl + command = $glslang_validator -V -o $out $in + +rule hlsl + command = $spirv_cross --hlsl $in --output $out + +rule msl + command = $spirv_cross --msl $in --output $out + +build gen/prefix.spv: glsl prefix.comp +build gen/prefix.hlsl: hlsl gen/prefix.spv +build gen/prefix.msl: msl gen/prefix.spv diff --git a/tests/shader/gen/collatz.hlsl b/tests/shader/gen/collatz.hlsl new file mode 100644 index 0000000..2f4861f --- /dev/null +++ b/tests/shader/gen/collatz.hlsl @@ -0,0 +1,43 @@ +static const uint3 gl_WorkGroupSize = uint3(1u, 1u, 1u); + +RWByteAddressBuffer _53 : register(u1); +ByteAddressBuffer _59 : register(t0); + +static uint3 gl_GlobalInvocationID; +struct SPIRV_Cross_Input +{ + uint3 gl_GlobalInvocationID : SV_DispatchThreadID; +}; + +uint collatz_iterations(inout uint n) +{ + uint i = 0u; + while (n != 1u) + { + if ((n % 2u) == 0u) + { + n /= 2u; + } + else + { + n = (3u * n) + 1u; + } + i++; + } + return i; +} + +void comp_main() +{ + uint index = gl_GlobalInvocationID.x; + uint param = _59.Load(index * 4 + 0); + uint _65 = collatz_iterations(param); + _53.Store(index * 4 + 0, _65); +} + +[numthreads(1, 1, 1)] +void main(SPIRV_Cross_Input stage_input) +{ + gl_GlobalInvocationID = stage_input.gl_GlobalInvocationID; + comp_main(); +} diff --git a/tests/shader/gen/collatz.msl b/tests/shader/gen/collatz.msl new file mode 100644 index 0000000..87cc7b5 --- /dev/null +++ b/tests/shader/gen/collatz.msl @@ -0,0 +1,46 @@ +#pragma clang diagnostic ignored "-Wmissing-prototypes" + +#include +#include + +using namespace metal; + +struct OutBuf +{ + uint out_buf[1]; +}; + +struct InBuf +{ + uint in_buf[1]; +}; + +constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1u); + +static inline __attribute__((always_inline)) +uint collatz_iterations(thread uint& n) +{ + uint i = 0u; + while (n != 1u) + { + if ((n % 2u) == 0u) + { + n /= 2u; + } + else + { + n = (3u * n) + 1u; + } + i++; + } + return i; +} + +kernel void main0(device OutBuf& _53 [[buffer(0)]], const device InBuf& _59 [[buffer(1)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]]) +{ + uint index = gl_GlobalInvocationID.x; + uint param = _59.in_buf[index]; + uint _65 = collatz_iterations(param); + _53.out_buf[index] = _65; +} + diff --git a/tests/shader/gen/collatz.spv b/tests/shader/gen/collatz.spv new file mode 100644 index 0000000000000000000000000000000000000000..638894c5bd2d759593614861db3ddde16d088441 GIT binary patch literal 1748 zcmYk6*-jKe6ozXy96?zG#0{Kb6iL8Qlto0)L6dP35-)s!Nwa7sX}S&c7~_=}#>7|A zxA39d5{>^?Qzv#MC+9!gU#Cu0cd6Jo<6OZFyQ+KU25Z6<<#BGr4SC((d%4#<%?_KZ zYj?>QbtRuj&X{s4(q+kU)Jq67CFAnrn~;!aQc`VE_BS~y%E;GM)Wfi12WcEf+52uU zJ35K7UYeXP_E)yZFS)advkCZ|bkF0mTLLGCNAJ8{^*y|fyX`pLkK#`9Ha+l69nQ$U z_oZL-v)%qH=TCsmO8Z%NUrzYf1}>k!DcnhXy3faE^%9@%^VOnX@NWyfiB6(p>9%ZU zG9tMa?0eEf63+iz)wX{_Hmj}u@iA%Y>1AHRGP^F!T(KWW|K}?^>+h2M@SC0I$7lTX zp6ADgp9TL@Y5Kn`nbZw(hjmp`51v*I``09_nruBRFZk=StI}ZnOCkTfFw5*IVa_xq z*6Ey|bFh!))r8rbeQ!!*Q%4+oA=oprmxbAveEPZ_V)(a$&HYeM{7%5s%&N-nYEGCt z!6t82nx62%gUvlzpY*XVVZmc}1P`;vhUcL)HTdAcX8!ny?*vRg%Q^{tb05sgWLcZg7<4@=ifmLE-+{QrJPY|a?#&4gK=s8VJzYoq=}>LJd?ug@l`$WW-m(caW}uk zfjuUHi(0VN&InWcIrPu>g^&0CCGcGl#!rtw)Gz-K7Tmfy2ax-dB(L;bvPO@jYJ@HK?-b5`TTm*<4rIBy6uZ>zs0jBi0=cY%+3 X=1xy&Y`YugXnc#p@cmUB%v<;eX3uOe literal 0 HcmV?d00001 diff --git a/tests/shader/gen/prefix.hlsl b/tests/shader/gen/prefix.hlsl new file mode 100644 index 0000000..c0600e2 --- /dev/null +++ b/tests/shader/gen/prefix.hlsl @@ -0,0 +1,223 @@ +struct Monoid +{ + uint element; +}; + +struct State +{ + uint flag; + Monoid aggregate; + Monoid prefix; +}; + +static const uint3 gl_WorkGroupSize = uint3(512u, 1u, 1u); + +static const Monoid _187 = { 0u }; + +globallycoherent RWByteAddressBuffer _43 : register(u2); +ByteAddressBuffer _67 : register(t0); +RWByteAddressBuffer _374 : register(u1); + +static uint3 gl_LocalInvocationID; +struct SPIRV_Cross_Input +{ + uint3 gl_LocalInvocationID : SV_GroupThreadID; +}; + +groupshared uint sh_part_ix; +groupshared Monoid sh_scratch[512]; +groupshared uint sh_flag; +groupshared Monoid sh_prefix; + +Monoid combine_monoid(Monoid a, Monoid b) +{ + Monoid _22 = { a.element + b.element }; + return _22; +} + +void comp_main() +{ + if (gl_LocalInvocationID.x == 0u) + { + uint _47; + _43.InterlockedAdd(0, 1u, _47); + sh_part_ix = _47; + } + GroupMemoryBarrierWithGroupSync(); + uint part_ix = sh_part_ix; + uint ix = (part_ix * 8192u) + (gl_LocalInvocationID.x * 16u); + Monoid _71; + _71.element = _67.Load(ix * 4 + 0); + Monoid local[16]; + local[0].element = _71.element; + Monoid param_1; + for (uint i = 1u; i < 16u; i++) + { + Monoid param = local[i - 1u]; + Monoid _94; + _94.element = _67.Load((ix + i) * 4 + 0); + param_1.element = _94.element; + local[i] = combine_monoid(param, param_1); + } + Monoid agg = local[15]; + sh_scratch[gl_LocalInvocationID.x] = agg; + for (uint i_1 = 0u; i_1 < 9u; i_1++) + { + GroupMemoryBarrierWithGroupSync(); + if (gl_LocalInvocationID.x >= uint(1 << int(i_1))) + { + Monoid other = sh_scratch[gl_LocalInvocationID.x - uint(1 << int(i_1))]; + Monoid param_2 = other; + Monoid param_3 = agg; + agg = combine_monoid(param_2, param_3); + } + GroupMemoryBarrierWithGroupSync(); + sh_scratch[gl_LocalInvocationID.x] = agg; + } + if (gl_LocalInvocationID.x == 511u) + { + _43.Store(part_ix * 12 + 8, agg.element); + if (part_ix == 0u) + { + _43.Store(12, agg.element); + } + } + DeviceMemoryBarrier(); + if (gl_LocalInvocationID.x == 511u) + { + uint flag = 1u; + if (part_ix == 0u) + { + flag = 2u; + } + _43.Store(part_ix * 12 + 4, flag); + } + Monoid exclusive = _187; + if (part_ix != 0u) + { + uint look_back_ix = part_ix - 1u; + uint their_ix = 0u; + Monoid their_prefix; + Monoid their_agg; + Monoid m; + while (true) + { + if (gl_LocalInvocationID.x == 511u) + { + sh_flag = _43.Load(look_back_ix * 12 + 4); + } + GroupMemoryBarrierWithGroupSync(); + DeviceMemoryBarrier(); + uint flag_1 = sh_flag; + if (flag_1 == 2u) + { + if (gl_LocalInvocationID.x == 511u) + { + Monoid _225; + _225.element = _43.Load(look_back_ix * 12 + 12); + their_prefix.element = _225.element; + Monoid param_4 = their_prefix; + Monoid param_5 = exclusive; + exclusive = combine_monoid(param_4, param_5); + } + break; + } + else + { + if (flag_1 == 1u) + { + if (gl_LocalInvocationID.x == 511u) + { + Monoid _247; + _247.element = _43.Load(look_back_ix * 12 + 8); + their_agg.element = _247.element; + Monoid param_6 = their_agg; + Monoid param_7 = exclusive; + exclusive = combine_monoid(param_6, param_7); + } + look_back_ix--; + their_ix = 0u; + continue; + } + } + if (gl_LocalInvocationID.x == 511u) + { + Monoid _269; + _269.element = _67.Load(((look_back_ix * 8192u) + their_ix) * 4 + 0); + m.element = _269.element; + if (their_ix == 0u) + { + their_agg = m; + } + else + { + Monoid param_8 = their_agg; + Monoid param_9 = m; + their_agg = combine_monoid(param_8, param_9); + } + their_ix++; + if (their_ix == 8192u) + { + Monoid param_10 = their_agg; + Monoid param_11 = exclusive; + exclusive = combine_monoid(param_10, param_11); + if (look_back_ix == 0u) + { + sh_flag = 2u; + } + else + { + look_back_ix--; + their_ix = 0u; + } + } + } + GroupMemoryBarrierWithGroupSync(); + flag_1 = sh_flag; + if (flag_1 == 2u) + { + break; + } + } + if (gl_LocalInvocationID.x == 511u) + { + Monoid param_12 = exclusive; + Monoid param_13 = agg; + Monoid inclusive_prefix = combine_monoid(param_12, param_13); + sh_prefix = exclusive; + _43.Store(part_ix * 12 + 12, inclusive_prefix.element); + } + DeviceMemoryBarrier(); + if (gl_LocalInvocationID.x == 511u) + { + _43.Store(part_ix * 12 + 4, 2u); + } + } + GroupMemoryBarrierWithGroupSync(); + if (part_ix != 0u) + { + exclusive = sh_prefix; + } + Monoid row = exclusive; + if (gl_LocalInvocationID.x > 0u) + { + Monoid other_1 = sh_scratch[gl_LocalInvocationID.x - 1u]; + Monoid param_14 = row; + Monoid param_15 = other_1; + row = combine_monoid(param_14, param_15); + } + for (uint i_2 = 0u; i_2 < 16u; i_2++) + { + Monoid param_16 = row; + Monoid param_17 = local[i_2]; + Monoid m_1 = combine_monoid(param_16, param_17); + _374.Store((ix + i_2) * 4 + 0, m_1.element); + } +} + +[numthreads(512, 1, 1)] +void main(SPIRV_Cross_Input stage_input) +{ + gl_LocalInvocationID = stage_input.gl_LocalInvocationID; + comp_main(); +} diff --git a/tests/shader/gen/prefix.msl b/tests/shader/gen/prefix.msl new file mode 100644 index 0000000..ecdf8bd --- /dev/null +++ b/tests/shader/gen/prefix.msl @@ -0,0 +1,262 @@ +#pragma clang diagnostic ignored "-Wmissing-prototypes" +#pragma clang diagnostic ignored "-Wmissing-braces" +#pragma clang diagnostic ignored "-Wunused-variable" + +#include +#include +#include + +using namespace metal; + +template +struct spvUnsafeArray +{ + T elements[Num ? Num : 1]; + + thread T& operator [] (size_t pos) thread + { + return elements[pos]; + } + constexpr const thread T& operator [] (size_t pos) const thread + { + return elements[pos]; + } + + device T& operator [] (size_t pos) device + { + return elements[pos]; + } + constexpr const device T& operator [] (size_t pos) const device + { + return elements[pos]; + } + + constexpr const constant T& operator [] (size_t pos) const constant + { + return elements[pos]; + } + + threadgroup T& operator [] (size_t pos) threadgroup + { + return elements[pos]; + } + constexpr const threadgroup T& operator [] (size_t pos) const threadgroup + { + return elements[pos]; + } +}; + +struct Monoid +{ + uint element; +}; + +struct Monoid_1 +{ + uint element; +}; + +struct State +{ + uint flag; + Monoid_1 aggregate; + Monoid_1 prefix; +}; + +struct StateBuf +{ + uint part_counter; + State state[1]; +}; + +struct InBuf +{ + Monoid_1 inbuf[1]; +}; + +struct OutBuf +{ + Monoid_1 outbuf[1]; +}; + +constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u); + +static inline __attribute__((always_inline)) +Monoid combine_monoid(thread const Monoid& a, thread const Monoid& b) +{ + return Monoid{ a.element + b.element }; +} + +kernel void main0(volatile device StateBuf& _43 [[buffer(0)]], const device InBuf& _67 [[buffer(1)]], device OutBuf& _374 [[buffer(2)]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]]) +{ + threadgroup uint sh_part_ix; + threadgroup Monoid sh_scratch[512]; + threadgroup uint sh_flag; + threadgroup Monoid sh_prefix; + if (gl_LocalInvocationID.x == 0u) + { + uint _47 = atomic_fetch_add_explicit((volatile device atomic_uint*)&_43.part_counter, 1u, memory_order_relaxed); + sh_part_ix = _47; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + uint part_ix = sh_part_ix; + uint ix = (part_ix * 8192u) + (gl_LocalInvocationID.x * 16u); + spvUnsafeArray local; + local[0].element = _67.inbuf[ix].element; + Monoid param_1; + for (uint i = 1u; i < 16u; i++) + { + Monoid param = local[i - 1u]; + param_1.element = _67.inbuf[ix + i].element; + local[i] = combine_monoid(param, param_1); + } + Monoid agg = local[15]; + sh_scratch[gl_LocalInvocationID.x] = agg; + for (uint i_1 = 0u; i_1 < 9u; i_1++) + { + threadgroup_barrier(mem_flags::mem_threadgroup); + if (gl_LocalInvocationID.x >= uint(1 << int(i_1))) + { + Monoid other = sh_scratch[gl_LocalInvocationID.x - uint(1 << int(i_1))]; + Monoid param_2 = other; + Monoid param_3 = agg; + agg = combine_monoid(param_2, param_3); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + sh_scratch[gl_LocalInvocationID.x] = agg; + } + if (gl_LocalInvocationID.x == 511u) + { + _43.state[part_ix].aggregate.element = agg.element; + if (part_ix == 0u) + { + _43.state[0].prefix.element = agg.element; + } + } + threadgroup_barrier(mem_flags::mem_device); + if (gl_LocalInvocationID.x == 511u) + { + uint flag = 1u; + if (part_ix == 0u) + { + flag = 2u; + } + _43.state[part_ix].flag = flag; + } + Monoid exclusive = Monoid{ 0u }; + if (part_ix != 0u) + { + uint look_back_ix = part_ix - 1u; + uint their_ix = 0u; + Monoid their_prefix; + Monoid their_agg; + Monoid m; + while (true) + { + if (gl_LocalInvocationID.x == 511u) + { + sh_flag = _43.state[look_back_ix].flag; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_barrier(mem_flags::mem_device); + uint flag_1 = sh_flag; + if (flag_1 == 2u) + { + if (gl_LocalInvocationID.x == 511u) + { + their_prefix.element = _43.state[look_back_ix].prefix.element; + Monoid param_4 = their_prefix; + Monoid param_5 = exclusive; + exclusive = combine_monoid(param_4, param_5); + } + break; + } + else + { + if (flag_1 == 1u) + { + if (gl_LocalInvocationID.x == 511u) + { + their_agg.element = _43.state[look_back_ix].aggregate.element; + Monoid param_6 = their_agg; + Monoid param_7 = exclusive; + exclusive = combine_monoid(param_6, param_7); + } + look_back_ix--; + their_ix = 0u; + continue; + } + } + if (gl_LocalInvocationID.x == 511u) + { + m.element = _67.inbuf[(look_back_ix * 8192u) + their_ix].element; + if (their_ix == 0u) + { + their_agg = m; + } + else + { + Monoid param_8 = their_agg; + Monoid param_9 = m; + their_agg = combine_monoid(param_8, param_9); + } + their_ix++; + if (their_ix == 8192u) + { + Monoid param_10 = their_agg; + Monoid param_11 = exclusive; + exclusive = combine_monoid(param_10, param_11); + if (look_back_ix == 0u) + { + sh_flag = 2u; + } + else + { + look_back_ix--; + their_ix = 0u; + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + flag_1 = sh_flag; + if (flag_1 == 2u) + { + break; + } + } + if (gl_LocalInvocationID.x == 511u) + { + Monoid param_12 = exclusive; + Monoid param_13 = agg; + Monoid inclusive_prefix = combine_monoid(param_12, param_13); + sh_prefix = exclusive; + _43.state[part_ix].prefix.element = inclusive_prefix.element; + } + threadgroup_barrier(mem_flags::mem_device); + if (gl_LocalInvocationID.x == 511u) + { + _43.state[part_ix].flag = 2u; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (part_ix != 0u) + { + exclusive = sh_prefix; + } + Monoid row = exclusive; + if (gl_LocalInvocationID.x > 0u) + { + Monoid other_1 = sh_scratch[gl_LocalInvocationID.x - 1u]; + Monoid param_14 = row; + Monoid param_15 = other_1; + row = combine_monoid(param_14, param_15); + } + for (uint i_2 = 0u; i_2 < 16u; i_2++) + { + Monoid param_16 = row; + Monoid param_17 = local[i_2]; + Monoid m_1 = combine_monoid(param_16, param_17); + _374.outbuf[ix + i_2].element = m_1.element; + } +} + diff --git a/tests/shader/gen/prefix.spv b/tests/shader/gen/prefix.spv new file mode 100644 index 0000000000000000000000000000000000000000..170a56967f9c8b7997b688536dd270fe3d1868bf GIT binary patch literal 9792 zcmZ{p2Y8m%702I@g@OzfCt@NR2SZRq1y@7`6bIs7A%rAEl6;taLELERuC;sEWv8uL zYwgCirfRiz+1Bo1wWYR?-K}*{M}NQXdr!D|`*?lMbN=Ul&OP_sbMO1UAF+1K^u{c! z$+pkNX4ltbl`%OR1CwRrvWALYyllm?*@NZPv*#T)$AAf0Z6(f_iNw^S>yWv+rM^`? z{he)nmDmFY%L79lLWQb&95(*E_R za!;v$=?S=rt)BAPy#+3=otqawpnOX*;@l+q=~b z{3mCB+Pk|4I=e^5+N*T~on6&6>$3x~EBTHe>H;#mvBHTf#dVa1`pcaIRbMsU+=F=y z?t6w#hW$PS9h=fO4?FgEEV$Q(eHx=x=v7_JS z-;CY4p`&+buxGvJQco9$(Y>Y8+O}2g9cwEmEMK4956@CP10!5r_6QvHRZ_bCquAT> zH)M~E%!!k$_FxNqTIZ9a?E38Kk$HBaI%n&sn7ZtRQFd+C$l*p$cLAsQcPrS<1$+7^ zyD>WgyQe?zOWWw$j~o?Smn|G+*JsCI`+VeAmCn^>%i#t}mxJrF^T6quT{3EZUAA(R zU7K}aSI@gPTMLdo_Kk|G%Sr{ijD7l0xpGOpquP$rP`PqRt!sVGEv7EtRlWz*ONehl zJEm7vg^Opl$eBj24sot=KCg0P5%V8_Za{ozEy4Erbxc80d)JCxeO~o%TdO+9wbD7- z>72Bd#q7~_Y7uMjdEWz#~-NwEVtk&qdxm+J1}V+*b=)-?aSxiRF~N@O%IEN2W2uJE33B+y|lUm%O>P zy(@D1XXe;B_CP%wsYz^q5Evi+x!|z`ggqadWFz(hOnWz#d`F}0p`7a-liTvvejM6; zxX)I!wsFU!T~p4wwHM`_{u6WT+&H_{;F`ql278`S=cQniTGv{jb;+5}@1OmacRg+2 zM{@cv%dvAJ-<4qRRoFL!U#5Vl=T>a@5dJ&BOOTo5GtWJU_v1kJ-h_S_iSzgS& zSZ~<=ehAz1^fyG<#-EbdK5OR}Y=0L-y#B7lcD}y@Qvarc?Qen9e^X*R-`@hMe{*8% zA1>JU7i_=v>3qNSslBCOKUuK-#;5UH3%1|))bIB_wRb7le&bVrbHVl-pZfjAr?%hq z)b_ic+J3`R+wXX4`yEg1m5FUX{f?)8zvZdz_dK=zo@=*q%X}wJLws-gEF8_9Ye4Uf z$lpdx6S^6ZGv--t>2$PXEAjgFLFBx`Bii9R3@o>aI$ZN`M9z0*?E46Cao(Vke7 zwz2lzGq88+*mnmw-FGLp+-fAwRW3bO&mh)Y19rV`Bzn>VmUn&^THAP^`F>pn=-+_Wua32^0GlK3-Bn<@ zD-rLV_7Gw&=ZcH(h1lEG_=l@^=!`guWgz z&ql;|;H%IM>(h4wB4>T#sOQz-Vm+^claG2{3r_2K9k%tzM?J3x%g6Kd2C#YLTKkYlrX6U^#s^qP4^K7O>~9?gze9d)|)tw_{)8`yY9)OkDDIZ>y!ah^e}r`C?Xsh^<6+2}ivIG<00i|2DEoP6x_ zGhn$mpU;9F?o;1gh@ATr+q=8bt`)ue960UW=dtCYcV7fM%&G4Sh@3gak#i2%eT7dQ zYkUbD_n`s&WyCsK`Ohz2uzBKLbRSsGp1JNYT26dF+Wgk{0NPw?zYPzfA3_}ZzlGMXwuXn% z-$aa)7suIu8(cj5@4(5&*?$)-7ia%{u*2N?zK6(}TO9lQ0eBdR-v1CBz1L^&BlbsN z{qFS=f}i3-#m|^wZr!luw3lxr(iknlh5wY(GK^e?`OzUi2D*p z&wc?e_UxB%@{#*jV7choZ@><7>-#k#XKr!S_*?KWV$V)QKaNDt^!Z*ewgvqJ@;DOj ziQnZMs#5DD^zV_!khtgC#+joL{Rc$9{o0b-;rk=@=G^xrT06d7e*%w3oTJbA#>#tc zu{X8$gJ}0^KVz?d2B+uo7i_uNi0AP)w8Nh1`zs=69&wz<-@(Q6_y?SPoX0=Ga&aF2 z20P5H?_Y?Vxy4cAf55{?oX0cZI1hcR5o24>&m#KcJf6!r)W~@}4}Kbn^UyXfdj8+s zZ$Gz!wZr!!IL_k*uy&lsOJL8#Ir`!}k29$Q%f*>AfE}KRzIsH?+~TO)zuOHH7yrK84qI+M zT*PdTZO!qUWgNC~^6_t)@!&pAWK))(hE0W9Zme?C_evHhDzoTIjJvHm2m z>&IRvW6Mp2^KS11cDPr4J0epM_bN7z&*#oy*YOOzH`?~x`NldYVs`}_W2}FF(GK74 z;K<|OVzlG#>;d-fI7eTsC7-V4nz6q;sNH`Zpx)@a+qZ=VLlpJNCFA*gZN&pL;Y`K5A?M55q-``(w++nH>mr zSfjoJkb@9w6r1xA{BscdX1qRQVvQN#Vf^vT&BT_AzUh-QSEYy8X)beWM{m^H@tmlO z-^o~e7C!4U_oHa-@XZFx#oiCbmUAzj&s?y>z3ZEU$hkkUH9Ux4TV6Zf6Kd@Ta=&r$ zo;Z});(KBqoP5-B7+5ae6NiHx)}n7dBInv-_vEv)037$hzaMIc?`;6ar zr(nmv^hLe$@k~S>we{c2eDkVb>E13Q)^+0d_Nmx%@m!t`c37{z(-1lHisKy40DD%R zr)Q^a4H2u3*fYU#zRSVd(bE;+Vo%S4lfNVXH^JFpxv246u)`YlorB0(qd0P(4=(1u z08T#M85e@({GH^y%g`5L_o4O09xevUyWjKB+Tm*hyDxp0ptZx-4mMujO0;%78>_%R z8_v;ZjCthkW$a0X2tTpP{G3O&@Cv4}YF*{?+M-QifohN?=_rZF0LF8XzO|iB950*QwWB>pF literal 0 HcmV?d00001 diff --git a/tests/shader/prefix.comp b/tests/shader/prefix.comp new file mode 100644 index 0000000..ed5bcbc --- /dev/null +++ b/tests/shader/prefix.comp @@ -0,0 +1,188 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT OR Unlicense + +// A prefix sum. + +#version 450 + +#define N_ROWS 16 +#define LG_WG_SIZE 9 +#define WG_SIZE (1 << LG_WG_SIZE) +#define PARTITION_SIZE (WG_SIZE * N_ROWS) + +layout(local_size_x = WG_SIZE, local_size_y = 1) in; + +struct Monoid { + uint element; +}; + +layout(set = 0, binding = 0) readonly buffer InBuf { + Monoid[] inbuf; +}; + +layout(set = 0, binding = 1) buffer OutBuf { + Monoid[] outbuf; +}; + +// These correspond to X, A, P respectively in the prefix sum paper. +#define FLAG_NOT_READY 0 +#define FLAG_AGGREGATE_READY 1 +#define FLAG_PREFIX_READY 2 + +struct State { + uint flag; + Monoid aggregate; + Monoid prefix; +}; + +layout(set = 0, binding = 2) volatile buffer StateBuf { + uint part_counter; + State[] state; +}; + +shared Monoid sh_scratch[WG_SIZE]; + +Monoid combine_monoid(Monoid a, Monoid b) { + return Monoid(a.element + b.element); +} + +shared uint sh_part_ix; +shared Monoid sh_prefix; +shared uint sh_flag; + +void main() { + Monoid local[N_ROWS]; + // Determine partition to process by atomic counter (described in Section + // 4.4 of prefix sum paper). + if (gl_LocalInvocationID.x == 0) { + sh_part_ix = atomicAdd(part_counter, 1); + } + barrier(); + uint part_ix = sh_part_ix; + + uint ix = part_ix * PARTITION_SIZE + gl_LocalInvocationID.x * N_ROWS; + + // TODO: gate buffer read? (evaluate whether shader check or + // CPU-side padding is better) + local[0] = inbuf[ix]; + for (uint i = 1; i < N_ROWS; i++) { + local[i] = combine_monoid(local[i - 1], inbuf[ix + i]); + } + Monoid agg = local[N_ROWS - 1]; + sh_scratch[gl_LocalInvocationID.x] = agg; + for (uint i = 0; i < LG_WG_SIZE; i++) { + barrier(); + if (gl_LocalInvocationID.x >= (1 << i)) { + Monoid other = sh_scratch[gl_LocalInvocationID.x - (1 << i)]; + agg = combine_monoid(other, agg); + } + barrier(); + sh_scratch[gl_LocalInvocationID.x] = agg; + } + + // Publish aggregate for this partition + if (gl_LocalInvocationID.x == WG_SIZE - 1) { + state[part_ix].aggregate = agg; + if (part_ix == 0) { + state[0].prefix = agg; + } + } + // Write flag with release semantics; this is done portably with a barrier. + memoryBarrierBuffer(); + if (gl_LocalInvocationID.x == WG_SIZE - 1) { + uint flag = FLAG_AGGREGATE_READY; + if (part_ix == 0) { + flag = FLAG_PREFIX_READY; + } + state[part_ix].flag = flag; + } + + Monoid exclusive = Monoid(0); + if (part_ix != 0) { + // step 4 of paper: decoupled lookback + uint look_back_ix = part_ix - 1; + + Monoid their_agg; + uint their_ix = 0; + while (true) { + // Read flag with acquire semantics. + if (gl_LocalInvocationID.x == WG_SIZE - 1) { + sh_flag = state[look_back_ix].flag; + } + // The flag load is done only in the last thread. However, because the + // translation of memoryBarrierBuffer to Metal requires uniform control + // flow, we broadcast it to all threads. + barrier(); + memoryBarrierBuffer(); + uint flag = sh_flag; + + if (flag == FLAG_PREFIX_READY) { + if (gl_LocalInvocationID.x == WG_SIZE - 1) { + Monoid their_prefix = state[look_back_ix].prefix; + exclusive = combine_monoid(their_prefix, exclusive); + } + break; + } else if (flag == FLAG_AGGREGATE_READY) { + if (gl_LocalInvocationID.x == WG_SIZE - 1) { + their_agg = state[look_back_ix].aggregate; + exclusive = combine_monoid(their_agg, exclusive); + } + look_back_ix--; + their_ix = 0; + continue; + } + // else spin + + if (gl_LocalInvocationID.x == WG_SIZE - 1) { + // Unfortunately there's no guarantee of forward progress of other + // workgroups, so compute a bit of the aggregate before trying again. + // In the worst case, spinning stops when the aggregate is complete. + Monoid m = inbuf[look_back_ix * PARTITION_SIZE + their_ix]; + if (their_ix == 0) { + their_agg = m; + } else { + their_agg = combine_monoid(their_agg, m); + } + their_ix++; + if (their_ix == PARTITION_SIZE) { + exclusive = combine_monoid(their_agg, exclusive); + if (look_back_ix == 0) { + sh_flag = FLAG_PREFIX_READY; + } else { + look_back_ix--; + their_ix = 0; + } + } + } + barrier(); + flag = sh_flag; + if (flag == FLAG_PREFIX_READY) { + break; + } + } + // step 5 of paper: compute inclusive prefix + if (gl_LocalInvocationID.x == WG_SIZE - 1) { + Monoid inclusive_prefix = combine_monoid(exclusive, agg); + sh_prefix = exclusive; + state[part_ix].prefix = inclusive_prefix; + } + memoryBarrierBuffer(); + if (gl_LocalInvocationID.x == WG_SIZE - 1) { + state[part_ix].flag = FLAG_PREFIX_READY; + } + } + barrier(); + if (part_ix != 0) { + exclusive = sh_prefix; + } + + Monoid row = exclusive; + if (gl_LocalInvocationID.x > 0) { + Monoid other = sh_scratch[gl_LocalInvocationID.x - 1]; + row = combine_monoid(row, other); + } + for (uint i = 0; i < N_ROWS; i++) { + Monoid m = combine_monoid(row, local[i]); + // Make sure buffer allocation is padded appropriately. + outbuf[ix + i] = m; + } +} diff --git a/tests/src/main.rs b/tests/src/main.rs new file mode 100644 index 0000000..2fe2a3d --- /dev/null +++ b/tests/src/main.rs @@ -0,0 +1,29 @@ +// Copyright 2021 The piet-gpu authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Also licensed under MIT license, at your choice. + +//! Tests for piet-gpu shaders and GPU capabilities. + +mod prefix; +mod runner; + +use runner::Runner; + +fn main() { + unsafe { + let mut runner = Runner::new(); + prefix::run_prefix_test(&mut runner); + } +} diff --git a/tests/src/prefix.rs b/tests/src/prefix.rs new file mode 100644 index 0000000..2a52f75 --- /dev/null +++ b/tests/src/prefix.rs @@ -0,0 +1,142 @@ +// Copyright 2021 The piet-gpu authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Also licensed under MIT license, at your choice. + +use piet_gpu_hal::{include_shader, BufferUsage, DescriptorSet}; +use piet_gpu_hal::{Buffer, Pipeline}; + +use crate::runner::{Commands, Runner}; + +const WG_SIZE: u64 = 512; +const N_ROWS: u64 = 16; +const ELEMENTS_PER_WG: u64 = WG_SIZE * N_ROWS; + +/// The shader code for the prefix sum example. +/// +/// A code struct can be created once and reused any number of times. +struct PrefixCode { + pipeline: Pipeline, +} + +/// The stage resources for the prefix sum example. +/// +/// A stage resources struct is specific to a particular problem size +/// and queue. +struct PrefixStage { + // This is the actual problem size but perhaps it would be better to + // treat it as a capacity. + n_elements: u64, + state_buf: Buffer, +} + +/// The binding for the prefix sum example. +struct PrefixBinding { + descriptor_set: DescriptorSet, +} + +pub unsafe fn run_prefix_test(runner: &mut Runner) { + // This will be configurable. + let n_elements: u64 = 1 << 23; + let data: Vec = (0..n_elements as u32).collect(); + let data_buf = runner + .session + .create_buffer_init(&data, BufferUsage::STORAGE) + .unwrap(); + let out_buf = runner.buf_down(data_buf.size()); + let code = PrefixCode::new(runner); + let stage = PrefixStage::new(runner, n_elements); + let binding = stage.bind(runner, &code, &data_buf, &out_buf.dev_buf); + // Also will be configurable of course. + let n_iter = 5000; + let mut total_elapsed = 0.0; + for i in 0..n_iter { + let mut commands = runner.commands(); + commands.write_timestamp(0); + stage.record(&mut commands, &code, &binding); + commands.write_timestamp(1); + if i == 0 { + commands.cmd_buf.memory_barrier(); + commands.download(&out_buf); + } + total_elapsed += runner.submit(commands); + if i == 0 { + let mut dst: Vec = Default::default(); + out_buf.read(&mut dst); + println!("failures: {:?}", verify(&dst)); + } + } + let throughput = (n_elements * n_iter) as f64 / total_elapsed; + println!( + "total {:?}ms, throughput = {}G el/s", + total_elapsed * 1e3, + throughput * 1e-9 + ); +} + +impl PrefixCode { + unsafe fn new(runner: &mut Runner) -> PrefixCode { + let code = include_shader!(&runner.session, "../shader/gen/prefix"); + let pipeline = runner + .session + .create_simple_compute_pipeline(code, 3) + .unwrap(); + PrefixCode { pipeline } + } +} + +impl PrefixStage { + unsafe fn new(runner: &mut Runner, n_elements: u64) -> PrefixStage { + let n_workgroups = (n_elements + ELEMENTS_PER_WG - 1) / ELEMENTS_PER_WG; + let state_buf_size = 4 + 12 * n_workgroups; + let state_buf = runner + .session + .create_buffer(state_buf_size, BufferUsage::STORAGE | BufferUsage::COPY_DST) + .unwrap(); + PrefixStage { + n_elements, + state_buf, + } + } + + unsafe fn bind(&self, runner: &mut Runner, code: &PrefixCode, in_buf: &Buffer, out_buf: &Buffer) -> PrefixBinding { + let descriptor_set = runner + .session + .create_simple_descriptor_set(&code.pipeline, &[in_buf, out_buf, &self.state_buf]) + .unwrap(); + PrefixBinding { descriptor_set } + } + + unsafe fn record(&self, commands: &mut Commands, code: &PrefixCode, bindings: &PrefixBinding) { + let n_workgroups = (self.n_elements + ELEMENTS_PER_WG - 1) / ELEMENTS_PER_WG; + commands.cmd_buf.clear_buffer(&self.state_buf, None); + commands.cmd_buf.memory_barrier(); + commands.cmd_buf.dispatch( + &code.pipeline, + &bindings.descriptor_set, + (n_workgroups as u32, 1, 1), + (WG_SIZE as u32, 1, 1), + ); + // One thing that's missing here is registering the buffers so + // they can be safely dropped by Rust code before the execution + // of the command buffer completes. + } +} + +// Verify that the data is OEIS A000217 +fn verify(data: &[u32]) -> Option { + data.iter() + .enumerate() + .position(|(i, val)| ((i * (i + 1)) / 2) as u32 != *val) +} diff --git a/tests/src/runner.rs b/tests/src/runner.rs new file mode 100644 index 0000000..be42b30 --- /dev/null +++ b/tests/src/runner.rs @@ -0,0 +1,132 @@ +// Copyright 2021 The piet-gpu authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Also licensed under MIT license, at your choice. + +//! Test runner intended to make it easy to write tests. + +use piet_gpu_hal::{Buffer, BufferUsage, CmdBuf, Instance, PlainData, QueryPool, Session}; + +pub struct Runner { + #[allow(unused)] + instance: Instance, + pub session: Session, + cmd_buf_pool: Vec, +} + +/// A wrapper around command buffers. +pub struct Commands { + pub cmd_buf: CmdBuf, + query_pool: QueryPool, +} + +/// Buffer for uploading data to GPU. +#[allow(unused)] +pub struct BufUp { + pub stage_buf: Buffer, + pub dev_buf: Buffer, +} + +/// Buffer for downloading data from GPU. +pub struct BufDown { + pub stage_buf: Buffer, + pub dev_buf: Buffer, +} + +impl Runner { + pub unsafe fn new() -> Runner { + let (instance, _) = Instance::new(None).unwrap(); + let device = instance.device(None).unwrap(); + let session = Session::new(device); + let cmd_buf_pool = Vec::new(); + Runner { + instance, + session, + cmd_buf_pool, + } + } + + pub unsafe fn commands(&mut self) -> Commands { + let mut cmd_buf = self + .cmd_buf_pool + .pop() + .unwrap_or_else(|| self.session.cmd_buf().unwrap()); + cmd_buf.begin(); + // TODO: also pool these. But we might sort by size, as + // we might not always be doing two. + let query_pool = self.session.create_query_pool(2).unwrap(); + cmd_buf.reset_query_pool(&query_pool); + Commands { + cmd_buf, + query_pool, + } + } + + pub unsafe fn submit(&mut self, commands: Commands) -> f64 { + let mut cmd_buf = commands.cmd_buf; + let query_pool = commands.query_pool; + cmd_buf.host_barrier(); + cmd_buf.finish(); + let submitted = self.session.run_cmd_buf(cmd_buf, &[], &[]).unwrap(); + self.cmd_buf_pool.extend(submitted.wait().unwrap()); + let timestamps = self.session.fetch_query_pool(&query_pool).unwrap(); + timestamps[0] + } + + #[allow(unused)] + pub fn buf_up(&self, size: u64) -> BufUp { + let stage_buf = self + .session + .create_buffer(size, BufferUsage::MAP_WRITE | BufferUsage::COPY_SRC) + .unwrap(); + let dev_buf = self + .session + .create_buffer(size, BufferUsage::COPY_DST | BufferUsage::STORAGE) + .unwrap(); + BufUp { stage_buf, dev_buf } + } + + pub fn buf_down(&self, size: u64) -> BufDown { + let stage_buf = self + .session + .create_buffer(size, BufferUsage::MAP_READ | BufferUsage::COPY_DST) + .unwrap(); + let dev_buf = self + .session + .create_buffer(size, BufferUsage::COPY_SRC | BufferUsage::STORAGE) + .unwrap(); + BufDown { stage_buf, dev_buf } + } +} + +impl Commands { + pub unsafe fn write_timestamp(&mut self, query: u32) { + self.cmd_buf.write_timestamp(&self.query_pool, query); + } + + #[allow(unused)] + pub unsafe fn upload(&mut self, buf: &BufUp) { + self.cmd_buf.copy_buffer(&buf.stage_buf, &buf.dev_buf); + } + + pub unsafe fn download(&mut self, buf: &BufDown) { + self.cmd_buf.copy_buffer(&buf.dev_buf, &buf.stage_buf); + } +} + +impl BufDown { + pub unsafe fn read(&self, dst: &mut Vec) { + self.stage_buf.read(dst).unwrap() + } +} From 4ed339d4343ef3a2eda5b1be193b04384631da35 Mon Sep 17 00:00:00 2001 From: Raph Levien Date: Sat, 6 Nov 2021 16:08:43 -0700 Subject: [PATCH 2/5] Add tree reduction prefix sum test Do a tree reduction in addition to the existing decoupled look-back, to explore the tradeoff between performance and compatibility. --- tests/shader/build.ninja | 15 ++- tests/shader/prefix_reduce.comp | 53 ++++++++ tests/src/main.rs | 2 + tests/src/prefix.rs | 10 +- tests/src/prefix_tree.rs | 210 ++++++++++++++++++++++++++++++++ tests/src/runner.rs | 7 +- 6 files changed, 293 insertions(+), 4 deletions(-) create mode 100644 tests/shader/prefix_reduce.comp create mode 100644 tests/src/prefix_tree.rs diff --git a/tests/shader/build.ninja b/tests/shader/build.ninja index c5ecb06..93a0b66 100644 --- a/tests/shader/build.ninja +++ b/tests/shader/build.ninja @@ -6,7 +6,7 @@ glslang_validator = glslangValidator spirv_cross = spirv-cross rule glsl - command = $glslang_validator -V -o $out $in + command = $glslang_validator $flags -V -o $out $in rule hlsl command = $spirv_cross --hlsl $in --output $out @@ -17,3 +17,16 @@ rule msl build gen/prefix.spv: glsl prefix.comp build gen/prefix.hlsl: hlsl gen/prefix.spv build gen/prefix.msl: msl gen/prefix.spv + +build gen/prefix_reduce.spv: glsl prefix_reduce.comp +build gen/prefix_reduce.hlsl: hlsl gen/prefix_reduce.spv +build gen/prefix_reduce.msl: msl gen/prefix_reduce.spv + +build gen/prefix_root.spv: glsl prefix_scan.comp + flags = -DROOT +build gen/prefix_root.hlsl: hlsl gen/prefix_root.spv +build gen/prefix_root.msl: msl gen/prefix_root.spv + +build gen/prefix_scan.spv: glsl prefix_scan.comp +build gen/prefix_scan.hlsl: hlsl gen/prefix_scan.spv +build gen/prefix_scan.msl: msl gen/prefix_scan.spv diff --git a/tests/shader/prefix_reduce.comp b/tests/shader/prefix_reduce.comp new file mode 100644 index 0000000..378da88 --- /dev/null +++ b/tests/shader/prefix_reduce.comp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT OR Unlicense + +// The reduction phase for prefix sum implemented as a tree reduction. + +#version 450 + +#define N_ROWS 8 +#define LG_WG_SIZE 9 +#define WG_SIZE (1 << LG_WG_SIZE) +#define PARTITION_SIZE (WG_SIZE * N_ROWS) + +layout(local_size_x = WG_SIZE, local_size_y = 1) in; + +struct Monoid { + uint element; +}; + +layout(set = 0, binding = 0) readonly buffer InBuf { + Monoid[] inbuf; +}; + +layout(set = 0, binding = 1) buffer OutBuf { + Monoid[] outbuf; +}; + +shared Monoid sh_scratch[WG_SIZE]; + +Monoid combine_monoid(Monoid a, Monoid b) { + return Monoid(a.element + b.element); +} + +void main() { + uint ix = gl_GlobalInvocationID.x * N_ROWS; + // TODO: gate buffer read + Monoid agg = inbuf[ix]; + for (uint i = 1; i < N_ROWS; i++) { + agg = combine_monoid(agg, inbuf[ix + i]); + } + sh_scratch[gl_LocalInvocationID.x] = agg; + for (uint i = 0; i < LG_WG_SIZE; i++) { + barrier(); + // We could make this predicate tighter, but would it help? + if (gl_LocalInvocationID.x + (1 << i) < WG_SIZE) { + Monoid other = sh_scratch[gl_LocalInvocationID.x + (1 << i)]; + agg = combine_monoid(agg, other); + } + barrier(); + sh_scratch[gl_LocalInvocationID.x] = agg; + } + if (gl_LocalInvocationID.x == 0) { + outbuf[gl_WorkGroupID.x] = agg; + } +} diff --git a/tests/src/main.rs b/tests/src/main.rs index 2fe2a3d..85d8a66 100644 --- a/tests/src/main.rs +++ b/tests/src/main.rs @@ -17,6 +17,7 @@ //! Tests for piet-gpu shaders and GPU capabilities. mod prefix; +mod prefix_tree; mod runner; use runner::Runner; @@ -25,5 +26,6 @@ fn main() { unsafe { let mut runner = Runner::new(); prefix::run_prefix_test(&mut runner); + prefix_tree::run_prefix_test(&mut runner); } } diff --git a/tests/src/prefix.rs b/tests/src/prefix.rs index 2a52f75..f95470a 100644 --- a/tests/src/prefix.rs +++ b/tests/src/prefix.rs @@ -59,7 +59,7 @@ pub unsafe fn run_prefix_test(runner: &mut Runner) { let stage = PrefixStage::new(runner, n_elements); let binding = stage.bind(runner, &code, &data_buf, &out_buf.dev_buf); // Also will be configurable of course. - let n_iter = 5000; + let n_iter = 1000; let mut total_elapsed = 0.0; for i in 0..n_iter { let mut commands = runner.commands(); @@ -110,7 +110,13 @@ impl PrefixStage { } } - unsafe fn bind(&self, runner: &mut Runner, code: &PrefixCode, in_buf: &Buffer, out_buf: &Buffer) -> PrefixBinding { + unsafe fn bind( + &self, + runner: &mut Runner, + code: &PrefixCode, + in_buf: &Buffer, + out_buf: &Buffer, + ) -> PrefixBinding { let descriptor_set = runner .session .create_simple_descriptor_set(&code.pipeline, &[in_buf, out_buf, &self.state_buf]) diff --git a/tests/src/prefix_tree.rs b/tests/src/prefix_tree.rs new file mode 100644 index 0000000..7b9743a --- /dev/null +++ b/tests/src/prefix_tree.rs @@ -0,0 +1,210 @@ +// Copyright 2021 The piet-gpu authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Also licensed under MIT license, at your choice. + +use piet_gpu_hal::{include_shader, BufferUsage, DescriptorSet}; +use piet_gpu_hal::{Buffer, Pipeline}; + +use crate::runner::{Commands, Runner}; + +const WG_SIZE: u64 = 512; +const N_ROWS: u64 = 8; +const ELEMENTS_PER_WG: u64 = WG_SIZE * N_ROWS; + +struct PrefixTreeCode { + reduce_pipeline: Pipeline, + scan_pipeline: Pipeline, + root_pipeline: Pipeline, +} + +struct PrefixTreeStage { + sizes: Vec, + tmp_bufs: Vec, +} + +struct PrefixTreeBinding { + // All but the first and last can be moved to stage. + descriptor_sets: Vec, +} + +pub unsafe fn run_prefix_test(runner: &mut Runner) { + // This will be configurable. Note though that the current code is + // prone to reading and writing past the end of buffers if this is + // not a power of the number of elements processed in a workgroup. + let n_elements: u64 = 1 << 24; + let data: Vec = (0..n_elements as u32).collect(); + let data_buf = runner + .session + .create_buffer_init(&data, BufferUsage::STORAGE) + .unwrap(); + let out_buf = runner.buf_down(data_buf.size()); + let code = PrefixTreeCode::new(runner); + let stage = PrefixTreeStage::new(runner, n_elements); + let binding = stage.bind(runner, &code, &out_buf.dev_buf); + // Also will be configurable of course. + let n_iter = 1000; + let mut total_elapsed = 0.0; + for i in 0..n_iter { + let mut commands = runner.commands(); + commands.cmd_buf.copy_buffer(&data_buf, &out_buf.dev_buf); + commands.cmd_buf.memory_barrier(); + commands.write_timestamp(0); + stage.record(&mut commands, &code, &binding); + commands.write_timestamp(1); + if i == 0 { + commands.cmd_buf.memory_barrier(); + commands.download(&out_buf); + } + total_elapsed += runner.submit(commands); + if i == 0 { + let mut dst: Vec = Default::default(); + out_buf.read(&mut dst); + println!("failures: {:?}", verify(&dst)); + } + } + let throughput = (n_elements * n_iter) as f64 / total_elapsed; + println!( + "total {:?}ms, throughput = {}G el/s", + total_elapsed * 1e3, + throughput * 1e-9 + ); +} + +impl PrefixTreeCode { + unsafe fn new(runner: &mut Runner) -> PrefixTreeCode { + let reduce_code = include_shader!(&runner.session, "../shader/gen/prefix_reduce"); + let reduce_pipeline = runner + .session + .create_simple_compute_pipeline(reduce_code, 2) + .unwrap(); + let scan_code = include_shader!(&runner.session, "../shader/gen/prefix_scan"); + let scan_pipeline = runner + .session + .create_simple_compute_pipeline(scan_code, 2) + .unwrap(); + let root_code = include_shader!(&runner.session, "../shader/gen/prefix_root"); + let root_pipeline = runner + .session + .create_simple_compute_pipeline(root_code, 1) + .unwrap(); + PrefixTreeCode { + reduce_pipeline, + scan_pipeline, + root_pipeline, + } + } +} + +impl PrefixTreeStage { + unsafe fn new(runner: &mut Runner, n_elements: u64) -> PrefixTreeStage { + let mut size = n_elements; + let mut sizes = vec![size]; + let mut tmp_bufs = Vec::new(); + while size > ELEMENTS_PER_WG { + size = (size + ELEMENTS_PER_WG - 1) / ELEMENTS_PER_WG; + sizes.push(size); + let buf = runner + .session + .create_buffer(4 * size, BufferUsage::STORAGE) + .unwrap(); + tmp_bufs.push(buf); + } + PrefixTreeStage { sizes, tmp_bufs } + } + + unsafe fn bind( + &self, + runner: &mut Runner, + code: &PrefixTreeCode, + data_buf: &Buffer, + ) -> PrefixTreeBinding { + let mut descriptor_sets = Vec::with_capacity(2 * self.tmp_bufs.len() + 1); + for i in 0..self.tmp_bufs.len() { + let buf0 = if i == 0 { + data_buf + } else { + &self.tmp_bufs[i - 1] + }; + let buf1 = &self.tmp_bufs[i]; + let descriptor_set = runner + .session + .create_simple_descriptor_set(&code.reduce_pipeline, &[buf0, buf1]) + .unwrap(); + descriptor_sets.push(descriptor_set); + } + let buf0 = self.tmp_bufs.last().unwrap_or(data_buf); + let descriptor_set = runner + .session + .create_simple_descriptor_set(&code.root_pipeline, &[buf0]) + .unwrap(); + descriptor_sets.push(descriptor_set); + for i in (0..self.tmp_bufs.len()).rev() { + let buf0 = if i == 0 { + data_buf + } else { + &self.tmp_bufs[i - 1] + }; + let buf1 = &self.tmp_bufs[i]; + let descriptor_set = runner + .session + .create_simple_descriptor_set(&code.scan_pipeline, &[buf0, buf1]) + .unwrap(); + descriptor_sets.push(descriptor_set); + } + PrefixTreeBinding { descriptor_sets } + } + + unsafe fn record( + &self, + commands: &mut Commands, + code: &PrefixTreeCode, + bindings: &PrefixTreeBinding, + ) { + let n = self.tmp_bufs.len(); + for i in 0..n { + let n_workgroups = self.sizes[i + 1]; + commands.cmd_buf.dispatch( + &code.reduce_pipeline, + &bindings.descriptor_sets[i], + (n_workgroups as u32, 1, 1), + (WG_SIZE as u32, 1, 1), + ); + commands.cmd_buf.memory_barrier(); + } + commands.cmd_buf.dispatch( + &code.root_pipeline, + &bindings.descriptor_sets[n], + (1, 1, 1), + (WG_SIZE as u32, 1, 1), + ); + for i in (0..n).rev() { + commands.cmd_buf.memory_barrier(); + let n_workgroups = self.sizes[i + 1]; + commands.cmd_buf.dispatch( + &code.scan_pipeline, + &bindings.descriptor_sets[2 * n - i], + (n_workgroups as u32, 1, 1), + (WG_SIZE as u32, 1, 1), + ); + } + } +} + +// Verify that the data is OEIS A000217 +fn verify(data: &[u32]) -> Option { + data.iter() + .enumerate() + .position(|(i, val)| ((i * (i + 1)) / 2) as u32 != *val) +} diff --git a/tests/src/runner.rs b/tests/src/runner.rs index be42b30..ce89961 100644 --- a/tests/src/runner.rs +++ b/tests/src/runner.rs @@ -102,9 +102,14 @@ impl Runner { .session .create_buffer(size, BufferUsage::MAP_READ | BufferUsage::COPY_DST) .unwrap(); + // Note: the COPY_DST isn't needed in all use cases, but I don't think + // making this tighter would help. let dev_buf = self .session - .create_buffer(size, BufferUsage::COPY_SRC | BufferUsage::STORAGE) + .create_buffer( + size, + BufferUsage::COPY_SRC | BufferUsage::COPY_DST | BufferUsage::STORAGE, + ) .unwrap(); BufDown { stage_buf, dev_buf } } From b36ca7fc2e75f25abd752477e7766a47a3280968 Mon Sep 17 00:00:00 2001 From: Raph Levien Date: Sat, 6 Nov 2021 16:25:56 -0700 Subject: [PATCH 3/5] Add generated shaders --- tests/shader/gen/prefix_reduce.hlsl | 72 ++++++++++++++++ tests/shader/gen/prefix_reduce.msl | 68 +++++++++++++++ tests/shader/gen/prefix_reduce.spv | Bin 0 -> 3504 bytes tests/shader/gen/prefix_root.hlsl | 80 ++++++++++++++++++ tests/shader/gen/prefix_root.msl | 112 +++++++++++++++++++++++++ tests/shader/gen/prefix_root.spv | Bin 0 -> 4104 bytes tests/shader/gen/prefix_scan.hlsl | 92 +++++++++++++++++++++ tests/shader/gen/prefix_scan.msl | 123 ++++++++++++++++++++++++++++ tests/shader/gen/prefix_scan.spv | Bin 0 -> 4736 bytes 9 files changed, 547 insertions(+) create mode 100644 tests/shader/gen/prefix_reduce.hlsl create mode 100644 tests/shader/gen/prefix_reduce.msl create mode 100644 tests/shader/gen/prefix_reduce.spv create mode 100644 tests/shader/gen/prefix_root.hlsl create mode 100644 tests/shader/gen/prefix_root.msl create mode 100644 tests/shader/gen/prefix_root.spv create mode 100644 tests/shader/gen/prefix_scan.hlsl create mode 100644 tests/shader/gen/prefix_scan.msl create mode 100644 tests/shader/gen/prefix_scan.spv diff --git a/tests/shader/gen/prefix_reduce.hlsl b/tests/shader/gen/prefix_reduce.hlsl new file mode 100644 index 0000000..837a75a --- /dev/null +++ b/tests/shader/gen/prefix_reduce.hlsl @@ -0,0 +1,72 @@ +struct Monoid +{ + uint element; +}; + +static const uint3 gl_WorkGroupSize = uint3(512u, 1u, 1u); + +ByteAddressBuffer _40 : register(t0); +RWByteAddressBuffer _129 : register(u1); + +static uint3 gl_WorkGroupID; +static uint3 gl_LocalInvocationID; +static uint3 gl_GlobalInvocationID; +struct SPIRV_Cross_Input +{ + uint3 gl_WorkGroupID : SV_GroupID; + uint3 gl_LocalInvocationID : SV_GroupThreadID; + uint3 gl_GlobalInvocationID : SV_DispatchThreadID; +}; + +groupshared Monoid sh_scratch[512]; + +Monoid combine_monoid(Monoid a, Monoid b) +{ + Monoid _22 = { a.element + b.element }; + return _22; +} + +void comp_main() +{ + uint ix = gl_GlobalInvocationID.x * 8u; + Monoid _44; + _44.element = _40.Load(ix * 4 + 0); + Monoid agg; + agg.element = _44.element; + Monoid param_1; + for (uint i = 1u; i < 8u; i++) + { + Monoid param = agg; + Monoid _64; + _64.element = _40.Load((ix + i) * 4 + 0); + param_1.element = _64.element; + agg = combine_monoid(param, param_1); + } + sh_scratch[gl_LocalInvocationID.x] = agg; + for (uint i_1 = 0u; i_1 < 9u; i_1++) + { + GroupMemoryBarrierWithGroupSync(); + if ((gl_LocalInvocationID.x + uint(1 << int(i_1))) < 512u) + { + Monoid other = sh_scratch[gl_LocalInvocationID.x + uint(1 << int(i_1))]; + Monoid param_2 = agg; + Monoid param_3 = other; + agg = combine_monoid(param_2, param_3); + } + GroupMemoryBarrierWithGroupSync(); + sh_scratch[gl_LocalInvocationID.x] = agg; + } + if (gl_LocalInvocationID.x == 0u) + { + _129.Store(gl_WorkGroupID.x * 4 + 0, agg.element); + } +} + +[numthreads(512, 1, 1)] +void main(SPIRV_Cross_Input stage_input) +{ + gl_WorkGroupID = stage_input.gl_WorkGroupID; + gl_LocalInvocationID = stage_input.gl_LocalInvocationID; + gl_GlobalInvocationID = stage_input.gl_GlobalInvocationID; + comp_main(); +} diff --git a/tests/shader/gen/prefix_reduce.msl b/tests/shader/gen/prefix_reduce.msl new file mode 100644 index 0000000..e1ed0ce --- /dev/null +++ b/tests/shader/gen/prefix_reduce.msl @@ -0,0 +1,68 @@ +#pragma clang diagnostic ignored "-Wmissing-prototypes" + +#include +#include + +using namespace metal; + +struct Monoid +{ + uint element; +}; + +struct Monoid_1 +{ + uint element; +}; + +struct InBuf +{ + Monoid_1 inbuf[1]; +}; + +struct OutBuf +{ + Monoid_1 outbuf[1]; +}; + +constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u); + +static inline __attribute__((always_inline)) +Monoid combine_monoid(thread const Monoid& a, thread const Monoid& b) +{ + return Monoid{ a.element + b.element }; +} + +kernel void main0(const device InBuf& _40 [[buffer(0)]], device OutBuf& _129 [[buffer(1)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]]) +{ + threadgroup Monoid sh_scratch[512]; + uint ix = gl_GlobalInvocationID.x * 8u; + Monoid agg; + agg.element = _40.inbuf[ix].element; + Monoid param_1; + for (uint i = 1u; i < 8u; i++) + { + Monoid param = agg; + param_1.element = _40.inbuf[ix + i].element; + agg = combine_monoid(param, param_1); + } + sh_scratch[gl_LocalInvocationID.x] = agg; + for (uint i_1 = 0u; i_1 < 9u; i_1++) + { + threadgroup_barrier(mem_flags::mem_threadgroup); + if ((gl_LocalInvocationID.x + uint(1 << int(i_1))) < 512u) + { + Monoid other = sh_scratch[gl_LocalInvocationID.x + uint(1 << int(i_1))]; + Monoid param_2 = agg; + Monoid param_3 = other; + agg = combine_monoid(param_2, param_3); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + sh_scratch[gl_LocalInvocationID.x] = agg; + } + if (gl_LocalInvocationID.x == 0u) + { + _129.outbuf[gl_WorkGroupID.x].element = agg.element; + } +} + diff --git a/tests/shader/gen/prefix_reduce.spv b/tests/shader/gen/prefix_reduce.spv new file mode 100644 index 0000000000000000000000000000000000000000..d1db3aab8ab019deb8a9f24a8d497a576e57c22e GIT binary patch literal 3504 zcmZ{ldsmco5XT?6iDCk&UDZI!D21fFMpBps6ozJHcU)kFP1vPn*Q_jstgPNifAtzV zy^vl-w{z=NRo6V9Vwrf zzBoO+R$CZ8dhCcEJCoMbQJ-$&+Tm@`N}*KNuorq6+JJhoX|X6$9X0b%V(cY#_iRjOAfs+IbxL!96qcg6Y3C-IEAIa`lDkH9;iHfWOnV{Z4Q z`sN%rV_H)mx72uNdI0i--rrDD-8osqgPHwv_~w0^z2)=UZ{)Yne13KFZ-KmT18)8Y zGX2|3ch`3DD`&f%(i8^xMciJcpqd9hcdkz z+0%#K1CN~dqFbB!twFo9Fjjmt(|x1Jr|&+rZs&}b@796 z-`5O#9=RY;4@&2g!SP6L})T?fr!e55Ea?lkX+ZB%-~GnRg1_*saFB z>i(X@qIY-1d-R{lZ03mkZzB5+p}&o6FQLDK+)09c^z42`|L&2xd2hju5%XU4bx17g z*gz&t>VGG*IoHEX*SB)F?8~`G$Jr0#?ZxB!{sz(xH>Z2x9@>k#?=_Iqwfh~)+x-^h z^|73;zu%R-z1XDtZOPkLGku(@+@psecg_7eO-=UK2Z`UtNBscQ2lYc^bY%}>;r|%2 z*f`$q%^}EL@;9JOKi}Xmr0qhMcPPVP#I~pN8Af&{#(I~!e#*|~2=qHM9f3Pr`7h*u z3VsywXg>zmE{{5oBb(y~?s9)mLSpvy2RYQAgv@2Ea>SlOw#U$)MmE>2td~(_eZ=)~ z=BFWhv3B24J!bSQa?HqBW1P`*khYl7c!t6Lnb8?!_aJ7ZuHPx<@4e4M{=N22{k-2g z#GQ?EvOjtBa~3(@&x`0{{iQ_V*XHuRSK-zx_pM)pzYck{zX8`SkA5y8n_UH~MdUzrR9Xhq@ts<6HDKvc4Yk z-+_Mv>94Q7s_SR3QIGsj{*s=A{2RFy-FNkDgH90F3E!TRyO8bWF1UB~z= uint(1 << int(i_1))) + { + Monoid other = sh_scratch[gl_LocalInvocationID.x - uint(1 << int(i_1))]; + Monoid param_2 = other; + Monoid param_3 = agg; + agg = combine_monoid(param_2, param_3); + } + GroupMemoryBarrierWithGroupSync(); + sh_scratch[gl_LocalInvocationID.x] = agg; + } + GroupMemoryBarrierWithGroupSync(); + Monoid row = _133; + if (gl_LocalInvocationID.x > 0u) + { + row = sh_scratch[gl_LocalInvocationID.x - 1u]; + } + for (uint i_2 = 0u; i_2 < 8u; i_2++) + { + Monoid param_4 = row; + Monoid param_5 = local[i_2]; + Monoid m = combine_monoid(param_4, param_5); + _42.Store((ix + i_2) * 4 + 0, m.element); + } +} + +[numthreads(512, 1, 1)] +void main(SPIRV_Cross_Input stage_input) +{ + gl_LocalInvocationID = stage_input.gl_LocalInvocationID; + gl_GlobalInvocationID = stage_input.gl_GlobalInvocationID; + comp_main(); +} diff --git a/tests/shader/gen/prefix_root.msl b/tests/shader/gen/prefix_root.msl new file mode 100644 index 0000000..ff02287 --- /dev/null +++ b/tests/shader/gen/prefix_root.msl @@ -0,0 +1,112 @@ +#pragma clang diagnostic ignored "-Wmissing-prototypes" +#pragma clang diagnostic ignored "-Wmissing-braces" + +#include +#include + +using namespace metal; + +template +struct spvUnsafeArray +{ + T elements[Num ? Num : 1]; + + thread T& operator [] (size_t pos) thread + { + return elements[pos]; + } + constexpr const thread T& operator [] (size_t pos) const thread + { + return elements[pos]; + } + + device T& operator [] (size_t pos) device + { + return elements[pos]; + } + constexpr const device T& operator [] (size_t pos) const device + { + return elements[pos]; + } + + constexpr const constant T& operator [] (size_t pos) const constant + { + return elements[pos]; + } + + threadgroup T& operator [] (size_t pos) threadgroup + { + return elements[pos]; + } + constexpr const threadgroup T& operator [] (size_t pos) const threadgroup + { + return elements[pos]; + } +}; + +struct Monoid +{ + uint element; +}; + +struct Monoid_1 +{ + uint element; +}; + +struct DataBuf +{ + Monoid_1 data[1]; +}; + +constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u); + +static inline __attribute__((always_inline)) +Monoid combine_monoid(thread const Monoid& a, thread const Monoid& b) +{ + return Monoid{ a.element + b.element }; +} + +kernel void main0(device DataBuf& _42 [[buffer(0)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]]) +{ + threadgroup Monoid sh_scratch[512]; + uint ix = gl_GlobalInvocationID.x * 8u; + spvUnsafeArray local; + local[0].element = _42.data[ix].element; + Monoid param_1; + for (uint i = 1u; i < 8u; i++) + { + Monoid param = local[i - 1u]; + param_1.element = _42.data[ix + i].element; + local[i] = combine_monoid(param, param_1); + } + Monoid agg = local[7]; + sh_scratch[gl_LocalInvocationID.x] = agg; + for (uint i_1 = 0u; i_1 < 9u; i_1++) + { + threadgroup_barrier(mem_flags::mem_threadgroup); + if (gl_LocalInvocationID.x >= uint(1 << int(i_1))) + { + Monoid other = sh_scratch[gl_LocalInvocationID.x - uint(1 << int(i_1))]; + Monoid param_2 = other; + Monoid param_3 = agg; + agg = combine_monoid(param_2, param_3); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + sh_scratch[gl_LocalInvocationID.x] = agg; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + Monoid row = Monoid{ 0u }; + if (gl_LocalInvocationID.x > 0u) + { + row = sh_scratch[gl_LocalInvocationID.x - 1u]; + } + for (uint i_2 = 0u; i_2 < 8u; i_2++) + { + Monoid param_4 = row; + Monoid param_5 = local[i_2]; + Monoid m = combine_monoid(param_4, param_5); + _42.data[ix + i_2].element = m.element; + } +} + diff --git a/tests/shader/gen/prefix_root.spv b/tests/shader/gen/prefix_root.spv new file mode 100644 index 0000000000000000000000000000000000000000..70ba31c92ccab887aaf9be7a7fc71ee83cc25c2a GIT binary patch literal 4104 zcmZ{miE>m`5QcA@B!oo*0$M0w0tm<^D2s?7AR5phAnse3EF`0oOiU&$ii#NA7u*q1 zal^GNpTY7$dC^WPg^LzUNRozRQqq!qm(-r=$we?p zGC65Zbw}sE&XuFp-j!?Ct}F$cL9nZ|ng=?E#)T&^kb z519M}Xd)uNR!l+3uF6oQ)SHl@DaBApU%79vZ>S2>noI%rR0g|CLw#L?Y3!2G>d076 zbwzFEim}zJH~yDlZGp2UNfFVWjPB0t+33HgIm3(4`I^N(=WA|EmVmbxtHo_&2hwwieR|>4*0WTDNBdfIDYYAuZRp|R zNO3U3w_Jb=$!>75zn^tY$vxPk16`v%BgJaZKk`itW<|o`#N`mw_(_o8KDrTdQm3H)hzL zMLuKaB0k2?L+g`wKXp4|J#)bJu2yd?eqn~?J>Md9EoUL^zgT_=dNJZ2SEKEZoHdE< zy`28#neE!B;W}*Vnn~0R8E-W15*n8Yohy6C>{eBw)yWV#_ud#%DY&c`^5-L2>l|R_y4%`P|F? zcOd$r*E=!}^>43tVmrgpYq4<~xLfDuHpJh5_cShQydB#b?TIysjkWLIfxT15xjV7* z=e`3h=Qk+sRW5(8-a+iQ7u)^zAm+5TI}v%;@6NFC&iQ?a-=F-P-=FcZ{sFM-T;uoY zA;fQ(apn{oqwRZaF6(;)?fKNs-F|ct@zLLn)~}B9_h6eN=B^K0t`~9c#9hdK#I@SL zz#l*>{Yac`0K5KdB{=!m?@?^``yeLo#<1n%eL9G39{ISVL)hMtzPQK3*z(@f zFj^eGC$Q!89YKr3cNE+E*XMUv9CQB^_7uc5`iwDF-aC#ysQ>Stp2k<-muKMQqlRPH z=8f|`i|zT0b5HRx#66ucd#`rAcX%AJukk&3F5~R2>-EXy=XO81zGh>`;ZGoOug_!G z*Z%^X{P)?+yofC~7jb4@M*BE3`d&ihtXn&Bzk*$#`&Bskxc}F%<^0}9e_qG7Kl-AF zZ(z&E8N{A9_EqaUiMCFCZ=$V9?K}K7`W?ha|GQ}Y>ge@**yf0L_= uint(1 << int(i_1))) + { + Monoid other = sh_scratch[gl_LocalInvocationID.x - uint(1 << int(i_1))]; + Monoid param_2 = other; + Monoid param_3 = agg; + agg = combine_monoid(param_2, param_3); + } + GroupMemoryBarrierWithGroupSync(); + sh_scratch[gl_LocalInvocationID.x] = agg; + } + GroupMemoryBarrierWithGroupSync(); + Monoid row = _133; + if (gl_WorkGroupID.x > 0u) + { + Monoid _148; + _148.element = _143.Load((gl_WorkGroupID.x - 1u) * 4 + 0); + row.element = _148.element; + } + if (gl_LocalInvocationID.x > 0u) + { + Monoid param_4 = row; + Monoid param_5 = sh_scratch[gl_LocalInvocationID.x - 1u]; + row = combine_monoid(param_4, param_5); + } + for (uint i_2 = 0u; i_2 < 8u; i_2++) + { + Monoid param_6 = row; + Monoid param_7 = local[i_2]; + Monoid m = combine_monoid(param_6, param_7); + _42.Store((ix + i_2) * 4 + 0, m.element); + } +} + +[numthreads(512, 1, 1)] +void main(SPIRV_Cross_Input stage_input) +{ + gl_WorkGroupID = stage_input.gl_WorkGroupID; + gl_LocalInvocationID = stage_input.gl_LocalInvocationID; + gl_GlobalInvocationID = stage_input.gl_GlobalInvocationID; + comp_main(); +} diff --git a/tests/shader/gen/prefix_scan.msl b/tests/shader/gen/prefix_scan.msl new file mode 100644 index 0000000..c1efb22 --- /dev/null +++ b/tests/shader/gen/prefix_scan.msl @@ -0,0 +1,123 @@ +#pragma clang diagnostic ignored "-Wmissing-prototypes" +#pragma clang diagnostic ignored "-Wmissing-braces" + +#include +#include + +using namespace metal; + +template +struct spvUnsafeArray +{ + T elements[Num ? Num : 1]; + + thread T& operator [] (size_t pos) thread + { + return elements[pos]; + } + constexpr const thread T& operator [] (size_t pos) const thread + { + return elements[pos]; + } + + device T& operator [] (size_t pos) device + { + return elements[pos]; + } + constexpr const device T& operator [] (size_t pos) const device + { + return elements[pos]; + } + + constexpr const constant T& operator [] (size_t pos) const constant + { + return elements[pos]; + } + + threadgroup T& operator [] (size_t pos) threadgroup + { + return elements[pos]; + } + constexpr const threadgroup T& operator [] (size_t pos) const threadgroup + { + return elements[pos]; + } +}; + +struct Monoid +{ + uint element; +}; + +struct Monoid_1 +{ + uint element; +}; + +struct DataBuf +{ + Monoid_1 data[1]; +}; + +struct ParentBuf +{ + Monoid_1 parent[1]; +}; + +constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u); + +static inline __attribute__((always_inline)) +Monoid combine_monoid(thread const Monoid& a, thread const Monoid& b) +{ + return Monoid{ a.element + b.element }; +} + +kernel void main0(device DataBuf& _42 [[buffer(0)]], device ParentBuf& _143 [[buffer(1)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]]) +{ + threadgroup Monoid sh_scratch[512]; + uint ix = gl_GlobalInvocationID.x * 8u; + spvUnsafeArray local; + local[0].element = _42.data[ix].element; + Monoid param_1; + for (uint i = 1u; i < 8u; i++) + { + Monoid param = local[i - 1u]; + param_1.element = _42.data[ix + i].element; + local[i] = combine_monoid(param, param_1); + } + Monoid agg = local[7]; + sh_scratch[gl_LocalInvocationID.x] = agg; + for (uint i_1 = 0u; i_1 < 9u; i_1++) + { + threadgroup_barrier(mem_flags::mem_threadgroup); + if (gl_LocalInvocationID.x >= uint(1 << int(i_1))) + { + Monoid other = sh_scratch[gl_LocalInvocationID.x - uint(1 << int(i_1))]; + Monoid param_2 = other; + Monoid param_3 = agg; + agg = combine_monoid(param_2, param_3); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + sh_scratch[gl_LocalInvocationID.x] = agg; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + Monoid row = Monoid{ 0u }; + if (gl_WorkGroupID.x > 0u) + { + row.element = _143.parent[gl_WorkGroupID.x - 1u].element; + } + if (gl_LocalInvocationID.x > 0u) + { + Monoid param_4 = row; + Monoid param_5 = sh_scratch[gl_LocalInvocationID.x - 1u]; + row = combine_monoid(param_4, param_5); + } + for (uint i_2 = 0u; i_2 < 8u; i_2++) + { + Monoid param_6 = row; + Monoid param_7 = local[i_2]; + Monoid m = combine_monoid(param_6, param_7); + _42.data[ix + i_2].element = m.element; + } +} + diff --git a/tests/shader/gen/prefix_scan.spv b/tests/shader/gen/prefix_scan.spv new file mode 100644 index 0000000000000000000000000000000000000000..d4216e95f91038753486e70cb2a8db3ac3beca00 GIT binary patch literal 4736 zcmZ{mhjUa_6vm%y5`!QC5giMfC?L`l6)ac)(TE0#Vn+>|LK01~akH@(#29<;6;ZL% z?22t1{}BHQ+l-FC-`jWeWgTzk%=y0WoO|!N=iaxI!lYyOPm=wT1Cqw%hom}YB$H5* zWJ=PI`qs7$ZA(Tf-AhhccA^>6l0uqjP80L$@paf>v46fbmHN9AR;W)oQqnWfGuSg!L1{{+!n;a?o&7^S9fN7^{Et@p_@xx#Y%r^ zXziLbw=S8BA3z+)_&j{RX1U|vn(LDJ@HNFsarNlt^jvaBH=5ph76=&Plkoj1u1i+q zhl}OnV8+jzzzfNGc(J#adG*N^;E}$Lk*;#F($!Zzk9Wd(+AO)|IO$&35s$rw!KF%H zPZ{rh+=kEZea8e|NXFph(kv-+Ma-a|Cax(v;hO_dFttk7xYwj6aIc?-#wv@23{~*{A$| za`)R0bN}!0?&q7H$=5as-@Ck##Cqn=#vJA!g*T?|e)<;7dX9wKt7^T~ z{5cs{_k8p4)in$0{mRwnyLw7e>OexeMY~0hjRCQ9PgTH&XfBVqhI?n+<;$1_q}W1rsuS8+H&5P z`JT<;`}hr|9`BmIBj2$#%IDzc=D2ZxDwFAaDM~y@jGhx zSPl2LE1$nT$K7v7j$7}}8os-R`y1wZtLTc~!O@uCiuZLEJ+Z&dnEDtw@?)@OtOYxp zuG&Mj$X^InTSd0-dNJl(_Zv25o_D?kGv?mbdIDzu&ATMq>qM}54)Z+YQp{ZQ++S{< z*!%MCf9IZ0!F%8OU)Xaw{#49i{4~6AebjzB*fsXBm*3n9OwF_Y!5Z>qm}@y#9CObE zJ2&`Q8Q+?nNzt(dXs^#z%R`={4y!M^e6wcNbZxm(}Oh1gWgJ#(?2wfwz$2eIEJVE1dsT+`Yv#nhd@KI7*5&acA!oy))Tt1~_3 zUjuiZb2j3y#r&-_&o$-di2cr8%lfXzdp^DIZWF$UIgEGWjqBt5U0~OUch>_}>&AR{ z@(yeh=3Md5=$r9cFBWI(1J|CdA5A^>y8-NeKXJBc_!}{Mt$ru#HsA*^^o-&$x)H4Fk*HeL8k9yQI3RaKr>1MF&sK*`M0``uK z#Xa(`ccSh+4ddkzyB(}%>^8hSVt0bQe`EeG%j4bO1)hpI$Cx?hs(Z)L2mSxu)7`{s z`*IJOdepE5?7DHjd%>R1Jol7u!Q9g~X7BaR_YSvW_BDP_?#ncL>wIHs`FFbsUR$%d zW9au|ajy@6$FbR1?Drs8{iE#7JOoxV$2ao`-r<`u_AsVq-QrmLNpS7jPob&D{XY#> z^Y=b}>z)BWfi+>qt!p0sSb$H_zTgjryqZ1#s>Cy@;kB z=YI*T7UzEj?C|`?UdGfszc|)@6y7Qedro8T;60Px-@y0q?_&<*AK;Da;}v2LnHoUY%b=U$p0kMJ;z?KJo@k{ z*gm+9F>`j})not9z=!4d=V0rJ@9qn*IqI?Jmtfb4xBC@%0p|NNKi=xsU~?UwVITe* z%=~z(a`Wu1eYQXPsP9{FzMk*kYVkW@Of7$R_W4i#Q$G#!f9d_<_Q5e3JC#1z&jWIJ z3fR35#9OQ1)eOuWYmIsi%Jj&Y33qNj=U}*c^l%o~dFt)F2kSWmQ~!%S#n$>avB;sb literal 0 HcmV?d00001 From 3820e4b2f414955003c99240256bec054d3259df Mon Sep 17 00:00:00 2001 From: Raph Levien Date: Sat, 6 Nov 2021 21:43:09 -0700 Subject: [PATCH 4/5] Add missing file Also add finish_timestamps call, which is needed for DX12 (there are other issues but this is an easy fix for that one). --- piet-gpu-hal/examples/collatz.rs | 1 + tests/shader/prefix_scan.comp | 77 ++++++++++++++++++++++++++++++++ tests/src/runner.rs | 1 + 3 files changed, 79 insertions(+) create mode 100644 tests/shader/prefix_scan.comp diff --git a/piet-gpu-hal/examples/collatz.rs b/piet-gpu-hal/examples/collatz.rs index e974cde..cad508e 100644 --- a/piet-gpu-hal/examples/collatz.rs +++ b/piet-gpu-hal/examples/collatz.rs @@ -21,6 +21,7 @@ fn main() { cmd_buf.write_timestamp(&query_pool, 0); cmd_buf.dispatch(&pipeline, &descriptor_set, (256, 1, 1), (1, 1, 1)); cmd_buf.write_timestamp(&query_pool, 1); + cmd_buf.finish_timestamps(&query_pool); cmd_buf.host_barrier(); cmd_buf.finish(); let submitted = session.run_cmd_buf(cmd_buf, &[], &[]).unwrap(); diff --git a/tests/shader/prefix_scan.comp b/tests/shader/prefix_scan.comp new file mode 100644 index 0000000..59903ab --- /dev/null +++ b/tests/shader/prefix_scan.comp @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT OR Unlicense + +// A scan for a tree reduction prefix scan (either root or not, by ifdef). + +#version 450 + +#define N_ROWS 8 +#define LG_WG_SIZE 9 +#define WG_SIZE (1 << LG_WG_SIZE) +#define PARTITION_SIZE (WG_SIZE * N_ROWS) + +layout(local_size_x = WG_SIZE, local_size_y = 1) in; + +struct Monoid { + uint element; +}; + +layout(set = 0, binding = 0) buffer DataBuf { + Monoid[] data; +}; + +#ifndef ROOT +layout(set = 0, binding = 1) buffer ParentBuf { + Monoid[] parent; +}; +#endif + +shared Monoid sh_scratch[WG_SIZE]; + +Monoid combine_monoid(Monoid a, Monoid b) { + return Monoid(a.element + b.element); +} + +void main() { + Monoid local[N_ROWS]; + + uint ix = gl_GlobalInvocationID.x * N_ROWS; + + // TODO: gate buffer read + local[0] = data[ix]; + for (uint i = 1; i < N_ROWS; i++) { + local[i] = combine_monoid(local[i - 1], data[ix + i]); + } + Monoid agg = local[N_ROWS - 1]; + sh_scratch[gl_LocalInvocationID.x] = agg; + for (uint i = 0; i < LG_WG_SIZE; i++) { + barrier(); + if (gl_LocalInvocationID.x >= (1 << i)) { + Monoid other = sh_scratch[gl_LocalInvocationID.x - (1 << i)]; + agg = combine_monoid(other, agg); + } + barrier(); + sh_scratch[gl_LocalInvocationID.x] = agg; + } + + barrier(); + // This could be a semigroup instead of a monoid if we reworked the + // conditional logic, but that might impact performance. + Monoid row = Monoid(0); +#ifdef ROOT + if (gl_LocalInvocationID.x > 0) { + row = sh_scratch[gl_LocalInvocationID.x - 1]; + } +#else + if (gl_WorkGroupID.x > 0) { + row = parent[gl_WorkGroupID.x - 1]; + } + if (gl_LocalInvocationID.x > 0) { + row = combine_monoid(row, sh_scratch[gl_LocalInvocationID.x - 1]); + } +#endif + for (uint i = 0; i < N_ROWS; i++) { + Monoid m = combine_monoid(row, local[i]); + // TODO: gate buffer write + data[ix + i] = m; + } +} diff --git a/tests/src/runner.rs b/tests/src/runner.rs index ce89961..9bfde3b 100644 --- a/tests/src/runner.rs +++ b/tests/src/runner.rs @@ -76,6 +76,7 @@ impl Runner { pub unsafe fn submit(&mut self, commands: Commands) -> f64 { let mut cmd_buf = commands.cmd_buf; let query_pool = commands.query_pool; + cmd_buf.finish_timestamps(&query_pool); cmd_buf.host_barrier(); cmd_buf.finish(); let submitted = self.session.run_cmd_buf(cmd_buf, &[], &[]).unwrap(); From bd39d26bce7cff8cfa5af12b829a30c3b2199e46 Mon Sep 17 00:00:00 2001 From: Raph Levien Date: Tue, 9 Nov 2021 14:04:58 -0800 Subject: [PATCH 5/5] Improve collection and reporting of test results Have a structured way of gathering test results, rather than the existing ad hoc approach of just printing stuff. The details are still pretty primitive, but there's room to grow. --- tests/src/config.rs | 72 +++++++++++++++++++++++++ tests/src/main.rs | 52 ++++++++++++++++-- tests/src/prefix.rs | 19 +++---- tests/src/prefix_tree.rs | 19 +++---- tests/src/test_result.rs | 110 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 251 insertions(+), 21 deletions(-) create mode 100644 tests/src/config.rs create mode 100644 tests/src/test_result.rs diff --git a/tests/src/config.rs b/tests/src/config.rs new file mode 100644 index 0000000..50bd3be --- /dev/null +++ b/tests/src/config.rs @@ -0,0 +1,72 @@ +// Copyright 2021 The piet-gpu authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Also licensed under MIT license, at your choice. + +//! Test config parameters. + +use clap::ArgMatches; + +pub struct Config { + pub groups: Groups, + pub size: Size, +} + +pub struct Groups(String); + +pub enum Size { + Small, + Medium, + Large, +} + +impl Config { + pub fn from_matches(matches: &ArgMatches) -> Config { + let groups = Groups::from_str(matches.value_of("groups").unwrap_or("all")); + let size = Size::from_str(matches.value_of("size").unwrap_or("m")); + Config { + groups, size + } + } +} + +impl Groups { + pub fn from_str(s: &str) -> Groups { + Groups(s.to_string()) + } + + pub fn matches(&self, group_name: &str) -> bool { + self.0 == "all" || self.0 == group_name + } +} + +impl Size { + fn from_str(s: &str) -> Size { + if s == "small" || s == "s" { + Size::Small + } else if s == "large" || s == "l" { + Size::Large + } else { + Size::Medium + } + } + + pub fn choose(&self, small: T, medium: T, large: T) -> T { + match self { + Size::Small => small, + Size::Medium => medium, + Size::Large => large, + } + } +} diff --git a/tests/src/main.rs b/tests/src/main.rs index 85d8a66..b7bc1d9 100644 --- a/tests/src/main.rs +++ b/tests/src/main.rs @@ -16,16 +16,62 @@ //! Tests for piet-gpu shaders and GPU capabilities. +mod config; mod prefix; mod prefix_tree; mod runner; +mod test_result; -use runner::Runner; +use clap::{App, Arg}; + +use crate::config::Config; +use crate::runner::Runner; +use crate::test_result::{ReportStyle, TestResult}; fn main() { + let matches = App::new("piet-gpu-tests") + .arg( + Arg::with_name("verbose") + .short("v") + .long("verbose") + .help("Verbose reporting of results"), + ) + .arg( + Arg::with_name("groups") + .short("g") + .long("groups") + .help("Groups to run") + .takes_value(true) + ) + .arg( + Arg::with_name("size") + .short("s") + .long("size") + .help("Size of tests") + .takes_value(true) + ) + .arg( + Arg::with_name("n_iter") + .short("n") + .long("n_iter") + .help("Number of iterations") + .takes_value(true) + ) + .get_matches(); + let style = if matches.is_present("verbose") { + ReportStyle::Verbose + } else { + ReportStyle::Short + }; + let config = Config::from_matches(&matches); unsafe { + let report = |test_result: &TestResult| { + test_result.report(style); + }; let mut runner = Runner::new(); - prefix::run_prefix_test(&mut runner); - prefix_tree::run_prefix_test(&mut runner); + if config.groups.matches("prefix") { + report(&prefix::run_prefix_test(&mut runner, &config)); + report(&prefix_tree::run_prefix_test(&mut runner, &config)); + } } } diff --git a/tests/src/prefix.rs b/tests/src/prefix.rs index f95470a..adc58b4 100644 --- a/tests/src/prefix.rs +++ b/tests/src/prefix.rs @@ -17,7 +17,9 @@ use piet_gpu_hal::{include_shader, BufferUsage, DescriptorSet}; use piet_gpu_hal::{Buffer, Pipeline}; +use crate::config::Config; use crate::runner::{Commands, Runner}; +use crate::test_result::TestResult; const WG_SIZE: u64 = 512; const N_ROWS: u64 = 16; @@ -46,9 +48,10 @@ struct PrefixBinding { descriptor_set: DescriptorSet, } -pub unsafe fn run_prefix_test(runner: &mut Runner) { +pub unsafe fn run_prefix_test(runner: &mut Runner, config: &Config) -> TestResult { + let mut result = TestResult::new("prefix sum, decoupled look-back"); // This will be configurable. - let n_elements: u64 = 1 << 23; + let n_elements: u64 = config.size.choose(1 << 12, 1 << 24, 1 << 25); let data: Vec = (0..n_elements as u32).collect(); let data_buf = runner .session @@ -74,15 +77,13 @@ pub unsafe fn run_prefix_test(runner: &mut Runner) { if i == 0 { let mut dst: Vec = Default::default(); out_buf.read(&mut dst); - println!("failures: {:?}", verify(&dst)); + if let Some(failure) = verify(&dst) { + result.fail(format!("failure at {}", failure)); + } } } - let throughput = (n_elements * n_iter) as f64 / total_elapsed; - println!( - "total {:?}ms, throughput = {}G el/s", - total_elapsed * 1e3, - throughput * 1e-9 - ); + result.timing(total_elapsed, n_elements * n_iter); + result } impl PrefixCode { diff --git a/tests/src/prefix_tree.rs b/tests/src/prefix_tree.rs index 7b9743a..1f78202 100644 --- a/tests/src/prefix_tree.rs +++ b/tests/src/prefix_tree.rs @@ -17,7 +17,9 @@ use piet_gpu_hal::{include_shader, BufferUsage, DescriptorSet}; use piet_gpu_hal::{Buffer, Pipeline}; +use crate::config::Config; use crate::runner::{Commands, Runner}; +use crate::test_result::TestResult; const WG_SIZE: u64 = 512; const N_ROWS: u64 = 8; @@ -39,11 +41,12 @@ struct PrefixTreeBinding { descriptor_sets: Vec, } -pub unsafe fn run_prefix_test(runner: &mut Runner) { +pub unsafe fn run_prefix_test(runner: &mut Runner, config: &Config) -> TestResult { + let mut result = TestResult::new("prefix sum, tree reduction"); // This will be configurable. Note though that the current code is // prone to reading and writing past the end of buffers if this is // not a power of the number of elements processed in a workgroup. - let n_elements: u64 = 1 << 24; + let n_elements: u64 = config.size.choose(1 << 12, 1 << 24, 1 << 24); let data: Vec = (0..n_elements as u32).collect(); let data_buf = runner .session @@ -71,15 +74,13 @@ pub unsafe fn run_prefix_test(runner: &mut Runner) { if i == 0 { let mut dst: Vec = Default::default(); out_buf.read(&mut dst); - println!("failures: {:?}", verify(&dst)); + if let Some(failure) = verify(&dst) { + result.fail(format!("failure at {}", failure)); + } } } - let throughput = (n_elements * n_iter) as f64 / total_elapsed; - println!( - "total {:?}ms, throughput = {}G el/s", - total_elapsed * 1e3, - throughput * 1e-9 - ); + result.timing(total_elapsed, n_elements * n_iter); + result } impl PrefixTreeCode { diff --git a/tests/src/test_result.rs b/tests/src/test_result.rs new file mode 100644 index 0000000..84bbc85 --- /dev/null +++ b/tests/src/test_result.rs @@ -0,0 +1,110 @@ +// Copyright 2021 The piet-gpu authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Also licensed under MIT license, at your choice. + +//! Recording of results from tests. + +pub struct TestResult { + name: String, + // TODO: statistics. We're lean and mean for now. + total_time: f64, + n_elements: u64, + failure: Option, +} + +#[derive(Clone, Copy)] +pub enum ReportStyle { + Short, + Verbose, +} + +impl TestResult { + pub fn new(name: &str) -> TestResult { + TestResult { + name: name.to_string(), + total_time: 0.0, + n_elements: 0, + failure: None, + } + } + + pub fn report(&self, style: ReportStyle) { + let fail_string = match &self.failure { + None => "pass".into(), + Some(s) => format!("fail ({})", s), + }; + match style { + ReportStyle::Short => { + let mut timing_string = String::new(); + if self.total_time > 0.0 { + if self.n_elements > 0 { + let throughput = self.n_elements as f64 / self.total_time; + timing_string = format!(" {} elements/s", format_nice(throughput, 1)); + } else { + timing_string = format!(" {}s", format_nice(self.total_time, 1)); + } + } + println!("{}: {}{}", self.name, fail_string, timing_string) + } + ReportStyle::Verbose => { + println!("test {}", self.name); + println!(" {}", fail_string); + if self.total_time > 0.0 { + println!(" {}s total time", format_nice(self.total_time, 1)); + if self.n_elements > 0 { + println!(" {} elements", self.n_elements); + let throughput = self.n_elements as f64 / self.total_time; + println!(" {} elements/s", format_nice(throughput, 1)); + } + } + } + } + } + + pub fn fail(&mut self, explanation: String) { + self.failure = Some(explanation); + } + + pub fn timing(&mut self, total_time: f64, n_elements: u64) { + self.total_time = total_time; + self.n_elements = n_elements; + } +} + +fn format_nice(x: f64, precision: usize) -> String { + // Precision should probably scale; later + let (scale, suffix) = if x >= 1e12 && x < 1e15 { + (1e-12, "T") + } else if x >= 1e9 { + (1e-9, "G") + } else if x >= 1e6 { + (1e-6, "M") + } else if x >= 1e3 { + (1e-3, "k") + } else if x >= 1.0 { + (1.0, "") + } else if x >= 1e-3 { + (1e3, "m") + } else if x >= 1e-6 { + (1e6, "\u{00b5}") + } else if x >= 1e-9 { + (1e9, "n") + } else if x >= 1e-12 { + (1e12, "p") + } else { + return format!("{:.*e}", precision, x); + }; + format!("{:.*}{}", precision, scale * x, suffix) +}