Add atomic versions of prefix

This adds both regular and Vulkan memory model atomic versions of the
prefix sum test, compiled by #ifdef. The build chain is getting messy,
but I think it's important to test this stuff.
This commit is contained in:
Raph Levien 2021-11-11 11:47:46 -08:00
parent 3f1bbe4af1
commit 825a1eb04c
11 changed files with 579 additions and 12 deletions

View file

@ -32,6 +32,16 @@ build gen/prefix.hlsl: hlsl gen/prefix.spv
build gen/prefix.dxil: dxil gen/prefix.hlsl
build gen/prefix.msl: msl gen/prefix.spv
build gen/prefix_atomic.spv: glsl prefix.comp
flags = -DATOMIC
build gen/prefix_atomic.hlsl: hlsl gen/prefix_atomic.spv
build gen/prefix_atomic.dxil: dxil gen/prefix_atomic.hlsl
build gen/prefix_atomic.msl: msl gen/prefix_atomic.spv
build gen/prefix_vkmm.spv: glsl prefix.comp
flags = -DATOMIC -DVKMM
# Vulkan memory model doesn't translate
build gen/prefix_reduce.spv: glsl prefix_reduce.comp
build gen/prefix_reduce.hlsl: hlsl gen/prefix_reduce.spv
build gen/prefix_reduce.dxil: dxil gen/prefix_reduce.hlsl

Binary file not shown.

Binary file not shown.

View file

