mirror of
https://github.com/italicsjenga/vello.git
synced 2025-01-10 12:41:30 +11:00
Add atomic versions of prefix
This adds both regular and Vulkan memory model atomic versions of the prefix sum test, compiled by #ifdef. The build chain is getting messy, but I think it's important to test this stuff.
This commit is contained in:
parent
3f1bbe4af1
commit
825a1eb04c
|
@ -32,6 +32,16 @@ build gen/prefix.hlsl: hlsl gen/prefix.spv
|
|||
build gen/prefix.dxil: dxil gen/prefix.hlsl
|
||||
build gen/prefix.msl: msl gen/prefix.spv
|
||||
|
||||
build gen/prefix_atomic.spv: glsl prefix.comp
|
||||
flags = -DATOMIC
|
||||
build gen/prefix_atomic.hlsl: hlsl gen/prefix_atomic.spv
|
||||
build gen/prefix_atomic.dxil: dxil gen/prefix_atomic.hlsl
|
||||
build gen/prefix_atomic.msl: msl gen/prefix_atomic.spv
|
||||
|
||||
build gen/prefix_vkmm.spv: glsl prefix.comp
|
||||
flags = -DATOMIC -DVKMM
|
||||
# Vulkan memory model doesn't translate
|
||||
|
||||
build gen/prefix_reduce.spv: glsl prefix_reduce.comp
|
||||
build gen/prefix_reduce.hlsl: hlsl gen/prefix_reduce.spv
|
||||
build gen/prefix_reduce.dxil: dxil gen/prefix_reduce.hlsl
|
||||
|
|
Binary file not shown.
BIN
tests/shader/gen/prefix_atomic.dxil
Normal file
BIN
tests/shader/gen/prefix_atomic.dxil
Normal file
Binary file not shown.
227
tests/shader/gen/prefix_atomic.hlsl
Normal file
227
tests/shader/gen/prefix_atomic.hlsl
Normal file
|
@ -0,0 +1,227 @@
|
|||
struct Monoid
|
||||
{
|
||||
uint element;
|
||||
};
|
||||
|
||||
struct State
|
||||
{
|
||||
uint flag;
|
||||
Monoid aggregate;
|
||||
Monoid prefix;
|
||||
};
|
||||
|
||||
static const uint3 gl_WorkGroupSize = uint3(512u, 1u, 1u);
|
||||
|
||||
static const Monoid _185 = { 0u };
|
||||
|
||||
globallycoherent RWByteAddressBuffer _43 : register(u2);
|
||||
ByteAddressBuffer _67 : register(t0);
|
||||
RWByteAddressBuffer _372 : register(u1);
|
||||
|
||||
static uint3 gl_LocalInvocationID;
|
||||
struct SPIRV_Cross_Input
|
||||
{
|
||||
uint3 gl_LocalInvocationID : SV_GroupThreadID;
|
||||
};
|
||||
|
||||
groupshared uint sh_part_ix;
|
||||
groupshared Monoid sh_scratch[512];
|
||||
groupshared uint sh_flag;
|
||||
groupshared Monoid sh_prefix;
|
||||
|
||||
Monoid combine_monoid(Monoid a, Monoid b)
|
||||
{
|
||||
Monoid _22 = { a.element + b.element };
|
||||
return _22;
|
||||
}
|
||||
|
||||
void comp_main()
|
||||
{
|
||||
if (gl_LocalInvocationID.x == 0u)
|
||||
{
|
||||
uint _47;
|
||||
_43.InterlockedAdd(0, 1u, _47);
|
||||
sh_part_ix = _47;
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
uint part_ix = sh_part_ix;
|
||||
uint ix = (part_ix * 8192u) + (gl_LocalInvocationID.x * 16u);
|
||||
Monoid _71;
|
||||
_71.element = _67.Load(ix * 4 + 0);
|
||||
Monoid local[16];
|
||||
local[0].element = _71.element;
|
||||
Monoid param_1;
|
||||
for (uint i = 1u; i < 16u; i++)
|
||||
{
|
||||
Monoid param = local[i - 1u];
|
||||
Monoid _94;
|
||||
_94.element = _67.Load((ix + i) * 4 + 0);
|
||||
param_1.element = _94.element;
|
||||
local[i] = combine_monoid(param, param_1);
|
||||
}
|
||||
Monoid agg = local[15];
|
||||
sh_scratch[gl_LocalInvocationID.x] = agg;
|
||||
for (uint i_1 = 0u; i_1 < 9u; i_1++)
|
||||
{
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
if (gl_LocalInvocationID.x >= (1u << i_1))
|
||||
{
|
||||
Monoid other = sh_scratch[gl_LocalInvocationID.x - (1u << i_1)];
|
||||
Monoid param_2 = other;
|
||||
Monoid param_3 = agg;
|
||||
agg = combine_monoid(param_2, param_3);
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
sh_scratch[gl_LocalInvocationID.x] = agg;
|
||||
}
|
||||
if (gl_LocalInvocationID.x == 511u)
|
||||
{
|
||||
_43.Store(part_ix * 12 + 8, agg.element);
|
||||
if (part_ix == 0u)
|
||||
{
|
||||
_43.Store(12, agg.element);
|
||||
}
|
||||
}
|
||||
DeviceMemoryBarrier();
|
||||
if (gl_LocalInvocationID.x == 511u)
|
||||
{
|
||||
uint flag = 1u;
|
||||
if (part_ix == 0u)
|
||||
{
|
||||
flag = 2u;
|
||||
}
|
||||
uint _383;
|
||||
_43.InterlockedExchange(part_ix * 12 + 4, flag, _383);
|
||||
}
|
||||
Monoid exclusive = _185;
|
||||
if (part_ix != 0u)
|
||||
{
|
||||
uint look_back_ix = part_ix - 1u;
|
||||
uint their_ix = 0u;
|
||||
Monoid their_prefix;
|
||||
Monoid their_agg;
|
||||
Monoid m;
|
||||
while (true)
|
||||
{
|
||||
if (gl_LocalInvocationID.x == 511u)
|
||||
{
|
||||
uint _208;
|
||||
_43.InterlockedAdd(look_back_ix * 12 + 4, 0, _208);
|
||||
sh_flag = _208;
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
DeviceMemoryBarrier();
|
||||
uint flag_1 = sh_flag;
|
||||
if (flag_1 == 2u)
|
||||
{
|
||||
if (gl_LocalInvocationID.x == 511u)
|
||||
{
|
||||
Monoid _223;
|
||||
_223.element = _43.Load(look_back_ix * 12 + 12);
|
||||
their_prefix.element = _223.element;
|
||||
Monoid param_4 = their_prefix;
|
||||
Monoid param_5 = exclusive;
|
||||
exclusive = combine_monoid(param_4, param_5);
|
||||
}
|
||||
break;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (flag_1 == 1u)
|
||||
{
|
||||
if (gl_LocalInvocationID.x == 511u)
|
||||
{
|
||||
Monoid _245;
|
||||
_245.element = _43.Load(look_back_ix * 12 + 8);
|
||||
their_agg.element = _245.element;
|
||||
Monoid param_6 = their_agg;
|
||||
Monoid param_7 = exclusive;
|
||||
exclusive = combine_monoid(param_6, param_7);
|
||||
}
|
||||
look_back_ix--;
|
||||
their_ix = 0u;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (gl_LocalInvocationID.x == 511u)
|
||||
{
|
||||
Monoid _267;
|
||||
_267.element = _67.Load(((look_back_ix * 8192u) + their_ix) * 4 + 0);
|
||||
m.element = _267.element;
|
||||
if (their_ix == 0u)
|
||||
{
|
||||
their_agg = m;
|
||||
}
|
||||
else
|
||||
{
|
||||
Monoid param_8 = their_agg;
|
||||
Monoid param_9 = m;
|
||||
their_agg = combine_monoid(param_8, param_9);
|
||||
}
|
||||
their_ix++;
|
||||
if (their_ix == 8192u)
|
||||
{
|
||||
Monoid param_10 = their_agg;
|
||||
Monoid param_11 = exclusive;
|
||||
exclusive = combine_monoid(param_10, param_11);
|
||||
if (look_back_ix == 0u)
|
||||
{
|
||||
sh_flag = 2u;
|
||||
}
|
||||
else
|
||||
{
|
||||
look_back_ix--;
|
||||
their_ix = 0u;
|
||||
}
|
||||
}
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
flag_1 = sh_flag;
|
||||
if (flag_1 == 2u)
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (gl_LocalInvocationID.x == 511u)
|
||||
{
|
||||
Monoid param_12 = exclusive;
|
||||
Monoid param_13 = agg;
|
||||
Monoid inclusive_prefix = combine_monoid(param_12, param_13);
|
||||
sh_prefix = exclusive;
|
||||
_43.Store(part_ix * 12 + 12, inclusive_prefix.element);
|
||||
}
|
||||
DeviceMemoryBarrier();
|
||||
if (gl_LocalInvocationID.x == 511u)
|
||||
{
|
||||
uint _384;
|
||||
_43.InterlockedExchange(part_ix * 12 + 4, 2u, _384);
|
||||
}
|
||||
}
|
||||
GroupMemoryBarrierWithGroupSync();
|
||||
if (part_ix != 0u)
|
||||
{
|
||||
exclusive = sh_prefix;
|
||||
}
|
||||
Monoid row = exclusive;
|
||||
if (gl_LocalInvocationID.x > 0u)
|
||||
{
|
||||
Monoid other_1 = sh_scratch[gl_LocalInvocationID.x - 1u];
|
||||
Monoid param_14 = row;
|
||||
Monoid param_15 = other_1;
|
||||
row = combine_monoid(param_14, param_15);
|
||||
}
|
||||
for (uint i_2 = 0u; i_2 < 16u; i_2++)
|
||||
{
|
||||
Monoid param_16 = row;
|
||||
Monoid param_17 = local[i_2];
|
||||
Monoid m_1 = combine_monoid(param_16, param_17);
|
||||
_372.Store((ix + i_2) * 4 + 0, m_1.element);
|
||||
}
|
||||
}
|
||||
|
||||
[numthreads(512, 1, 1)]
|
||||
void main(SPIRV_Cross_Input stage_input)
|
||||
{
|
||||
gl_LocalInvocationID = stage_input.gl_LocalInvocationID;
|
||||
comp_main();
|
||||
}
|
263
tests/shader/gen/prefix_atomic.msl
Normal file
263
tests/shader/gen/prefix_atomic.msl
Normal file
|
@ -0,0 +1,263 @@
|
|||
#pragma clang diagnostic ignored "-Wmissing-prototypes"
|
||||
#pragma clang diagnostic ignored "-Wmissing-braces"
|
||||
#pragma clang diagnostic ignored "-Wunused-variable"
|
||||
|
||||
#include <metal_stdlib>
|
||||
#include <simd/simd.h>
|
||||
#include <metal_atomic>
|
||||
|
||||
using namespace metal;
|
||||
|
||||
template<typename T, size_t Num>
|
||||
struct spvUnsafeArray
|
||||
{
|
||||
T elements[Num ? Num : 1];
|
||||
|
||||
thread T& operator [] (size_t pos) thread
|
||||
{
|
||||
return elements[pos];
|
||||
}
|
||||
constexpr const thread T& operator [] (size_t pos) const thread
|
||||
{
|
||||
return elements[pos];
|
||||
}
|
||||
|
||||
device T& operator [] (size_t pos) device
|
||||
{
|
||||
return elements[pos];
|
||||
}
|
||||
constexpr const device T& operator [] (size_t pos) const device
|
||||
{
|
||||
return elements[pos];
|
||||
}
|
||||
|
||||
constexpr const constant T& operator [] (size_t pos) const constant
|
||||
{
|
||||
return elements[pos];
|
||||
}
|
||||
|
||||
threadgroup T& operator [] (size_t pos) threadgroup
|
||||
{
|
||||
return elements[pos];
|
||||
}
|
||||
constexpr const threadgroup T& operator [] (size_t pos) const threadgroup
|
||||
{
|
||||
return elements[pos];
|
||||
}
|
||||
};
|
||||
|
||||
struct Monoid
|
||||
{
|
||||
uint element;
|
||||
};
|
||||
|
||||
struct Monoid_1
|
||||
{
|
||||
uint element;
|
||||
};
|
||||
|
||||
struct State
|
||||
{
|
||||
uint flag;
|
||||
Monoid_1 aggregate;
|
||||
Monoid_1 prefix;
|
||||
};
|
||||
|
||||
struct StateBuf
|
||||
{
|
||||
uint part_counter;
|
||||
State state[1];
|
||||
};
|
||||
|
||||
struct InBuf
|
||||
{
|
||||
Monoid_1 inbuf[1];
|
||||
};
|
||||
|
||||
struct OutBuf
|
||||
{
|
||||
Monoid_1 outbuf[1];
|
||||
};
|
||||
|
||||
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u);
|
||||
|
||||
static inline __attribute__((always_inline))
|
||||
Monoid combine_monoid(thread const Monoid& a, thread const Monoid& b)
|
||||
{
|
||||
return Monoid{ a.element + b.element };
|
||||
}
|
||||
|
||||
kernel void main0(const device InBuf& _67 [[buffer(0)]], device OutBuf& _372 [[buffer(1)]], volatile device StateBuf& _43 [[buffer(2)]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]])
|
||||
{
|
||||
threadgroup uint sh_part_ix;
|
||||
threadgroup Monoid sh_scratch[512];
|
||||
threadgroup uint sh_flag;
|
||||
threadgroup Monoid sh_prefix;
|
||||
if (gl_LocalInvocationID.x == 0u)
|
||||
{
|
||||
uint _47 = atomic_fetch_add_explicit((volatile device atomic_uint*)&_43.part_counter, 1u, memory_order_relaxed);
|
||||
sh_part_ix = _47;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
uint part_ix = sh_part_ix;
|
||||
uint ix = (part_ix * 8192u) + (gl_LocalInvocationID.x * 16u);
|
||||
spvUnsafeArray<Monoid, 16> local;
|
||||
local[0].element = _67.inbuf[ix].element;
|
||||
Monoid param_1;
|
||||
for (uint i = 1u; i < 16u; i++)
|
||||
{
|
||||
Monoid param = local[i - 1u];
|
||||
param_1.element = _67.inbuf[ix + i].element;
|
||||
local[i] = combine_monoid(param, param_1);
|
||||
}
|
||||
Monoid agg = local[15];
|
||||
sh_scratch[gl_LocalInvocationID.x] = agg;
|
||||
for (uint i_1 = 0u; i_1 < 9u; i_1++)
|
||||
{
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (gl_LocalInvocationID.x >= (1u << i_1))
|
||||
{
|
||||
Monoid other = sh_scratch[gl_LocalInvocationID.x - (1u << i_1)];
|
||||
Monoid param_2 = other;
|
||||
Monoid param_3 = agg;
|
||||
agg = combine_monoid(param_2, param_3);
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
sh_scratch[gl_LocalInvocationID.x] = agg;
|
||||
}
|
||||
if (gl_LocalInvocationID.x == 511u)
|
||||
{
|
||||
_43.state[part_ix].aggregate.element = agg.element;
|
||||
if (part_ix == 0u)
|
||||
{
|
||||
_43.state[0].prefix.element = agg.element;
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_device);
|
||||
if (gl_LocalInvocationID.x == 511u)
|
||||
{
|
||||
uint flag = 1u;
|
||||
if (part_ix == 0u)
|
||||
{
|
||||
flag = 2u;
|
||||
}
|
||||
atomic_store_explicit((volatile device atomic_uint*)&_43.state[part_ix].flag, flag, memory_order_relaxed);
|
||||
}
|
||||
Monoid exclusive = Monoid{ 0u };
|
||||
if (part_ix != 0u)
|
||||
{
|
||||
uint look_back_ix = part_ix - 1u;
|
||||
uint their_ix = 0u;
|
||||
Monoid their_prefix;
|
||||
Monoid their_agg;
|
||||
Monoid m;
|
||||
while (true)
|
||||
{
|
||||
if (gl_LocalInvocationID.x == 511u)
|
||||
{
|
||||
uint _208 = atomic_load_explicit((volatile device atomic_uint*)&_43.state[look_back_ix].flag, memory_order_relaxed);
|
||||
sh_flag = _208;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
threadgroup_barrier(mem_flags::mem_device);
|
||||
uint flag_1 = sh_flag;
|
||||
if (flag_1 == 2u)
|
||||
{
|
||||
if (gl_LocalInvocationID.x == 511u)
|
||||
{
|
||||
their_prefix.element = _43.state[look_back_ix].prefix.element;
|
||||
Monoid param_4 = their_prefix;
|
||||
Monoid param_5 = exclusive;
|
||||
exclusive = combine_monoid(param_4, param_5);
|
||||
}
|
||||
break;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (flag_1 == 1u)
|
||||
{
|
||||
if (gl_LocalInvocationID.x == 511u)
|
||||
{
|
||||
their_agg.element = _43.state[look_back_ix].aggregate.element;
|
||||
Monoid param_6 = their_agg;
|
||||
Monoid param_7 = exclusive;
|
||||
exclusive = combine_monoid(param_6, param_7);
|
||||
}
|
||||
look_back_ix--;
|
||||
their_ix = 0u;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (gl_LocalInvocationID.x == 511u)
|
||||
{
|
||||
m.element = _67.inbuf[(look_back_ix * 8192u) + their_ix].element;
|
||||
if (their_ix == 0u)
|
||||
{
|
||||
their_agg = m;
|
||||
}
|
||||
else
|
||||
{
|
||||
Monoid param_8 = their_agg;
|
||||
Monoid param_9 = m;
|
||||
their_agg = combine_monoid(param_8, param_9);
|
||||
}
|
||||
their_ix++;
|
||||
if (their_ix == 8192u)
|
||||
{
|
||||
Monoid param_10 = their_agg;
|
||||
Monoid param_11 = exclusive;
|
||||
exclusive = combine_monoid(param_10, param_11);
|
||||
if (look_back_ix == 0u)
|
||||
{
|
||||
sh_flag = 2u;
|
||||
}
|
||||
else
|
||||
{
|
||||
look_back_ix--;
|
||||
their_ix = 0u;
|
||||
}
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
flag_1 = sh_flag;
|
||||
if (flag_1 == 2u)
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (gl_LocalInvocationID.x == 511u)
|
||||
{
|
||||
Monoid param_12 = exclusive;
|
||||
Monoid param_13 = agg;
|
||||
Monoid inclusive_prefix = combine_monoid(param_12, param_13);
|
||||
sh_prefix = exclusive;
|
||||
_43.state[part_ix].prefix.element = inclusive_prefix.element;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_device);
|
||||
if (gl_LocalInvocationID.x == 511u)
|
||||
{
|
||||
atomic_store_explicit((volatile device atomic_uint*)&_43.state[part_ix].flag, 2u, memory_order_relaxed);
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (part_ix != 0u)
|
||||
{
|
||||
exclusive = sh_prefix;
|
||||
}
|
||||
Monoid row = exclusive;
|
||||
if (gl_LocalInvocationID.x > 0u)
|
||||
{
|
||||
Monoid other_1 = sh_scratch[gl_LocalInvocationID.x - 1u];
|
||||
Monoid param_14 = row;
|
||||
Monoid param_15 = other_1;
|
||||
row = combine_monoid(param_14, param_15);
|
||||
}
|
||||
for (uint i_2 = 0u; i_2 < 16u; i_2++)
|
||||
{
|
||||
Monoid param_16 = row;
|
||||
Monoid param_17 = local[i_2];
|
||||
Monoid m_1 = combine_monoid(param_16, param_17);
|
||||
_372.outbuf[ix + i_2].element = m_1.element;
|
||||
}
|
||||
}
|
||||
|
BIN
tests/shader/gen/prefix_atomic.spv
Normal file
BIN
tests/shader/gen/prefix_atomic.spv
Normal file
Binary file not shown.
BIN
tests/shader/gen/prefix_vkmm.spv
Normal file
BIN
tests/shader/gen/prefix_vkmm.spv
Normal file
Binary file not shown.
|
@ -1,9 +1,26 @@
|
|||
// SPDX-License-Identifier: Apache-2.0 OR MIT OR Unlicense
|
||||
|
||||
// A prefix sum.
|
||||
//
|
||||
// This test builds in three configurations. The default is a
|
||||
// compatibility mode, essentially plain GLSL. With ATOMIC set, the
|
||||
// flag loads and stores are atomic operations, but uses barriers.
|
||||
// With both ATOMIC and VKMM set, it uses acquire/release semantics
|
||||
// instead of barriers.
|
||||
|
||||
#version 450
|
||||
|
||||
#extension GL_KHR_memory_scope_semantics : enable
|
||||
|
||||
#ifdef VKMM
|
||||
#pragma use_vulkan_memory_model
|
||||
#define ACQUIRE gl_StorageSemanticsBuffer, gl_SemanticsAcquire
|
||||
#define RELEASE gl_StorageSemanticsBuffer, gl_SemanticsRelease
|
||||
#else
|
||||
#define ACQUIRE 0, 0
|
||||
#define RELEASE 0, 0
|
||||
#endif
|
||||
|
||||
#define N_ROWS 16
|
||||
#define LG_WG_SIZE 9
|
||||
#define WG_SIZE (1 << LG_WG_SIZE)
|
||||
|
@ -24,9 +41,9 @@ layout(set = 0, binding = 1) buffer OutBuf {
|
|||
};
|
||||
|
||||
// These correspond to X, A, P respectively in the prefix sum paper.
|
||||
#define FLAG_NOT_READY 0
|
||||
#define FLAG_AGGREGATE_READY 1
|
||||
#define FLAG_PREFIX_READY 2
|
||||
#define FLAG_NOT_READY 0u
|
||||
#define FLAG_AGGREGATE_READY 1u
|
||||
#define FLAG_PREFIX_READY 2u
|
||||
|
||||
struct State {
|
||||
uint flag;
|
||||
|
@ -34,6 +51,7 @@ struct State {
|
|||
Monoid prefix;
|
||||
};
|
||||
|
||||
// Perhaps this should be "nonprivate" with VKMM
|
||||
layout(set = 0, binding = 2) volatile buffer StateBuf {
|
||||
uint part_counter;
|
||||
State[] state;
|
||||
|
@ -87,13 +105,19 @@ void main() {
|
|||
}
|
||||
}
|
||||
// Write flag with release semantics; this is done portably with a barrier.
|
||||
#ifndef VKMM
|
||||
memoryBarrierBuffer();
|
||||
#endif
|
||||
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
|
||||
uint flag = FLAG_AGGREGATE_READY;
|
||||
if (part_ix == 0) {
|
||||
flag = FLAG_PREFIX_READY;
|
||||
}
|
||||
#ifdef ATOMIC
|
||||
atomicStore(state[part_ix].flag, flag, gl_ScopeDevice, RELEASE);
|
||||
#else
|
||||
state[part_ix].flag = flag;
|
||||
#endif
|
||||
}
|
||||
|
||||
Monoid exclusive = Monoid(0);
|
||||
|
@ -106,13 +130,19 @@ void main() {
|
|||
while (true) {
|
||||
// Read flag with acquire semantics.
|
||||
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
|
||||
#ifdef ATOMIC
|
||||
sh_flag = atomicLoad(state[look_back_ix].flag, gl_ScopeDevice, ACQUIRE);
|
||||
#else
|
||||
sh_flag = state[look_back_ix].flag;
|
||||
#endif
|
||||
}
|
||||
// The flag load is done only in the last thread. However, because the
|
||||
// translation of memoryBarrierBuffer to Metal requires uniform control
|
||||
// flow, we broadcast it to all threads.
|
||||
barrier();
|
||||
#ifndef VKMM
|
||||
memoryBarrierBuffer();
|
||||
#endif
|
||||
uint flag = sh_flag;
|
||||
|
||||
if (flag == FLAG_PREFIX_READY) {
|
||||
|
@ -165,9 +195,15 @@ void main() {
|
|||
sh_prefix = exclusive;
|
||||
state[part_ix].prefix = inclusive_prefix;
|
||||
}
|
||||
#ifndef VKMM
|
||||
memoryBarrierBuffer();
|
||||
#endif
|
||||
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
|
||||
#ifdef ATOMIC
|
||||
atomicStore(state[part_ix].flag, FLAG_PREFIX_READY, gl_ScopeDevice, RELEASE);
|
||||
#else
|
||||
state[part_ix].flag = FLAG_PREFIX_READY;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
barrier();
|
||||
|
|
|
@ -86,7 +86,23 @@ fn main() {
|
|||
}
|
||||
report(&clear::run_clear_test(&mut runner, &config));
|
||||
if config.groups.matches("prefix") {
|
||||
report(&prefix::run_prefix_test(&mut runner, &config));
|
||||
report(&prefix::run_prefix_test(
|
||||
&mut runner,
|
||||
&config,
|
||||
prefix::Variant::Compatibility,
|
||||
));
|
||||
report(&prefix::run_prefix_test(
|
||||
&mut runner,
|
||||
&config,
|
||||
prefix::Variant::Atomic,
|
||||
));
|
||||
if runner.session.gpu_info().has_memory_model {
|
||||
report(&prefix::run_prefix_test(
|
||||
&mut runner,
|
||||
&config,
|
||||
prefix::Variant::Vkmm,
|
||||
));
|
||||
}
|
||||
report(&prefix_tree::run_prefix_test(&mut runner, &config));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
//
|
||||
// Also licensed under MIT license, at your choice.
|
||||
|
||||
use piet_gpu_hal::{include_shader, BackendType, BindType, BufferUsage, DescriptorSet};
|
||||
use piet_gpu_hal::{include_shader, BackendType, BindType, BufferUsage, DescriptorSet, ShaderCode};
|
||||
use piet_gpu_hal::{Buffer, Pipeline};
|
||||
|
||||
use crate::clear::{ClearBinding, ClearCode, ClearStage};
|
||||
|
@ -51,8 +51,19 @@ struct PrefixBinding {
|
|||
descriptor_set: DescriptorSet,
|
||||
}
|
||||
|
||||
pub unsafe fn run_prefix_test(runner: &mut Runner, config: &Config) -> TestResult {
|
||||
let mut result = TestResult::new("prefix sum, decoupled look-back");
|
||||
#[derive(Debug)]
|
||||
pub enum Variant {
|
||||
Compatibility,
|
||||
Atomic,
|
||||
Vkmm,
|
||||
}
|
||||
|
||||
pub unsafe fn run_prefix_test(
|
||||
runner: &mut Runner,
|
||||
config: &Config,
|
||||
variant: Variant,
|
||||
) -> TestResult {
|
||||
let mut result = TestResult::new(format!("prefix sum, decoupled look-back, {:?}", variant));
|
||||
/*
|
||||
// We're good if we're using DXC.
|
||||
if runner.backend_type() == BackendType::Dx12 {
|
||||
|
@ -67,7 +78,7 @@ pub unsafe fn run_prefix_test(runner: &mut Runner, config: &Config) -> TestResul
|
|||
.create_buffer_init(&data, BufferUsage::STORAGE)
|
||||
.unwrap();
|
||||
let out_buf = runner.buf_down(data_buf.size());
|
||||
let code = PrefixCode::new(runner);
|
||||
let code = PrefixCode::new(runner, variant);
|
||||
let stage = PrefixStage::new(runner, &code, n_elements);
|
||||
let binding = stage.bind(runner, &code, &data_buf, &out_buf.dev_buf);
|
||||
let n_iter = config.n_iter;
|
||||
|
@ -95,8 +106,12 @@ pub unsafe fn run_prefix_test(runner: &mut Runner, config: &Config) -> TestResul
|
|||
}
|
||||
|
||||
impl PrefixCode {
|
||||
unsafe fn new(runner: &mut Runner) -> PrefixCode {
|
||||
let code = include_shader!(&runner.session, "../shader/gen/prefix");
|
||||
unsafe fn new(runner: &mut Runner, variant: Variant) -> PrefixCode {
|
||||
let code = match variant {
|
||||
Variant::Compatibility => include_shader!(&runner.session, "../shader/gen/prefix"),
|
||||
Variant::Atomic => include_shader!(&runner.session, "../shader/gen/prefix_atomic"),
|
||||
Variant::Vkmm => ShaderCode::Spv(include_bytes!("../shader/gen/prefix_vkmm.spv")),
|
||||
};
|
||||
let pipeline = runner
|
||||
.session
|
||||
.create_compute_pipeline(
|
||||
|
|
|
@ -38,9 +38,9 @@ pub enum ReportStyle {
|
|||
}
|
||||
|
||||
impl TestResult {
|
||||
pub fn new(name: &str) -> TestResult {
|
||||
pub fn new(name: impl Into<String>) -> TestResult {
|
||||
TestResult {
|
||||
name: name.to_string(),
|
||||
name: name.into(),
|
||||
total_time: 0.0,
|
||||
n_elements: 0,
|
||||
status: Status::Pass,
|
||||
|
|
Loading…
Reference in a new issue