mirror of
https://github.com/italicsjenga/vello.git
synced 2025-01-09 20:31:29 +11:00
Start testing framework
This adds a prefix sum test. This patch is also trying to get a little more serious about structuring both the test runner (toward the goal of collecting proper statistics) and pipeline stages for the tests. Still WIP but giving good results.
This commit is contained in:
parent
b0b0f33c3c
commit
33d7b25a92
15
Cargo.lock
generated
15
Cargo.lock
generated
|
@ -85,6 +85,12 @@ version = "0.1.6"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0d8c1fef690941d3e7788d328517591fecc684c084084702d6ff1641e993699a"
|
||||
|
||||
[[package]]
|
||||
name = "bytemuck"
|
||||
version = "1.7.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "72957246c41db82b8ef88a5486143830adeb8227ef9837740bdec67724cf2c5b"
|
||||
|
||||
[[package]]
|
||||
name = "byteorder"
|
||||
version = "1.3.4"
|
||||
|
@ -887,6 +893,15 @@ dependencies = [
|
|||
"wio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "piet-gpu-tests"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"clap",
|
||||
"piet-gpu-hal",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "piet-gpu-types"
|
||||
version = "0.0.0"
|
||||
|
|
|
@ -4,5 +4,6 @@ members = [
|
|||
"piet-gpu",
|
||||
"piet-gpu-derive",
|
||||
"piet-gpu-hal",
|
||||
"piet-gpu-types"
|
||||
"piet-gpu-types",
|
||||
"tests"
|
||||
]
|
||||
|
|
14
tests/Cargo.toml
Normal file
14
tests/Cargo.toml
Normal file
|
@ -0,0 +1,14 @@
|
|||
[package]
|
||||
name = "piet-gpu-tests"
|
||||
version = "0.1.0"
|
||||
authors = ["Raph Levien <raph.levien@gmail.com>"]
|
||||
description = "Tests for piet-gpu shaders and generic GPU capabilities."
|
||||
license = "MIT/Apache-2.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
clap = "2.33"
|
||||
bytemuck = "1.7.2"
|
||||
|
||||
[dependencies.piet-gpu-hal]
|
||||
path = "../piet-gpu-hal"
|
19
tests/shader/build.ninja
Normal file
19
tests/shader/build.ninja
Normal file
|
@ -0,0 +1,19 @@
|
|||
# Build file for shaders.
|
||||
|
||||
# You must have Vulkan tools in your path, or patch here.
|
||||
|
||||
glslang_validator = glslangValidator
|
||||
spirv_cross = spirv-cross
|
||||
|
||||
rule glsl
|
||||
command = $glslang_validator -V -o $out $in
|
||||
|
||||
rule hlsl
|
||||
command = $spirv_cross --hlsl $in --output $out
|
||||
|
||||
rule msl
|
||||
command = $spirv_cross --msl $in --output $out
|
||||
|
||||
build gen/prefix.spv: glsl prefix.comp
|
||||
build gen/prefix.hlsl: hlsl gen/prefix.spv
|
||||
build gen/prefix.msl: msl gen/prefix.spv
|
43
tests/shader/gen/collatz.hlsl
Normal file
43
tests/shader/gen/collatz.hlsl
Normal file
|
@ -0,0 +1,43 @@
|
|||
static const uint3 gl_WorkGroupSize = uint3(1u, 1u, 1u);
|
||||
|
||||
RWByteAddressBuffer _53 : register(u1);
|
||||
ByteAddressBuffer _59 : register(t0);
|
||||
|
||||
static uint3 gl_GlobalInvocationID;
|
||||
struct SPIRV_Cross_Input
|
||||
{
|
||||
uint3 gl_GlobalInvocationID : SV_DispatchThreadID;
|
||||
};
|
||||
|
||||
uint collatz_iterations(inout uint n)
|
||||
{
|
||||
uint i = 0u;
|
||||
while (n != 1u)
|
||||
{
|
||||
if ((n % 2u) == 0u)
|
||||
{
|
||||
n /= 2u;
|
||||
}
|
||||
else
|
||||
{
|
||||
n = (3u * n) + 1u;
|
||||
}
|
||||
i++;
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
void comp_main()
|
||||
{
|
||||
uint index = gl_GlobalInvocationID.x;
|
||||
uint param = _59.Load(index * 4 + 0);
|
||||
uint _65 = collatz_iterations(param);
|
||||
_53.Store(index * 4 + 0, _65);
|
||||
}
|
||||
|
||||
[numthreads(1, 1, 1)]
|
||||
void main(SPIRV_Cross_Input stage_input)
|
||||
{
|
||||
gl_GlobalInvocationID = stage_input.gl_GlobalInvocationID;
|
||||
comp_main();
|
||||
}
|
46
tests/shader/gen/collatz.msl
Normal file
46
tests/shader/gen/collatz.msl
Normal file
|
@ -0,0 +1,46 @@
|
|||
#pragma clang diagnostic ignored "-Wmissing-prototypes"
|
||||
|
||||
#include <metal_stdlib>
|
||||
#include <simd/simd.h>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
struct OutBuf
|
||||
{
|
||||
uint out_buf[1];
|
||||
};
|
||||
|
||||
struct InBuf
|
||||
{
|
||||
uint in_buf[1];
|
||||
};
|
||||
|
||||
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1u);
|
||||
|
||||
static inline __attribute__((always_inline))
|
||||
uint collatz_iterations(thread uint& n)
|
||||
{
|
||||
uint i = 0u;
|
||||
while (n != 1u)
|
||||
{
|
||||
if ((n % 2u) == 0u)
|
||||
{
|
||||
n /= 2u;
|
||||
}
|
||||
else
|
||||
{
|
||||
n = (3u * n) + 1u;
|
||||
}
|
||||
i++;
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
kernel void main0(device OutBuf& _53 [[buffer(0)]], const device InBuf& _59 [[buffer(1)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]])
|
||||
{
|
||||
uint index = gl_GlobalInvocationID.x;
|
||||
uint param = _59.in_buf[index];
|
||||
uint _65 = collatz_iterations(param);
|
||||
_53.out_buf[index] = _65;
|
||||
}
|
||||
|
BIN
tests/shader/gen/collatz.spv
Normal file
BIN
tests/shader/gen/collatz.spv
Normal file
Binary file not shown.
223
tests/shader/gen/prefix.hlsl
Normal file
223
tests/shader/gen/prefix.hlsl
Normal file
|
@ -0,0 +1,223 @@
|
|||
struct Monoid
|
||||
{
|
||||
uint element;
|
||||
};
|
||||
|
||||
struct State
|
||||
{
|
||||
uint flag;
|
||||
Monoid aggregate;
|
||||
Monoid prefix;
|
||||
};
|
||||
|
||||
static const uint3 gl_WorkGroupSize = uint3(512u, 1u, 1u);
|
||||
|
||||
static const Monoid _187 = { 0u };
|
||||
|
||||
globallycoherent RWByteAddressBuffer _43 : register(u2);
|
||||
ByteAddressBuffer _67 : register(t0);
|
||||
RWByteAddressBuffer _374 : register(u1);
|
||||
|
||||
static uint3 gl_LocalInvocationID;
|
||||
struct SPIRV_Cross_Input
|
||||
{
|
||||
uint3 gl_LocalInvocationID : SV_GroupThreadID;
|
||||
};
|
||||
|
||||
groupshared uint sh_part_ix;
|
||||
groupshared Monoid sh_scratch[512];
|
||||
groupshared uint sh_flag;
|
||||
groupshared Monoid sh_prefix;
|
||||
|
||||
Monoid combine_monoid(Monoid a, Monoid b)
|
||||
{
|
||||
Monoid _22 = { a.element + b.element };
|
||||
return _22;
|
||||
}
|
||||
|
||||
void comp_main()
|
||||
{
|
||||
if (gl_LocalInvocationID.x == 0u)
|
||||
{
|
||||
uint _47;
|
||||
_43.InterlockedAdd(0, 1u, _47);
|
||||
sh_part_ix = _47;
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
uint part_ix = sh_part_ix;
|
||||
uint ix = (part_ix * 8192u) + (gl_LocalInvocationID.x * 16u);
|
||||
Monoid _71;
|
||||
_71.element = _67.Load(ix * 4 + 0);
|
||||
Monoid local[16];
|
||||
local[0].element = _71.element;
|
||||
Monoid param_1;
|
||||
for (uint i = 1u; i < 16u; i++)
|
||||
{
|
||||
Monoid param = local[i - 1u];
|
||||
Monoid _94;
|
||||
_94.element = _67.Load((ix + i) * 4 + 0);
|
||||
param_1.element = _94.element;
|
||||
local[i] = combine_monoid(param, param_1);
|
||||
}
|
||||
Monoid agg = local[15];
|
||||
sh_scratch[gl_LocalInvocationID.x] = agg;
|
||||
for (uint i_1 = 0u; i_1 < 9u; i_1++)
|
||||
{
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
if (gl_LocalInvocationID.x >= uint(1 << int(i_1)))
|
||||
{
|
||||
Monoid other = sh_scratch[gl_LocalInvocationID.x - uint(1 << int(i_1))];
|
||||
Monoid param_2 = other;
|
||||
Monoid param_3 = agg;
|
||||
agg = combine_monoid(param_2, param_3);
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
sh_scratch[gl_LocalInvocationID.x] = agg;
|
||||
}
|
||||
if (gl_LocalInvocationID.x == 511u)
|
||||
{
|
||||
_43.Store(part_ix * 12 + 8, agg.element);
|
||||
if (part_ix == 0u)
|
||||
{
|
||||
_43.Store(12, agg.element);
|
||||
}
|
||||
}
|
||||
DeviceMemoryBarrier();
|
||||
if (gl_LocalInvocationID.x == 511u)
|
||||
{
|
||||
uint flag = 1u;
|
||||
if (part_ix == 0u)
|
||||
{
|
||||
flag = 2u;
|
||||
}
|
||||
_43.Store(part_ix * 12 + 4, flag);
|
||||
}
|
||||
Monoid exclusive = _187;
|
||||
if (part_ix != 0u)
|
||||
{
|
||||
uint look_back_ix = part_ix - 1u;
|
||||
uint their_ix = 0u;
|
||||
Monoid their_prefix;
|
||||
Monoid their_agg;
|
||||
Monoid m;
|
||||
while (true)
|
||||
{
|
||||
if (gl_LocalInvocationID.x == 511u)
|
||||
{
|
||||
sh_flag = _43.Load(look_back_ix * 12 + 4);
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
DeviceMemoryBarrier();
|
||||
uint flag_1 = sh_flag;
|
||||
if (flag_1 == 2u)
|
||||
{
|
||||
if (gl_LocalInvocationID.x == 511u)
|
||||
{
|
||||
Monoid _225;
|
||||
_225.element = _43.Load(look_back_ix * 12 + 12);
|
||||
their_prefix.element = _225.element;
|
||||
Monoid param_4 = their_prefix;
|
||||
Monoid param_5 = exclusive;
|
||||
exclusive = combine_monoid(param_4, param_5);
|
||||
}
|
||||
break;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (flag_1 == 1u)
|
||||
{
|
||||
if (gl_LocalInvocationID.x == 511u)
|
||||
{
|
||||
Monoid _247;
|
||||
_247.element = _43.Load(look_back_ix * 12 + 8);
|
||||
their_agg.element = _247.element;
|
||||
Monoid param_6 = their_agg;
|
||||
Monoid param_7 = exclusive;
|
||||
exclusive = combine_monoid(param_6, param_7);
|
||||
}
|
||||
look_back_ix--;
|
||||
their_ix = 0u;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (gl_LocalInvocationID.x == 511u)
|
||||
{
|
||||
Monoid _269;
|
||||
_269.element = _67.Load(((look_back_ix * 8192u) + their_ix) * 4 + 0);
|
||||
m.element = _269.element;
|
||||
if (their_ix == 0u)
|
||||
{
|
||||
their_agg = m;
|
||||
}
|
||||
else
|
||||
{
|
||||
Monoid param_8 = their_agg;
|
||||
Monoid param_9 = m;
|
||||
their_agg = combine_monoid(param_8, param_9);
|
||||
}
|
||||
their_ix++;
|
||||
if (their_ix == 8192u)
|
||||
{
|
||||
Monoid param_10 = their_agg;
|
||||
Monoid param_11 = exclusive;
|
||||
exclusive = combine_monoid(param_10, param_11);
|
||||
if (look_back_ix == 0u)
|
||||
{
|
||||
sh_flag = 2u;
|
||||
}
|
||||
else
|
||||
{
|
||||
look_back_ix--;
|
||||
their_ix = 0u;
|
||||
}
|
||||
}
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
flag_1 = sh_flag;
|
||||
if (flag_1 == 2u)
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (gl_LocalInvocationID.x == 511u)
|
||||
{
|
||||
Monoid param_12 = exclusive;
|
||||
Monoid param_13 = agg;
|
||||
Monoid inclusive_prefix = combine_monoid(param_12, param_13);
|
||||
sh_prefix = exclusive;
|
||||
_43.Store(part_ix * 12 + 12, inclusive_prefix.element);
|
||||
}
|
||||
DeviceMemoryBarrier();
|
||||
if (gl_LocalInvocationID.x == 511u)
|
||||
{
|
||||
_43.Store(part_ix * 12 + 4, 2u);
|
||||
}
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
if (part_ix != 0u)
|
||||
{
|
||||
exclusive = sh_prefix;
|
||||
}
|
||||
Monoid row = exclusive;
|
||||
if (gl_LocalInvocationID.x > 0u)
|
||||
{
|
||||
Monoid other_1 = sh_scratch[gl_LocalInvocationID.x - 1u];
|
||||
Monoid param_14 = row;
|
||||
Monoid param_15 = other_1;
|
||||
row = combine_monoid(param_14, param_15);
|
||||
}
|
||||
for (uint i_2 = 0u; i_2 < 16u; i_2++)
|
||||
{
|
||||
Monoid param_16 = row;
|
||||
Monoid param_17 = local[i_2];
|
||||
Monoid m_1 = combine_monoid(param_16, param_17);
|
||||
_374.Store((ix + i_2) * 4 + 0, m_1.element);
|
||||
}
|
||||
}
|
||||
|
||||
[numthreads(512, 1, 1)]
|
||||
void main(SPIRV_Cross_Input stage_input)
|
||||
{
|
||||
gl_LocalInvocationID = stage_input.gl_LocalInvocationID;
|
||||
comp_main();
|
||||
}
|
262
tests/shader/gen/prefix.msl
Normal file
262
tests/shader/gen/prefix.msl
Normal file
|
@ -0,0 +1,262 @@
|
|||
#pragma clang diagnostic ignored "-Wmissing-prototypes"
|
||||
#pragma clang diagnostic ignored "-Wmissing-braces"
|
||||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
|
||||
#include <metal_stdlib>
|
||||
#include <simd/simd.h>
|
||||
#include <metal_atomic>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
template<typename T, size_t Num>
|
||||
struct spvUnsafeArray
|
||||
{
|
||||
T elements[Num ? Num : 1];
|
||||
|
||||
thread T& operator [] (size_t pos) thread
|
||||
{
|
||||
return elements[pos];
|
||||
}
|
||||
constexpr const thread T& operator [] (size_t pos) const thread
|
||||
{
|
||||
return elements[pos];
|
||||
}
|
||||
|
||||
device T& operator [] (size_t pos) device
|
||||
{
|
||||
return elements[pos];
|
||||
}
|
||||
constexpr const device T& operator [] (size_t pos) const device
|
||||
{
|
||||
return elements[pos];
|
||||
}
|
||||
|
||||
constexpr const constant T& operator [] (size_t pos) const constant
|
||||
{
|
||||
return elements[pos];
|
||||
}
|
||||
|
||||
threadgroup T& operator [] (size_t pos) threadgroup
|
||||
{
|
||||
return elements[pos];
|
||||
}
|
||||
constexpr const threadgroup T& operator [] (size_t pos) const threadgroup
|
||||
{
|
||||
return elements[pos];
|
||||
}
|
||||
};
|
||||
|
||||
struct Monoid
|
||||
{
|
||||
uint element;
|
||||
};
|
||||
|
||||
struct Monoid_1
|
||||
{
|
||||
uint element;
|
||||
};
|
||||
|
||||
struct State
|
||||
{
|
||||
uint flag;
|
||||
Monoid_1 aggregate;
|
||||
Monoid_1 prefix;
|
||||
};
|
||||
|
||||
struct StateBuf
|
||||
{
|
||||
uint part_counter;
|
||||
State state[1];
|
||||
};
|
||||
|
||||
struct InBuf
|
||||
{
|
||||
Monoid_1 inbuf[1];
|
||||
};
|
||||
|
||||
struct OutBuf
|
||||
{
|
||||
Monoid_1 outbuf[1];
|
||||
};
|
||||
|
||||
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u);
|
||||
|
||||
static inline __attribute__((always_inline))
|
||||
Monoid combine_monoid(thread const Monoid& a, thread const Monoid& b)
|
||||
{
|
||||
return Monoid{ a.element + b.element };
|
||||
}
|
||||
|
||||
kernel void main0(volatile device StateBuf& _43 [[buffer(0)]], const device InBuf& _67 [[buffer(1)]], device OutBuf& _374 [[buffer(2)]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]])
|
||||
{
|
||||
threadgroup uint sh_part_ix;
|
||||
threadgroup Monoid sh_scratch[512];
|
||||
threadgroup uint sh_flag;
|
||||
threadgroup Monoid sh_prefix;
|
||||
if (gl_LocalInvocationID.x == 0u)
|
||||
{
|
||||
uint _47 = atomic_fetch_add_explicit((volatile device atomic_uint*)&_43.part_counter, 1u, memory_order_relaxed);
|
||||
sh_part_ix = _47;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
uint part_ix = sh_part_ix;
|
||||
uint ix = (part_ix * 8192u) + (gl_LocalInvocationID.x * 16u);
|
||||
spvUnsafeArray<Monoid, 16> local;
|
||||
local[0].element = _67.inbuf[ix].element;
|
||||
Monoid param_1;
|
||||
for (uint i = 1u; i < 16u; i++)
|
||||
{
|
||||
Monoid param = local[i - 1u];
|
||||
param_1.element = _67.inbuf[ix + i].element;
|
||||
local[i] = combine_monoid(param, param_1);
|
||||
}
|
||||
Monoid agg = local[15];
|
||||
sh_scratch[gl_LocalInvocationID.x] = agg;
|
||||
for (uint i_1 = 0u; i_1 < 9u; i_1++)
|
||||
{
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (gl_LocalInvocationID.x >= uint(1 << int(i_1)))
|
||||
{
|
||||
Monoid other = sh_scratch[gl_LocalInvocationID.x - uint(1 << int(i_1))];
|
||||
Monoid param_2 = other;
|
||||
Monoid param_3 = agg;
|
||||
agg = combine_monoid(param_2, param_3);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
sh_scratch[gl_LocalInvocationID.x] = agg;
|
||||
}
|
||||
if (gl_LocalInvocationID.x == 511u)
|
||||
{
|
||||
_43.state[part_ix].aggregate.element = agg.element;
|
||||
if (part_ix == 0u)
|
||||
{
|
||||
_43.state[0].prefix.element = agg.element;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_device);
|
||||
if (gl_LocalInvocationID.x == 511u)
|
||||
{
|
||||
uint flag = 1u;
|
||||
if (part_ix == 0u)
|
||||
{
|
||||
flag = 2u;
|
||||
}
|
||||
_43.state[part_ix].flag = flag;
|
||||
}
|
||||
Monoid exclusive = Monoid{ 0u };
|
||||
if (part_ix != 0u)
|
||||
{
|
||||
uint look_back_ix = part_ix - 1u;
|
||||
uint their_ix = 0u;
|
||||
Monoid their_prefix;
|
||||
Monoid their_agg;
|
||||
Monoid m;
|
||||
while (true)
|
||||
{
|
||||
if (gl_LocalInvocationID.x == 511u)
|
||||
{
|
||||
sh_flag = _43.state[look_back_ix].flag;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
threadgroup_barrier(mem_flags::mem_device);
|
||||
uint flag_1 = sh_flag;
|
||||
if (flag_1 == 2u)
|
||||
{
|
||||
if (gl_LocalInvocationID.x == 511u)
|
||||
{
|
||||
their_prefix.element = _43.state[look_back_ix].prefix.element;
|
||||
Monoid param_4 = their_prefix;
|
||||
Monoid param_5 = exclusive;
|
||||
exclusive = combine_monoid(param_4, param_5);
|
||||
}
|
||||
break;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (flag_1 == 1u)
|
||||
{
|
||||
if (gl_LocalInvocationID.x == 511u)
|
||||
{
|
||||
their_agg.element = _43.state[look_back_ix].aggregate.element;
|
||||
Monoid param_6 = their_agg;
|
||||
Monoid param_7 = exclusive;
|
||||
exclusive = combine_monoid(param_6, param_7);
|
||||
}
|
||||
look_back_ix--;
|
||||
their_ix = 0u;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (gl_LocalInvocationID.x == 511u)
|
||||
{
|
||||
m.element = _67.inbuf[(look_back_ix * 8192u) + their_ix].element;
|
||||
if (their_ix == 0u)
|
||||
{
|
||||
their_agg = m;
|
||||
}
|
||||
else
|
||||
{
|
||||
Monoid param_8 = their_agg;
|
||||
Monoid param_9 = m;
|
||||
their_agg = combine_monoid(param_8, param_9);
|
||||
}
|
||||
their_ix++;
|
||||
if (their_ix == 8192u)
|
||||
{
|
||||
Monoid param_10 = their_agg;
|
||||
Monoid param_11 = exclusive;
|
||||
exclusive = combine_monoid(param_10, param_11);
|
||||
if (look_back_ix == 0u)
|
||||
{
|
||||
sh_flag = 2u;
|
||||
}
|
||||
else
|
||||
{
|
||||
look_back_ix--;
|
||||
their_ix = 0u;
|
||||
}
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
flag_1 = sh_flag;
|
||||
if (flag_1 == 2u)
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (gl_LocalInvocationID.x == 511u)
|
||||
{
|
||||
Monoid param_12 = exclusive;
|
||||
Monoid param_13 = agg;
|
||||
Monoid inclusive_prefix = combine_monoid(param_12, param_13);
|
||||
sh_prefix = exclusive;
|
||||
_43.state[part_ix].prefix.element = inclusive_prefix.element;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_device);
|
||||
if (gl_LocalInvocationID.x == 511u)
|
||||
{
|
||||
_43.state[part_ix].flag = 2u;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (part_ix != 0u)
|
||||
{
|
||||
exclusive = sh_prefix;
|
||||
}
|
||||
Monoid row = exclusive;
|
||||
if (gl_LocalInvocationID.x > 0u)
|
||||
{
|
||||
Monoid other_1 = sh_scratch[gl_LocalInvocationID.x - 1u];
|
||||
Monoid param_14 = row;
|
||||
Monoid param_15 = other_1;
|
||||
row = combine_monoid(param_14, param_15);
|
||||
}
|
||||
for (uint i_2 = 0u; i_2 < 16u; i_2++)
|
||||
{
|
||||
Monoid param_16 = row;
|
||||
Monoid param_17 = local[i_2];
|
||||
Monoid m_1 = combine_monoid(param_16, param_17);
|
||||
_374.outbuf[ix + i_2].element = m_1.element;
|
||||
}
|
||||
}
|
||||
|
BIN
tests/shader/gen/prefix.spv
Normal file
BIN
tests/shader/gen/prefix.spv
Normal file
Binary file not shown.
188
tests/shader/prefix.comp
Normal file
188
tests/shader/prefix.comp
Normal file
|
@ -0,0 +1,188 @@
|
|||
// SPDX-License-Identifier: Apache-2.0 OR MIT OR Unlicense
|
||||
|
||||
// A prefix sum.
|
||||
|
||||
#version 450
|
||||
|
||||
#define N_ROWS 16
|
||||
#define LG_WG_SIZE 9
|
||||
#define WG_SIZE (1 << LG_WG_SIZE)
|
||||
#define PARTITION_SIZE (WG_SIZE * N_ROWS)
|
||||
|
||||
layout(local_size_x = WG_SIZE, local_size_y = 1) in;
|
||||
|
||||
struct Monoid {
|
||||
uint element;
|
||||
};
|
||||
|
||||
layout(set = 0, binding = 0) readonly buffer InBuf {
|
||||
Monoid[] inbuf;
|
||||
};
|
||||
|
||||
layout(set = 0, binding = 1) buffer OutBuf {
|
||||
Monoid[] outbuf;
|
||||
};
|
||||
|
||||
// These correspond to X, A, P respectively in the prefix sum paper.
|
||||
#define FLAG_NOT_READY 0
|
||||
#define FLAG_AGGREGATE_READY 1
|
||||
#define FLAG_PREFIX_READY 2
|
||||
|
||||
struct State {
|
||||
uint flag;
|
||||
Monoid aggregate;
|
||||
Monoid prefix;
|
||||
};
|
||||
|
||||
layout(set = 0, binding = 2) volatile buffer StateBuf {
|
||||
uint part_counter;
|
||||
State[] state;
|
||||
};
|
||||
|
||||
shared Monoid sh_scratch[WG_SIZE];
|
||||
|
||||
Monoid combine_monoid(Monoid a, Monoid b) {
|
||||
return Monoid(a.element + b.element);
|
||||
}
|
||||
|
||||
shared uint sh_part_ix;
|
||||
shared Monoid sh_prefix;
|
||||
shared uint sh_flag;
|
||||
|
||||
void main() {
|
||||
Monoid local[N_ROWS];
|
||||
// Determine partition to process by atomic counter (described in Section
|
||||
// 4.4 of prefix sum paper).
|
||||
if (gl_LocalInvocationID.x == 0) {
|
||||
sh_part_ix = atomicAdd(part_counter, 1);
|
||||
}
|
||||
barrier();
|
||||
uint part_ix = sh_part_ix;
|
||||
|
||||
uint ix = part_ix * PARTITION_SIZE + gl_LocalInvocationID.x * N_ROWS;
|
||||
|
||||
// TODO: gate buffer read? (evaluate whether shader check or
|
||||
// CPU-side padding is better)
|
||||
local[0] = inbuf[ix];
|
||||
for (uint i = 1; i < N_ROWS; i++) {
|
||||
local[i] = combine_monoid(local[i - 1], inbuf[ix + i]);
|
||||
}
|
||||
Monoid agg = local[N_ROWS - 1];
|
||||
sh_scratch[gl_LocalInvocationID.x] = agg;
|
||||
for (uint i = 0; i < LG_WG_SIZE; i++) {
|
||||
barrier();
|
||||
if (gl_LocalInvocationID.x >= (1 << i)) {
|
||||
Monoid other = sh_scratch[gl_LocalInvocationID.x - (1 << i)];
|
||||
agg = combine_monoid(other, agg);
|
||||
}
|
||||
barrier();
|
||||
sh_scratch[gl_LocalInvocationID.x] = agg;
|
||||
}
|
||||
|
||||
// Publish aggregate for this partition
|
||||
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
|
||||
state[part_ix].aggregate = agg;
|
||||
if (part_ix == 0) {
|
||||
state[0].prefix = agg;
|
||||
}
|
||||
}
|
||||
// Write flag with release semantics; this is done portably with a barrier.
|
||||
memoryBarrierBuffer();
|
||||
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
|
||||
uint flag = FLAG_AGGREGATE_READY;
|
||||
if (part_ix == 0) {
|
||||
flag = FLAG_PREFIX_READY;
|
||||
}
|
||||
state[part_ix].flag = flag;
|
||||
}
|
||||
|
||||
Monoid exclusive = Monoid(0);
|
||||
if (part_ix != 0) {
|
||||
// step 4 of paper: decoupled lookback
|
||||
uint look_back_ix = part_ix - 1;
|
||||
|
||||
Monoid their_agg;
|
||||
uint their_ix = 0;
|
||||
while (true) {
|
||||
// Read flag with acquire semantics.
|
||||
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
|
||||
sh_flag = state[look_back_ix].flag;
|
||||
}
|
||||
// The flag load is done only in the last thread. However, because the
|
||||
// translation of memoryBarrierBuffer to Metal requires uniform control
|
||||
// flow, we broadcast it to all threads.
|
||||
barrier();
|
||||
memoryBarrierBuffer();
|
||||
uint flag = sh_flag;
|
||||
|
||||
if (flag == FLAG_PREFIX_READY) {
|
||||
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
|
||||
Monoid their_prefix = state[look_back_ix].prefix;
|
||||
exclusive = combine_monoid(their_prefix, exclusive);
|
||||
}
|
||||
break;
|
||||
} else if (flag == FLAG_AGGREGATE_READY) {
|
||||
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
|
||||
their_agg = state[look_back_ix].aggregate;
|
||||
exclusive = combine_monoid(their_agg, exclusive);
|
||||
}
|
||||
look_back_ix--;
|
||||
their_ix = 0;
|
||||
continue;
|
||||
}
|
||||
// else spin
|
||||
|
||||
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
|
||||
// Unfortunately there's no guarantee of forward progress of other
|
||||
// workgroups, so compute a bit of the aggregate before trying again.
|
||||
// In the worst case, spinning stops when the aggregate is complete.
|
||||
Monoid m = inbuf[look_back_ix * PARTITION_SIZE + their_ix];
|
||||
if (their_ix == 0) {
|
||||
their_agg = m;
|
||||
} else {
|
||||
their_agg = combine_monoid(their_agg, m);
|
||||
}
|
||||
their_ix++;
|
||||
if (their_ix == PARTITION_SIZE) {
|
||||
exclusive = combine_monoid(their_agg, exclusive);
|
||||
if (look_back_ix == 0) {
|
||||
sh_flag = FLAG_PREFIX_READY;
|
||||
} else {
|
||||
look_back_ix--;
|
||||
their_ix = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
flag = sh_flag;
|
||||
if (flag == FLAG_PREFIX_READY) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// step 5 of paper: compute inclusive prefix
|
||||
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
|
||||
Monoid inclusive_prefix = combine_monoid(exclusive, agg);
|
||||
sh_prefix = exclusive;
|
||||
state[part_ix].prefix = inclusive_prefix;
|
||||
}
|
||||
memoryBarrierBuffer();
|
||||
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
|
||||
state[part_ix].flag = FLAG_PREFIX_READY;
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
if (part_ix != 0) {
|
||||
exclusive = sh_prefix;
|
||||
}
|
||||
|
||||
Monoid row = exclusive;
|
||||
if (gl_LocalInvocationID.x > 0) {
|
||||
Monoid other = sh_scratch[gl_LocalInvocationID.x - 1];
|
||||
row = combine_monoid(row, other);
|
||||
}
|
||||
for (uint i = 0; i < N_ROWS; i++) {
|
||||
Monoid m = combine_monoid(row, local[i]);
|
||||
// Make sure buffer allocation is padded appropriately.
|
||||
outbuf[ix + i] = m;
|
||||
}
|
||||
}
|
29
tests/src/main.rs
Normal file
29
tests/src/main.rs
Normal file
|
@ -0,0 +1,29 @@
|
|||
// Copyright 2021 The piet-gpu authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// Also licensed under MIT license, at your choice.
|
||||
|
||||
//! Tests for piet-gpu shaders and GPU capabilities.
|
||||
|
||||
mod prefix;
|
||||
mod runner;
|
||||
|
||||
use runner::Runner;
|
||||
|
||||
fn main() {
|
||||
unsafe {
|
||||
let mut runner = Runner::new();
|
||||
prefix::run_prefix_test(&mut runner);
|
||||
}
|
||||
}
|
142
tests/src/prefix.rs
Normal file
142
tests/src/prefix.rs
Normal file
|
@ -0,0 +1,142 @@
|
|||
// Copyright 2021 The piet-gpu authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// Also licensed under MIT license, at your choice.
|
||||
|
||||
use piet_gpu_hal::{include_shader, BufferUsage, DescriptorSet};
|
||||
use piet_gpu_hal::{Buffer, Pipeline};
|
||||
|
||||
use crate::runner::{Commands, Runner};
|
||||
|
||||
const WG_SIZE: u64 = 512;
|
||||
const N_ROWS: u64 = 16;
|
||||
const ELEMENTS_PER_WG: u64 = WG_SIZE * N_ROWS;
|
||||
|
||||
/// The shader code for the prefix sum example.
|
||||
///
|
||||
/// A code struct can be created once and reused any number of times.
|
||||
struct PrefixCode {
|
||||
pipeline: Pipeline,
|
||||
}
|
||||
|
||||
/// The stage resources for the prefix sum example.
|
||||
///
|
||||
/// A stage resources struct is specific to a particular problem size
|
||||
/// and queue.
|
||||
struct PrefixStage {
|
||||
// This is the actual problem size but perhaps it would be better to
|
||||
// treat it as a capacity.
|
||||
n_elements: u64,
|
||||
state_buf: Buffer,
|
||||
}
|
||||
|
||||
/// The binding for the prefix sum example.
|
||||
struct PrefixBinding {
|
||||
descriptor_set: DescriptorSet,
|
||||
}
|
||||
|
||||
pub unsafe fn run_prefix_test(runner: &mut Runner) {
|
||||
// This will be configurable.
|
||||
let n_elements: u64 = 1 << 23;
|
||||
let data: Vec<u32> = (0..n_elements as u32).collect();
|
||||
let data_buf = runner
|
||||
.session
|
||||
.create_buffer_init(&data, BufferUsage::STORAGE)
|
||||
.unwrap();
|
||||
let out_buf = runner.buf_down(data_buf.size());
|
||||
let code = PrefixCode::new(runner);
|
||||
let stage = PrefixStage::new(runner, n_elements);
|
||||
let binding = stage.bind(runner, &code, &data_buf, &out_buf.dev_buf);
|
||||
// Also will be configurable of course.
|
||||
let n_iter = 5000;
|
||||
let mut total_elapsed = 0.0;
|
||||
for i in 0..n_iter {
|
||||
let mut commands = runner.commands();
|
||||
commands.write_timestamp(0);
|
||||
stage.record(&mut commands, &code, &binding);
|
||||
commands.write_timestamp(1);
|
||||
if i == 0 {
|
||||
commands.cmd_buf.memory_barrier();
|
||||
commands.download(&out_buf);
|
||||
}
|
||||
total_elapsed += runner.submit(commands);
|
||||
if i == 0 {
|
||||
let mut dst: Vec<u32> = Default::default();
|
||||
out_buf.read(&mut dst);
|
||||
println!("failures: {:?}", verify(&dst));
|
||||
}
|
||||
}
|
||||
let throughput = (n_elements * n_iter) as f64 / total_elapsed;
|
||||
println!(
|
||||
"total {:?}ms, throughput = {}G el/s",
|
||||
total_elapsed * 1e3,
|
||||
throughput * 1e-9
|
||||
);
|
||||
}
|
||||
|
||||
impl PrefixCode {
|
||||
unsafe fn new(runner: &mut Runner) -> PrefixCode {
|
||||
let code = include_shader!(&runner.session, "../shader/gen/prefix");
|
||||
let pipeline = runner
|
||||
.session
|
||||
.create_simple_compute_pipeline(code, 3)
|
||||
.unwrap();
|
||||
PrefixCode { pipeline }
|
||||
}
|
||||
}
|
||||
|
||||
impl PrefixStage {
|
||||
unsafe fn new(runner: &mut Runner, n_elements: u64) -> PrefixStage {
|
||||
let n_workgroups = (n_elements + ELEMENTS_PER_WG - 1) / ELEMENTS_PER_WG;
|
||||
let state_buf_size = 4 + 12 * n_workgroups;
|
||||
let state_buf = runner
|
||||
.session
|
||||
.create_buffer(state_buf_size, BufferUsage::STORAGE | BufferUsage::COPY_DST)
|
||||
.unwrap();
|
||||
PrefixStage {
|
||||
n_elements,
|
||||
state_buf,
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn bind(&self, runner: &mut Runner, code: &PrefixCode, in_buf: &Buffer, out_buf: &Buffer) -> PrefixBinding {
|
||||
let descriptor_set = runner
|
||||
.session
|
||||
.create_simple_descriptor_set(&code.pipeline, &[in_buf, out_buf, &self.state_buf])
|
||||
.unwrap();
|
||||
PrefixBinding { descriptor_set }
|
||||
}
|
||||
|
||||
unsafe fn record(&self, commands: &mut Commands, code: &PrefixCode, bindings: &PrefixBinding) {
|
||||
let n_workgroups = (self.n_elements + ELEMENTS_PER_WG - 1) / ELEMENTS_PER_WG;
|
||||
commands.cmd_buf.clear_buffer(&self.state_buf, None);
|
||||
commands.cmd_buf.memory_barrier();
|
||||
commands.cmd_buf.dispatch(
|
||||
&code.pipeline,
|
||||
&bindings.descriptor_set,
|
||||
(n_workgroups as u32, 1, 1),
|
||||
(WG_SIZE as u32, 1, 1),
|
||||
);
|
||||
// One thing that's missing here is registering the buffers so
|
||||
// they can be safely dropped by Rust code before the execution
|
||||
// of the command buffer completes.
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that the data is OEIS A000217
|
||||
fn verify(data: &[u32]) -> Option<usize> {
|
||||
data.iter()
|
||||
.enumerate()
|
||||
.position(|(i, val)| ((i * (i + 1)) / 2) as u32 != *val)
|
||||
}
|
132
tests/src/runner.rs
Normal file
132
tests/src/runner.rs
Normal file
|
@ -0,0 +1,132 @@
|
|||
// Copyright 2021 The piet-gpu authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// Also licensed under MIT license, at your choice.
|
||||
|
||||
//! Test runner intended to make it easy to write tests.
|
||||
|
||||
use piet_gpu_hal::{Buffer, BufferUsage, CmdBuf, Instance, PlainData, QueryPool, Session};
|
||||
|
||||
pub struct Runner {
|
||||
#[allow(unused)]
|
||||
instance: Instance,
|
||||
pub session: Session,
|
||||
cmd_buf_pool: Vec<CmdBuf>,
|
||||
}
|
||||
|
||||
/// A wrapper around command buffers.
|
||||
pub struct Commands {
|
||||
pub cmd_buf: CmdBuf,
|
||||
query_pool: QueryPool,
|
||||
}
|
||||
|
||||
/// Buffer for uploading data to GPU.
|
||||
#[allow(unused)]
|
||||
pub struct BufUp {
|
||||
pub stage_buf: Buffer,
|
||||
pub dev_buf: Buffer,
|
||||
}
|
||||
|
||||
/// Buffer for downloading data from GPU.
|
||||
pub struct BufDown {
|
||||
pub stage_buf: Buffer,
|
||||
pub dev_buf: Buffer,
|
||||
}
|
||||
|
||||
impl Runner {
|
||||
pub unsafe fn new() -> Runner {
|
||||
let (instance, _) = Instance::new(None).unwrap();
|
||||
let device = instance.device(None).unwrap();
|
||||
let session = Session::new(device);
|
||||
let cmd_buf_pool = Vec::new();
|
||||
Runner {
|
||||
instance,
|
||||
session,
|
||||
cmd_buf_pool,
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn commands(&mut self) -> Commands {
|
||||
let mut cmd_buf = self
|
||||
.cmd_buf_pool
|
||||
.pop()
|
||||
.unwrap_or_else(|| self.session.cmd_buf().unwrap());
|
||||
cmd_buf.begin();
|
||||
// TODO: also pool these. But we might sort by size, as
|
||||
// we might not always be doing two.
|
||||
let query_pool = self.session.create_query_pool(2).unwrap();
|
||||
cmd_buf.reset_query_pool(&query_pool);
|
||||
Commands {
|
||||
cmd_buf,
|
||||
query_pool,
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn submit(&mut self, commands: Commands) -> f64 {
|
||||
let mut cmd_buf = commands.cmd_buf;
|
||||
let query_pool = commands.query_pool;
|
||||
cmd_buf.host_barrier();
|
||||
cmd_buf.finish();
|
||||
let submitted = self.session.run_cmd_buf(cmd_buf, &[], &[]).unwrap();
|
||||
self.cmd_buf_pool.extend(submitted.wait().unwrap());
|
||||
let timestamps = self.session.fetch_query_pool(&query_pool).unwrap();
|
||||
timestamps[0]
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub fn buf_up(&self, size: u64) -> BufUp {
|
||||
let stage_buf = self
|
||||
.session
|
||||
.create_buffer(size, BufferUsage::MAP_WRITE | BufferUsage::COPY_SRC)
|
||||
.unwrap();
|
||||
let dev_buf = self
|
||||
.session
|
||||
.create_buffer(size, BufferUsage::COPY_DST | BufferUsage::STORAGE)
|
||||
.unwrap();
|
||||
BufUp { stage_buf, dev_buf }
|
||||
}
|
||||
|
||||
pub fn buf_down(&self, size: u64) -> BufDown {
|
||||
let stage_buf = self
|
||||
.session
|
||||
.create_buffer(size, BufferUsage::MAP_READ | BufferUsage::COPY_DST)
|
||||
.unwrap();
|
||||
let dev_buf = self
|
||||
.session
|
||||
.create_buffer(size, BufferUsage::COPY_SRC | BufferUsage::STORAGE)
|
||||
.unwrap();
|
||||
BufDown { stage_buf, dev_buf }
|
||||
}
|
||||
}
|
||||
|
||||
impl Commands {
|
||||
pub unsafe fn write_timestamp(&mut self, query: u32) {
|
||||
self.cmd_buf.write_timestamp(&self.query_pool, query);
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub unsafe fn upload(&mut self, buf: &BufUp) {
|
||||
self.cmd_buf.copy_buffer(&buf.stage_buf, &buf.dev_buf);
|
||||
}
|
||||
|
||||
pub unsafe fn download(&mut self, buf: &BufDown) {
|
||||
self.cmd_buf.copy_buffer(&buf.dev_buf, &buf.stage_buf);
|
||||
}
|
||||
}
|
||||
|
||||
impl BufDown {
|
||||
pub unsafe fn read(&self, dst: &mut Vec<impl PlainData>) {
|
||||
self.stage_buf.read(dst).unwrap()
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue