Merge pull request #123 from linebender/tests

Start testing framework, with prefix sum
This commit is contained in:
Raph Levien 2021-11-10 11:10:45 -08:00 committed by GitHub
commit 19fedf36db
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
29 changed files with 2260 additions and 1 deletions

15
Cargo.lock generated
View file

@ -85,6 +85,12 @@ version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d8c1fef690941d3e7788d328517591fecc684c084084702d6ff1641e993699a"
[[package]]
name = "bytemuck"
version = "1.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72957246c41db82b8ef88a5486143830adeb8227ef9837740bdec67724cf2c5b"
[[package]]
name = "byteorder"
version = "1.3.4"
@ -897,6 +903,15 @@ dependencies = [
"wio",
]
[[package]]
name = "piet-gpu-tests"
version = "0.1.0"
dependencies = [
"bytemuck",
"clap",
"piet-gpu-hal",
]
[[package]]
name = "piet-gpu-types"
version = "0.0.0"

View file

@ -4,5 +4,6 @@ members = [
"piet-gpu",
"piet-gpu-derive",
"piet-gpu-hal",
"piet-gpu-types"
"piet-gpu-types",
"tests"
]

View file

@ -21,6 +21,7 @@ fn main() {
cmd_buf.write_timestamp(&query_pool, 0);
cmd_buf.dispatch(&pipeline, &descriptor_set, (256, 1, 1), (1, 1, 1));
cmd_buf.write_timestamp(&query_pool, 1);
cmd_buf.finish_timestamps(&query_pool);
cmd_buf.host_barrier();
cmd_buf.finish();
let submitted = session.run_cmd_buf(cmd_buf, &[], &[]).unwrap();

14
tests/Cargo.toml Normal file
View file

@ -0,0 +1,14 @@
[package]
name = "piet-gpu-tests"
version = "0.1.0"
authors = ["Raph Levien <raph.levien@gmail.com>"]
description = "Tests for piet-gpu shaders and generic GPU capabilities."
license = "MIT/Apache-2.0"
edition = "2021"
[dependencies]
clap = "2.33"
bytemuck = "1.7.2"
[dependencies.piet-gpu-hal]
path = "../piet-gpu-hal"

32
tests/shader/build.ninja Normal file
View file

@ -0,0 +1,32 @@
# Build file for shaders.
# You must have Vulkan tools in your path, or patch here.
glslang_validator = glslangValidator
spirv_cross = spirv-cross
rule glsl
command = $glslang_validator $flags -V -o $out $in
rule hlsl
command = $spirv_cross --hlsl $in --output $out
rule msl
command = $spirv_cross --msl $in --output $out
build gen/prefix.spv: glsl prefix.comp
build gen/prefix.hlsl: hlsl gen/prefix.spv
build gen/prefix.msl: msl gen/prefix.spv
build gen/prefix_reduce.spv: glsl prefix_reduce.comp
build gen/prefix_reduce.hlsl: hlsl gen/prefix_reduce.spv
build gen/prefix_reduce.msl: msl gen/prefix_reduce.spv
build gen/prefix_root.spv: glsl prefix_scan.comp
flags = -DROOT
build gen/prefix_root.hlsl: hlsl gen/prefix_root.spv
build gen/prefix_root.msl: msl gen/prefix_root.spv
build gen/prefix_scan.spv: glsl prefix_scan.comp
build gen/prefix_scan.hlsl: hlsl gen/prefix_scan.spv
build gen/prefix_scan.msl: msl gen/prefix_scan.spv

View file

@ -0,0 +1,43 @@
static const uint3 gl_WorkGroupSize = uint3(1u, 1u, 1u);
RWByteAddressBuffer _53 : register(u1);
ByteAddressBuffer _59 : register(t0);
static uint3 gl_GlobalInvocationID;
struct SPIRV_Cross_Input
{
uint3 gl_GlobalInvocationID : SV_DispatchThreadID;
};
uint collatz_iterations(inout uint n)
{
uint i = 0u;
while (n != 1u)
{
if ((n % 2u) == 0u)
{
n /= 2u;
}
else
{
n = (3u * n) + 1u;
}
i++;
}
return i;
}
void comp_main()
{
uint index = gl_GlobalInvocationID.x;
uint param = _59.Load(index * 4 + 0);
uint _65 = collatz_iterations(param);
_53.Store(index * 4 + 0, _65);
}
[numthreads(1, 1, 1)]
void main(SPIRV_Cross_Input stage_input)
{
gl_GlobalInvocationID = stage_input.gl_GlobalInvocationID;
comp_main();
}

View file

@ -0,0 +1,46 @@
#pragma clang diagnostic ignored "-Wmissing-prototypes"
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
struct OutBuf
{
uint out_buf[1];
};
struct InBuf
{
uint in_buf[1];
};
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1u);
static inline __attribute__((always_inline))
uint collatz_iterations(thread uint& n)
{
uint i = 0u;
while (n != 1u)
{
if ((n % 2u) == 0u)
{
n /= 2u;
}
else
{
n = (3u * n) + 1u;
}
i++;
}
return i;
}
kernel void main0(device OutBuf& _53 [[buffer(0)]], const device InBuf& _59 [[buffer(1)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]])
{
uint index = gl_GlobalInvocationID.x;
uint param = _59.in_buf[index];
uint _65 = collatz_iterations(param);
_53.out_buf[index] = _65;
}

Binary file not shown.

View file

@ -0,0 +1,223 @@
struct Monoid
{
uint element;
};
struct State
{
uint flag;
Monoid aggregate;
Monoid prefix;
};
static const uint3 gl_WorkGroupSize = uint3(512u, 1u, 1u);
static const Monoid _187 = { 0u };
globallycoherent RWByteAddressBuffer _43 : register(u2);
ByteAddressBuffer _67 : register(t0);
RWByteAddressBuffer _374 : register(u1);
static uint3 gl_LocalInvocationID;
struct SPIRV_Cross_Input
{
uint3 gl_LocalInvocationID : SV_GroupThreadID;
};
groupshared uint sh_part_ix;
groupshared Monoid sh_scratch[512];
groupshared uint sh_flag;
groupshared Monoid sh_prefix;
Monoid combine_monoid(Monoid a, Monoid b)
{
Monoid _22 = { a.element + b.element };
return _22;
}
void comp_main()
{
if (gl_LocalInvocationID.x == 0u)
{
uint _47;
_43.InterlockedAdd(0, 1u, _47);
sh_part_ix = _47;
}
GroupMemoryBarrierWithGroupSync();
uint part_ix = sh_part_ix;
uint ix = (part_ix * 8192u) + (gl_LocalInvocationID.x * 16u);
Monoid _71;
_71.element = _67.Load(ix * 4 + 0);
Monoid local[16];
local[0].element = _71.element;
Monoid param_1;
for (uint i = 1u; i < 16u; i++)
{
Monoid param = local[i - 1u];
Monoid _94;
_94.element = _67.Load((ix + i) * 4 + 0);
param_1.element = _94.element;
local[i] = combine_monoid(param, param_1);
}
Monoid agg = local[15];
sh_scratch[gl_LocalInvocationID.x] = agg;
for (uint i_1 = 0u; i_1 < 9u; i_1++)
{
GroupMemoryBarrierWithGroupSync();
if (gl_LocalInvocationID.x >= uint(1 << int(i_1)))
{
Monoid other = sh_scratch[gl_LocalInvocationID.x - uint(1 << int(i_1))];
Monoid param_2 = other;
Monoid param_3 = agg;
agg = combine_monoid(param_2, param_3);
}
GroupMemoryBarrierWithGroupSync();
sh_scratch[gl_LocalInvocationID.x] = agg;
}
if (gl_LocalInvocationID.x == 511u)
{
_43.Store(part_ix * 12 + 8, agg.element);
if (part_ix == 0u)
{
_43.Store(12, agg.element);
}
}
DeviceMemoryBarrier();
if (gl_LocalInvocationID.x == 511u)
{
uint flag = 1u;
if (part_ix == 0u)
{
flag = 2u;
}
_43.Store(part_ix * 12 + 4, flag);
}
Monoid exclusive = _187;
if (part_ix != 0u)
{
uint look_back_ix = part_ix - 1u;
uint their_ix = 0u;
Monoid their_prefix;
Monoid their_agg;
Monoid m;
while (true)
{
if (gl_LocalInvocationID.x == 511u)
{
sh_flag = _43.Load(look_back_ix * 12 + 4);
}
GroupMemoryBarrierWithGroupSync();
DeviceMemoryBarrier();
uint flag_1 = sh_flag;
if (flag_1 == 2u)
{
if (gl_LocalInvocationID.x == 511u)
{
Monoid _225;
_225.element = _43.Load(look_back_ix * 12 + 12);
their_prefix.element = _225.element;
Monoid param_4 = their_prefix;
Monoid param_5 = exclusive;
exclusive = combine_monoid(param_4, param_5);
}
break;
}
else
{
if (flag_1 == 1u)
{
if (gl_LocalInvocationID.x == 511u)
{
Monoid _247;
_247.element = _43.Load(look_back_ix * 12 + 8);
their_agg.element = _247.element;
Monoid param_6 = their_agg;
Monoid param_7 = exclusive;
exclusive = combine_monoid(param_6, param_7);
}
look_back_ix--;
their_ix = 0u;
continue;
}
}
if (gl_LocalInvocationID.x == 511u)
{
Monoid _269;
_269.element = _67.Load(((look_back_ix * 8192u) + their_ix) * 4 + 0);
m.element = _269.element;
if (their_ix == 0u)
{
their_agg = m;
}
else
{
Monoid param_8 = their_agg;
Monoid param_9 = m;
their_agg = combine_monoid(param_8, param_9);
}
their_ix++;
if (their_ix == 8192u)
{
Monoid param_10 = their_agg;
Monoid param_11 = exclusive;
exclusive = combine_monoid(param_10, param_11);
if (look_back_ix == 0u)
{
sh_flag = 2u;
}
else
{
look_back_ix--;
their_ix = 0u;
}
}
}
GroupMemoryBarrierWithGroupSync();
flag_1 = sh_flag;
if (flag_1 == 2u)
{
break;
}
}
if (gl_LocalInvocationID.x == 511u)
{
Monoid param_12 = exclusive;
Monoid param_13 = agg;
Monoid inclusive_prefix = combine_monoid(param_12, param_13);
sh_prefix = exclusive;
_43.Store(part_ix * 12 + 12, inclusive_prefix.element);
}
DeviceMemoryBarrier();
if (gl_LocalInvocationID.x == 511u)
{
_43.Store(part_ix * 12 + 4, 2u);
}
}
GroupMemoryBarrierWithGroupSync();
if (part_ix != 0u)
{
exclusive = sh_prefix;
}
Monoid row = exclusive;
if (gl_LocalInvocationID.x > 0u)
{
Monoid other_1 = sh_scratch[gl_LocalInvocationID.x - 1u];
Monoid param_14 = row;
Monoid param_15 = other_1;
row = combine_monoid(param_14, param_15);
}
for (uint i_2 = 0u; i_2 < 16u; i_2++)
{
Monoid param_16 = row;
Monoid param_17 = local[i_2];
Monoid m_1 = combine_monoid(param_16, param_17);
_374.Store((ix + i_2) * 4 + 0, m_1.element);
}
}
[numthreads(512, 1, 1)]
void main(SPIRV_Cross_Input stage_input)
{
gl_LocalInvocationID = stage_input.gl_LocalInvocationID;
comp_main();
}

262
tests/shader/gen/prefix.msl Normal file
View file

