mirror of
https://github.com/italicsjenga/vello.git
synced 2025-01-10 20:51:29 +11:00
Merge pull request #123 from linebender/tests
Start testing framework, with prefix sum
This commit is contained in:
commit
19fedf36db
15
Cargo.lock
generated
15
Cargo.lock
generated
|
@ -85,6 +85,12 @@ version = "0.1.6"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "0d8c1fef690941d3e7788d328517591fecc684c084084702d6ff1641e993699a"
|
checksum = "0d8c1fef690941d3e7788d328517591fecc684c084084702d6ff1641e993699a"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "bytemuck"
|
||||||
|
version = "1.7.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "72957246c41db82b8ef88a5486143830adeb8227ef9837740bdec67724cf2c5b"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "byteorder"
|
name = "byteorder"
|
||||||
version = "1.3.4"
|
version = "1.3.4"
|
||||||
|
@ -897,6 +903,15 @@ dependencies = [
|
||||||
"wio",
|
"wio",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "piet-gpu-tests"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"bytemuck",
|
||||||
|
"clap",
|
||||||
|
"piet-gpu-hal",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "piet-gpu-types"
|
name = "piet-gpu-types"
|
||||||
version = "0.0.0"
|
version = "0.0.0"
|
||||||
|
|
|
@ -4,5 +4,6 @@ members = [
|
||||||
"piet-gpu",
|
"piet-gpu",
|
||||||
"piet-gpu-derive",
|
"piet-gpu-derive",
|
||||||
"piet-gpu-hal",
|
"piet-gpu-hal",
|
||||||
"piet-gpu-types"
|
"piet-gpu-types",
|
||||||
|
"tests"
|
||||||
]
|
]
|
||||||
|
|
|
@ -21,6 +21,7 @@ fn main() {
|
||||||
cmd_buf.write_timestamp(&query_pool, 0);
|
cmd_buf.write_timestamp(&query_pool, 0);
|
||||||
cmd_buf.dispatch(&pipeline, &descriptor_set, (256, 1, 1), (1, 1, 1));
|
cmd_buf.dispatch(&pipeline, &descriptor_set, (256, 1, 1), (1, 1, 1));
|
||||||
cmd_buf.write_timestamp(&query_pool, 1);
|
cmd_buf.write_timestamp(&query_pool, 1);
|
||||||
|
cmd_buf.finish_timestamps(&query_pool);
|
||||||
cmd_buf.host_barrier();
|
cmd_buf.host_barrier();
|
||||||
cmd_buf.finish();
|
cmd_buf.finish();
|
||||||
let submitted = session.run_cmd_buf(cmd_buf, &[], &[]).unwrap();
|
let submitted = session.run_cmd_buf(cmd_buf, &[], &[]).unwrap();
|
||||||
|
|
14
tests/Cargo.toml
Normal file
14
tests/Cargo.toml
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
[package]
|
||||||
|
name = "piet-gpu-tests"
|
||||||
|
version = "0.1.0"
|
||||||
|
authors = ["Raph Levien <raph.levien@gmail.com>"]
|
||||||
|
description = "Tests for piet-gpu shaders and generic GPU capabilities."
|
||||||
|
license = "MIT/Apache-2.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
clap = "2.33"
|
||||||
|
bytemuck = "1.7.2"
|
||||||
|
|
||||||
|
[dependencies.piet-gpu-hal]
|
||||||
|
path = "../piet-gpu-hal"
|
32
tests/shader/build.ninja
Normal file
32
tests/shader/build.ninja
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
# Build file for shaders.
|
||||||
|
|
||||||
|
# You must have Vulkan tools in your path, or patch here.
|
||||||
|
|
||||||
|
glslang_validator = glslangValidator
|
||||||
|
spirv_cross = spirv-cross
|
||||||
|
|
||||||
|
rule glsl
|
||||||
|
command = $glslang_validator $flags -V -o $out $in
|
||||||
|
|
||||||
|
rule hlsl
|
||||||
|
command = $spirv_cross --hlsl $in --output $out
|
||||||
|
|
||||||
|
rule msl
|
||||||
|
command = $spirv_cross --msl $in --output $out
|
||||||
|
|
||||||
|
build gen/prefix.spv: glsl prefix.comp
|
||||||
|
build gen/prefix.hlsl: hlsl gen/prefix.spv
|
||||||
|
build gen/prefix.msl: msl gen/prefix.spv
|
||||||
|
|
||||||
|
build gen/prefix_reduce.spv: glsl prefix_reduce.comp
|
||||||
|
build gen/prefix_reduce.hlsl: hlsl gen/prefix_reduce.spv
|
||||||
|
build gen/prefix_reduce.msl: msl gen/prefix_reduce.spv
|
||||||
|
|
||||||
|
build gen/prefix_root.spv: glsl prefix_scan.comp
|
||||||
|
flags = -DROOT
|
||||||
|
build gen/prefix_root.hlsl: hlsl gen/prefix_root.spv
|
||||||
|
build gen/prefix_root.msl: msl gen/prefix_root.spv
|
||||||
|
|
||||||
|
build gen/prefix_scan.spv: glsl prefix_scan.comp
|
||||||
|
build gen/prefix_scan.hlsl: hlsl gen/prefix_scan.spv
|
||||||
|
build gen/prefix_scan.msl: msl gen/prefix_scan.spv
|
43
tests/shader/gen/collatz.hlsl
Normal file
43
tests/shader/gen/collatz.hlsl
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
static const uint3 gl_WorkGroupSize = uint3(1u, 1u, 1u);
|
||||||
|
|
||||||
|
RWByteAddressBuffer _53 : register(u1);
|
||||||
|
ByteAddressBuffer _59 : register(t0);
|
||||||
|
|
||||||
|
static uint3 gl_GlobalInvocationID;
|
||||||
|
struct SPIRV_Cross_Input
|
||||||
|
{
|
||||||
|
uint3 gl_GlobalInvocationID : SV_DispatchThreadID;
|
||||||
|
};
|
||||||
|
|
||||||
|
uint collatz_iterations(inout uint n)
|
||||||
|
{
|
||||||
|
uint i = 0u;
|
||||||
|
while (n != 1u)
|
||||||
|
{
|
||||||
|
if ((n % 2u) == 0u)
|
||||||
|
{
|
||||||
|
n /= 2u;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
n = (3u * n) + 1u;
|
||||||
|
}
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
|
||||||
|
void comp_main()
|
||||||
|
{
|
||||||
|
uint index = gl_GlobalInvocationID.x;
|
||||||
|
uint param = _59.Load(index * 4 + 0);
|
||||||
|
uint _65 = collatz_iterations(param);
|
||||||
|
_53.Store(index * 4 + 0, _65);
|
||||||
|
}
|
||||||
|
|
||||||
|
[numthreads(1, 1, 1)]
|
||||||
|
void main(SPIRV_Cross_Input stage_input)
|
||||||
|
{
|
||||||
|
gl_GlobalInvocationID = stage_input.gl_GlobalInvocationID;
|
||||||
|
comp_main();
|
||||||
|
}
|
46
tests/shader/gen/collatz.msl
Normal file
46
tests/shader/gen/collatz.msl
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
#pragma clang diagnostic ignored "-Wmissing-prototypes"
|
||||||
|
|
||||||
|
#include <metal_stdlib>
|
||||||
|
#include <simd/simd.h>
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
struct OutBuf
|
||||||
|
{
|
||||||
|
uint out_buf[1];
|
||||||
|
};
|
||||||
|
|
||||||
|
struct InBuf
|
||||||
|
{
|
||||||
|
uint in_buf[1];
|
||||||
|
};
|
||||||
|
|
||||||
|
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1u);
|
||||||
|
|
||||||
|
static inline __attribute__((always_inline))
|
||||||
|
uint collatz_iterations(thread uint& n)
|
||||||
|
{
|
||||||
|
uint i = 0u;
|
||||||
|
while (n != 1u)
|
||||||
|
{
|
||||||
|
if ((n % 2u) == 0u)
|
||||||
|
{
|
||||||
|
n /= 2u;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
n = (3u * n) + 1u;
|
||||||
|
}
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void main0(device OutBuf& _53 [[buffer(0)]], const device InBuf& _59 [[buffer(1)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]])
|
||||||
|
{
|
||||||
|
uint index = gl_GlobalInvocationID.x;
|
||||||
|
uint param = _59.in_buf[index];
|
||||||
|
uint _65 = collatz_iterations(param);
|
||||||
|
_53.out_buf[index] = _65;
|
||||||
|
}
|
||||||
|
|
BIN
tests/shader/gen/collatz.spv
Normal file
BIN
tests/shader/gen/collatz.spv
Normal file
Binary file not shown.
223
tests/shader/gen/prefix.hlsl
Normal file
223
tests/shader/gen/prefix.hlsl
Normal file
|
@ -0,0 +1,223 @@
|
||||||
|
struct Monoid
|
||||||
|
{
|
||||||
|
uint element;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct State
|
||||||
|
{
|
||||||
|
uint flag;
|
||||||
|
Monoid aggregate;
|
||||||
|
Monoid prefix;
|
||||||
|
};
|
||||||
|
|
||||||
|
static const uint3 gl_WorkGroupSize = uint3(512u, 1u, 1u);
|
||||||
|
|
||||||
|
static const Monoid _187 = { 0u };
|
||||||
|
|
||||||
|
globallycoherent RWByteAddressBuffer _43 : register(u2);
|
||||||
|
ByteAddressBuffer _67 : register(t0);
|
||||||
|
RWByteAddressBuffer _374 : register(u1);
|
||||||
|
|
||||||
|
static uint3 gl_LocalInvocationID;
|
||||||
|
struct SPIRV_Cross_Input
|
||||||
|
{
|
||||||
|
uint3 gl_LocalInvocationID : SV_GroupThreadID;
|
||||||
|
};
|
||||||
|
|
||||||
|
groupshared uint sh_part_ix;
|
||||||
|
groupshared Monoid sh_scratch[512];
|
||||||
|
groupshared uint sh_flag;
|
||||||
|
groupshared Monoid sh_prefix;
|
||||||
|
|
||||||
|
Monoid combine_monoid(Monoid a, Monoid b)
|
||||||
|
{
|
||||||
|
Monoid _22 = { a.element + b.element };
|
||||||
|
return _22;
|
||||||
|
}
|
||||||
|
|
||||||
|
void comp_main()
|
||||||
|
{
|
||||||
|
if (gl_LocalInvocationID.x == 0u)
|
||||||
|
{
|
||||||
|
uint _47;
|
||||||
|
_43.InterlockedAdd(0, 1u, _47);
|
||||||
|
sh_part_ix = _47;
|
||||||
|
}
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
uint part_ix = sh_part_ix;
|
||||||
|
uint ix = (part_ix * 8192u) + (gl_LocalInvocationID.x * 16u);
|
||||||
|
Monoid _71;
|
||||||
|
_71.element = _67.Load(ix * 4 + 0);
|
||||||
|
Monoid local[16];
|
||||||
|
local[0].element = _71.element;
|
||||||
|
Monoid param_1;
|
||||||
|
for (uint i = 1u; i < 16u; i++)
|
||||||
|
{
|
||||||
|
Monoid param = local[i - 1u];
|
||||||
|
Monoid _94;
|
||||||
|
_94.element = _67.Load((ix + i) * 4 + 0);
|
||||||
|
param_1.element = _94.element;
|
||||||
|
local[i] = combine_monoid(param, param_1);
|
||||||
|
}
|
||||||
|
Monoid agg = local[15];
|
||||||
|
sh_scratch[gl_LocalInvocationID.x] = agg;
|
||||||
|
for (uint i_1 = 0u; i_1 < 9u; i_1++)
|
||||||
|
{
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
if (gl_LocalInvocationID.x >= uint(1 << int(i_1)))
|
||||||
|
{
|
||||||
|
Monoid other = sh_scratch[gl_LocalInvocationID.x - uint(1 << int(i_1))];
|
||||||
|
Monoid param_2 = other;
|
||||||
|
Monoid param_3 = agg;
|
||||||
|
agg = combine_monoid(param_2, param_3);
|
||||||
|
}
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
sh_scratch[gl_LocalInvocationID.x] = agg;
|
||||||
|
}
|
||||||
|
if (gl_LocalInvocationID.x == 511u)
|
||||||
|
{
|
||||||
|
_43.Store(part_ix * 12 + 8, agg.element);
|
||||||
|
if (part_ix == 0u)
|
||||||
|
{
|
||||||
|
_43.Store(12, agg.element);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DeviceMemoryBarrier();
|
||||||
|
if (gl_LocalInvocationID.x == 511u)
|
||||||
|
{
|
||||||
|
uint flag = 1u;
|
||||||
|
if (part_ix == 0u)
|
||||||
|
{
|
||||||
|
flag = 2u;
|
||||||
|
}
|
||||||
|
_43.Store(part_ix * 12 + 4, flag);
|
||||||
|
}
|
||||||
|
Monoid exclusive = _187;
|
||||||
|
if (part_ix != 0u)
|
||||||
|
{
|
||||||
|
uint look_back_ix = part_ix - 1u;
|
||||||
|
uint their_ix = 0u;
|
||||||
|
Monoid their_prefix;
|
||||||
|
Monoid their_agg;
|
||||||
|
Monoid m;
|
||||||
|
while (true)
|
||||||
|
{
|
||||||
|
if (gl_LocalInvocationID.x == 511u)
|
||||||
|
{
|
||||||
|
sh_flag = _43.Load(look_back_ix * 12 + 4);
|
||||||
|
}
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
DeviceMemoryBarrier();
|
||||||
|
uint flag_1 = sh_flag;
|
||||||
|
if (flag_1 == 2u)
|
||||||
|
{
|
||||||
|
if (gl_LocalInvocationID.x == 511u)
|
||||||
|
{
|
||||||
|
Monoid _225;
|
||||||
|
_225.element = _43.Load(look_back_ix * 12 + 12);
|
||||||
|
their_prefix.element = _225.element;
|
||||||
|
Monoid param_4 = their_prefix;
|
||||||
|
Monoid param_5 = exclusive;
|
||||||
|
exclusive = combine_monoid(param_4, param_5);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
if (flag_1 == 1u)
|
||||||
|
{
|
||||||
|
if (gl_LocalInvocationID.x == 511u)
|
||||||
|
{
|
||||||
|
Monoid _247;
|
||||||
|
_247.element = _43.Load(look_back_ix * 12 + 8);
|
||||||
|
their_agg.element = _247.element;
|
||||||
|
Monoid param_6 = their_agg;
|
||||||
|
Monoid param_7 = exclusive;
|
||||||
|
exclusive = combine_monoid(param_6, param_7);
|
||||||
|
}
|
||||||
|
look_back_ix--;
|
||||||
|
their_ix = 0u;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (gl_LocalInvocationID.x == 511u)
|
||||||
|
{
|
||||||
|
Monoid _269;
|
||||||
|
_269.element = _67.Load(((look_back_ix * 8192u) + their_ix) * 4 + 0);
|
||||||
|
m.element = _269.element;
|
||||||
|
if (their_ix == 0u)
|
||||||
|
{
|
||||||
|
their_agg = m;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
Monoid param_8 = their_agg;
|
||||||
|
Monoid param_9 = m;
|
||||||
|
their_agg = combine_monoid(param_8, param_9);
|
||||||
|
}
|
||||||
|
their_ix++;
|
||||||
|
if (their_ix == 8192u)
|
||||||
|
{
|
||||||
|
Monoid param_10 = their_agg;
|
||||||
|
Monoid param_11 = exclusive;
|
||||||
|
exclusive = combine_monoid(param_10, param_11);
|
||||||
|
if (look_back_ix == 0u)
|
||||||
|
{
|
||||||
|
sh_flag = 2u;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
look_back_ix--;
|
||||||
|
their_ix = 0u;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
flag_1 = sh_flag;
|
||||||
|
if (flag_1 == 2u)
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (gl_LocalInvocationID.x == 511u)
|
||||||
|
{
|
||||||
|
Monoid param_12 = exclusive;
|
||||||
|
Monoid param_13 = agg;
|
||||||
|
Monoid inclusive_prefix = combine_monoid(param_12, param_13);
|
||||||
|
sh_prefix = exclusive;
|
||||||
|
_43.Store(part_ix * 12 + 12, inclusive_prefix.element);
|
||||||
|
}
|
||||||
|
DeviceMemoryBarrier();
|
||||||
|
if (gl_LocalInvocationID.x == 511u)
|
||||||
|
{
|
||||||
|
_43.Store(part_ix * 12 + 4, 2u);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
if (part_ix != 0u)
|
||||||
|
{
|
||||||
|
exclusive = sh_prefix;
|
||||||
|
}
|
||||||
|
Monoid row = exclusive;
|
||||||
|
if (gl_LocalInvocationID.x > 0u)
|
||||||
|
{
|
||||||
|
Monoid other_1 = sh_scratch[gl_LocalInvocationID.x - 1u];
|
||||||
|
Monoid param_14 = row;
|
||||||
|
Monoid param_15 = other_1;
|
||||||
|
row = combine_monoid(param_14, param_15);
|
||||||
|
}
|
||||||
|
for (uint i_2 = 0u; i_2 < 16u; i_2++)
|
||||||
|
{
|
||||||
|
Monoid param_16 = row;
|
||||||
|
Monoid param_17 = local[i_2];
|
||||||
|
Monoid m_1 = combine_monoid(param_16, param_17);
|
||||||
|
_374.Store((ix + i_2) * 4 + 0, m_1.element);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[numthreads(512, 1, 1)]
|
||||||
|
void main(SPIRV_Cross_Input stage_input)
|
||||||
|
{
|
||||||
|
gl_LocalInvocationID = stage_input.gl_LocalInvocationID;
|
||||||
|
comp_main();
|
||||||
|
}
|
262
tests/shader/gen/prefix.msl
Normal file
262
tests/shader/gen/prefix.msl
Normal file
|
@ -0,0 +1,262 @@
|
||||||
|
#pragma clang diagnostic ignored "-Wmissing-prototypes"
|
||||||
|
#pragma clang diagnostic ignored "-Wmissing-braces"
|
||||||
|
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||||
|
|
||||||
|
#include <metal_stdlib>
|
||||||
|
#include <simd/simd.h>
|
||||||
|
#include <metal_atomic>
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
template<typename T, size_t Num>
|
||||||
|
struct spvUnsafeArray
|
||||||
|
{
|
||||||
|
T elements[Num ? Num : 1];
|
||||||
|
|
||||||
|
thread T& operator [] (size_t pos) thread
|
||||||
|
{
|
||||||
|
return elements[pos];
|
||||||
|
}
|
||||||
|
constexpr const thread T& operator [] (size_t pos) const thread
|
||||||
|
{
|
||||||
|
return elements[pos];
|
||||||
|
}
|
||||||
|
|
||||||
|
device T& operator [] (size_t pos) device
|
||||||
|
{
|
||||||
|
return elements[pos];
|
||||||
|
}
|
||||||
|
constexpr const device T& operator [] (size_t pos) const device
|
||||||
|
{
|
||||||
|
return elements[pos];
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr const constant T& operator [] (size_t pos) const constant
|
||||||
|
{
|
||||||
|
return elements[pos];
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup T& operator [] (size_t pos) threadgroup
|
||||||
|
{
|
||||||
|
return elements[pos];
|
||||||
|
}
|
||||||
|
constexpr const threadgroup T& operator [] (size_t pos) const threadgroup
|
||||||
|
{
|
||||||
|
return elements[pos];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Monoid
|
||||||
|
{
|
||||||
|
uint element;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Monoid_1
|
||||||
|
{
|
||||||
|
uint element;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct State
|
||||||
|
{
|
||||||
|
uint flag;
|
||||||
|
Monoid_1 aggregate;
|
||||||
|
Monoid_1 prefix;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct StateBuf
|
||||||
|
{
|
||||||
|
uint part_counter;
|
||||||
|
State state[1];
|
||||||
|
};
|
||||||
|
|
||||||
|
struct InBuf
|
||||||
|
{
|
||||||
|
Monoid_1 inbuf[1];
|
||||||
|
};
|
||||||
|
|
||||||
|
struct OutBuf
|
||||||
|
{
|
||||||
|
Monoid_1 outbuf[1];
|
||||||
|
};
|
||||||
|
|
||||||
|
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u);
|
||||||
|
|
||||||
|
static inline __attribute__((always_inline))
|
||||||
|
Monoid combine_monoid(thread const Monoid& a, thread const Monoid& b)
|
||||||
|
{
|
||||||
|
return Monoid{ a.element + b.element };
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void main0(volatile device StateBuf& _43 [[buffer(0)]], const device InBuf& _67 [[buffer(1)]], device OutBuf& _374 [[buffer(2)]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]])
|
||||||
|
{
|
||||||
|
threadgroup uint sh_part_ix;
|
||||||
|
threadgroup Monoid sh_scratch[512];
|
||||||
|
threadgroup uint sh_flag;
|
||||||
|
threadgroup Monoid sh_prefix;
|
||||||
|
if (gl_LocalInvocationID.x == 0u)
|
||||||
|
{
|
||||||
|
uint _47 = atomic_fetch_add_explicit((volatile device atomic_uint*)&_43.part_counter, 1u, memory_order_relaxed);
|
||||||
|
sh_part_ix = _47;
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
uint part_ix = sh_part_ix;
|
||||||
|
uint ix = (part_ix * 8192u) + (gl_LocalInvocationID.x * 16u);
|
||||||
|
spvUnsafeArray<Monoid, 16> local;
|
||||||
|
local[0].element = _67.inbuf[ix].element;
|
||||||
|
Monoid param_1;
|
||||||
|
for (uint i = 1u; i < 16u; i++)
|
||||||
|
{
|
||||||
|
Monoid param = local[i - 1u];
|
||||||
|
param_1.element = _67.inbuf[ix + i].element;
|
||||||
|
local[i] = combine_monoid(param, param_1);
|
||||||
|
}
|
||||||
|
Monoid agg = local[15];
|
||||||
|
sh_scratch[gl_LocalInvocationID.x] = agg;
|
||||||
|
for (uint i_1 = 0u; i_1 < 9u; i_1++)
|
||||||
|
{
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
if (gl_LocalInvocationID.x >= uint(1 << int(i_1)))
|
||||||
|
{
|
||||||
|
Monoid other = sh_scratch[gl_LocalInvocationID.x - uint(1 << int(i_1))];
|
||||||
|
Monoid param_2 = other;
|
||||||
|
Monoid param_3 = agg;
|
||||||
|
agg = combine_monoid(param_2, param_3);
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
sh_scratch[gl_LocalInvocationID.x] = agg;
|
||||||
|
}
|
||||||
|
if (gl_LocalInvocationID.x == 511u)
|
||||||
|
{
|
||||||
|
_43.state[part_ix].aggregate.element = agg.element;
|
||||||
|
if (part_ix == 0u)
|
||||||
|
{
|
||||||
|
_43.state[0].prefix.element = agg.element;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_device);
|
||||||
|
if (gl_LocalInvocationID.x == 511u)
|
||||||
|
{
|
||||||
|
uint flag = 1u;
|
||||||
|
if (part_ix == 0u)
|
||||||
|
{
|
||||||
|
flag = 2u;
|
||||||
|
}
|
||||||
|
_43.state[part_ix].flag = flag;
|
||||||
|
}
|
||||||
|
Monoid exclusive = Monoid{ 0u };
|
||||||
|
if (part_ix != 0u)
|
||||||
|
{
|
||||||
|
uint look_back_ix = part_ix - 1u;
|
||||||
|
uint their_ix = 0u;
|
||||||
|
Monoid their_prefix;
|
||||||
|
Monoid their_agg;
|
||||||
|
Monoid m;
|
||||||
|
while (true)
|
||||||
|
{
|
||||||
|
if (gl_LocalInvocationID.x == 511u)
|
||||||
|
{
|
||||||
|
sh_flag = _43.state[look_back_ix].flag;
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
threadgroup_barrier(mem_flags::mem_device);
|
||||||
|
uint flag_1 = sh_flag;
|
||||||
|
if (flag_1 == 2u)
|
||||||
|
{
|
||||||
|
if (gl_LocalInvocationID.x == 511u)
|
||||||
|
{
|
||||||
|
their_prefix.element = _43.state[look_back_ix].prefix.element;
|
||||||
|
Monoid param_4 = their_prefix;
|
||||||
|
Monoid param_5 = exclusive;
|
||||||
|
exclusive = combine_monoid(param_4, param_5);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
if (flag_1 == 1u)
|
||||||
|
{
|
||||||
|
if (gl_LocalInvocationID.x == 511u)
|
||||||
|
{
|
||||||
|
their_agg.element = _43.state[look_back_ix].aggregate.element;
|
||||||
|
Monoid param_6 = their_agg;
|
||||||
|
Monoid param_7 = exclusive;
|
||||||
|
exclusive = combine_monoid(param_6, param_7);
|
||||||
|
}
|
||||||
|
look_back_ix--;
|
||||||
|
their_ix = 0u;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (gl_LocalInvocationID.x == 511u)
|
||||||
|
{
|
||||||
|
m.element = _67.inbuf[(look_back_ix * 8192u) + their_ix].element;
|
||||||
|
if (their_ix == 0u)
|
||||||
|
{
|
||||||
|
their_agg = m;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
Monoid param_8 = their_agg;
|
||||||
|
Monoid param_9 = m;
|
||||||
|
their_agg = combine_monoid(param_8, param_9);
|
||||||
|
}
|
||||||
|
their_ix++;
|
||||||
|
if (their_ix == 8192u)
|
||||||
|
{
|
||||||
|
Monoid param_10 = their_agg;
|
||||||
|
Monoid param_11 = exclusive;
|
||||||
|
exclusive = combine_monoid(param_10, param_11);
|
||||||
|
if (look_back_ix == 0u)
|
||||||
|
{
|
||||||
|
sh_flag = 2u;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
look_back_ix--;
|
||||||
|
their_ix = 0u;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
flag_1 = sh_flag;
|
||||||
|
if (flag_1 == 2u)
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (gl_LocalInvocationID.x == 511u)
|
||||||
|
{
|
||||||
|
Monoid param_12 = exclusive;
|
||||||
|
Monoid param_13 = agg;
|
||||||
|
Monoid inclusive_prefix = combine_monoid(param_12, param_13);
|
||||||
|
sh_prefix = exclusive;
|
||||||
|
_43.state[part_ix].prefix.element = inclusive_prefix.element;
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_device);
|
||||||
|
if (gl_LocalInvocationID.x == 511u)
|
||||||
|
{
|
||||||
|
_43.state[part_ix].flag = 2u;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
if (part_ix != 0u)
|
||||||
|
{
|
||||||
|
exclusive = sh_prefix;
|
||||||
|
}
|
||||||
|
Monoid row = exclusive;
|
||||||
|
if (gl_LocalInvocationID.x > 0u)
|
||||||
|
{
|
||||||
|
Monoid other_1 = sh_scratch[gl_LocalInvocationID.x - 1u];
|
||||||
|
Monoid param_14 = row;
|
||||||
|
Monoid param_15 = other_1;
|
||||||
|
row = combine_monoid(param_14, param_15);
|
||||||
|
}
|
||||||
|
for (uint i_2 = 0u; i_2 < 16u; i_2++)
|
||||||
|
{
|
||||||
|
Monoid param_16 = row;
|
||||||
|
Monoid param_17 = local[i_2];
|
||||||
|
Monoid m_1 = combine_monoid(param_16, param_17);
|
||||||
|
_374.outbuf[ix + i_2].element = m_1.element;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
BIN
tests/shader/gen/prefix.spv
Normal file
BIN
tests/shader/gen/prefix.spv
Normal file
Binary file not shown.
72
tests/shader/gen/prefix_reduce.hlsl
Normal file
72
tests/shader/gen/prefix_reduce.hlsl
Normal file
|
@ -0,0 +1,72 @@
|
||||||
|
struct Monoid
|
||||||
|
{
|
||||||
|
uint element;
|
||||||
|
};
|
||||||
|
|
||||||
|
static const uint3 gl_WorkGroupSize = uint3(512u, 1u, 1u);
|
||||||
|
|
||||||
|
ByteAddressBuffer _40 : register(t0);
|
||||||
|
RWByteAddressBuffer _129 : register(u1);
|
||||||
|
|
||||||
|
static uint3 gl_WorkGroupID;
|
||||||
|
static uint3 gl_LocalInvocationID;
|
||||||
|
static uint3 gl_GlobalInvocationID;
|
||||||
|
struct SPIRV_Cross_Input
|
||||||
|
{
|
||||||
|
uint3 gl_WorkGroupID : SV_GroupID;
|
||||||
|
uint3 gl_LocalInvocationID : SV_GroupThreadID;
|
||||||
|
uint3 gl_GlobalInvocationID : SV_DispatchThreadID;
|
||||||
|
};
|
||||||
|
|
||||||
|
groupshared Monoid sh_scratch[512];
|
||||||
|
|
||||||
|
Monoid combine_monoid(Monoid a, Monoid b)
|
||||||
|
{
|
||||||
|
Monoid _22 = { a.element + b.element };
|
||||||
|
return _22;
|
||||||
|
}
|
||||||
|
|
||||||
|
void comp_main()
|
||||||
|
{
|
||||||
|
uint ix = gl_GlobalInvocationID.x * 8u;
|
||||||
|
Monoid _44;
|
||||||
|
_44.element = _40.Load(ix * 4 + 0);
|
||||||
|
Monoid agg;
|
||||||
|
agg.element = _44.element;
|
||||||
|
Monoid param_1;
|
||||||
|
for (uint i = 1u; i < 8u; i++)
|
||||||
|
{
|
||||||
|
Monoid param = agg;
|
||||||
|
Monoid _64;
|
||||||
|
_64.element = _40.Load((ix + i) * 4 + 0);
|
||||||
|
param_1.element = _64.element;
|
||||||
|
agg = combine_monoid(param, param_1);
|
||||||
|
}
|
||||||
|
sh_scratch[gl_LocalInvocationID.x] = agg;
|
||||||
|
for (uint i_1 = 0u; i_1 < 9u; i_1++)
|
||||||
|
{
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
if ((gl_LocalInvocationID.x + uint(1 << int(i_1))) < 512u)
|
||||||
|
{
|
||||||
|
Monoid other = sh_scratch[gl_LocalInvocationID.x + uint(1 << int(i_1))];
|
||||||
|
Monoid param_2 = agg;
|
||||||
|
Monoid param_3 = other;
|
||||||
|
agg = combine_monoid(param_2, param_3);
|
||||||
|
}
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
sh_scratch[gl_LocalInvocationID.x] = agg;
|
||||||
|
}
|
||||||
|
if (gl_LocalInvocationID.x == 0u)
|
||||||
|
{
|
||||||
|
_129.Store(gl_WorkGroupID.x * 4 + 0, agg.element);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[numthreads(512, 1, 1)]
|
||||||
|
void main(SPIRV_Cross_Input stage_input)
|
||||||
|
{
|
||||||
|
gl_WorkGroupID = stage_input.gl_WorkGroupID;
|
||||||
|
gl_LocalInvocationID = stage_input.gl_LocalInvocationID;
|
||||||
|
gl_GlobalInvocationID = stage_input.gl_GlobalInvocationID;
|
||||||
|
comp_main();
|
||||||
|
}
|
68
tests/shader/gen/prefix_reduce.msl
Normal file
68
tests/shader/gen/prefix_reduce.msl
Normal file
|
@ -0,0 +1,68 @@
|
||||||
|
#pragma clang diagnostic ignored "-Wmissing-prototypes"
|
||||||
|
|
||||||
|
#include <metal_stdlib>
|
||||||
|
#include <simd/simd.h>
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
struct Monoid
|
||||||
|
{
|
||||||
|
uint element;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Monoid_1
|
||||||
|
{
|
||||||
|
uint element;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct InBuf
|
||||||
|
{
|
||||||
|
Monoid_1 inbuf[1];
|
||||||
|
};
|
||||||
|
|
||||||
|
struct OutBuf
|
||||||
|
{
|
||||||
|
Monoid_1 outbuf[1];
|
||||||
|
};
|
||||||
|
|
||||||
|
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u);
|
||||||
|
|
||||||
|
static inline __attribute__((always_inline))
|
||||||
|
Monoid combine_monoid(thread const Monoid& a, thread const Monoid& b)
|
||||||
|
{
|
||||||
|
return Monoid{ a.element + b.element };
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void main0(const device InBuf& _40 [[buffer(0)]], device OutBuf& _129 [[buffer(1)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]])
|
||||||
|
{
|
||||||
|
threadgroup Monoid sh_scratch[512];
|
||||||
|
uint ix = gl_GlobalInvocationID.x * 8u;
|
||||||
|
Monoid agg;
|
||||||
|
agg.element = _40.inbuf[ix].element;
|
||||||
|
Monoid param_1;
|
||||||
|
for (uint i = 1u; i < 8u; i++)
|
||||||
|
{
|
||||||
|
Monoid param = agg;
|
||||||
|
param_1.element = _40.inbuf[ix + i].element;
|
||||||
|
agg = combine_monoid(param, param_1);
|
||||||
|
}
|
||||||
|
sh_scratch[gl_LocalInvocationID.x] = agg;
|
||||||
|
for (uint i_1 = 0u; i_1 < 9u; i_1++)
|
||||||
|
{
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
if ((gl_LocalInvocationID.x + uint(1 << int(i_1))) < 512u)
|
||||||
|
{
|
||||||
|
Monoid other = sh_scratch[gl_LocalInvocationID.x + uint(1 << int(i_1))];
|
||||||
|
Monoid param_2 = agg;
|
||||||
|
Monoid param_3 = other;
|
||||||
|
agg = combine_monoid(param_2, param_3);
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
sh_scratch[gl_LocalInvocationID.x] = agg;
|
||||||
|
}
|
||||||
|
if (gl_LocalInvocationID.x == 0u)
|
||||||
|
{
|
||||||
|
_129.outbuf[gl_WorkGroupID.x].element = agg.element;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
BIN
tests/shader/gen/prefix_reduce.spv
Normal file
BIN
tests/shader/gen/prefix_reduce.spv
Normal file
Binary file not shown.
80
tests/shader/gen/prefix_root.hlsl
Normal file
80
tests/shader/gen/prefix_root.hlsl
Normal file
|
@ -0,0 +1,80 @@
|
||||||
|
struct Monoid
|
||||||
|
{
|
||||||
|
uint element;
|
||||||
|
};
|
||||||
|
|
||||||
|
static const uint3 gl_WorkGroupSize = uint3(512u, 1u, 1u);
|
||||||
|
|
||||||
|
static const Monoid _133 = { 0u };
|
||||||
|
|
||||||
|
RWByteAddressBuffer _42 : register(u0);
|
||||||
|
|
||||||
|
static uint3 gl_LocalInvocationID;
|
||||||
|
static uint3 gl_GlobalInvocationID;
|
||||||
|
struct SPIRV_Cross_Input
|
||||||
|
{
|
||||||
|
uint3 gl_LocalInvocationID : SV_GroupThreadID;
|
||||||
|
uint3 gl_GlobalInvocationID : SV_DispatchThreadID;
|
||||||
|
};
|
||||||
|
|
||||||
|
groupshared Monoid sh_scratch[512];
|
||||||
|
|
||||||
|
Monoid combine_monoid(Monoid a, Monoid b)
|
||||||
|
{
|
||||||
|
Monoid _22 = { a.element + b.element };
|
||||||
|
return _22;
|
||||||
|
}
|
||||||
|
|
||||||
|
void comp_main()
|
||||||
|
{
|
||||||
|
uint ix = gl_GlobalInvocationID.x * 8u;
|
||||||
|
Monoid _46;
|
||||||
|
_46.element = _42.Load(ix * 4 + 0);
|
||||||
|
Monoid local[8];
|
||||||
|
local[0].element = _46.element;
|
||||||
|
Monoid param_1;
|
||||||
|
for (uint i = 1u; i < 8u; i++)
|
||||||
|
{
|
||||||
|
Monoid param = local[i - 1u];
|
||||||
|
Monoid _71;
|
||||||
|
_71.element = _42.Load((ix + i) * 4 + 0);
|
||||||
|
param_1.element = _71.element;
|
||||||
|
local[i] = combine_monoid(param, param_1);
|
||||||
|
}
|
||||||
|
Monoid agg = local[7];
|
||||||
|
sh_scratch[gl_LocalInvocationID.x] = agg;
|
||||||
|
for (uint i_1 = 0u; i_1 < 9u; i_1++)
|
||||||
|
{
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
if (gl_LocalInvocationID.x >= uint(1 << int(i_1)))
|
||||||
|
{
|
||||||
|
Monoid other = sh_scratch[gl_LocalInvocationID.x - uint(1 << int(i_1))];
|
||||||
|
Monoid param_2 = other;
|
||||||
|
Monoid param_3 = agg;
|
||||||
|
agg = combine_monoid(param_2, param_3);
|
||||||
|
}
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
sh_scratch[gl_LocalInvocationID.x] = agg;
|
||||||
|
}
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
Monoid row = _133;
|
||||||
|
if (gl_LocalInvocationID.x > 0u)
|
||||||
|
{
|
||||||
|
row = sh_scratch[gl_LocalInvocationID.x - 1u];
|
||||||
|
}
|
||||||
|
for (uint i_2 = 0u; i_2 < 8u; i_2++)
|
||||||
|
{
|
||||||
|
Monoid param_4 = row;
|
||||||
|
Monoid param_5 = local[i_2];
|
||||||
|
Monoid m = combine_monoid(param_4, param_5);
|
||||||
|
_42.Store((ix + i_2) * 4 + 0, m.element);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[numthreads(512, 1, 1)]
|
||||||
|
void main(SPIRV_Cross_Input stage_input)
|
||||||
|
{
|
||||||
|
gl_LocalInvocationID = stage_input.gl_LocalInvocationID;
|
||||||
|
gl_GlobalInvocationID = stage_input.gl_GlobalInvocationID;
|
||||||
|
comp_main();
|
||||||
|
}
|
112
tests/shader/gen/prefix_root.msl
Normal file
112
tests/shader/gen/prefix_root.msl
Normal file
|
@ -0,0 +1,112 @@
|
||||||
|
#pragma clang diagnostic ignored "-Wmissing-prototypes"
|
||||||
|
#pragma clang diagnostic ignored "-Wmissing-braces"
|
||||||
|
|
||||||
|
#include <metal_stdlib>
|
||||||
|
#include <simd/simd.h>
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
template<typename T, size_t Num>
|
||||||
|
struct spvUnsafeArray
|
||||||
|
{
|
||||||
|
T elements[Num ? Num : 1];
|
||||||
|
|
||||||
|
thread T& operator [] (size_t pos) thread
|
||||||
|
{
|
||||||
|
return elements[pos];
|
||||||
|
}
|
||||||
|
constexpr const thread T& operator [] (size_t pos) const thread
|
||||||
|
{
|
||||||
|
return elements[pos];
|
||||||
|
}
|
||||||
|
|
||||||
|
device T& operator [] (size_t pos) device
|
||||||
|
{
|
||||||
|
return elements[pos];
|
||||||
|
}
|
||||||
|
constexpr const device T& operator [] (size_t pos) const device
|
||||||
|
{
|
||||||
|
return elements[pos];
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr const constant T& operator [] (size_t pos) const constant
|
||||||
|
{
|
||||||
|
return elements[pos];
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup T& operator [] (size_t pos) threadgroup
|
||||||
|
{
|
||||||
|
return elements[pos];
|
||||||
|
}
|
||||||
|
constexpr const threadgroup T& operator [] (size_t pos) const threadgroup
|
||||||
|
{
|
||||||
|
return elements[pos];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Monoid
|
||||||
|
{
|
||||||
|
uint element;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Monoid_1
|
||||||
|
{
|
||||||
|
uint element;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct DataBuf
|
||||||
|
{
|
||||||
|
Monoid_1 data[1];
|
||||||
|
};
|
||||||
|
|
||||||
|
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u);
|
||||||
|
|
||||||
|
static inline __attribute__((always_inline))
|
||||||
|
Monoid combine_monoid(thread const Monoid& a, thread const Monoid& b)
|
||||||
|
{
|
||||||
|
return Monoid{ a.element + b.element };
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void main0(device DataBuf& _42 [[buffer(0)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]])
|
||||||
|
{
|
||||||
|
threadgroup Monoid sh_scratch[512];
|
||||||
|
uint ix = gl_GlobalInvocationID.x * 8u;
|
||||||
|
spvUnsafeArray<Monoid, 8> local;
|
||||||
|
local[0].element = _42.data[ix].element;
|
||||||
|
Monoid param_1;
|
||||||
|
for (uint i = 1u; i < 8u; i++)
|
||||||
|
{
|
||||||
|
Monoid param = local[i - 1u];
|
||||||
|
param_1.element = _42.data[ix + i].element;
|
||||||
|
local[i] = combine_monoid(param, param_1);
|
||||||
|
}
|
||||||
|
Monoid agg = local[7];
|
||||||
|
sh_scratch[gl_LocalInvocationID.x] = agg;
|
||||||
|
for (uint i_1 = 0u; i_1 < 9u; i_1++)
|
||||||
|
{
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
if (gl_LocalInvocationID.x >= uint(1 << int(i_1)))
|
||||||
|
{
|
||||||
|
Monoid other = sh_scratch[gl_LocalInvocationID.x - uint(1 << int(i_1))];
|
||||||
|
Monoid param_2 = other;
|
||||||
|
Monoid param_3 = agg;
|
||||||
|
agg = combine_monoid(param_2, param_3);
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
sh_scratch[gl_LocalInvocationID.x] = agg;
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
Monoid row = Monoid{ 0u };
|
||||||
|
if (gl_LocalInvocationID.x > 0u)
|
||||||
|
{
|
||||||
|
row = sh_scratch[gl_LocalInvocationID.x - 1u];
|
||||||
|
}
|
||||||
|
for (uint i_2 = 0u; i_2 < 8u; i_2++)
|
||||||
|
{
|
||||||
|
Monoid param_4 = row;
|
||||||
|
Monoid param_5 = local[i_2];
|
||||||
|
Monoid m = combine_monoid(param_4, param_5);
|
||||||
|
_42.data[ix + i_2].element = m.element;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
BIN
tests/shader/gen/prefix_root.spv
Normal file
BIN
tests/shader/gen/prefix_root.spv
Normal file
Binary file not shown.
92
tests/shader/gen/prefix_scan.hlsl
Normal file
92
tests/shader/gen/prefix_scan.hlsl
Normal file
|
@ -0,0 +1,92 @@
|
||||||
|
struct Monoid
|
||||||
|
{
|
||||||
|
uint element;
|
||||||
|
};
|
||||||
|
|
||||||
|
static const uint3 gl_WorkGroupSize = uint3(512u, 1u, 1u);
|
||||||
|
|
||||||
|
static const Monoid _133 = { 0u };
|
||||||
|
|
||||||
|
RWByteAddressBuffer _42 : register(u0);
|
||||||
|
RWByteAddressBuffer _143 : register(u1);
|
||||||
|
|
||||||
|
static uint3 gl_WorkGroupID;
|
||||||
|
static uint3 gl_LocalInvocationID;
|
||||||
|
static uint3 gl_GlobalInvocationID;
|
||||||
|
struct SPIRV_Cross_Input
|
||||||
|
{
|
||||||
|
uint3 gl_WorkGroupID : SV_GroupID;
|
||||||
|
uint3 gl_LocalInvocationID : SV_GroupThreadID;
|
||||||
|
uint3 gl_GlobalInvocationID : SV_DispatchThreadID;
|
||||||
|
};
|
||||||
|
|
||||||
|
groupshared Monoid sh_scratch[512];
|
||||||
|
|
||||||
|
Monoid combine_monoid(Monoid a, Monoid b)
|
||||||
|
{
|
||||||
|
Monoid _22 = { a.element + b.element };
|
||||||
|
return _22;
|
||||||
|
}
|
||||||
|
|
||||||
|
void comp_main()
|
||||||
|
{
|
||||||
|
uint ix = gl_GlobalInvocationID.x * 8u;
|
||||||
|
Monoid _46;
|
||||||
|
_46.element = _42.Load(ix * 4 + 0);
|
||||||
|
Monoid local[8];
|
||||||
|
local[0].element = _46.element;
|
||||||
|
Monoid param_1;
|
||||||
|
for (uint i = 1u; i < 8u; i++)
|
||||||
|
{
|
||||||
|
Monoid param = local[i - 1u];
|
||||||
|
Monoid _71;
|
||||||
|
_71.element = _42.Load((ix + i) * 4 + 0);
|
||||||
|
param_1.element = _71.element;
|
||||||
|
local[i] = combine_monoid(param, param_1);
|
||||||
|
}
|
||||||
|
Monoid agg = local[7];
|
||||||
|
sh_scratch[gl_LocalInvocationID.x] = agg;
|
||||||
|
for (uint i_1 = 0u; i_1 < 9u; i_1++)
|
||||||
|
{
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
if (gl_LocalInvocationID.x >= uint(1 << int(i_1)))
|
||||||
|
{
|
||||||
|
Monoid other = sh_scratch[gl_LocalInvocationID.x - uint(1 << int(i_1))];
|
||||||
|
Monoid param_2 = other;
|
||||||
|
Monoid param_3 = agg;
|
||||||
|
agg = combine_monoid(param_2, param_3);
|
||||||
|
}
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
sh_scratch[gl_LocalInvocationID.x] = agg;
|
||||||
|
}
|
||||||
|
GroupMemoryBarrierWithGroupSync();
|
||||||
|
Monoid row = _133;
|
||||||
|
if (gl_WorkGroupID.x > 0u)
|
||||||
|
{
|
||||||
|
Monoid _148;
|
||||||
|
_148.element = _143.Load((gl_WorkGroupID.x - 1u) * 4 + 0);
|
||||||
|
row.element = _148.element;
|
||||||
|
}
|
||||||
|
if (gl_LocalInvocationID.x > 0u)
|
||||||
|
{
|
||||||
|
Monoid param_4 = row;
|
||||||
|
Monoid param_5 = sh_scratch[gl_LocalInvocationID.x - 1u];
|
||||||
|
row = combine_monoid(param_4, param_5);
|
||||||
|
}
|
||||||
|
for (uint i_2 = 0u; i_2 < 8u; i_2++)
|
||||||
|
{
|
||||||
|
Monoid param_6 = row;
|
||||||
|
Monoid param_7 = local[i_2];
|
||||||
|
Monoid m = combine_monoid(param_6, param_7);
|
||||||
|
_42.Store((ix + i_2) * 4 + 0, m.element);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[numthreads(512, 1, 1)]
|
||||||
|
void main(SPIRV_Cross_Input stage_input)
|
||||||
|
{
|
||||||
|
gl_WorkGroupID = stage_input.gl_WorkGroupID;
|
||||||
|
gl_LocalInvocationID = stage_input.gl_LocalInvocationID;
|
||||||
|
gl_GlobalInvocationID = stage_input.gl_GlobalInvocationID;
|
||||||
|
comp_main();
|
||||||
|
}
|
123
tests/shader/gen/prefix_scan.msl
Normal file
123
tests/shader/gen/prefix_scan.msl
Normal file
|
@ -0,0 +1,123 @@
|
||||||
|
#pragma clang diagnostic ignored "-Wmissing-prototypes"
|
||||||
|
#pragma clang diagnostic ignored "-Wmissing-braces"
|
||||||
|
|
||||||
|
#include <metal_stdlib>
|
||||||
|
#include <simd/simd.h>
|
||||||
|
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
template<typename T, size_t Num>
|
||||||
|
struct spvUnsafeArray
|
||||||
|
{
|
||||||
|
T elements[Num ? Num : 1];
|
||||||
|
|
||||||
|
thread T& operator [] (size_t pos) thread
|
||||||
|
{
|
||||||
|
return elements[pos];
|
||||||
|
}
|
||||||
|
constexpr const thread T& operator [] (size_t pos) const thread
|
||||||
|
{
|
||||||
|
return elements[pos];
|
||||||
|
}
|
||||||
|
|
||||||
|
device T& operator [] (size_t pos) device
|
||||||
|
{
|
||||||
|
return elements[pos];
|
||||||
|
}
|
||||||
|
constexpr const device T& operator [] (size_t pos) const device
|
||||||
|
{
|
||||||
|
return elements[pos];
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr const constant T& operator [] (size_t pos) const constant
|
||||||
|
{
|
||||||
|
return elements[pos];
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup T& operator [] (size_t pos) threadgroup
|
||||||
|
{
|
||||||
|
return elements[pos];
|
||||||
|
}
|
||||||
|
constexpr const threadgroup T& operator [] (size_t pos) const threadgroup
|
||||||
|
{
|
||||||
|
return elements[pos];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Monoid
|
||||||
|
{
|
||||||
|
uint element;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Monoid_1
|
||||||
|
{
|
||||||
|
uint element;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct DataBuf
|
||||||
|
{
|
||||||
|
Monoid_1 data[1];
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ParentBuf
|
||||||
|
{
|
||||||
|
Monoid_1 parent[1];
|
||||||
|
};
|
||||||
|
|
||||||
|
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u);
|
||||||
|
|
||||||
|
static inline __attribute__((always_inline))
|
||||||
|
Monoid combine_monoid(thread const Monoid& a, thread const Monoid& b)
|
||||||
|
{
|
||||||
|
return Monoid{ a.element + b.element };
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void main0(device DataBuf& _42 [[buffer(0)]], device ParentBuf& _143 [[buffer(1)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]])
|
||||||
|
{
|
||||||
|
threadgroup Monoid sh_scratch[512];
|
||||||
|
uint ix = gl_GlobalInvocationID.x * 8u;
|
||||||
|
spvUnsafeArray<Monoid, 8> local;
|
||||||
|
local[0].element = _42.data[ix].element;
|
||||||
|
Monoid param_1;
|
||||||
|
for (uint i = 1u; i < 8u; i++)
|
||||||
|
{
|
||||||
|
Monoid param = local[i - 1u];
|
||||||
|
param_1.element = _42.data[ix + i].element;
|
||||||
|
local[i] = combine_monoid(param, param_1);
|
||||||
|
}
|
||||||
|
Monoid agg = local[7];
|
||||||
|
sh_scratch[gl_LocalInvocationID.x] = agg;
|
||||||
|
for (uint i_1 = 0u; i_1 < 9u; i_1++)
|
||||||
|
{
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
if (gl_LocalInvocationID.x >= uint(1 << int(i_1)))
|
||||||
|
{
|
||||||
|
Monoid other = sh_scratch[gl_LocalInvocationID.x - uint(1 << int(i_1))];
|
||||||
|
Monoid param_2 = other;
|
||||||
|
Monoid param_3 = agg;
|
||||||
|
agg = combine_monoid(param_2, param_3);
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
sh_scratch[gl_LocalInvocationID.x] = agg;
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
Monoid row = Monoid{ 0u };
|
||||||
|
if (gl_WorkGroupID.x > 0u)
|
||||||
|
{
|
||||||
|
row.element = _143.parent[gl_WorkGroupID.x - 1u].element;
|
||||||
|
}
|
||||||
|
if (gl_LocalInvocationID.x > 0u)
|
||||||
|
{
|
||||||
|
Monoid param_4 = row;
|
||||||
|
Monoid param_5 = sh_scratch[gl_LocalInvocationID.x - 1u];
|
||||||
|
row = combine_monoid(param_4, param_5);
|
||||||
|
}
|
||||||
|
for (uint i_2 = 0u; i_2 < 8u; i_2++)
|
||||||
|
{
|
||||||
|
Monoid param_6 = row;
|
||||||
|
Monoid param_7 = local[i_2];
|
||||||
|
Monoid m = combine_monoid(param_6, param_7);
|
||||||
|
_42.data[ix + i_2].element = m.element;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
BIN
tests/shader/gen/prefix_scan.spv
Normal file
BIN
tests/shader/gen/prefix_scan.spv
Normal file
Binary file not shown.
188
tests/shader/prefix.comp
Normal file
188
tests/shader/prefix.comp
Normal file
|
@ -0,0 +1,188 @@
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 OR MIT OR Unlicense
|
||||||
|
|
||||||
|
// A prefix sum.
|
||||||
|
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#define N_ROWS 16
|
||||||
|
#define LG_WG_SIZE 9
|
||||||
|
#define WG_SIZE (1 << LG_WG_SIZE)
|
||||||
|
#define PARTITION_SIZE (WG_SIZE * N_ROWS)
|
||||||
|
|
||||||
|
layout(local_size_x = WG_SIZE, local_size_y = 1) in;
|
||||||
|
|
||||||
|
struct Monoid {
|
||||||
|
uint element;
|
||||||
|
};
|
||||||
|
|
||||||
|
layout(set = 0, binding = 0) readonly buffer InBuf {
|
||||||
|
Monoid[] inbuf;
|
||||||
|
};
|
||||||
|
|
||||||
|
layout(set = 0, binding = 1) buffer OutBuf {
|
||||||
|
Monoid[] outbuf;
|
||||||
|
};
|
||||||
|
|
||||||
|
// These correspond to X, A, P respectively in the prefix sum paper.
|
||||||
|
#define FLAG_NOT_READY 0
|
||||||
|
#define FLAG_AGGREGATE_READY 1
|
||||||
|
#define FLAG_PREFIX_READY 2
|
||||||
|
|
||||||
|
struct State {
|
||||||
|
uint flag;
|
||||||
|
Monoid aggregate;
|
||||||
|
Monoid prefix;
|
||||||
|
};
|
||||||
|
|
||||||
|
layout(set = 0, binding = 2) volatile buffer StateBuf {
|
||||||
|
uint part_counter;
|
||||||
|
State[] state;
|
||||||
|
};
|
||||||
|
|
||||||
|
shared Monoid sh_scratch[WG_SIZE];
|
||||||
|
|
||||||
|
Monoid combine_monoid(Monoid a, Monoid b) {
|
||||||
|
return Monoid(a.element + b.element);
|
||||||
|
}
|
||||||
|
|
||||||
|
shared uint sh_part_ix;
|
||||||
|
shared Monoid sh_prefix;
|
||||||
|
shared uint sh_flag;
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
Monoid local[N_ROWS];
|
||||||
|
// Determine partition to process by atomic counter (described in Section
|
||||||
|
// 4.4 of prefix sum paper).
|
||||||
|
if (gl_LocalInvocationID.x == 0) {
|
||||||
|
sh_part_ix = atomicAdd(part_counter, 1);
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
uint part_ix = sh_part_ix;
|
||||||
|
|
||||||
|
uint ix = part_ix * PARTITION_SIZE + gl_LocalInvocationID.x * N_ROWS;
|
||||||
|
|
||||||
|
// TODO: gate buffer read? (evaluate whether shader check or
|
||||||
|
// CPU-side padding is better)
|
||||||
|
local[0] = inbuf[ix];
|
||||||
|
for (uint i = 1; i < N_ROWS; i++) {
|
||||||
|
local[i] = combine_monoid(local[i - 1], inbuf[ix + i]);
|
||||||
|
}
|
||||||
|
Monoid agg = local[N_ROWS - 1];
|
||||||
|
sh_scratch[gl_LocalInvocationID.x] = agg;
|
||||||
|
for (uint i = 0; i < LG_WG_SIZE; i++) {
|
||||||
|
barrier();
|
||||||
|
if (gl_LocalInvocationID.x >= (1 << i)) {
|
||||||
|
Monoid other = sh_scratch[gl_LocalInvocationID.x - (1 << i)];
|
||||||
|
agg = combine_monoid(other, agg);
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
sh_scratch[gl_LocalInvocationID.x] = agg;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Publish aggregate for this partition
|
||||||
|
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
|
||||||
|
state[part_ix].aggregate = agg;
|
||||||
|
if (part_ix == 0) {
|
||||||
|
state[0].prefix = agg;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Write flag with release semantics; this is done portably with a barrier.
|
||||||
|
memoryBarrierBuffer();
|
||||||
|
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
|
||||||
|
uint flag = FLAG_AGGREGATE_READY;
|
||||||
|
if (part_ix == 0) {
|
||||||
|
flag = FLAG_PREFIX_READY;
|
||||||
|
}
|
||||||
|
state[part_ix].flag = flag;
|
||||||
|
}
|
||||||
|
|
||||||
|
Monoid exclusive = Monoid(0);
|
||||||
|
if (part_ix != 0) {
|
||||||
|
// step 4 of paper: decoupled lookback
|
||||||
|
uint look_back_ix = part_ix - 1;
|
||||||
|
|
||||||
|
Monoid their_agg;
|
||||||
|
uint their_ix = 0;
|
||||||
|
while (true) {
|
||||||
|
// Read flag with acquire semantics.
|
||||||
|
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
|
||||||
|
sh_flag = state[look_back_ix].flag;
|
||||||
|
}
|
||||||
|
// The flag load is done only in the last thread. However, because the
|
||||||
|
// translation of memoryBarrierBuffer to Metal requires uniform control
|
||||||
|
// flow, we broadcast it to all threads.
|
||||||
|
barrier();
|
||||||
|
memoryBarrierBuffer();
|
||||||
|
uint flag = sh_flag;
|
||||||
|
|
||||||
|
if (flag == FLAG_PREFIX_READY) {
|
||||||
|
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
|
||||||
|
Monoid their_prefix = state[look_back_ix].prefix;
|
||||||
|
exclusive = combine_monoid(their_prefix, exclusive);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
} else if (flag == FLAG_AGGREGATE_READY) {
|
||||||
|
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
|
||||||
|
their_agg = state[look_back_ix].aggregate;
|
||||||
|
exclusive = combine_monoid(their_agg, exclusive);
|
||||||
|
}
|
||||||
|
look_back_ix--;
|
||||||
|
their_ix = 0;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// else spin
|
||||||
|
|
||||||
|
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
|
||||||
|
// Unfortunately there's no guarantee of forward progress of other
|
||||||
|
// workgroups, so compute a bit of the aggregate before trying again.
|
||||||
|
// In the worst case, spinning stops when the aggregate is complete.
|
||||||
|
Monoid m = inbuf[look_back_ix * PARTITION_SIZE + their_ix];
|
||||||
|
if (their_ix == 0) {
|
||||||
|
their_agg = m;
|
||||||
|
} else {
|
||||||
|
their_agg = combine_monoid(their_agg, m);
|
||||||
|
}
|
||||||
|
their_ix++;
|
||||||
|
if (their_ix == PARTITION_SIZE) {
|
||||||
|
exclusive = combine_monoid(their_agg, exclusive);
|
||||||
|
if (look_back_ix == 0) {
|
||||||
|
sh_flag = FLAG_PREFIX_READY;
|
||||||
|
} else {
|
||||||
|
look_back_ix--;
|
||||||
|
their_ix = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
flag = sh_flag;
|
||||||
|
if (flag == FLAG_PREFIX_READY) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// step 5 of paper: compute inclusive prefix
|
||||||
|
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
|
||||||
|
Monoid inclusive_prefix = combine_monoid(exclusive, agg);
|
||||||
|
sh_prefix = exclusive;
|
||||||
|
state[part_ix].prefix = inclusive_prefix;
|
||||||
|
}
|
||||||
|
memoryBarrierBuffer();
|
||||||
|
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
|
||||||
|
state[part_ix].flag = FLAG_PREFIX_READY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
if (part_ix != 0) {
|
||||||
|
exclusive = sh_prefix;
|
||||||
|
}
|
||||||
|
|
||||||
|
Monoid row = exclusive;
|
||||||
|
if (gl_LocalInvocationID.x > 0) {
|
||||||
|
Monoid other = sh_scratch[gl_LocalInvocationID.x - 1];
|
||||||
|
row = combine_monoid(row, other);
|
||||||
|
}
|
||||||
|
for (uint i = 0; i < N_ROWS; i++) {
|
||||||
|
Monoid m = combine_monoid(row, local[i]);
|
||||||
|
// Make sure buffer allocation is padded appropriately.
|
||||||
|
outbuf[ix + i] = m;
|
||||||
|
}
|
||||||
|
}
|
53
tests/shader/prefix_reduce.comp
Normal file
53
tests/shader/prefix_reduce.comp
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 OR MIT OR Unlicense
|
||||||
|
|
||||||
|
// The reduction phase for prefix sum implemented as a tree reduction.
|
||||||
|
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#define N_ROWS 8
|
||||||
|
#define LG_WG_SIZE 9
|
||||||
|
#define WG_SIZE (1 << LG_WG_SIZE)
|
||||||
|
#define PARTITION_SIZE (WG_SIZE * N_ROWS)
|
||||||
|
|
||||||
|
layout(local_size_x = WG_SIZE, local_size_y = 1) in;
|
||||||
|
|
||||||
|
struct Monoid {
|
||||||
|
uint element;
|
||||||
|
};
|
||||||
|
|
||||||
|
layout(set = 0, binding = 0) readonly buffer InBuf {
|
||||||
|
Monoid[] inbuf;
|
||||||
|
};
|
||||||
|
|
||||||
|
layout(set = 0, binding = 1) buffer OutBuf {
|
||||||
|
Monoid[] outbuf;
|
||||||
|
};
|
||||||
|
|
||||||
|
shared Monoid sh_scratch[WG_SIZE];
|
||||||
|
|
||||||
|
Monoid combine_monoid(Monoid a, Monoid b) {
|
||||||
|
return Monoid(a.element + b.element);
|
||||||
|
}
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
uint ix = gl_GlobalInvocationID.x * N_ROWS;
|
||||||
|
// TODO: gate buffer read
|
||||||
|
Monoid agg = inbuf[ix];
|
||||||
|
for (uint i = 1; i < N_ROWS; i++) {
|
||||||
|
agg = combine_monoid(agg, inbuf[ix + i]);
|
||||||
|
}
|
||||||
|
sh_scratch[gl_LocalInvocationID.x] = agg;
|
||||||
|
for (uint i = 0; i < LG_WG_SIZE; i++) {
|
||||||
|
barrier();
|
||||||
|
// We could make this predicate tighter, but would it help?
|
||||||
|
if (gl_LocalInvocationID.x + (1 << i) < WG_SIZE) {
|
||||||
|
Monoid other = sh_scratch[gl_LocalInvocationID.x + (1 << i)];
|
||||||
|
agg = combine_monoid(agg, other);
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
sh_scratch[gl_LocalInvocationID.x] = agg;
|
||||||
|
}
|
||||||
|
if (gl_LocalInvocationID.x == 0) {
|
||||||
|
outbuf[gl_WorkGroupID.x] = agg;
|
||||||
|
}
|
||||||
|
}
|
77
tests/shader/prefix_scan.comp
Normal file
77
tests/shader/prefix_scan.comp
Normal file
|
@ -0,0 +1,77 @@
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 OR MIT OR Unlicense
|
||||||
|
|
||||||
|
// A scan for a tree reduction prefix scan (either root or not, by ifdef).
|
||||||
|
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#define N_ROWS 8
|
||||||
|
#define LG_WG_SIZE 9
|
||||||
|
#define WG_SIZE (1 << LG_WG_SIZE)
|
||||||
|
#define PARTITION_SIZE (WG_SIZE * N_ROWS)
|
||||||
|
|
||||||
|
layout(local_size_x = WG_SIZE, local_size_y = 1) in;
|
||||||
|
|
||||||
|
struct Monoid {
|
||||||
|
uint element;
|
||||||
|
};
|
||||||
|
|
||||||
|
layout(set = 0, binding = 0) buffer DataBuf {
|
||||||
|
Monoid[] data;
|
||||||
|
};
|
||||||
|
|
||||||
|
#ifndef ROOT
|
||||||
|
layout(set = 0, binding = 1) buffer ParentBuf {
|
||||||
|
Monoid[] parent;
|
||||||
|
};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
shared Monoid sh_scratch[WG_SIZE];
|
||||||
|
|
||||||
|
Monoid combine_monoid(Monoid a, Monoid b) {
|
||||||
|
return Monoid(a.element + b.element);
|
||||||
|
}
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
Monoid local[N_ROWS];
|
||||||
|
|
||||||
|
uint ix = gl_GlobalInvocationID.x * N_ROWS;
|
||||||
|
|
||||||
|
// TODO: gate buffer read
|
||||||
|
local[0] = data[ix];
|
||||||
|
for (uint i = 1; i < N_ROWS; i++) {
|
||||||
|
local[i] = combine_monoid(local[i - 1], data[ix + i]);
|
||||||
|
}
|
||||||
|
Monoid agg = local[N_ROWS - 1];
|
||||||
|
sh_scratch[gl_LocalInvocationID.x] = agg;
|
||||||
|
for (uint i = 0; i < LG_WG_SIZE; i++) {
|
||||||
|
barrier();
|
||||||
|
if (gl_LocalInvocationID.x >= (1 << i)) {
|
||||||
|
Monoid other = sh_scratch[gl_LocalInvocationID.x - (1 << i)];
|
||||||
|
agg = combine_monoid(other, agg);
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
sh_scratch[gl_LocalInvocationID.x] = agg;
|
||||||
|
}
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
// This could be a semigroup instead of a monoid if we reworked the
|
||||||
|
// conditional logic, but that might impact performance.
|
||||||
|
Monoid row = Monoid(0);
|
||||||
|
#ifdef ROOT
|
||||||
|
if (gl_LocalInvocationID.x > 0) {
|
||||||
|
row = sh_scratch[gl_LocalInvocationID.x - 1];
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
if (gl_WorkGroupID.x > 0) {
|
||||||
|
row = parent[gl_WorkGroupID.x - 1];
|
||||||
|
}
|
||||||
|
if (gl_LocalInvocationID.x > 0) {
|
||||||
|
row = combine_monoid(row, sh_scratch[gl_LocalInvocationID.x - 1]);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
for (uint i = 0; i < N_ROWS; i++) {
|
||||||
|
Monoid m = combine_monoid(row, local[i]);
|
||||||
|
// TODO: gate buffer write
|
||||||
|
data[ix + i] = m;
|
||||||
|
}
|
||||||
|
}
|
72
tests/src/config.rs
Normal file
72
tests/src/config.rs
Normal file
|
@ -0,0 +1,72 @@
|
||||||
|
// Copyright 2021 The piet-gpu authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//
|
||||||
|
// Also licensed under MIT license, at your choice.
|
||||||
|
|
||||||
|
//! Test config parameters.
|
||||||
|
|
||||||
|
use clap::ArgMatches;
|
||||||
|
|
||||||
|
pub struct Config {
|
||||||
|
pub groups: Groups,
|
||||||
|
pub size: Size,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Groups(String);
|
||||||
|
|
||||||
|
pub enum Size {
|
||||||
|
Small,
|
||||||
|
Medium,
|
||||||
|
Large,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Config {
|
||||||
|
pub fn from_matches(matches: &ArgMatches) -> Config {
|
||||||
|
let groups = Groups::from_str(matches.value_of("groups").unwrap_or("all"));
|
||||||
|
let size = Size::from_str(matches.value_of("size").unwrap_or("m"));
|
||||||
|
Config {
|
||||||
|
groups, size
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Groups {
|
||||||
|
pub fn from_str(s: &str) -> Groups {
|
||||||
|
Groups(s.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn matches(&self, group_name: &str) -> bool {
|
||||||
|
self.0 == "all" || self.0 == group_name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Size {
|
||||||
|
fn from_str(s: &str) -> Size {
|
||||||
|
if s == "small" || s == "s" {
|
||||||
|
Size::Small
|
||||||
|
} else if s == "large" || s == "l" {
|
||||||
|
Size::Large
|
||||||
|
} else {
|
||||||
|
Size::Medium
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn choose<T>(&self, small: T, medium: T, large: T) -> T {
|
||||||
|
match self {
|
||||||
|
Size::Small => small,
|
||||||
|
Size::Medium => medium,
|
||||||
|
Size::Large => large,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
77
tests/src/main.rs
Normal file
77
tests/src/main.rs
Normal file
|
@ -0,0 +1,77 @@
|
||||||
|
// Copyright 2021 The piet-gpu authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//
|
||||||
|
// Also licensed under MIT license, at your choice.
|
||||||
|
|
||||||
|
//! Tests for piet-gpu shaders and GPU capabilities.
|
||||||
|
|
||||||
|
mod config;
|
||||||
|
mod prefix;
|
||||||
|
mod prefix_tree;
|
||||||
|
mod runner;
|
||||||
|
mod test_result;
|
||||||
|
|
||||||
|
use clap::{App, Arg};
|
||||||
|
|
||||||
|
use crate::config::Config;
|
||||||
|
use crate::runner::Runner;
|
||||||
|
use crate::test_result::{ReportStyle, TestResult};
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let matches = App::new("piet-gpu-tests")
|
||||||
|
.arg(
|
||||||
|
Arg::with_name("verbose")
|
||||||
|
.short("v")
|
||||||
|
.long("verbose")
|
||||||
|
.help("Verbose reporting of results"),
|
||||||
|
)
|
||||||
|
.arg(
|
||||||
|
Arg::with_name("groups")
|
||||||
|
.short("g")
|
||||||
|
.long("groups")
|
||||||
|
.help("Groups to run")
|
||||||
|
.takes_value(true)
|
||||||
|
)
|
||||||
|
.arg(
|
||||||
|
Arg::with_name("size")
|
||||||
|
.short("s")
|
||||||
|
.long("size")
|
||||||
|
.help("Size of tests")
|
||||||
|
.takes_value(true)
|
||||||
|
)
|
||||||
|
.arg(
|
||||||
|
Arg::with_name("n_iter")
|
||||||
|
.short("n")
|
||||||
|
.long("n_iter")
|
||||||
|
.help("Number of iterations")
|
||||||
|
.takes_value(true)
|
||||||
|
)
|
||||||
|
.get_matches();
|
||||||
|
let style = if matches.is_present("verbose") {
|
||||||
|
ReportStyle::Verbose
|
||||||
|
} else {
|
||||||
|
ReportStyle::Short
|
||||||
|
};
|
||||||
|
let config = Config::from_matches(&matches);
|
||||||
|
unsafe {
|
||||||
|
let report = |test_result: &TestResult| {
|
||||||
|
test_result.report(style);
|
||||||
|
};
|
||||||
|
let mut runner = Runner::new();
|
||||||
|
if config.groups.matches("prefix") {
|
||||||
|
report(&prefix::run_prefix_test(&mut runner, &config));
|
||||||
|
report(&prefix_tree::run_prefix_test(&mut runner, &config));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
149
tests/src/prefix.rs
Normal file
149
tests/src/prefix.rs
Normal file
|
@ -0,0 +1,149 @@
|
||||||
|
// Copyright 2021 The piet-gpu authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//
|
||||||
|
// Also licensed under MIT license, at your choice.
|
||||||
|
|
||||||
|
use piet_gpu_hal::{include_shader, BufferUsage, DescriptorSet};
|
||||||
|
use piet_gpu_hal::{Buffer, Pipeline};
|
||||||
|
|
||||||
|
use crate::config::Config;
|
||||||
|
use crate::runner::{Commands, Runner};
|
||||||
|
use crate::test_result::TestResult;
|
||||||
|
|
||||||
|
const WG_SIZE: u64 = 512;
|
||||||
|
const N_ROWS: u64 = 16;
|
||||||
|
const ELEMENTS_PER_WG: u64 = WG_SIZE * N_ROWS;
|
||||||
|
|
||||||
|
/// The shader code for the prefix sum example.
|
||||||
|
///
|
||||||
|
/// A code struct can be created once and reused any number of times.
|
||||||
|
struct PrefixCode {
|
||||||
|
pipeline: Pipeline,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The stage resources for the prefix sum example.
|
||||||
|
///
|
||||||
|
/// A stage resources struct is specific to a particular problem size
|
||||||
|
/// and queue.
|
||||||
|
struct PrefixStage {
|
||||||
|
// This is the actual problem size but perhaps it would be better to
|
||||||
|
// treat it as a capacity.
|
||||||
|
n_elements: u64,
|
||||||
|
state_buf: Buffer,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The binding for the prefix sum example.
|
||||||
|
struct PrefixBinding {
|
||||||
|
descriptor_set: DescriptorSet,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub unsafe fn run_prefix_test(runner: &mut Runner, config: &Config) -> TestResult {
|
||||||
|
let mut result = TestResult::new("prefix sum, decoupled look-back");
|
||||||
|
// This will be configurable.
|
||||||
|
let n_elements: u64 = config.size.choose(1 << 12, 1 << 24, 1 << 25);
|
||||||
|
let data: Vec<u32> = (0..n_elements as u32).collect();
|
||||||
|
let data_buf = runner
|
||||||
|
.session
|
||||||
|
.create_buffer_init(&data, BufferUsage::STORAGE)
|
||||||
|
.unwrap();
|
||||||
|
let out_buf = runner.buf_down(data_buf.size());
|
||||||
|
let code = PrefixCode::new(runner);
|
||||||
|
let stage = PrefixStage::new(runner, n_elements);
|
||||||
|
let binding = stage.bind(runner, &code, &data_buf, &out_buf.dev_buf);
|
||||||
|
// Also will be configurable of course.
|
||||||
|
let n_iter = 1000;
|
||||||
|
let mut total_elapsed = 0.0;
|
||||||
|
for i in 0..n_iter {
|
||||||
|
let mut commands = runner.commands();
|
||||||
|
commands.write_timestamp(0);
|
||||||
|
stage.record(&mut commands, &code, &binding);
|
||||||
|
commands.write_timestamp(1);
|
||||||
|
if i == 0 {
|
||||||
|
commands.cmd_buf.memory_barrier();
|
||||||
|
commands.download(&out_buf);
|
||||||
|
}
|
||||||
|
total_elapsed += runner.submit(commands);
|
||||||
|
if i == 0 {
|
||||||
|
let mut dst: Vec<u32> = Default::default();
|
||||||
|
out_buf.read(&mut dst);
|
||||||
|
if let Some(failure) = verify(&dst) {
|
||||||
|
result.fail(format!("failure at {}", failure));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result.timing(total_elapsed, n_elements * n_iter);
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PrefixCode {
|
||||||
|
unsafe fn new(runner: &mut Runner) -> PrefixCode {
|
||||||
|
let code = include_shader!(&runner.session, "../shader/gen/prefix");
|
||||||
|
let pipeline = runner
|
||||||
|
.session
|
||||||
|
.create_simple_compute_pipeline(code, 3)
|
||||||
|
.unwrap();
|
||||||
|
PrefixCode { pipeline }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PrefixStage {
|
||||||
|
unsafe fn new(runner: &mut Runner, n_elements: u64) -> PrefixStage {
|
||||||
|
let n_workgroups = (n_elements + ELEMENTS_PER_WG - 1) / ELEMENTS_PER_WG;
|
||||||
|
let state_buf_size = 4 + 12 * n_workgroups;
|
||||||
|
let state_buf = runner
|
||||||
|
.session
|
||||||
|
.create_buffer(state_buf_size, BufferUsage::STORAGE | BufferUsage::COPY_DST)
|
||||||
|
.unwrap();
|
||||||
|
PrefixStage {
|
||||||
|
n_elements,
|
||||||
|
state_buf,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe fn bind(
|
||||||
|
&self,
|
||||||
|
runner: &mut Runner,
|
||||||
|
code: &PrefixCode,
|
||||||
|
in_buf: &Buffer,
|
||||||
|
out_buf: &Buffer,
|
||||||
|
) -> PrefixBinding {
|
||||||
|
let descriptor_set = runner
|
||||||
|
.session
|
||||||
|
.create_simple_descriptor_set(&code.pipeline, &[in_buf, out_buf, &self.state_buf])
|
||||||
|
.unwrap();
|
||||||
|
PrefixBinding { descriptor_set }
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe fn record(&self, commands: &mut Commands, code: &PrefixCode, bindings: &PrefixBinding) {
|
||||||
|
let n_workgroups = (self.n_elements + ELEMENTS_PER_WG - 1) / ELEMENTS_PER_WG;
|
||||||
|
commands.cmd_buf.clear_buffer(&self.state_buf, None);
|
||||||
|
commands.cmd_buf.memory_barrier();
|
||||||
|
commands.cmd_buf.dispatch(
|
||||||
|
&code.pipeline,
|
||||||
|
&bindings.descriptor_set,
|
||||||
|
(n_workgroups as u32, 1, 1),
|
||||||
|
(WG_SIZE as u32, 1, 1),
|
||||||
|
);
|
||||||
|
// One thing that's missing here is registering the buffers so
|
||||||
|
// they can be safely dropped by Rust code before the execution
|
||||||
|
// of the command buffer completes.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that the data is OEIS A000217
|
||||||
|
fn verify(data: &[u32]) -> Option<usize> {
|
||||||
|
data.iter()
|
||||||
|
.enumerate()
|
||||||
|
.position(|(i, val)| ((i * (i + 1)) / 2) as u32 != *val)
|
||||||
|
}
|
211
tests/src/prefix_tree.rs
Normal file
211
tests/src/prefix_tree.rs
Normal file
|
@ -0,0 +1,211 @@
|
||||||
|
// Copyright 2021 The piet-gpu authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//
|
||||||
|
// Also licensed under MIT license, at your choice.
|
||||||
|
|
||||||
|
use piet_gpu_hal::{include_shader, BufferUsage, DescriptorSet};
|
||||||
|
use piet_gpu_hal::{Buffer, Pipeline};
|
||||||
|
|
||||||
|
use crate::config::Config;
|
||||||
|
use crate::runner::{Commands, Runner};
|
||||||
|
use crate::test_result::TestResult;
|
||||||
|
|
||||||
|
const WG_SIZE: u64 = 512;
|
||||||
|
const N_ROWS: u64 = 8;
|
||||||
|
const ELEMENTS_PER_WG: u64 = WG_SIZE * N_ROWS;
|
||||||
|
|
||||||
|
struct PrefixTreeCode {
|
||||||
|
reduce_pipeline: Pipeline,
|
||||||
|
scan_pipeline: Pipeline,
|
||||||
|
root_pipeline: Pipeline,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct PrefixTreeStage {
|
||||||
|
sizes: Vec<u64>,
|
||||||
|
tmp_bufs: Vec<Buffer>,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct PrefixTreeBinding {
|
||||||
|
// All but the first and last can be moved to stage.
|
||||||
|
descriptor_sets: Vec<DescriptorSet>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub unsafe fn run_prefix_test(runner: &mut Runner, config: &Config) -> TestResult {
|
||||||
|
let mut result = TestResult::new("prefix sum, tree reduction");
|
||||||
|
// This will be configurable. Note though that the current code is
|
||||||
|
// prone to reading and writing past the end of buffers if this is
|
||||||
|
// not a power of the number of elements processed in a workgroup.
|
||||||
|
let n_elements: u64 = config.size.choose(1 << 12, 1 << 24, 1 << 24);
|
||||||
|
let data: Vec<u32> = (0..n_elements as u32).collect();
|
||||||
|
let data_buf = runner
|
||||||
|
.session
|
||||||
|
.create_buffer_init(&data, BufferUsage::STORAGE)
|
||||||
|
.unwrap();
|
||||||
|
let out_buf = runner.buf_down(data_buf.size());
|
||||||
|
let code = PrefixTreeCode::new(runner);
|
||||||
|
let stage = PrefixTreeStage::new(runner, n_elements);
|
||||||
|
let binding = stage.bind(runner, &code, &out_buf.dev_buf);
|
||||||
|
// Also will be configurable of course.
|
||||||
|
let n_iter = 1000;
|
||||||
|
let mut total_elapsed = 0.0;
|
||||||
|
for i in 0..n_iter {
|
||||||
|
let mut commands = runner.commands();
|
||||||
|
commands.cmd_buf.copy_buffer(&data_buf, &out_buf.dev_buf);
|
||||||
|
commands.cmd_buf.memory_barrier();
|
||||||
|
commands.write_timestamp(0);
|
||||||
|
stage.record(&mut commands, &code, &binding);
|
||||||
|
commands.write_timestamp(1);
|
||||||
|
if i == 0 {
|
||||||
|
commands.cmd_buf.memory_barrier();
|
||||||
|
commands.download(&out_buf);
|
||||||
|
}
|
||||||
|
total_elapsed += runner.submit(commands);
|
||||||
|
if i == 0 {
|
||||||
|
let mut dst: Vec<u32> = Default::default();
|
||||||
|
out_buf.read(&mut dst);
|
||||||
|
if let Some(failure) = verify(&dst) {
|
||||||
|
result.fail(format!("failure at {}", failure));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result.timing(total_elapsed, n_elements * n_iter);
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PrefixTreeCode {
|
||||||
|
unsafe fn new(runner: &mut Runner) -> PrefixTreeCode {
|
||||||
|
let reduce_code = include_shader!(&runner.session, "../shader/gen/prefix_reduce");
|
||||||
|
let reduce_pipeline = runner
|
||||||
|
.session
|
||||||
|
.create_simple_compute_pipeline(reduce_code, 2)
|
||||||
|
.unwrap();
|
||||||
|
let scan_code = include_shader!(&runner.session, "../shader/gen/prefix_scan");
|
||||||
|
let scan_pipeline = runner
|
||||||
|
.session
|
||||||
|
.create_simple_compute_pipeline(scan_code, 2)
|
||||||
|
.unwrap();
|
||||||
|
let root_code = include_shader!(&runner.session, "../shader/gen/prefix_root");
|
||||||
|
let root_pipeline = runner
|
||||||
|
.session
|
||||||
|
.create_simple_compute_pipeline(root_code, 1)
|
||||||
|
.unwrap();
|
||||||
|
PrefixTreeCode {
|
||||||
|
reduce_pipeline,
|
||||||
|
scan_pipeline,
|
||||||
|
root_pipeline,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PrefixTreeStage {
|
||||||
|
unsafe fn new(runner: &mut Runner, n_elements: u64) -> PrefixTreeStage {
|
||||||
|
let mut size = n_elements;
|
||||||
|
let mut sizes = vec![size];
|
||||||
|
let mut tmp_bufs = Vec::new();
|
||||||
|
while size > ELEMENTS_PER_WG {
|
||||||
|
size = (size + ELEMENTS_PER_WG - 1) / ELEMENTS_PER_WG;
|
||||||
|
sizes.push(size);
|
||||||
|
let buf = runner
|
||||||
|
.session
|
||||||
|
.create_buffer(4 * size, BufferUsage::STORAGE)
|
||||||
|
.unwrap();
|
||||||
|
tmp_bufs.push(buf);
|
||||||
|
}
|
||||||
|
PrefixTreeStage { sizes, tmp_bufs }
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe fn bind(
|
||||||
|
&self,
|
||||||
|
runner: &mut Runner,
|
||||||
|
code: &PrefixTreeCode,
|
||||||
|
data_buf: &Buffer,
|
||||||
|
) -> PrefixTreeBinding {
|
||||||
|
let mut descriptor_sets = Vec::with_capacity(2 * self.tmp_bufs.len() + 1);
|
||||||
|
for i in 0..self.tmp_bufs.len() {
|
||||||
|
let buf0 = if i == 0 {
|
||||||
|
data_buf
|
||||||
|
} else {
|
||||||
|
&self.tmp_bufs[i - 1]
|
||||||
|
};
|
||||||
|
let buf1 = &self.tmp_bufs[i];
|
||||||
|
let descriptor_set = runner
|
||||||
|
.session
|
||||||
|
.create_simple_descriptor_set(&code.reduce_pipeline, &[buf0, buf1])
|
||||||
|
.unwrap();
|
||||||
|
descriptor_sets.push(descriptor_set);
|
||||||
|
}
|
||||||
|
let buf0 = self.tmp_bufs.last().unwrap_or(data_buf);
|
||||||
|
let descriptor_set = runner
|
||||||
|
.session
|
||||||
|
.create_simple_descriptor_set(&code.root_pipeline, &[buf0])
|
||||||
|
.unwrap();
|
||||||
|
descriptor_sets.push(descriptor_set);
|
||||||
|
for i in (0..self.tmp_bufs.len()).rev() {
|
||||||
|
let buf0 = if i == 0 {
|
||||||
|
data_buf
|
||||||
|
} else {
|
||||||
|
&self.tmp_bufs[i - 1]
|
||||||
|
};
|
||||||
|
let buf1 = &self.tmp_bufs[i];
|
||||||
|
let descriptor_set = runner
|
||||||
|
.session
|
||||||
|
.create_simple_descriptor_set(&code.scan_pipeline, &[buf0, buf1])
|
||||||
|
.unwrap();
|
||||||
|
descriptor_sets.push(descriptor_set);
|
||||||
|
}
|
||||||
|
PrefixTreeBinding { descriptor_sets }
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe fn record(
|
||||||
|
&self,
|
||||||
|
commands: &mut Commands,
|
||||||
|
code: &PrefixTreeCode,
|
||||||
|
bindings: &PrefixTreeBinding,
|
||||||
|
) {
|
||||||
|
let n = self.tmp_bufs.len();
|
||||||
|
for i in 0..n {
|
||||||
|
let n_workgroups = self.sizes[i + 1];
|
||||||
|
commands.cmd_buf.dispatch(
|
||||||
|
&code.reduce_pipeline,
|
||||||
|
&bindings.descriptor_sets[i],
|
||||||
|
(n_workgroups as u32, 1, 1),
|
||||||
|
(WG_SIZE as u32, 1, 1),
|
||||||
|
);
|
||||||
|
commands.cmd_buf.memory_barrier();
|
||||||
|
}
|
||||||
|
commands.cmd_buf.dispatch(
|
||||||
|
&code.root_pipeline,
|
||||||
|
&bindings.descriptor_sets[n],
|
||||||
|
(1, 1, 1),
|
||||||
|
(WG_SIZE as u32, 1, 1),
|
||||||
|
);
|
||||||
|
for i in (0..n).rev() {
|
||||||
|
commands.cmd_buf.memory_barrier();
|
||||||
|
let n_workgroups = self.sizes[i + 1];
|
||||||
|
commands.cmd_buf.dispatch(
|
||||||
|
&code.scan_pipeline,
|
||||||
|
&bindings.descriptor_sets[2 * n - i],
|
||||||
|
(n_workgroups as u32, 1, 1),
|
||||||
|
(WG_SIZE as u32, 1, 1),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that the data is OEIS A000217
|
||||||
|
fn verify(data: &[u32]) -> Option<usize> {
|
||||||
|
data.iter()
|
||||||
|
.enumerate()
|
||||||
|
.position(|(i, val)| ((i * (i + 1)) / 2) as u32 != *val)
|
||||||
|
}
|
138
tests/src/runner.rs
Normal file
138
tests/src/runner.rs
Normal file
|
@ -0,0 +1,138 @@
|
||||||
|
// Copyright 2021 The piet-gpu authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//
|
||||||
|
// Also licensed under MIT license, at your choice.
|
||||||
|
|
||||||
|
//! Test runner intended to make it easy to write tests.
|
||||||
|
|
||||||
|
use piet_gpu_hal::{Buffer, BufferUsage, CmdBuf, Instance, PlainData, QueryPool, Session};
|
||||||
|
|
||||||
|
pub struct Runner {
|
||||||
|
#[allow(unused)]
|
||||||
|
instance: Instance,
|
||||||
|
pub session: Session,
|
||||||
|
cmd_buf_pool: Vec<CmdBuf>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A wrapper around command buffers.
|
||||||
|
pub struct Commands {
|
||||||
|
pub cmd_buf: CmdBuf,
|
||||||
|
query_pool: QueryPool,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Buffer for uploading data to GPU.
|
||||||
|
#[allow(unused)]
|
||||||
|
pub struct BufUp {
|
||||||
|
pub stage_buf: Buffer,
|
||||||
|
pub dev_buf: Buffer,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Buffer for downloading data from GPU.
|
||||||
|
pub struct BufDown {
|
||||||
|
pub stage_buf: Buffer,
|
||||||
|
pub dev_buf: Buffer,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Runner {
|
||||||
|
pub unsafe fn new() -> Runner {
|
||||||
|
let (instance, _) = Instance::new(None).unwrap();
|
||||||
|
let device = instance.device(None).unwrap();
|
||||||
|
let session = Session::new(device);
|
||||||
|
let cmd_buf_pool = Vec::new();
|
||||||
|
Runner {
|
||||||
|
instance,
|
||||||
|
session,
|
||||||
|
cmd_buf_pool,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub unsafe fn commands(&mut self) -> Commands {
|
||||||
|
let mut cmd_buf = self
|
||||||
|
.cmd_buf_pool
|
||||||
|
.pop()
|
||||||
|
.unwrap_or_else(|| self.session.cmd_buf().unwrap());
|
||||||
|
cmd_buf.begin();
|
||||||
|
// TODO: also pool these. But we might sort by size, as
|
||||||
|
// we might not always be doing two.
|
||||||
|
let query_pool = self.session.create_query_pool(2).unwrap();
|
||||||
|
cmd_buf.reset_query_pool(&query_pool);
|
||||||
|
Commands {
|
||||||
|
cmd_buf,
|
||||||
|
query_pool,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub unsafe fn submit(&mut self, commands: Commands) -> f64 {
|
||||||
|
let mut cmd_buf = commands.cmd_buf;
|
||||||
|
let query_pool = commands.query_pool;
|
||||||
|
cmd_buf.finish_timestamps(&query_pool);
|
||||||
|
cmd_buf.host_barrier();
|
||||||
|
cmd_buf.finish();
|
||||||
|
let submitted = self.session.run_cmd_buf(cmd_buf, &[], &[]).unwrap();
|
||||||
|
self.cmd_buf_pool.extend(submitted.wait().unwrap());
|
||||||
|
let timestamps = self.session.fetch_query_pool(&query_pool).unwrap();
|
||||||
|
timestamps[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(unused)]
|
||||||
|
pub fn buf_up(&self, size: u64) -> BufUp {
|
||||||
|
let stage_buf = self
|
||||||
|
.session
|
||||||
|
.create_buffer(size, BufferUsage::MAP_WRITE | BufferUsage::COPY_SRC)
|
||||||
|
.unwrap();
|
||||||
|
let dev_buf = self
|
||||||
|
.session
|
||||||
|
.create_buffer(size, BufferUsage::COPY_DST | BufferUsage::STORAGE)
|
||||||
|
.unwrap();
|
||||||
|
BufUp { stage_buf, dev_buf }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn buf_down(&self, size: u64) -> BufDown {
|
||||||
|
let stage_buf = self
|
||||||
|
.session
|
||||||
|
.create_buffer(size, BufferUsage::MAP_READ | BufferUsage::COPY_DST)
|
||||||
|
.unwrap();
|
||||||
|
// Note: the COPY_DST isn't needed in all use cases, but I don't think
|
||||||
|
// making this tighter would help.
|
||||||
|
let dev_buf = self
|
||||||
|
.session
|
||||||
|
.create_buffer(
|
||||||
|
size,
|
||||||
|
BufferUsage::COPY_SRC | BufferUsage::COPY_DST | BufferUsage::STORAGE,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
BufDown { stage_buf, dev_buf }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Commands {
|
||||||
|
pub unsafe fn write_timestamp(&mut self, query: u32) {
|
||||||
|
self.cmd_buf.write_timestamp(&self.query_pool, query);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(unused)]
|
||||||
|
pub unsafe fn upload(&mut self, buf: &BufUp) {
|
||||||
|
self.cmd_buf.copy_buffer(&buf.stage_buf, &buf.dev_buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub unsafe fn download(&mut self, buf: &BufDown) {
|
||||||
|
self.cmd_buf.copy_buffer(&buf.dev_buf, &buf.stage_buf);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl BufDown {
|
||||||
|
pub unsafe fn read(&self, dst: &mut Vec<impl PlainData>) {
|
||||||
|
self.stage_buf.read(dst).unwrap()
|
||||||
|
}
|
||||||
|
}
|
110
tests/src/test_result.rs
Normal file
110
tests/src/test_result.rs
Normal file
|
@ -0,0 +1,110 @@
|
||||||
|
// Copyright 2021 The piet-gpu authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//
|
||||||
|
// Also licensed under MIT license, at your choice.
|
||||||
|
|
||||||
|
//! Recording of results from tests.
|
||||||
|
|
||||||
|
pub struct TestResult {
|
||||||
|
name: String,
|
||||||
|
// TODO: statistics. We're lean and mean for now.
|
||||||
|
total_time: f64,
|
||||||
|
n_elements: u64,
|
||||||
|
failure: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy)]
|
||||||
|
pub enum ReportStyle {
|
||||||
|
Short,
|
||||||
|
Verbose,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TestResult {
|
||||||
|
pub fn new(name: &str) -> TestResult {
|
||||||
|
TestResult {
|
||||||
|
name: name.to_string(),
|
||||||
|
total_time: 0.0,
|
||||||
|
n_elements: 0,
|
||||||
|
failure: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn report(&self, style: ReportStyle) {
|
||||||
|
let fail_string = match &self.failure {
|
||||||
|
None => "pass".into(),
|
||||||
|
Some(s) => format!("fail ({})", s),
|
||||||
|
};
|
||||||
|
match style {
|
||||||
|
ReportStyle::Short => {
|
||||||
|
let mut timing_string = String::new();
|
||||||
|
if self.total_time > 0.0 {
|
||||||
|
if self.n_elements > 0 {
|
||||||
|
let throughput = self.n_elements as f64 / self.total_time;
|
||||||
|
timing_string = format!(" {} elements/s", format_nice(throughput, 1));
|
||||||
|
} else {
|
||||||
|
timing_string = format!(" {}s", format_nice(self.total_time, 1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
println!("{}: {}{}", self.name, fail_string, timing_string)
|
||||||
|
}
|
||||||
|
ReportStyle::Verbose => {
|
||||||
|
println!("test {}", self.name);
|
||||||
|
println!(" {}", fail_string);
|
||||||
|
if self.total_time > 0.0 {
|
||||||
|
println!(" {}s total time", format_nice(self.total_time, 1));
|
||||||
|
if self.n_elements > 0 {
|
||||||
|
println!(" {} elements", self.n_elements);
|
||||||
|
let throughput = self.n_elements as f64 / self.total_time;
|
||||||
|
println!(" {} elements/s", format_nice(throughput, 1));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn fail(&mut self, explanation: String) {
|
||||||
|
self.failure = Some(explanation);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn timing(&mut self, total_time: f64, n_elements: u64) {
|
||||||
|
self.total_time = total_time;
|
||||||
|
self.n_elements = n_elements;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn format_nice(x: f64, precision: usize) -> String {
|
||||||
|
// Precision should probably scale; later
|
||||||
|
let (scale, suffix) = if x >= 1e12 && x < 1e15 {
|
||||||
|
(1e-12, "T")
|
||||||
|
} else if x >= 1e9 {
|
||||||
|
(1e-9, "G")
|
||||||
|
} else if x >= 1e6 {
|
||||||
|
(1e-6, "M")
|
||||||
|
} else if x >= 1e3 {
|
||||||
|
(1e-3, "k")
|
||||||
|
} else if x >= 1.0 {
|
||||||
|
(1.0, "")
|
||||||
|
} else if x >= 1e-3 {
|
||||||
|
(1e3, "m")
|
||||||
|
} else if x >= 1e-6 {
|
||||||
|
(1e6, "\u{00b5}")
|
||||||
|
} else if x >= 1e-9 {
|
||||||
|
(1e9, "n")
|
||||||
|
} else if x >= 1e-12 {
|
||||||
|
(1e12, "p")
|
||||||
|
} else {
|
||||||
|
return format!("{:.*e}", precision, x);
|
||||||
|
};
|
||||||
|
format!("{:.*}{}", precision, scale * x, suffix)
|
||||||
|
}
|
Loading…
Reference in a new issue