@ -0,0 +1,227 @@
struct Monoid
{
uint element;
};
struct State
{
uint flag;
Monoid aggregate;
Monoid prefix;
};
static const uint3 gl_WorkGroupSize = uint3(512u, 1u, 1u);
static const Monoid _185 = { 0u };
globallycoherent RWByteAddressBuffer _43 : register(u2);
ByteAddressBuffer _67 : register(t0);
RWByteAddressBuffer _372 : register(u1);
static uint3 gl_LocalInvocationID;
struct SPIRV_Cross_Input
{
uint3 gl_LocalInvocationID : SV_GroupThreadID;
};
groupshared uint sh_part_ix;
groupshared Monoid sh_scratch[512];
groupshared uint sh_flag;
groupshared Monoid sh_prefix;
Monoid combine_monoid(Monoid a, Monoid b)
{
Monoid _22 = { a.element + b.element };
return _22;
}
void comp_main()
{
if (gl_LocalInvocationID.x == 0u)
{
uint _47;
_43.InterlockedAdd(0, 1u, _47);
sh_part_ix = _47;
}
GroupMemoryBarrierWithGroupSync();
uint part_ix = sh_part_ix;
uint ix = (part_ix * 8192u) + (gl_LocalInvocationID.x * 16u);
Monoid _71;
_71.element = _67.Load(ix * 4 + 0);
Monoid local[16];
local[0].element = _71.element;
Monoid param_1;
for (uint i = 1u; i < 16u; i++)
{
Monoid param = local[i - 1u];
Monoid _94;
_94.element = _67.Load((ix + i) * 4 + 0);
param_1.element = _94.element;
local[i] = combine_monoid(param, param_1);
}
Monoid agg = local[15];
sh_scratch[gl_LocalInvocationID.x] = agg;
for (uint i_1 = 0u; i_1 < 9u; i_1++)
{
GroupMemoryBarrierWithGroupSync();
if (gl_LocalInvocationID.x >= (1u << i_1))
{
Monoid other = sh_scratch[gl_LocalInvocationID.x - (1u << i_1)];
Monoid param_2 = other;
Monoid param_3 = agg;
agg = combine_monoid(param_2, param_3);
}
GroupMemoryBarrierWithGroupSync();
sh_scratch[gl_LocalInvocationID.x] = agg;
}
if (gl_LocalInvocationID.x == 511u)
{
_43.Store(part_ix * 12 + 8, agg.element);
if (part_ix == 0u)
{
_43.Store(12, agg.element);
}
}
DeviceMemoryBarrier();
if (gl_LocalInvocationID.x == 511u)
{
uint flag = 1u;
if (part_ix == 0u)
{
flag = 2u;
}
uint _383;
_43.InterlockedExchange(part_ix * 12 + 4, flag, _383);
}
Monoid exclusive = _185;
if (part_ix != 0u)
{
uint look_back_ix = part_ix - 1u;
uint their_ix = 0u;
Monoid their_prefix;
Monoid their_agg;
Monoid m;
while (true)
{
if (gl_LocalInvocationID.x == 511u)
{
uint _208;
_43.InterlockedAdd(look_back_ix * 12 + 4, 0, _208);
sh_flag = _208;
}
GroupMemoryBarrierWithGroupSync();
DeviceMemoryBarrier();
uint flag_1 = sh_flag;
if (flag_1 == 2u)
{
if (gl_LocalInvocationID.x == 511u)
{
Monoid _223;
_223.element = _43.Load(look_back_ix * 12 + 12);
their_prefix.element = _223.element;
Monoid param_4 = their_prefix;
Monoid param_5 = exclusive;
exclusive = combine_monoid(param_4, param_5);
}
break;
}
else
{
if (flag_1 == 1u)
{
if (gl_LocalInvocationID.x == 511u)
{
Monoid _245;
_245.element = _43.Load(look_back_ix * 12 + 8);
their_agg.element = _245.element;
Monoid param_6 = their_agg;
Monoid param_7 = exclusive;
exclusive = combine_monoid(param_6, param_7);
}
look_back_ix--;
their_ix = 0u;
continue;
}
}
if (gl_LocalInvocationID.x == 511u)
{
Monoid _267;
_267.element = _67.Load(((look_back_ix * 8192u) + their_ix) * 4 + 0);
m.element = _267.element;
if (their_ix == 0u)
{
their_agg = m;
}
else
{
Monoid param_8 = their_agg;
Monoid param_9 = m;
their_agg = combine_monoid(param_8, param_9);
}
their_ix++;
if (their_ix == 8192u)
{
Monoid param_10 = their_agg;
Monoid param_11 = exclusive;
exclusive = combine_monoid(param_10, param_11);
if (look_back_ix == 0u)
{
sh_flag = 2u;
}
else
{
look_back_ix--;
their_ix = 0u;
}
}
}
GroupMemoryBarrierWithGroupSync();
flag_1 = sh_flag;
if (flag_1 == 2u)
{
break;
}
}
if (gl_LocalInvocationID.x == 511u)
{
Monoid param_12 = exclusive;
Monoid param_13 = agg;
Monoid inclusive_prefix = combine_monoid(param_12, param_13);
sh_prefix = exclusive;
_43.Store(part_ix * 12 + 12, inclusive_prefix.element);
}
DeviceMemoryBarrier();
if (gl_LocalInvocationID.x == 511u)
{
uint _384;
_43.InterlockedExchange(part_ix * 12 + 4, 2u, _384);
}
}
GroupMemoryBarrierWithGroupSync();
if (part_ix != 0u)
{
exclusive = sh_prefix;
}
Monoid row = exclusive;
if (gl_LocalInvocationID.x > 0u)
{
Monoid other_1 = sh_scratch[gl_LocalInvocationID.x - 1u];
Monoid param_14 = row;
Monoid param_15 = other_1;
row = combine_monoid(param_14, param_15);
}
for (uint i_2 = 0u; i_2 < 16u; i_2++)
{
Monoid param_16 = row;
Monoid param_17 = local[i_2];
Monoid m_1 = combine_monoid(param_16, param_17);
_372.Store((ix + i_2) * 4 + 0, m_1.element);
}
}
[numthreads(512, 1, 1)]
void main(SPIRV_Cross_Input stage_input)
{
gl_LocalInvocationID = stage_input.gl_LocalInvocationID;
comp_main();
}

View file