@ -0,0 +1,262 @@
#pragma clang diagnostic ignored "-Wmissing-prototypes"
#pragma clang diagnostic ignored "-Wmissing-braces"
#pragma clang diagnostic ignored "-Wunused-variable"
#include <metal_stdlib>
#include <simd/simd.h>
#include <metal_atomic>
using namespace metal;
template<typename T, size_t Num>
struct spvUnsafeArray
{
T elements[Num ? Num : 1];
thread T& operator [] (size_t pos) thread
{
return elements[pos];
}
constexpr const thread T& operator [] (size_t pos) const thread
{
return elements[pos];
}
device T& operator [] (size_t pos) device
{
return elements[pos];
}
constexpr const device T& operator [] (size_t pos) const device
{
return elements[pos];
}
constexpr const constant T& operator [] (size_t pos) const constant
{
return elements[pos];
}
threadgroup T& operator [] (size_t pos) threadgroup
{
return elements[pos];
}
constexpr const threadgroup T& operator [] (size_t pos) const threadgroup
{
return elements[pos];
}
};
struct Monoid
{
uint element;
};
struct Monoid_1
{
uint element;
};
struct State
{
uint flag;
Monoid_1 aggregate;
Monoid_1 prefix;
};
struct StateBuf
{
uint part_counter;
State state[1];
};
struct InBuf
{
Monoid_1 inbuf[1];
};
struct OutBuf
{
Monoid_1 outbuf[1];
};
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u);
static inline __attribute__((always_inline))
Monoid combine_monoid(thread const Monoid& a, thread const Monoid& b)
{
return Monoid{ a.element + b.element };
}
kernel void main0(volatile device StateBuf& _43 [[buffer(0)]], const device InBuf& _67 [[buffer(1)]], device OutBuf& _374 [[buffer(2)]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]])
{
threadgroup uint sh_part_ix;
threadgroup Monoid sh_scratch[512];
threadgroup uint sh_flag;
threadgroup Monoid sh_prefix;
if (gl_LocalInvocationID.x == 0u)
{
uint _47 = atomic_fetch_add_explicit((volatile device atomic_uint*)&_43.part_counter, 1u, memory_order_relaxed);
sh_part_ix = _47;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
uint part_ix = sh_part_ix;
uint ix = (part_ix * 8192u) + (gl_LocalInvocationID.x * 16u);
spvUnsafeArray<Monoid, 16> local;
local[0].element = _67.inbuf[ix].element;
Monoid param_1;
for (uint i = 1u; i < 16u; i++)
{
Monoid param = local[i - 1u];
param_1.element = _67.inbuf[ix + i].element;
local[i] = combine_monoid(param, param_1);
}
Monoid agg = local[15];
sh_scratch[gl_LocalInvocationID.x] = agg;
for (uint i_1 = 0u; i_1 < 9u; i_1++)
{
threadgroup_barrier(mem_flags::mem_threadgroup);
if (gl_LocalInvocationID.x >= uint(1 << int(i_1)))
{
Monoid other = sh_scratch[gl_LocalInvocationID.x - uint(1 << int(i_1))];
Monoid param_2 = other;
Monoid param_3 = agg;
agg = combine_monoid(param_2, param_3);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sh_scratch[gl_LocalInvocationID.x] = agg;
}
if (gl_LocalInvocationID.x == 511u)
{
_43.state[part_ix].aggregate.element = agg.element;
if (part_ix == 0u)
{
_43.state[0].prefix.element = agg.element;
}
}
threadgroup_barrier(mem_flags::mem_device);
if (gl_LocalInvocationID.x == 511u)
{
uint flag = 1u;
if (part_ix == 0u)
{
flag = 2u;
}
_43.state[part_ix].flag = flag;
}
Monoid exclusive = Monoid{ 0u };
if (part_ix != 0u)
{
uint look_back_ix = part_ix - 1u;
uint their_ix = 0u;
Monoid their_prefix;
Monoid their_agg;
Monoid m;
while (true)
{
if (gl_LocalInvocationID.x == 511u)
{
sh_flag = _43.state[look_back_ix].flag;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup_barrier(mem_flags::mem_device);
uint flag_1 = sh_flag;
if (flag_1 == 2u)
{
if (gl_LocalInvocationID.x == 511u)
{
their_prefix.element = _43.state[look_back_ix].prefix.element;
Monoid param_4 = their_prefix;
Monoid param_5 = exclusive;
exclusive = combine_monoid(param_4, param_5);
}
break;
}
else
{
if (flag_1 == 1u)
{
if (gl_LocalInvocationID.x == 511u)
{
their_agg.element = _43.state[look_back_ix].aggregate.element;
Monoid param_6 = their_agg;
Monoid param_7 = exclusive;
exclusive = combine_monoid(param_6, param_7);
}
look_back_ix--;
their_ix = 0u;
continue;
}
}
if (gl_LocalInvocationID.x == 511u)
{
m.element = _67.inbuf[(look_back_ix * 8192u) + their_ix].element;
if (their_ix == 0u)
{
their_agg = m;
}
else
{
Monoid param_8 = their_agg;
Monoid param_9 = m;
their_agg = combine_monoid(param_8, param_9);
}
their_ix++;
if (their_ix == 8192u)
{
Monoid param_10 = their_agg;
Monoid param_11 = exclusive;
exclusive = combine_monoid(param_10, param_11);
if (look_back_ix == 0u)
{
sh_flag = 2u;
}
else
{
look_back_ix--;
their_ix = 0u;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
flag_1 = sh_flag;
if (flag_1 == 2u)
{
break;
}
}
if (gl_LocalInvocationID.x == 511u)
{
Monoid param_12 = exclusive;
Monoid param_13 = agg;
Monoid inclusive_prefix = combine_monoid(param_12, param_13);
sh_prefix = exclusive;
_43.state[part_ix].prefix.element = inclusive_prefix.element;
}
threadgroup_barrier(mem_flags::mem_device);
if (gl_LocalInvocationID.x == 511u)
{
_43.state[part_ix].flag = 2u;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (part_ix != 0u)
{
exclusive = sh_prefix;
}
Monoid row = exclusive;
if (gl_LocalInvocationID.x > 0u)
{
Monoid other_1 = sh_scratch[gl_LocalInvocationID.x - 1u];
Monoid param_14 = row;
Monoid param_15 = other_1;
row = combine_monoid(param_14, param_15);
}
for (uint i_2 = 0u; i_2 < 16u; i_2++)
{
Monoid param_16 = row;
Monoid param_17 = local[i_2];
Monoid m_1 = combine_monoid(param_16, param_17);
_374.outbuf[ix + i_2].element = m_1.element;
}
}

BIN
tests/shader/gen/prefix.spv Normal file

Binary file not shown.

View file

@ -0,0 +1,72 @@
struct Monoid
{
uint element;
};
static const uint3 gl_WorkGroupSize = uint3(512u, 1u, 1u);
ByteAddressBuffer _40 : register(t0);
RWByteAddressBuffer _129 : register(u1);
static uint3 gl_WorkGroupID;
static uint3 gl_LocalInvocationID;
static uint3 gl_GlobalInvocationID;
struct SPIRV_Cross_Input
{
uint3 gl_WorkGroupID : SV_GroupID;
uint3 gl_LocalInvocationID : SV_GroupThreadID;
uint3 gl_GlobalInvocationID : SV_DispatchThreadID;
};
groupshared Monoid sh_scratch[512];
Monoid combine_monoid(Monoid a, Monoid b)
{
Monoid _22 = { a.element + b.element };
return _22;
}
void comp_main()
{
uint ix = gl_GlobalInvocationID.x * 8u;
Monoid _44;
_44.element = _40.Load(ix * 4 + 0);
Monoid agg;
agg.element = _44.element;
Monoid param_1;
for (uint i = 1u; i < 8u; i++)
{
Monoid param = agg;
Monoid _64;
_64.element = _40.Load((ix + i) * 4 + 0);
param_1.element = _64.element;
agg = combine_monoid(param, param_1);
}
sh_scratch[gl_LocalInvocationID.x] = agg;
for (uint i_1 = 0u; i_1 < 9u; i_1++)
{
GroupMemoryBarrierWithGroupSync();
if ((gl_LocalInvocationID.x + uint(1 << int(i_1))) < 512u)
{
Monoid other = sh_scratch[gl_LocalInvocationID.x + uint(1 << int(i_1))];
Monoid param_2 = agg;
Monoid param_3 = other;
agg = combine_monoid(param_2, param_3);
}
GroupMemoryBarrierWithGroupSync();
sh_scratch[gl_LocalInvocationID.x] = agg;
}
if (gl_LocalInvocationID.x == 0u)
{
_129.Store(gl_WorkGroupID.x * 4 + 0, agg.element);
}
}
[numthreads(512, 1, 1)]
void main(SPIRV_Cross_Input stage_input)
{
gl_WorkGroupID = stage_input.gl_WorkGroupID;
gl_LocalInvocationID = stage_input.gl_LocalInvocationID;
gl_GlobalInvocationID = stage_input.gl_GlobalInvocationID;
comp_main();
}

View file

@ -0,0 +1,68 @@
#pragma clang diagnostic ignored "-Wmissing-prototypes"
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
struct Monoid
{
uint element;
};
struct Monoid_1
{
uint element;
};
struct InBuf
{
Monoid_1 inbuf[1];
};
struct OutBuf
{
Monoid_1 outbuf[1];
};
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u);
static inline __attribute__((always_inline))
Monoid combine_monoid(thread const Monoid& a, thread const Monoid& b)
{
return Monoid{ a.element + b.element };
}
kernel void main0(const device InBuf& _40 [[buffer(0)]], device OutBuf& _129 [[buffer(1)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]])
{
threadgroup Monoid sh_scratch[512];
uint ix = gl_GlobalInvocationID.x * 8u;
Monoid agg;
agg.element = _40.inbuf[ix].element;
Monoid param_1;
for (uint i = 1u; i < 8u; i++)
{
Monoid param = agg;
param_1.element = _40.inbuf[ix + i].element;
agg = combine_monoid(param, param_1);
}
sh_scratch[gl_LocalInvocationID.x] = agg;
for (uint i_1 = 0u; i_1 < 9u; i_1++)
{
threadgroup_barrier(mem_flags::mem_threadgroup);
if ((gl_LocalInvocationID.x + uint(1 << int(i_1))) < 512u)
{
Monoid other = sh_scratch[gl_LocalInvocationID.x + uint(1 << int(i_1))];
Monoid param_2 = agg;
Monoid param_3 = other;
agg = combine_monoid(param_2, param_3);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sh_scratch[gl_LocalInvocationID.x] = agg;
}
if (gl_LocalInvocationID.x == 0u)
{
_129.outbuf[gl_WorkGroupID.x].element = agg.element;
}
}

Binary file not shown.

View file

@ -0,0 +1,80 @@
struct Monoid
{
uint element;
};
static const uint3 gl_WorkGroupSize = uint3(512u, 1u, 1u);
static const Monoid _133 = { 0u };
RWByteAddressBuffer _42 : register(u0);
static uint3 gl_LocalInvocationID;
static uint3 gl_GlobalInvocationID;
struct SPIRV_Cross_Input
{
uint3 gl_LocalInvocationID : SV_GroupThreadID;
uint3 gl_GlobalInvocationID : SV_DispatchThreadID;
};
groupshared Monoid sh_scratch[512];
Monoid combine_monoid(Monoid a, Monoid b)
{
Monoid _22 = { a.element + b.element };
return _22;
}
void comp_main()
{
uint ix = gl_GlobalInvocationID.x * 8u;
Monoid _46;
_46.element = _42.Load(ix * 4 + 0);
Monoid local[8];
local[0].element = _46.element;
Monoid param_1;
for (uint i = 1u; i < 8u; i++)
{
Monoid param = local[i - 1u];
Monoid _71;
_71.element = _42.Load((ix + i) * 4 + 0);
param_1.element = _71.element;
local[i] = combine_monoid(param, param_1);
}
Monoid agg = local[7];
sh_scratch[gl_LocalInvocationID.x] = agg;
for (uint i_1 = 0u; i_1 < 9u; i_1++)
{
GroupMemoryBarrierWithGroupSync();
if (gl_LocalInvocationID.x >= uint(1 << int(i_1)))
{
Monoid other = sh_scratch[gl_LocalInvocationID.x - uint(1 << int(i_1))];
Monoid param_2 = other;
Monoid param_3 = agg;
agg = combine_monoid(param_2, param_3);
}
GroupMemoryBarrierWithGroupSync();
sh_scratch[gl_LocalInvocationID.x] = agg;
}
GroupMemoryBarrierWithGroupSync();
Monoid row = _133;
if (gl_LocalInvocationID.x > 0u)
{
row = sh_scratch[gl_LocalInvocationID.x - 1u];
}
for (uint i_2 = 0u; i_2 < 8u; i_2++)
{
Monoid param_4 = row;
Monoid param_5 = local[i_2];
Monoid m = combine_monoid(param_4, param_5);
_42.Store((ix + i_2) * 4 + 0, m.element);
}
}
[numthreads(512, 1, 1)]
void main(SPIRV_Cross_Input stage_input)
{
gl_LocalInvocationID = stage_input.gl_LocalInvocationID;
gl_GlobalInvocationID = stage_input.gl_GlobalInvocationID;
comp_main();
}

View file

@ -0,0 +1,112 @@
#pragma clang diagnostic ignored "-Wmissing-prototypes"
#pragma clang diagnostic ignored "-Wmissing-braces"
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
template<typename T, size_t Num>
struct spvUnsafeArray
{
T elements[Num ? Num : 1];
thread T& operator [] (size_t pos) thread
{
return elements[pos];
}
constexpr const thread T& operator [] (size_t pos) const thread
{
return elements[pos];
}
device T& operator [] (size_t pos) device
{
return elements[pos];
}
constexpr const device T& operator [] (size_t pos) const device
{
return elements[pos];
}
constexpr const constant T& operator [] (size_t pos) const constant
{
return elements[pos];
}
threadgroup T& operator [] (size_t pos) threadgroup
{
return elements[pos];
}
constexpr const threadgroup T& operator [] (size_t pos) const threadgroup
{
return elements[pos];
}
};
struct Monoid
{
uint element;
};
struct Monoid_1
{
uint element;
};
struct DataBuf
{
Monoid_1 data[1];
};
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u);
static inline __attribute__((always_inline))
Monoid combine_monoid(thread const Monoid& a, thread const Monoid& b)
{
return Monoid{ a.element + b.element };
}
kernel void main0(device DataBuf& _42 [[buffer(0)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]])
{
threadgroup Monoid sh_scratch[512];
uint ix = gl_GlobalInvocationID.x * 8u;
spvUnsafeArray<Monoid, 8> local;
local[0].element = _42.data[ix].element;
Monoid param_1;
for (uint i = 1u; i < 8u; i++)
{
Monoid param = local[i - 1u];
param_1.element = _42.data[ix + i].element;
local[i] = combine_monoid(param, param_1);
}
Monoid agg = local[7];
sh_scratch[gl_LocalInvocationID.x] = agg;
for (uint i_1 = 0u; i_1 < 9u; i_1++)
{
threadgroup_barrier(mem_flags::mem_threadgroup);
if (gl_LocalInvocationID.x >= uint(1 << int(i_1)))
{
Monoid other = sh_scratch[gl_LocalInvocationID.x - uint(1 << int(i_1))];
Monoid param_2 = other;
Monoid param_3 = agg;
agg = combine_monoid(param_2, param_3);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sh_scratch[gl_LocalInvocationID.x] = agg;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
Monoid row = Monoid{ 0u };
if (gl_LocalInvocationID.x > 0u)
{
row = sh_scratch[gl_LocalInvocationID.x - 1u];
}
for (uint i_2 = 0u; i_2 < 8u; i_2++)
{
Monoid param_4 = row;
Monoid param_5 = local[i_2];
Monoid m = combine_monoid(param_4, param_5);
_42.data[ix + i_2].element = m.element;
}
}

Binary file not shown.

View file

@ -0,0 +1,92 @@
struct Monoid
{
uint element;
};
static const uint3 gl_WorkGroupSize = uint3(512u, 1u, 1u);
static const Monoid _133 = { 0u };
RWByteAddressBuffer _42 : register(u0);
RWByteAddressBuffer _143 : register(u1);
static uint3 gl_WorkGroupID;
static uint3 gl_LocalInvocationID;
static uint3 gl_GlobalInvocationID;
struct SPIRV_Cross_Input
{
uint3 gl_WorkGroupID : SV_GroupID;
uint3 gl_LocalInvocationID : SV_GroupThreadID;
uint3 gl_GlobalInvocationID : SV_DispatchThreadID;
};
groupshared Monoid sh_scratch[512];
Monoid combine_monoid(Monoid a, Monoid b)
{
Monoid _22 = { a.element + b.element };
return _22;
}
void comp_main()
{
uint ix = gl_GlobalInvocationID.x * 8u;
Monoid _46;
_46.element = _42.Load(ix * 4 + 0);
Monoid local[8];
local[0].element = _46.element;
Monoid param_1;
for (uint i = 1u; i < 8u; i++)
{
Monoid param = local[i - 1u];
Monoid _71;
_71.element = _42.Load((ix + i) * 4 + 0);
param_1.element = _71.element;
local[i] = combine_monoid(param, param_1);
}
Monoid agg = local[7];
sh_scratch[gl_LocalInvocationID.x] = agg;
for (uint i_1 = 0u; i_1 < 9u; i_1++)
{
GroupMemoryBarrierWithGroupSync();
if (gl_LocalInvocationID.x >= uint(1 << int(i_1)))
{
Monoid other = sh_scratch[gl_LocalInvocationID.x - uint(1 << int(i_1))];
Monoid param_2 = other;
Monoid param_3 = agg;
agg = combine_monoid(param_2, param_3);
}
GroupMemoryBarrierWithGroupSync();
sh_scratch[gl_LocalInvocationID.x] = agg;
}
GroupMemoryBarrierWithGroupSync();
Monoid row = _133;
if (gl_WorkGroupID.x > 0u)
{
Monoid _148;
_148.element = _143.Load((gl_WorkGroupID.x - 1u) * 4 + 0);
row.element = _148.element;
}
if (gl_LocalInvocationID.x > 0u)
{
Monoid param_4 = row;
Monoid param_5 = sh_scratch[gl_LocalInvocationID.x - 1u];
row = combine_monoid(param_4, param_5);
}
for (uint i_2 = 0u; i_2 < 8u; i_2++)
{
Monoid param_6 = row;
Monoid param_7 = local[i_2];
Monoid m = combine_monoid(param_6, param_7);
_42.Store((ix + i_2) * 4 + 0, m.element);
}
}
[numthreads(512, 1, 1)]
void main(SPIRV_Cross_Input stage_input)
{
gl_WorkGroupID = stage_input.gl_WorkGroupID;
gl_LocalInvocationID = stage_input.gl_LocalInvocationID;
gl_GlobalInvocationID = stage_input.gl_GlobalInvocationID;
comp_main();
}

View file

@ -0,0 +1,123 @@
#pragma clang diagnostic ignored "-Wmissing-prototypes"
#pragma clang diagnostic ignored "-Wmissing-braces"
#include <metal_stdlib>
#include <simd/simd.h>
using namespace metal;
template<typename T, size_t Num>
struct spvUnsafeArray
{
T elements[Num ? Num : 1];
thread T& operator [] (size_t pos) thread
{
return elements[pos];
}
constexpr const thread T& operator [] (size_t pos) const thread
{
return elements[pos];
}
device T& operator [] (size_t pos) device
{
return elements[pos];
}
constexpr const device T& operator [] (size_t pos) const device
{
return elements[pos];
}
constexpr const constant T& operator [] (size_t pos) const constant
{
return elements[pos];
}
threadgroup T& operator [] (size_t pos) threadgroup
{
return elements[pos];
}
constexpr const threadgroup T& operator [] (size_t pos) const threadgroup
{
return elements[pos];
}
};
struct Monoid
{
uint element;
};
struct Monoid_1
{
uint element;
};
struct DataBuf
{
Monoid_1 data[1];
};
struct ParentBuf
{
Monoid_1 parent[1];
};
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u);
static inline __attribute__((always_inline))
Monoid combine_monoid(thread const Monoid& a, thread const Monoid& b)
{
return Monoid{ a.element + b.element };
}
kernel void main0(device DataBuf& _42 [[buffer(0)]], device ParentBuf& _143 [[buffer(1)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]])
{
threadgroup Monoid sh_scratch[512];
uint ix = gl_GlobalInvocationID.x * 8u;
spvUnsafeArray<Monoid, 8> local;
local[0].element = _42.data[ix].element;
Monoid param_1;
for (uint i = 1u; i < 8u; i++)
{
Monoid param = local[i - 1u];
param_1.element = _42.data[ix + i].element;
local[i] = combine_monoid(param, param_1);
}
Monoid agg = local[7];
sh_scratch[gl_LocalInvocationID.x] = agg;
for (uint i_1 = 0u; i_1 < 9u; i_1++)
{
threadgroup_barrier(mem_flags::mem_threadgroup);
if (gl_LocalInvocationID.x >= uint(1 << int(i_1)))
{
Monoid other = sh_scratch[gl_LocalInvocationID.x - uint(1 << int(i_1))];
Monoid param_2 = other;
Monoid param_3 = agg;
agg = combine_monoid(param_2, param_3);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sh_scratch[gl_LocalInvocationID.x] = agg;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
Monoid row = Monoid{ 0u };
if (gl_WorkGroupID.x > 0u)
{
row.element = _143.parent[gl_WorkGroupID.x - 1u].element;
}
if (gl_LocalInvocationID.x > 0u)
{
Monoid param_4 = row;
Monoid param_5 = sh_scratch[gl_LocalInvocationID.x - 1u];
row = combine_monoid(param_4, param_5);
}
for (uint i_2 = 0u; i_2 < 8u; i_2++)
{
Monoid param_6 = row;
Monoid param_7 = local[i_2];
Monoid m = combine_monoid(param_6, param_7);
_42.data[ix + i_2].element = m.element;
}
}

Binary file not shown.

188
tests/shader/prefix.comp Normal file
View file

@ -0,0 +1,188 @@
// SPDX-License-Identifier: Apache-2.0 OR MIT OR Unlicense
// A prefix sum.
#version 450
#define N_ROWS 16
#define LG_WG_SIZE 9
#define WG_SIZE (1 << LG_WG_SIZE)
#define PARTITION_SIZE (WG_SIZE * N_ROWS)
layout(local_size_x = WG_SIZE, local_size_y = 1) in;
struct Monoid {
uint element;
};
layout(set = 0, binding = 0) readonly buffer InBuf {
Monoid[] inbuf;
};
layout(set = 0, binding = 1) buffer OutBuf {
Monoid[] outbuf;
};
// These correspond to X, A, P respectively in the prefix sum paper.
#define FLAG_NOT_READY 0
#define FLAG_AGGREGATE_READY 1
#define FLAG_PREFIX_READY 2
struct State {
uint flag;
Monoid aggregate;
Monoid prefix;
};
layout(set = 0, binding = 2) volatile buffer StateBuf {
uint part_counter;
State[] state;
};
shared Monoid sh_scratch[WG_SIZE];
Monoid combine_monoid(Monoid a, Monoid b) {
return Monoid(a.element + b.element);
}
shared uint sh_part_ix;
shared Monoid sh_prefix;
shared uint sh_flag;
void main() {
Monoid local[N_ROWS];
// Determine partition to process by atomic counter (described in Section
// 4.4 of prefix sum paper).
if (gl_LocalInvocationID.x == 0) {
sh_part_ix = atomicAdd(part_counter, 1);
}
barrier();
uint part_ix = sh_part_ix;
uint ix = part_ix * PARTITION_SIZE + gl_LocalInvocationID.x * N_ROWS;
// TODO: gate buffer read? (evaluate whether shader check or
// CPU-side padding is better)
local[0] = inbuf[ix];
for (uint i = 1; i < N_ROWS; i++) {
local[i] = combine_monoid(local[i - 1], inbuf[ix + i]);
}
Monoid agg = local[N_ROWS - 1];
sh_scratch[gl_LocalInvocationID.x] = agg;
for (uint i = 0; i < LG_WG_SIZE; i++) {
barrier();
if (gl_LocalInvocationID.x >= (1 << i)) {
Monoid other = sh_scratch[gl_LocalInvocationID.x - (1 << i)];
agg = combine_monoid(other, agg);
}
barrier();
sh_scratch[gl_LocalInvocationID.x] = agg;
}
// Publish aggregate for this partition
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
state[part_ix].aggregate = agg;
if (part_ix == 0) {
state[0].prefix = agg;
}
}
// Write flag with release semantics; this is done portably with a barrier.
memoryBarrierBuffer();
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
uint flag = FLAG_AGGREGATE_READY;
if (part_ix == 0) {
flag = FLAG_PREFIX_READY;
}
state[part_ix].flag = flag;
}
Monoid exclusive = Monoid(0);
if (part_ix != 0) {
// step 4 of paper: decoupled lookback
uint look_back_ix = part_ix - 1;
Monoid their_agg;
uint their_ix = 0;
while (true) {
// Read flag with acquire semantics.
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
sh_flag = state[look_back_ix].flag;
}
// The flag load is done only in the last thread. However, because the
// translation of memoryBarrierBuffer to Metal requires uniform control
// flow, we broadcast it to all threads.
barrier();
memoryBarrierBuffer();
uint flag = sh_flag;
if (flag == FLAG_PREFIX_READY) {
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
Monoid their_prefix = state[look_back_ix].prefix;
exclusive = combine_monoid(their_prefix, exclusive);
}
break;
} else if (flag == FLAG_AGGREGATE_READY) {
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
their_agg = state[look_back_ix].aggregate;
exclusive = combine_monoid(their_agg, exclusive);
}
look_back_ix--;
their_ix = 0;
continue;
}
// else spin
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
// Unfortunately there's no guarantee of forward progress of other
// workgroups, so compute a bit of the aggregate before trying again.
// In the worst case, spinning stops when the aggregate is complete.
Monoid m = inbuf[look_back_ix * PARTITION_SIZE + their_ix];
if (their_ix == 0) {
their_agg = m;
} else {
their_agg = combine_monoid(their_agg, m);
}
their_ix++;
if (their_ix == PARTITION_SIZE) {
exclusive = combine_monoid(their_agg, exclusive);
if (look_back_ix == 0) {
sh_flag = FLAG_PREFIX_READY;
} else {
look_back_ix--;
their_ix = 0;
}
}
}
barrier();
flag = sh_flag;
if (flag == FLAG_PREFIX_READY) {
break;
}
}
// step 5 of paper: compute inclusive prefix
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
Monoid inclusive_prefix = combine_monoid(exclusive, agg);
sh_prefix = exclusive;
state[part_ix].prefix = inclusive_prefix;
}
memoryBarrierBuffer();
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
state[part_ix].flag = FLAG_PREFIX_READY;
}
}
barrier();
if (part_ix != 0) {
exclusive = sh_prefix;
}
Monoid row = exclusive;
if (gl_LocalInvocationID.x > 0) {
Monoid other = sh_scratch[gl_LocalInvocationID.x - 1];
row = combine_monoid(row, other);
}
for (uint i = 0; i < N_ROWS; i++) {
Monoid m = combine_monoid(row, local[i]);
// Make sure buffer allocation is padded appropriately.
outbuf[ix + i] = m;
}
}

View file

@ -0,0 +1,53 @@
// SPDX-License-Identifier: Apache-2.0 OR MIT OR Unlicense
// The reduction phase for prefix sum implemented as a tree reduction.
#version 450
#define N_ROWS 8
#define LG_WG_SIZE 9
#define WG_SIZE (1 << LG_WG_SIZE)
#define PARTITION_SIZE (WG_SIZE * N_ROWS)
layout(local_size_x = WG_SIZE, local_size_y = 1) in;
struct Monoid {
uint element;
};
layout(set = 0, binding = 0) readonly buffer InBuf {
Monoid[] inbuf;
};
layout(set = 0, binding = 1) buffer OutBuf {
Monoid[] outbuf;
};
shared Monoid sh_scratch[WG_SIZE];
Monoid combine_monoid(Monoid a, Monoid b) {
return Monoid(a.element + b.element);
}
void main() {
uint ix = gl_GlobalInvocationID.x * N_ROWS;
// TODO: gate buffer read
Monoid agg = inbuf[ix];
for (uint i = 1; i < N_ROWS; i++) {
agg = combine_monoid(agg, inbuf[ix + i]);
}
sh_scratch[gl_LocalInvocationID.x] = agg;
for (uint i = 0; i < LG_WG_SIZE; i++) {
barrier();
// We could make this predicate tighter, but would it help?
if (gl_LocalInvocationID.x + (1 << i) < WG_SIZE) {
Monoid other = sh_scratch[gl_LocalInvocationID.x + (1 << i)];
agg = combine_monoid(agg, other);
}
barrier();
sh_scratch[gl_LocalInvocationID.x] = agg;
}
if (gl_LocalInvocationID.x == 0) {
outbuf[gl_WorkGroupID.x] = agg;
}
}

View file

@ -0,0 +1,77 @@
// SPDX-License-Identifier: Apache-2.0 OR MIT OR Unlicense
// A scan for a tree reduction prefix scan (either root or not, by ifdef).
#version 450
#define N_ROWS 8
#define LG_WG_SIZE 9
#define WG_SIZE (1 << LG_WG_SIZE)
#define PARTITION_SIZE (WG_SIZE * N_ROWS)
layout(local_size_x = WG_SIZE, local_size_y = 1) in;
struct Monoid {
uint element;
};
layout(set = 0, binding = 0) buffer DataBuf {
Monoid[] data;
};
#ifndef ROOT
layout(set = 0, binding = 1) buffer ParentBuf {
Monoid[] parent;
};
#endif
shared Monoid sh_scratch[WG_SIZE];
Monoid combine_monoid(Monoid a, Monoid b) {
return Monoid(a.element + b.element);
}
void main() {
Monoid local[N_ROWS];
uint ix = gl_GlobalInvocationID.x * N_ROWS;
// TODO: gate buffer read
local[0] = data[ix];
for (uint i = 1; i < N_ROWS; i++) {
local[i] = combine_monoid(local[i - 1], data[ix + i]);
}
Monoid agg = local[N_ROWS - 1];
sh_scratch[gl_LocalInvocationID.x] = agg;
for (uint i = 0; i < LG_WG_SIZE; i++) {
barrier();
if (gl_LocalInvocationID.x >= (1 << i)) {
Monoid other = sh_scratch[gl_LocalInvocationID.x - (1 << i)];
agg = combine_monoid(other, agg);
}
barrier();
sh_scratch[gl_LocalInvocationID.x] = agg;
}
barrier();
// This could be a semigroup instead of a monoid if we reworked the
// conditional logic, but that might impact performance.
Monoid row = Monoid(0);
#ifdef ROOT
if (gl_LocalInvocationID.x > 0) {
row = sh_scratch[gl_LocalInvocationID.x - 1];
}
#else
if (gl_WorkGroupID.x > 0) {
row = parent[gl_WorkGroupID.x - 1];
}
if (gl_LocalInvocationID.x > 0) {
row = combine_monoid(row, sh_scratch[gl_LocalInvocationID.x - 1]);
}
#endif
for (uint i = 0; i < N_ROWS; i++) {
Monoid m = combine_monoid(row, local[i]);
// TODO: gate buffer write
data[ix + i] = m;
}
}

72
tests/src/config.rs Normal file
View file

@ -0,0 +1,72 @@
// Copyright 2021 The piet-gpu authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Also licensed under MIT license, at your choice.
//! Test config parameters.
use clap::ArgMatches;
pub struct Config {
pub groups: Groups,
pub size: Size,
}
pub struct Groups(String);
pub enum Size {
Small,
Medium,
Large,
}
impl Config {
pub fn from_matches(matches: &ArgMatches) -> Config {
let groups = Groups::from_str(matches.value_of("groups").unwrap_or("all"));
let size = Size::from_str(matches.value_of("size").unwrap_or("m"));
Config {
groups, size
}
}
}
impl Groups {
pub fn from_str(s: &str) -> Groups {
Groups(s.to_string())
}
pub fn matches(&self, group_name: &str) -> bool {
self.0 == "all" || self.0 == group_name
}
}
impl Size {
fn from_str(s: &str) -> Size {
if s == "small" || s == "s" {
Size::Small
} else if s == "large" || s == "l" {
Size::Large
} else {
Size::Medium
}
}
pub fn choose<T>(&self, small: T, medium: T, large: T) -> T {
match self {
Size::Small => small,
Size::Medium => medium,
Size::Large => large,
}
}
}

77
tests/src/main.rs Normal file
View file

@ -0,0 +1,77 @@
// Copyright 2021 The piet-gpu authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Also licensed under MIT license, at your choice.
//! Tests for piet-gpu shaders and GPU capabilities.
mod config;
mod prefix;
mod prefix_tree;
mod runner;
mod test_result;
use clap::{App, Arg};
use crate::config::Config;
use crate::runner::Runner;
use crate::test_result::{ReportStyle, TestResult};
fn main() {
let matches = App::new("piet-gpu-tests")
.arg(
Arg::with_name("verbose")
.short("v")
.long("verbose")
.help("Verbose reporting of results"),
)
.arg(
Arg::with_name("groups")
.short("g")
.long("groups")
.help("Groups to run")
.takes_value(true)
)
.arg(
Arg::with_name("size")
.short("s")
.long("size")
.help("Size of tests")
.takes_value(true)
)
.arg(
Arg::with_name("n_iter")
.short("n")
.long("n_iter")
.help("Number of iterations")
.takes_value(true)
)
.get_matches();
let style = if matches.is_present("verbose") {
ReportStyle::Verbose
} else {
ReportStyle::Short
};
let config = Config::from_matches(&matches);
unsafe {
let report = |test_result: &TestResult| {
test_result.report(style);
};
let mut runner = Runner::new();
if config.groups.matches("prefix") {
report(&prefix::run_prefix_test(&mut runner, &config));
report(&prefix_tree::run_prefix_test(&mut runner, &config));
}
}
}

149
tests/src/prefix.rs Normal file
View file

@ -0,0 +1,149 @@
// Copyright 2021 The piet-gpu authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Also licensed under MIT license, at your choice.
use piet_gpu_hal::{include_shader, BufferUsage, DescriptorSet};
use piet_gpu_hal::{Buffer, Pipeline};
use crate::config::Config;
use crate::runner::{Commands, Runner};
use crate::test_result::TestResult;
const WG_SIZE: u64 = 512;
const N_ROWS: u64 = 16;
const ELEMENTS_PER_WG: u64 = WG_SIZE * N_ROWS;
/// The shader code for the prefix sum example.
///
/// A code struct can be created once and reused any number of times.
struct PrefixCode {
pipeline: Pipeline,
}
/// The stage resources for the prefix sum example.
///
/// A stage resources struct is specific to a particular problem size
/// and queue.
struct PrefixStage {
// This is the actual problem size but perhaps it would be better to
// treat it as a capacity.
n_elements: u64,
state_buf: Buffer,
}
/// The binding for the prefix sum example.
struct PrefixBinding {
descriptor_set: DescriptorSet,
}
pub unsafe fn run_prefix_test(runner: &mut Runner, config: &Config) -> TestResult {
let mut result = TestResult::new("prefix sum, decoupled look-back");
// This will be configurable.
let n_elements: u64 = config.size.choose(1 << 12, 1 << 24, 1 << 25);
let data: Vec<u32> = (0..n_elements as u32).collect();
let data_buf = runner
.session
.create_buffer_init(&data, BufferUsage::STORAGE)
.unwrap();
let out_buf = runner.buf_down(data_buf.size());
let code = PrefixCode::new(runner);
let stage = PrefixStage::new(runner, n_elements);
let binding = stage.bind(runner, &code, &data_buf, &out_buf.dev_buf);
// Also will be configurable of course.
let n_iter = 1000;
let mut total_elapsed = 0.0;
for i in 0..n_iter {
let mut commands = runner.commands();
commands.write_timestamp(0);
stage.record(&mut commands, &code, &binding);
commands.write_timestamp(1);
if i == 0 {
commands.cmd_buf.memory_barrier();
commands.download(&out_buf);
}
total_elapsed += runner.submit(commands);
if i == 0 {
let mut dst: Vec<u32> = Default::default();
out_buf.read(&mut dst);
if let Some(failure) = verify(&dst) {
result.fail(format!("failure at {}", failure));
}
}
}
result.timing(total_elapsed, n_elements * n_iter);
result
}
impl PrefixCode {
unsafe fn new(runner: &mut Runner) -> PrefixCode {
let code = include_shader!(&runner.session, "../shader/gen/prefix");
let pipeline = runner
.session
.create_simple_compute_pipeline(code, 3)
.unwrap();
PrefixCode { pipeline }
}
}
impl PrefixStage {
unsafe fn new(runner: &mut Runner, n_elements: u64) -> PrefixStage {
let n_workgroups = (n_elements + ELEMENTS_PER_WG - 1) / ELEMENTS_PER_WG;
let state_buf_size = 4 + 12 * n_workgroups;
let state_buf = runner
.session
.create_buffer(state_buf_size, BufferUsage::STORAGE | BufferUsage::COPY_DST)
.unwrap();
PrefixStage {
n_elements,
state_buf,
}
}
unsafe fn bind(
&self,
runner: &mut Runner,
code: &PrefixCode,
in_buf: &Buffer,
out_buf: &Buffer,
) -> PrefixBinding {
let descriptor_set = runner
.session
.create_simple_descriptor_set(&code.pipeline, &[in_buf, out_buf, &self.state_buf])
.unwrap();
PrefixBinding { descriptor_set }
}
unsafe fn record(&self, commands: &mut Commands, code: &PrefixCode, bindings: &PrefixBinding) {
let n_workgroups = (self.n_elements + ELEMENTS_PER_WG - 1) / ELEMENTS_PER_WG;
commands.cmd_buf.clear_buffer(&self.state_buf, None);
commands.cmd_buf.memory_barrier();
commands.cmd_buf.dispatch(
&code.pipeline,
&bindings.descriptor_set,
(n_workgroups as u32, 1, 1),
(WG_SIZE as u32, 1, 1),
);
// One thing that's missing here is registering the buffers so
// they can be safely dropped by Rust code before the execution
// of the command buffer completes.
}
}
// Verify that the data is OEIS A000217
fn verify(data: &[u32]) -> Option<usize> {
data.iter()
.enumerate()
.position(|(i, val)| ((i * (i + 1)) / 2) as u32 != *val)
}

211
tests/src/prefix_tree.rs Normal file
View file

@ -0,0 +1,211 @@
// Copyright 2021 The piet-gpu authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Also licensed under MIT license, at your choice.
use piet_gpu_hal::{include_shader, BufferUsage, DescriptorSet};
use piet_gpu_hal::{Buffer, Pipeline};
use crate::config::Config;
use crate::runner::{Commands, Runner};
use crate::test_result::TestResult;
const WG_SIZE: u64 = 512;
const N_ROWS: u64 = 8;
const ELEMENTS_PER_WG: u64 = WG_SIZE * N_ROWS;
struct PrefixTreeCode {
reduce_pipeline: Pipeline,
scan_pipeline: Pipeline,
root_pipeline: Pipeline,
}
struct PrefixTreeStage {
sizes: Vec<u64>,
tmp_bufs: Vec<Buffer>,
}
struct PrefixTreeBinding {
// All but the first and last can be moved to stage.
descriptor_sets: Vec<DescriptorSet>,
}
pub unsafe fn run_prefix_test(runner: &mut Runner, config: &Config) -> TestResult {
let mut result = TestResult::new("prefix sum, tree reduction");
// This will be configurable. Note though that the current code is
// prone to reading and writing past the end of buffers if this is
// not a power of the number of elements processed in a workgroup.
let n_elements: u64 = config.size.choose(1 << 12, 1 << 24, 1 << 24);
let data: Vec<u32> = (0..n_elements as u32).collect();
let data_buf = runner
.session
.create_buffer_init(&data, BufferUsage::STORAGE)
.unwrap();
let out_buf = runner.buf_down(data_buf.size());
let code = PrefixTreeCode::new(runner);
let stage = PrefixTreeStage::new(runner, n_elements);
let binding = stage.bind(runner, &code, &out_buf.dev_buf);
// Also will be configurable of course.
let n_iter = 1000;
let mut total_elapsed = 0.0;
for i in 0..n_iter {
let mut commands = runner.commands();
commands.cmd_buf.copy_buffer(&data_buf, &out_buf.dev_buf);
commands.cmd_buf.memory_barrier();
commands.write_timestamp(0);
stage.record(&mut commands, &code, &binding);
commands.write_timestamp(1);
if i == 0 {
commands.cmd_buf.memory_barrier();
commands.download(&out_buf);
}
total_elapsed += runner.submit(commands);
if i == 0 {
let mut dst: Vec<u32> = Default::default();
out_buf.read(&mut dst);
if let Some(failure) = verify(&dst) {
result.fail(format!("failure at {}", failure));
}
}
}
result.timing(total_elapsed, n_elements * n_iter);
result
}
impl PrefixTreeCode {
unsafe fn new(runner: &mut Runner) -> PrefixTreeCode {
let reduce_code = include_shader!(&runner.session, "../shader/gen/prefix_reduce");
let reduce_pipeline = runner
.session
.create_simple_compute_pipeline(reduce_code, 2)
.unwrap();
let scan_code = include_shader!(&runner.session, "../shader/gen/prefix_scan");
let scan_pipeline = runner
.session
.create_simple_compute_pipeline(scan_code, 2)
.unwrap();
let root_code = include_shader!(&runner.session, "../shader/gen/prefix_root");
let root_pipeline = runner
.session
.create_simple_compute_pipeline(root_code, 1)
.unwrap();
PrefixTreeCode {
reduce_pipeline,
scan_pipeline,
root_pipeline,
}
}
}
impl PrefixTreeStage {
unsafe fn new(runner: &mut Runner, n_elements: u64) -> PrefixTreeStage {
let mut size = n_elements;
let mut sizes = vec![size];
let mut tmp_bufs = Vec::new();
while size > ELEMENTS_PER_WG {
size = (size + ELEMENTS_PER_WG - 1) / ELEMENTS_PER_WG;
sizes.push(size);
let buf = runner
.session
.create_buffer(4 * size, BufferUsage::STORAGE)
.unwrap();
tmp_bufs.push(buf);
}
PrefixTreeStage { sizes, tmp_bufs }
}
unsafe fn bind(
&self,
runner: &mut Runner,
code: &PrefixTreeCode,
data_buf: &Buffer,
) -> PrefixTreeBinding {
let mut descriptor_sets = Vec::with_capacity(2 * self.tmp_bufs.len() + 1);
for i in 0..self.tmp_bufs.len() {
let buf0 = if i == 0 {
data_buf
} else {
&self.tmp_bufs[i - 1]
};
let buf1 = &self.tmp_bufs[i];
let descriptor_set = runner
.session
.create_simple_descriptor_set(&code.reduce_pipeline, &[buf0, buf1])
.unwrap();
descriptor_sets.push(descriptor_set);
}
let buf0 = self.tmp_bufs.last().unwrap_or(data_buf);
let descriptor_set = runner
.session
.create_simple_descriptor_set(&code.root_pipeline, &[buf0])
.unwrap();
descriptor_sets.push(descriptor_set);
for i in (0..self.tmp_bufs.len()).rev() {
let buf0 = if i == 0 {
data_buf
} else {
&self.tmp_bufs[i - 1]
};
let buf1 = &self.tmp_bufs[i];
let descriptor_set = runner
.session
.create_simple_descriptor_set(&code.scan_pipeline, &[buf0, buf1])
.unwrap();
descriptor_sets.push(descriptor_set);
}
PrefixTreeBinding { descriptor_sets }
}
unsafe fn record(
&self,
commands: &mut Commands,
code: &PrefixTreeCode,
bindings: &PrefixTreeBinding,
) {
let n = self.tmp_bufs.len();
for i in 0..n {
let n_workgroups = self.sizes[i + 1];
commands.cmd_buf.dispatch(
&code.reduce_pipeline,
&bindings.descriptor_sets[i],
(n_workgroups as u32, 1, 1),
(WG_SIZE as u32, 1, 1),
);
commands.cmd_buf.memory_barrier();
}
commands.cmd_buf.dispatch(
&code.root_pipeline,
&bindings.descriptor_sets[n],
(1, 1, 1),
(WG_SIZE as u32, 1, 1),
);
for i in (0..n).rev() {
commands.cmd_buf.memory_barrier();
let n_workgroups = self.sizes[i + 1];
commands.cmd_buf.dispatch(
&code.scan_pipeline,
&bindings.descriptor_sets[2 * n - i],
(n_workgroups as u32, 1, 1),
(WG_SIZE as u32, 1, 1),
);
}
}
}
// Verify that the data is OEIS A000217
fn verify(data: &[u32]) -> Option<usize> {
data.iter()
.enumerate()
.position(|(i, val)| ((i * (i + 1)) / 2) as u32 != *val)
}

138
tests/src/runner.rs Normal file
View file

@ -0,0 +1,138 @@
// Copyright 2021 The piet-gpu authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Also licensed under MIT license, at your choice.
//! Test runner intended to make it easy to write tests.
use piet_gpu_hal::{Buffer, BufferUsage, CmdBuf, Instance, PlainData, QueryPool, Session};
pub struct Runner {
#[allow(unused)]
instance: Instance,
pub session: Session,
cmd_buf_pool: Vec<CmdBuf>,
}
/// A wrapper around command buffers.
pub struct Commands {
pub cmd_buf: CmdBuf,
query_pool: QueryPool,
}
/// Buffer for uploading data to GPU.
#[allow(unused)]
pub struct BufUp {
pub stage_buf: Buffer,
pub dev_buf: Buffer,
}
/// Buffer for downloading data from GPU.
pub struct BufDown {
pub stage_buf: Buffer,
pub dev_buf: Buffer,
}
impl Runner {
pub unsafe fn new() -> Runner {
let (instance, _) = Instance::new(None).unwrap();
let device = instance.device(None).unwrap();
let session = Session::new(device);
let cmd_buf_pool = Vec::new();
Runner {
instance,
session,
cmd_buf_pool,
}
}
pub unsafe fn commands(&mut self) -> Commands {
let mut cmd_buf = self
.cmd_buf_pool
.pop()
.unwrap_or_else(|| self.session.cmd_buf().unwrap());
cmd_buf.begin();
// TODO: also pool these. But we might sort by size, as
// we might not always be doing two.
let query_pool = self.session.create_query_pool(2).unwrap();
cmd_buf.reset_query_pool(&query_pool);
Commands {
cmd_buf,
query_pool,
}
}
pub unsafe fn submit(&mut self, commands: Commands) -> f64 {
let mut cmd_buf = commands.cmd_buf;
let query_pool = commands.query_pool;
cmd_buf.finish_timestamps(&query_pool);
cmd_buf.host_barrier();
cmd_buf.finish();
let submitted = self.session.run_cmd_buf(cmd_buf, &[], &[]).unwrap();
self.cmd_buf_pool.extend(submitted.wait().unwrap());
let timestamps = self.session.fetch_query_pool(&query_pool).unwrap();
timestamps[0]
}
#[allow(unused)]
pub fn buf_up(&self, size: u64) -> BufUp {
let stage_buf = self
.session
.create_buffer(size, BufferUsage::MAP_WRITE | BufferUsage::COPY_SRC)
.unwrap();
let dev_buf = self
.session
.create_buffer(size, BufferUsage::COPY_DST | BufferUsage::STORAGE)
.unwrap();
BufUp { stage_buf, dev_buf }
}
pub fn buf_down(&self, size: u64) -> BufDown {
let stage_buf = self
.session
.create_buffer(size, BufferUsage::MAP_READ | BufferUsage::COPY_DST)
.unwrap();
// Note: the COPY_DST isn't needed in all use cases, but I don't think
// making this tighter would help.
let dev_buf = self
.session
.create_buffer(
size,
BufferUsage::COPY_SRC | BufferUsage::COPY_DST | BufferUsage::STORAGE,
)
.unwrap();
BufDown { stage_buf, dev_buf }
}
}
impl Commands {
pub unsafe fn write_timestamp(&mut self, query: u32) {
self.cmd_buf.write_timestamp(&self.query_pool, query);
}
#[allow(unused)]
pub unsafe fn upload(&mut self, buf: &BufUp) {
self.cmd_buf.copy_buffer(&buf.stage_buf, &buf.dev_buf);
}
pub unsafe fn download(&mut self, buf: &BufDown) {
self.cmd_buf.copy_buffer(&buf.dev_buf, &buf.stage_buf);
}
}
impl BufDown {
pub unsafe fn read(&self, dst: &mut Vec<impl PlainData>) {
self.stage_buf.read(dst).unwrap()
}
}

110
tests/src/test_result.rs Normal file
View file

@ -0,0 +1,110 @@
// Copyright 2021 The piet-gpu authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Also licensed under MIT license, at your choice.
//! Recording of results from tests.
pub struct TestResult {
name: String,
// TODO: statistics. We're lean and mean for now.
total_time: f64,
n_elements: u64,
failure: Option<String>,
}
#[derive(Clone, Copy)]
pub enum ReportStyle {
Short,
Verbose,
}
impl TestResult {
pub fn new(name: &str) -> TestResult {
TestResult {
name: name.to_string(),
total_time: 0.0,
n_elements: 0,
failure: None,
}
}
pub fn report(&self, style: ReportStyle) {
let fail_string = match &self.failure {
None => "pass".into(),
Some(s) => format!("fail ({})", s),
};
match style {
ReportStyle::Short => {
let mut timing_string = String::new();
if self.total_time > 0.0 {
if self.n_elements > 0 {
let throughput = self.n_elements as f64 / self.total_time;
timing_string = format!(" {} elements/s", format_nice(throughput, 1));
} else {
timing_string = format!(" {}s", format_nice(self.total_time, 1));
}
}
println!("{}: {}{}", self.name, fail_string, timing_string)
}
ReportStyle::Verbose => {
println!("test {}", self.name);
println!(" {}", fail_string);
if self.total_time > 0.0 {
println!(" {}s total time", format_nice(self.total_time, 1));
if self.n_elements > 0 {
println!(" {} elements", self.n_elements);
let throughput = self.n_elements as f64 / self.total_time;
println!(" {} elements/s", format_nice(throughput, 1));
}
}
}
}
}
pub fn fail(&mut self, explanation: String) {
self.failure = Some(explanation);
}
pub fn timing(&mut self, total_time: f64, n_elements: u64) {
self.total_time = total_time;
self.n_elements = n_elements;
}
}
fn format_nice(x: f64, precision: usize) -> String {
// Precision should probably scale; later
let (scale, suffix) = if x >= 1e12 && x < 1e15 {
(1e-12, "T")
} else if x >= 1e9 {
(1e-9, "G")
} else if x >= 1e6 {
(1e-6, "M")
} else if x >= 1e3 {
(1e-3, "k")
} else if x >= 1.0 {
(1.0, "")
} else if x >= 1e-3 {
(1e3, "m")
} else if x >= 1e-6 {
(1e6, "\u{00b5}")
} else if x >= 1e-9 {
(1e9, "n")
} else if x >= 1e-12 {
(1e12, "p")
} else {
return format!("{:.*e}", precision, x);
};
format!("{:.*}{}", precision, scale * x, suffix)
}