@ -0,0 +1,263 @@
#pragma clang diagnostic ignored "-Wmissing-prototypes"
#pragma clang diagnostic ignored "-Wmissing-braces"
#pragma clang diagnostic ignored "-Wunused-variable"
#include <metal_stdlib>
#include <simd/simd.h>
#include <metal_atomic>
using namespace metal;
template<typename T, size_t Num>
struct spvUnsafeArray
{
T elements[Num ? Num : 1];
thread T& operator [] (size_t pos) thread
{
return elements[pos];
}
constexpr const thread T& operator [] (size_t pos) const thread
{
return elements[pos];
}
device T& operator [] (size_t pos) device
{
return elements[pos];
}
constexpr const device T& operator [] (size_t pos) const device
{
return elements[pos];
}
constexpr const constant T& operator [] (size_t pos) const constant
{
return elements[pos];
}
threadgroup T& operator [] (size_t pos) threadgroup
{
return elements[pos];
}
constexpr const threadgroup T& operator [] (size_t pos) const threadgroup
{
return elements[pos];
}
};
struct Monoid
{
uint element;
};
struct Monoid_1
{
uint element;
};
struct State
{
uint flag;
Monoid_1 aggregate;
Monoid_1 prefix;
};
struct StateBuf
{
uint part_counter;
State state[1];
};
struct InBuf
{
Monoid_1 inbuf[1];
};
struct OutBuf
{
Monoid_1 outbuf[1];
};
constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u);
static inline __attribute__((always_inline))
Monoid combine_monoid(thread const Monoid& a, thread const Monoid& b)
{
return Monoid{ a.element + b.element };
}
kernel void main0(const device InBuf& _67 [[buffer(0)]], device OutBuf& _372 [[buffer(1)]], volatile device StateBuf& _43 [[buffer(2)]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]])
{
threadgroup uint sh_part_ix;
threadgroup Monoid sh_scratch[512];
threadgroup uint sh_flag;
threadgroup Monoid sh_prefix;
if (gl_LocalInvocationID.x == 0u)
{
uint _47 = atomic_fetch_add_explicit((volatile device atomic_uint*)&_43.part_counter, 1u, memory_order_relaxed);
sh_part_ix = _47;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
uint part_ix = sh_part_ix;
uint ix = (part_ix * 8192u) + (gl_LocalInvocationID.x * 16u);
spvUnsafeArray<Monoid, 16> local;
local[0].element = _67.inbuf[ix].element;
Monoid param_1;
for (uint i = 1u; i < 16u; i++)
{
Monoid param = local[i - 1u];
param_1.element = _67.inbuf[ix + i].element;
local[i] = combine_monoid(param, param_1);
}
Monoid agg = local[15];
sh_scratch[gl_LocalInvocationID.x] = agg;
for (uint i_1 = 0u; i_1 < 9u; i_1++)
{
threadgroup_barrier(mem_flags::mem_threadgroup);
if (gl_LocalInvocationID.x >= (1u << i_1))
{
Monoid other = sh_scratch[gl_LocalInvocationID.x - (1u << i_1)];
Monoid param_2 = other;
Monoid param_3 = agg;
agg = combine_monoid(param_2, param_3);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sh_scratch[gl_LocalInvocationID.x] = agg;
}
if (gl_LocalInvocationID.x == 511u)
{
_43.state[part_ix].aggregate.element = agg.element;
if (part_ix == 0u)
{
_43.state[0].prefix.element = agg.element;
}
}
threadgroup_barrier(mem_flags::mem_device);
if (gl_LocalInvocationID.x == 511u)
{
uint flag = 1u;
if (part_ix == 0u)
{
flag = 2u;
}
atomic_store_explicit((volatile device atomic_uint*)&_43.state[part_ix].flag, flag, memory_order_relaxed);
}
Monoid exclusive = Monoid{ 0u };
if (part_ix != 0u)
{
uint look_back_ix = part_ix - 1u;
uint their_ix = 0u;
Monoid their_prefix;
Monoid their_agg;
Monoid m;
while (true)
{
if (gl_LocalInvocationID.x == 511u)
{
uint _208 = atomic_load_explicit((volatile device atomic_uint*)&_43.state[look_back_ix].flag, memory_order_relaxed);
sh_flag = _208;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup_barrier(mem_flags::mem_device);
uint flag_1 = sh_flag;
if (flag_1 == 2u)
{
if (gl_LocalInvocationID.x == 511u)
{
their_prefix.element = _43.state[look_back_ix].prefix.element;
Monoid param_4 = their_prefix;
Monoid param_5 = exclusive;
exclusive = combine_monoid(param_4, param_5);
}
break;
}
else
{
if (flag_1 == 1u)
{
if (gl_LocalInvocationID.x == 511u)
{
their_agg.element = _43.state[look_back_ix].aggregate.element;
Monoid param_6 = their_agg;
Monoid param_7 = exclusive;
exclusive = combine_monoid(param_6, param_7);
}
look_back_ix--;
their_ix = 0u;
continue;
}
}
if (gl_LocalInvocationID.x == 511u)
{
m.element = _67.inbuf[(look_back_ix * 8192u) + their_ix].element;
if (their_ix == 0u)
{
their_agg = m;
}
else
{
Monoid param_8 = their_agg;
Monoid param_9 = m;
their_agg = combine_monoid(param_8, param_9);
}
their_ix++;
if (their_ix == 8192u)
{
Monoid param_10 = their_agg;
Monoid param_11 = exclusive;
exclusive = combine_monoid(param_10, param_11);
if (look_back_ix == 0u)
{
sh_flag = 2u;
}
else
{
look_back_ix--;
their_ix = 0u;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
flag_1 = sh_flag;
if (flag_1 == 2u)
{
break;
}
}
if (gl_LocalInvocationID.x == 511u)
{
Monoid param_12 = exclusive;
Monoid param_13 = agg;
Monoid inclusive_prefix = combine_monoid(param_12, param_13);
sh_prefix = exclusive;
_43.state[part_ix].prefix.element = inclusive_prefix.element;
}
threadgroup_barrier(mem_flags::mem_device);
if (gl_LocalInvocationID.x == 511u)
{
atomic_store_explicit((volatile device atomic_uint*)&_43.state[part_ix].flag, 2u, memory_order_relaxed);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (part_ix != 0u)
{
exclusive = sh_prefix;
}
Monoid row = exclusive;
if (gl_LocalInvocationID.x > 0u)
{
Monoid other_1 = sh_scratch[gl_LocalInvocationID.x - 1u];
Monoid param_14 = row;
Monoid param_15 = other_1;
row = combine_monoid(param_14, param_15);
}
for (uint i_2 = 0u; i_2 < 16u; i_2++)
{
Monoid param_16 = row;
Monoid param_17 = local[i_2];
Monoid m_1 = combine_monoid(param_16, param_17);
_372.outbuf[ix + i_2].element = m_1.element;
}
}

Binary file not shown.

Binary file not shown.

View file

@ -1,9 +1,26 @@
// SPDX-License-Identifier: Apache-2.0 OR MIT OR Unlicense
// A prefix sum.
//
// This test builds in three configurations. The default is a
// compatibility mode, essentially plain GLSL. With ATOMIC set, the
// flag loads and stores are atomic operations, but uses barriers.
// With both ATOMIC and VKMM set, it uses acquire/release semantics
// instead of barriers.
#version 450
#extension GL_KHR_memory_scope_semantics : enable
#ifdef VKMM
#pragma use_vulkan_memory_model
#define ACQUIRE gl_StorageSemanticsBuffer, gl_SemanticsAcquire
#define RELEASE gl_StorageSemanticsBuffer, gl_SemanticsRelease
#else
#define ACQUIRE 0, 0
#define RELEASE 0, 0
#endif
#define N_ROWS 16
#define LG_WG_SIZE 9
#define WG_SIZE (1 << LG_WG_SIZE)
@ -24,9 +41,9 @@ layout(set = 0, binding = 1) buffer OutBuf {
};
// These correspond to X, A, P respectively in the prefix sum paper.
#define FLAG_NOT_READY 0
#define FLAG_AGGREGATE_READY 1
#define FLAG_PREFIX_READY 2
#define FLAG_NOT_READY 0u
#define FLAG_AGGREGATE_READY 1u
#define FLAG_PREFIX_READY 2u
struct State {
uint flag;
@ -34,6 +51,7 @@ struct State {
Monoid prefix;
};
// Perhaps this should be "nonprivate" with VKMM
layout(set = 0, binding = 2) volatile buffer StateBuf {
uint part_counter;
State[] state;
@ -87,13 +105,19 @@ void main() {
}
}
// Write flag with release semantics; this is done portably with a barrier.
#ifndef VKMM
memoryBarrierBuffer();
#endif
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
uint flag = FLAG_AGGREGATE_READY;
if (part_ix == 0) {
flag = FLAG_PREFIX_READY;
}
#ifdef ATOMIC
atomicStore(state[part_ix].flag, flag, gl_ScopeDevice, RELEASE);
#else
state[part_ix].flag = flag;
#endif
}
Monoid exclusive = Monoid(0);
@ -106,13 +130,19 @@ void main() {
while (true) {
// Read flag with acquire semantics.
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
#ifdef ATOMIC
sh_flag = atomicLoad(state[look_back_ix].flag, gl_ScopeDevice, ACQUIRE);
#else
sh_flag = state[look_back_ix].flag;
#endif
}
// The flag load is done only in the last thread. However, because the
// translation of memoryBarrierBuffer to Metal requires uniform control
// flow, we broadcast it to all threads.
barrier();
#ifndef VKMM
memoryBarrierBuffer();
#endif
uint flag = sh_flag;
if (flag == FLAG_PREFIX_READY) {
@ -165,9 +195,15 @@ void main() {
sh_prefix = exclusive;
state[part_ix].prefix = inclusive_prefix;
}
#ifndef VKMM
memoryBarrierBuffer();
#endif
if (gl_LocalInvocationID.x == WG_SIZE - 1) {
#ifdef ATOMIC
atomicStore(state[part_ix].flag, FLAG_PREFIX_READY, gl_ScopeDevice, RELEASE);
#else
state[part_ix].flag = FLAG_PREFIX_READY;
#endif
}
}
barrier();

View file

@ -86,7 +86,23 @@ fn main() {
}
report(&clear::run_clear_test(&mut runner, &config));
if config.groups.matches("prefix") {
report(&prefix::run_prefix_test(&mut runner, &config));
report(&prefix::run_prefix_test(
&mut runner,
&config,
prefix::Variant::Compatibility,
));
report(&prefix::run_prefix_test(
&mut runner,
&config,
prefix::Variant::Atomic,
));
if runner.session.gpu_info().has_memory_model {
report(&prefix::run_prefix_test(
&mut runner,
&config,
prefix::Variant::Vkmm,
));
}
report(&prefix_tree::run_prefix_test(&mut runner, &config));
}
}

View file

@ -14,7 +14,7 @@
//
// Also licensed under MIT license, at your choice.
use piet_gpu_hal::{include_shader, BackendType, BindType, BufferUsage, DescriptorSet};
use piet_gpu_hal::{include_shader, BackendType, BindType, BufferUsage, DescriptorSet, ShaderCode};
use piet_gpu_hal::{Buffer, Pipeline};
use crate::clear::{ClearBinding, ClearCode, ClearStage};
@ -51,8 +51,19 @@ struct PrefixBinding {
descriptor_set: DescriptorSet,
}
pub unsafe fn run_prefix_test(runner: &mut Runner, config: &Config) -> TestResult {
let mut result = TestResult::new("prefix sum, decoupled look-back");
#[derive(Debug)]
pub enum Variant {
Compatibility,
Atomic,
Vkmm,
}
pub unsafe fn run_prefix_test(
runner: &mut Runner,
config: &Config,
variant: Variant,
) -> TestResult {
let mut result = TestResult::new(format!("prefix sum, decoupled look-back, {:?}", variant));
/*
// We're good if we're using DXC.
if runner.backend_type() == BackendType::Dx12 {
@ -67,7 +78,7 @@ pub unsafe fn run_prefix_test(runner: &mut Runner, config: &Config) -> TestResul
.create_buffer_init(&data, BufferUsage::STORAGE)
.unwrap();
let out_buf = runner.buf_down(data_buf.size());
let code = PrefixCode::new(runner);
let code = PrefixCode::new(runner, variant);
let stage = PrefixStage::new(runner, &code, n_elements);
let binding = stage.bind(runner, &code, &data_buf, &out_buf.dev_buf);
let n_iter = config.n_iter;
@ -95,8 +106,12 @@ pub unsafe fn run_prefix_test(runner: &mut Runner, config: &Config) -> TestResul
}
impl PrefixCode {
unsafe fn new(runner: &mut Runner) -> PrefixCode {
let code = include_shader!(&runner.session, "../shader/gen/prefix");
unsafe fn new(runner: &mut Runner, variant: Variant) -> PrefixCode {
let code = match variant {
Variant::Compatibility => include_shader!(&runner.session, "../shader/gen/prefix"),
Variant::Atomic => include_shader!(&runner.session, "../shader/gen/prefix_atomic"),
Variant::Vkmm => ShaderCode::Spv(include_bytes!("../shader/gen/prefix_vkmm.spv")),
};
let pipeline = runner
.session
.create_compute_pipeline(

View file

@ -38,9 +38,9 @@ pub enum ReportStyle {
}
impl TestResult {
pub fn new(name: &str) -> TestResult {
pub fn new(name: impl Into<String>) -> TestResult {
TestResult {
name: name.to_string(),
name: name.into(),
total_time: 0.0,
n_elements: 0,
status: Status::Pass,