From 825a1eb04c7582247ffad6270e8bc30cfd3c4fc6 Mon Sep 17 00:00:00 2001 From: Raph Levien Date: Thu, 11 Nov 2021 11:47:46 -0800 Subject: [PATCH] Add atomic versions of prefix This adds both regular and Vulkan memory model atomic versions of the prefix sum test, compiled by #ifdef. The build chain is getting messy, but I think it's important to test this stuff. --- tests/shader/build.ninja | 10 ++ tests/shader/gen/prefix.spv | Bin 9760 -> 9796 bytes tests/shader/gen/prefix_atomic.dxil | Bin 0 -> 4864 bytes tests/shader/gen/prefix_atomic.hlsl | 227 ++++++++++++++++++++++++ tests/shader/gen/prefix_atomic.msl | 263 ++++++++++++++++++++++++++++ tests/shader/gen/prefix_atomic.spv | Bin 0 -> 9820 bytes tests/shader/gen/prefix_vkmm.spv | Bin 0 -> 9984 bytes tests/shader/prefix.comp | 42 ++++- tests/src/main.rs | 18 +- tests/src/prefix.rs | 27 ++- tests/src/test_result.rs | 4 +- 11 files changed, 579 insertions(+), 12 deletions(-) create mode 100644 tests/shader/gen/prefix_atomic.dxil create mode 100644 tests/shader/gen/prefix_atomic.hlsl create mode 100644 tests/shader/gen/prefix_atomic.msl create mode 100644 tests/shader/gen/prefix_atomic.spv create mode 100644 tests/shader/gen/prefix_vkmm.spv diff --git a/tests/shader/build.ninja b/tests/shader/build.ninja index 19297c9..8a25473 100644 --- a/tests/shader/build.ninja +++ b/tests/shader/build.ninja @@ -32,6 +32,16 @@ build gen/prefix.hlsl: hlsl gen/prefix.spv build gen/prefix.dxil: dxil gen/prefix.hlsl build gen/prefix.msl: msl gen/prefix.spv +build gen/prefix_atomic.spv: glsl prefix.comp + flags = -DATOMIC +build gen/prefix_atomic.hlsl: hlsl gen/prefix_atomic.spv +build gen/prefix_atomic.dxil: dxil gen/prefix_atomic.hlsl +build gen/prefix_atomic.msl: msl gen/prefix_atomic.spv + +build gen/prefix_vkmm.spv: glsl prefix.comp + flags = -DATOMIC -DVKMM +# Vulkan memory model doesn't translate + build gen/prefix_reduce.spv: glsl prefix_reduce.comp build gen/prefix_reduce.hlsl: hlsl gen/prefix_reduce.spv build gen/prefix_reduce.dxil: dxil gen/prefix_reduce.hlsl diff --git a/tests/shader/gen/prefix.spv b/tests/shader/gen/prefix.spv index b934189c036bc18f0cdd7e5e65ea5f7aefeccf91..d2c1aadd6ee5592e1f023414f72843340b903e73 100644 GIT binary patch delta 45 zcmZ4BbHrytjS>q3Cxg3BythYCd~RxPeoqmd_8X9eE(AOXYk?5mAJ_L8VX~Y(BYy}ODxOvT2AOf~D8c;KN z{Xv7KHLXE|kjXU0$@V0(ZX2BljN^9G1QEgDM|O!iE+j~d$7C5NZggfgv-J=>?#w@X z&h8(3&dfbj)va5%Zr!?dU;V1dDazJG_8i@JZ$W$bj*{_}>ptCi5&@9H01&_?3zmFX zieagS1&4!Ea6n7~P>{Qtfz4A|U9c)?w$Hcg{;Y2(*u=P?oOwFqzcYV=6!feJ1%UU3 zFBmKeSa4XLCDBZW_mw9&WxD*ieAdpa{8>8%h1$Xi0ss~Sfo`anQ?xptUzCEjhM+uK z=ew;E1K@=+k(Hjd55oflmY74@0em$n+Z4+nxRRqFiD<+#K(l(mBS5L7+iB68jX- zVss7T;F$%0aHT5xcs7K?3jtUS=kRd}v$ENAKIsTX5U`I1;CU!hN9%*>vL^h^=o5BW zePQo($wTw^J4N;r`BZ}$XPFit4%P!gW;fz)+@2Q!K>Z|1<$Wi9>j%aV)nnC&SY$yr zQeY0WR#Hv_S4RxF|C9y=8_rdA{V{kCdGW>Is@S*%-8Ik^5dSeQ0>@XXSU%AE?rVI; ztMc`|6Wd2~Dg%`4@R{{1wK?&ib={lqwS>mK)|6Hqea6^AuD;+>S@x78{=s(`-LbT6ltDzTtP5# z`FvB)Z&O6b@?Fm0xK;hQ4owEJ3~CUUug9AV92ulAaCGeYgrKE{%AXmi>moZVaRS#n z2>o}MrUa%cXoNXohwVHp-<6Aoy4DKPl@3f5N48%Wzcg841zF0_?9F5e zOR~C=e5&h>z}(>>9O2?fDoxAvNJPKYz{D{9M969WaZfG__VUQNYs!^VpN zqh(qoK8I!GjKu$uQK3GVv!$s0Cf<27j?|=R!3l{@Ar27MM?Q@>RenssmiprJf4C&^ z-(UZ_Ceiuzn#4Cw@;>F;dnYD(ADkb(-QPQUeso~+F4=VPV$+>hul!{Am#^--bLGhJ zS_ETIfcl(yL}{ z$pA~yW6OYT8NgC2v9ty(y#>o?!Q8;`7zBP&{T>Gps#n-erP2+VCe#FHT8YfAbpCruV-&+qj|UG9vm zV1lOAY|0^Q#ei;w6I(IV;H)1R&Qz=NRY_5Yf*KBe% zn^?iFc)_kmLD1|gT=R=$^A1IGU0ZWwnvmw+$-f7u!DYrS>a>&kImoy>yY%P7T;!d5m-N#w(7NU zW5$1W^kbIOWnSWpZ~MXTd+(H+=Y`;x-yc5vho8h@7VX?$;641nHtSL8z6aY%@ANifokU6fXCT#hk4X zerkGr`3phedPxa#KjqDxtyc z9xpy;4C|EWa`abo!-Hg4M`~B+Tb)P3sf13q zTcTCuQ$1AyK$I|q((#74pA8s@a)#JAI(q*`@zsIBTX)~Xwc+ZlxxZ3(cb`c+y9l{r z7k~2c66f&A{MtPBx+-8krW3_1eskOIn8l?Xp>+AcRPn^%AHAzE`(g#0CqwzWhf zX)Ugj>k*V_-L7eD{Sy)kpRiC3#1O&2%2Efk(R+Woti%Oql8{Bj2QF-#893rsbBKbC zAqc2uI)Q`|$+G!AXRpb|5G-?vBK<9Xl%o~QHYU7GLsYOZ;j1+RS_YiPds{3B97^A2 z+PqNU^MD?8|5zUBMw7c_7!I_I`NJ|f1#xuxVM z20s#h#(zKo-{@9kqrgZ7nVEJ9ge$zi)*i=7M>*1tcmq=yY0piHGvBd7@8ZxM3z@lrSH6+|!36^jAx_sN3NDsA^{V-95(C$5N)#&~|%5-Wax zN`M6UJY(EuD^#b%9z_xZC>(n({2nKDNXS?P(CHB(Akju~{zPm89*~r|`B1nLwN2Xj z`k`~>P$B=P|%5WPA@?4^8Y zqWCmEY_iQe$1mZsgk7IEg$#ve!59{$Z|5Y1pmO{il4?K{mz69Q4t@lM!)lGJ5KYhr zeFz;EL5F>htc%9{&*R-L(OGzgE**b{S7w|!?-u@?DPeiX4=(j&KTly)b{?tO)g%!$#Hi0Yn4Ot?^OQ3*_g<6cunNdm4c4E20tD+zucV-$NWq z>)!+0X2-DT<#@cM zPm7H-vBzLp;PWD4>e+5HgbyM_2o>3_LceHD`5eU&c^qi(eDe7h+@=u(OfR zgKwW9wG)ZJABw)C!|nJ?9$X83Xp(+Qx$!(_Um_9#*Beo2O2(pxXHIWPXUJpb79)+9 z5h-E5wpx-f9`%a{K~!GUH^@=u7QP0TgLhH1B6x^}Nk6Dj**05~6&~R06SpUi-ky>s z$qc|LB;&17ZuD4}%tpZP*IQSza7yLdhdgIOpZnJD?BHOkf>5|^j(CY;HSpgv(|6Gp oVqv`P`9Nr^|EN(o$X+stv=P-2bt)*tR^#Q~^2&W~UU!!M3ki@d-~a#s literal 0 HcmV?d00001 diff --git a/tests/shader/gen/prefix_atomic.hlsl b/tests/shader/gen/prefix_atomic.hlsl new file mode 100644 index 0000000..10f7081 --- /dev/null +++ b/tests/shader/gen/prefix_atomic.hlsl @@ -0,0 +1,227 @@ +struct Monoid +{ + uint element; +}; + +struct State +{ + uint flag; + Monoid aggregate; + Monoid prefix; +}; + +static const uint3 gl_WorkGroupSize = uint3(512u, 1u, 1u); + +static const Monoid _185 = { 0u }; + +globallycoherent RWByteAddressBuffer _43 : register(u2); +ByteAddressBuffer _67 : register(t0); +RWByteAddressBuffer _372 : register(u1); + +static uint3 gl_LocalInvocationID; +struct SPIRV_Cross_Input +{ + uint3 gl_LocalInvocationID : SV_GroupThreadID; +}; + +groupshared uint sh_part_ix; +groupshared Monoid sh_scratch[512]; +groupshared uint sh_flag; +groupshared Monoid sh_prefix; + +Monoid combine_monoid(Monoid a, Monoid b) +{ + Monoid _22 = { a.element + b.element }; + return _22; +} + +void comp_main() +{ + if (gl_LocalInvocationID.x == 0u) + { + uint _47; + _43.InterlockedAdd(0, 1u, _47); + sh_part_ix = _47; + } + GroupMemoryBarrierWithGroupSync(); + uint part_ix = sh_part_ix; + uint ix = (part_ix * 8192u) + (gl_LocalInvocationID.x * 16u); + Monoid _71; + _71.element = _67.Load(ix * 4 + 0); + Monoid local[16]; + local[0].element = _71.element; + Monoid param_1; + for (uint i = 1u; i < 16u; i++) + { + Monoid param = local[i - 1u]; + Monoid _94; + _94.element = _67.Load((ix + i) * 4 + 0); + param_1.element = _94.element; + local[i] = combine_monoid(param, param_1); + } + Monoid agg = local[15]; + sh_scratch[gl_LocalInvocationID.x] = agg; + for (uint i_1 = 0u; i_1 < 9u; i_1++) + { + GroupMemoryBarrierWithGroupSync(); + if (gl_LocalInvocationID.x >= (1u << i_1)) + { + Monoid other = sh_scratch[gl_LocalInvocationID.x - (1u << i_1)]; + Monoid param_2 = other; + Monoid param_3 = agg; + agg = combine_monoid(param_2, param_3); + } + GroupMemoryBarrierWithGroupSync(); + sh_scratch[gl_LocalInvocationID.x] = agg; + } + if (gl_LocalInvocationID.x == 511u) + { + _43.Store(part_ix * 12 + 8, agg.element); + if (part_ix == 0u) + { + _43.Store(12, agg.element); + } + } + DeviceMemoryBarrier(); + if (gl_LocalInvocationID.x == 511u) + { + uint flag = 1u; + if (part_ix == 0u) + { + flag = 2u; + } + uint _383; + _43.InterlockedExchange(part_ix * 12 + 4, flag, _383); + } + Monoid exclusive = _185; + if (part_ix != 0u) + { + uint look_back_ix = part_ix - 1u; + uint their_ix = 0u; + Monoid their_prefix; + Monoid their_agg; + Monoid m; + while (true) + { + if (gl_LocalInvocationID.x == 511u) + { + uint _208; + _43.InterlockedAdd(look_back_ix * 12 + 4, 0, _208); + sh_flag = _208; + } + GroupMemoryBarrierWithGroupSync(); + DeviceMemoryBarrier(); + uint flag_1 = sh_flag; + if (flag_1 == 2u) + { + if (gl_LocalInvocationID.x == 511u) + { + Monoid _223; + _223.element = _43.Load(look_back_ix * 12 + 12); + their_prefix.element = _223.element; + Monoid param_4 = their_prefix; + Monoid param_5 = exclusive; + exclusive = combine_monoid(param_4, param_5); + } + break; + } + else + { + if (flag_1 == 1u) + { + if (gl_LocalInvocationID.x == 511u) + { + Monoid _245; + _245.element = _43.Load(look_back_ix * 12 + 8); + their_agg.element = _245.element; + Monoid param_6 = their_agg; + Monoid param_7 = exclusive; + exclusive = combine_monoid(param_6, param_7); + } + look_back_ix--; + their_ix = 0u; + continue; + } + } + if (gl_LocalInvocationID.x == 511u) + { + Monoid _267; + _267.element = _67.Load(((look_back_ix * 8192u) + their_ix) * 4 + 0); + m.element = _267.element; + if (their_ix == 0u) + { + their_agg = m; + } + else + { + Monoid param_8 = their_agg; + Monoid param_9 = m; + their_agg = combine_monoid(param_8, param_9); + } + their_ix++; + if (their_ix == 8192u) + { + Monoid param_10 = their_agg; + Monoid param_11 = exclusive; + exclusive = combine_monoid(param_10, param_11); + if (look_back_ix == 0u) + { + sh_flag = 2u; + } + else + { + look_back_ix--; + their_ix = 0u; + } + } + } + GroupMemoryBarrierWithGroupSync(); + flag_1 = sh_flag; + if (flag_1 == 2u) + { + break; + } + } + if (gl_LocalInvocationID.x == 511u) + { + Monoid param_12 = exclusive; + Monoid param_13 = agg; + Monoid inclusive_prefix = combine_monoid(param_12, param_13); + sh_prefix = exclusive; + _43.Store(part_ix * 12 + 12, inclusive_prefix.element); + } + DeviceMemoryBarrier(); + if (gl_LocalInvocationID.x == 511u) + { + uint _384; + _43.InterlockedExchange(part_ix * 12 + 4, 2u, _384); + } + } + GroupMemoryBarrierWithGroupSync(); + if (part_ix != 0u) + { + exclusive = sh_prefix; + } + Monoid row = exclusive; + if (gl_LocalInvocationID.x > 0u) + { + Monoid other_1 = sh_scratch[gl_LocalInvocationID.x - 1u]; + Monoid param_14 = row; + Monoid param_15 = other_1; + row = combine_monoid(param_14, param_15); + } + for (uint i_2 = 0u; i_2 < 16u; i_2++) + { + Monoid param_16 = row; + Monoid param_17 = local[i_2]; + Monoid m_1 = combine_monoid(param_16, param_17); + _372.Store((ix + i_2) * 4 + 0, m_1.element); + } +} + +[numthreads(512, 1, 1)] +void main(SPIRV_Cross_Input stage_input) +{ + gl_LocalInvocationID = stage_input.gl_LocalInvocationID; + comp_main(); +} diff --git a/tests/shader/gen/prefix_atomic.msl b/tests/shader/gen/prefix_atomic.msl new file mode 100644 index 0000000..6d7d155 --- /dev/null +++ b/tests/shader/gen/prefix_atomic.msl @@ -0,0 +1,263 @@ +#pragma clang diagnostic ignored "-Wmissing-prototypes" +#pragma clang diagnostic ignored "-Wmissing-braces" +#pragma clang diagnostic ignored "-Wunused-variable" + +#include +#include +#include + +using namespace metal; + +template +struct spvUnsafeArray +{ + T elements[Num ? Num : 1]; + + thread T& operator [] (size_t pos) thread + { + return elements[pos]; + } + constexpr const thread T& operator [] (size_t pos) const thread + { + return elements[pos]; + } + + device T& operator [] (size_t pos) device + { + return elements[pos]; + } + constexpr const device T& operator [] (size_t pos) const device + { + return elements[pos]; + } + + constexpr const constant T& operator [] (size_t pos) const constant + { + return elements[pos]; + } + + threadgroup T& operator [] (size_t pos) threadgroup + { + return elements[pos]; + } + constexpr const threadgroup T& operator [] (size_t pos) const threadgroup + { + return elements[pos]; + } +}; + +struct Monoid +{ + uint element; +}; + +struct Monoid_1 +{ + uint element; +}; + +struct State +{ + uint flag; + Monoid_1 aggregate; + Monoid_1 prefix; +}; + +struct StateBuf +{ + uint part_counter; + State state[1]; +}; + +struct InBuf +{ + Monoid_1 inbuf[1]; +}; + +struct OutBuf +{ + Monoid_1 outbuf[1]; +}; + +constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(512u, 1u, 1u); + +static inline __attribute__((always_inline)) +Monoid combine_monoid(thread const Monoid& a, thread const Monoid& b) +{ + return Monoid{ a.element + b.element }; +} + +kernel void main0(const device InBuf& _67 [[buffer(0)]], device OutBuf& _372 [[buffer(1)]], volatile device StateBuf& _43 [[buffer(2)]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]]) +{ + threadgroup uint sh_part_ix; + threadgroup Monoid sh_scratch[512]; + threadgroup uint sh_flag; + threadgroup Monoid sh_prefix; + if (gl_LocalInvocationID.x == 0u) + { + uint _47 = atomic_fetch_add_explicit((volatile device atomic_uint*)&_43.part_counter, 1u, memory_order_relaxed); + sh_part_ix = _47; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + uint part_ix = sh_part_ix; + uint ix = (part_ix * 8192u) + (gl_LocalInvocationID.x * 16u); + spvUnsafeArray local; + local[0].element = _67.inbuf[ix].element; + Monoid param_1; + for (uint i = 1u; i < 16u; i++) + { + Monoid param = local[i - 1u]; + param_1.element = _67.inbuf[ix + i].element; + local[i] = combine_monoid(param, param_1); + } + Monoid agg = local[15]; + sh_scratch[gl_LocalInvocationID.x] = agg; + for (uint i_1 = 0u; i_1 < 9u; i_1++) + { + threadgroup_barrier(mem_flags::mem_threadgroup); + if (gl_LocalInvocationID.x >= (1u << i_1)) + { + Monoid other = sh_scratch[gl_LocalInvocationID.x - (1u << i_1)]; + Monoid param_2 = other; + Monoid param_3 = agg; + agg = combine_monoid(param_2, param_3); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + sh_scratch[gl_LocalInvocationID.x] = agg; + } + if (gl_LocalInvocationID.x == 511u) + { + _43.state[part_ix].aggregate.element = agg.element; + if (part_ix == 0u) + { + _43.state[0].prefix.element = agg.element; + } + } + threadgroup_barrier(mem_flags::mem_device); + if (gl_LocalInvocationID.x == 511u) + { + uint flag = 1u; + if (part_ix == 0u) + { + flag = 2u; + } + atomic_store_explicit((volatile device atomic_uint*)&_43.state[part_ix].flag, flag, memory_order_relaxed); + } + Monoid exclusive = Monoid{ 0u }; + if (part_ix != 0u) + { + uint look_back_ix = part_ix - 1u; + uint their_ix = 0u; + Monoid their_prefix; + Monoid their_agg; + Monoid m; + while (true) + { + if (gl_LocalInvocationID.x == 511u) + { + uint _208 = atomic_load_explicit((volatile device atomic_uint*)&_43.state[look_back_ix].flag, memory_order_relaxed); + sh_flag = _208; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_barrier(mem_flags::mem_device); + uint flag_1 = sh_flag; + if (flag_1 == 2u) + { + if (gl_LocalInvocationID.x == 511u) + { + their_prefix.element = _43.state[look_back_ix].prefix.element; + Monoid param_4 = their_prefix; + Monoid param_5 = exclusive; + exclusive = combine_monoid(param_4, param_5); + } + break; + } + else + { + if (flag_1 == 1u) + { + if (gl_LocalInvocationID.x == 511u) + { + their_agg.element = _43.state[look_back_ix].aggregate.element; + Monoid param_6 = their_agg; + Monoid param_7 = exclusive; + exclusive = combine_monoid(param_6, param_7); + } + look_back_ix--; + their_ix = 0u; + continue; + } + } + if (gl_LocalInvocationID.x == 511u) + { + m.element = _67.inbuf[(look_back_ix * 8192u) + their_ix].element; + if (their_ix == 0u) + { + their_agg = m; + } + else + { + Monoid param_8 = their_agg; + Monoid param_9 = m; + their_agg = combine_monoid(param_8, param_9); + } + their_ix++; + if (their_ix == 8192u) + { + Monoid param_10 = their_agg; + Monoid param_11 = exclusive; + exclusive = combine_monoid(param_10, param_11); + if (look_back_ix == 0u) + { + sh_flag = 2u; + } + else + { + look_back_ix--; + their_ix = 0u; + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + flag_1 = sh_flag; + if (flag_1 == 2u) + { + break; + } + } + if (gl_LocalInvocationID.x == 511u) + { + Monoid param_12 = exclusive; + Monoid param_13 = agg; + Monoid inclusive_prefix = combine_monoid(param_12, param_13); + sh_prefix = exclusive; + _43.state[part_ix].prefix.element = inclusive_prefix.element; + } + threadgroup_barrier(mem_flags::mem_device); + if (gl_LocalInvocationID.x == 511u) + { + atomic_store_explicit((volatile device atomic_uint*)&_43.state[part_ix].flag, 2u, memory_order_relaxed); + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (part_ix != 0u) + { + exclusive = sh_prefix; + } + Monoid row = exclusive; + if (gl_LocalInvocationID.x > 0u) + { + Monoid other_1 = sh_scratch[gl_LocalInvocationID.x - 1u]; + Monoid param_14 = row; + Monoid param_15 = other_1; + row = combine_monoid(param_14, param_15); + } + for (uint i_2 = 0u; i_2 < 16u; i_2++) + { + Monoid param_16 = row; + Monoid param_17 = local[i_2]; + Monoid m_1 = combine_monoid(param_16, param_17); + _372.outbuf[ix + i_2].element = m_1.element; + } +} + diff --git a/tests/shader/gen/prefix_atomic.spv b/tests/shader/gen/prefix_atomic.spv new file mode 100644 index 0000000000000000000000000000000000000000..acca5452de5de6728d604a997c43cfad7b177987 GIT binary patch literal 9820 zcmZ{p33yf25yx+MSwK)g!41R&6c-i&5f|JdA|NO#?xOLLyo5-S2gwWKZq=@q?u*@5 zyQsBv*S1-H}vRmpe*5rQT9sg*oGsaq#wX@4D{3Qd@5|_TYg^|6qG%PHpC# z!TIwS|G$N~1$vPrMMOLny{?K!SL^EPXjrEcUuZT7I()zx3>8X0RpH};o0 zYil+nhrp}(P8#e4GP|j26IVCaULNeLl=^GFTD-Xj(mh!BES(Jdy#O6W>01bo`c44% z*pM`~6}_Ubwy?Fx9o>EFs+)2Cl9~x`Az7&pah(g$tW)ew`Pvb@kXQ`$6Wfp&L$3$g z`-_$K_0BIO-j~RABe+~yUoFzyH^Z~Jx90dZc=X%+cfm_r+It2Ex;J|+4RrBdbWgdw zp>18UeM9wxwKpU~*jcK(f7n)^JdBO~Ra3hDBk-O18^`@T;0?*s z!}IJ!ZO(HeV(OC}BX}Wc;Bcd-lfc>hd*=B5Io>jYHzkYU-F<0a+D7jE$Puyi$0T zuO6&aFR6D_?kEpds+ZJ#RZ`nB_T~G`_kemi@$n4LfIGCsGn==WMw{=ji1T^XHX1Si z!RSWBch+*a&#z-LlI43SGQ zX7~hj{qRn$g-1KzlCQ%_H|)q|Lf! zqV-Kn@2OauvLAl$-$BSUW_Ty`YcuzuX#1tz+;Z=VHvO|x>>PWbo{Q9FxIY?<5C44d zXaa&Sf|G2R9|N^_Q^|K6+8%0iz2j4^-MycPwhrrTMazvl3GJHN+_!v5YSX_o#mj2kfco*37jD21PHmUpC0CZp4%;)#dertC2(>q(A41~%{rL9@2Ss~HG9U&BJPxab344-&tYjmo5OST_k;Tl?(c`-o~OSdf*Zdw!+q8+$Z>xc zM7;ja40pc21G4@bbNuEU-$qEFa2ozxCmFzTf&Repie&4hHNjdH}KI``zpXDt%?l(T`_q(3ue%G_yZ+Mpb9nW&V<5})^Jh=TV zX1M+GTONM7-}5Z@doFL~mibPchWOs}SvZb6*NEO9(S92-&FB_Hn=#LFOJ|@RJBZge z9nt396`R|4+Z<#T;(OpJRyFrwh<5M4`P9dg(|4}E=gASd)Q_>oJhbAwFZOsi*g4Mi z8IgzY2(Y%B*oSK#iD>g(8Fe28&ewf3HtjKQF<6`L&+F4%$AZnL&$WFXwZ+;egSEA? z@>29E$P&aJ=re96dKse6{di8x5uat_&P>-@0XEKIoacQiVytnlFE>u?vu{7`;VQJ} zp?-!nPe-pt9Qw~d>sQC#&jg!e8*4SA&qB0Wzh^B!4KbH<#W8mc*tx;aPWcUKE$6_s zxxa`x7i_PBp9eNa^!j|TG1{X~7lQ4R`JCrnz6jB74}9OrqyHC!qyOe}j{Uy`(HFgL zOKqs(USA4ci$t&G#vRSsx^58}hq$J3vBz~__h?VtliXPQ?ituSb=2Js&emOmYwJMb zT(xD-)ia3o)`MNI3yGd|gS9)q6D>F1XTA@af=oqX59QPz^Ebks=bTOGexw&M&YW^% z!~S{pwiz3a~li-dzROb|vDylMf>1a;`Z4UWnSR#+R?{ z8f@BQy=%d)_dt3#uLFB8wcm-Y8GSwCJbk|NZa_QS&-Lh6AllrIIM%-joL~Qy*tEy` zuL5V+e>L3owa4@F8nE;9#TmR7>>22bb9xi{nf_2F~vD!#efdiD3VeM?GrSz6*Uf67#ap?aBTEE&oJcRx_Vw`qy+<|X`^LOA|*tEwT_%>Ku z+=1_c9p={e9YmYC#Zlk)z}66F@_lfei9XNdN!~P#=pP{Z<1DQI;nZ$TkzbvECuZ+< zto=jcoNt~-(DLy82&^sY`!QIX_seJbr)Y=u>H7)t2gLfs(Yv34^S%2yHtmu77hrAC zyI+AF=GOO1M4P$AvBzJ7qjyWezd@pR`g~6qI}`m|M1QRuCZU*8ic58vIFDyi8*2DGo&`UR#Cgb#Gl%E#Jfh!zK9}&lR z^Y|~=^Kg#7I1laiKWbA)Z7+bWE#5;fg0)$b^Yv*HXWv5)VzW=4g}Hrht}c-`U;3PbBkl&jbQf{|DJ4uYg>dZV*Fd)5IlavjDZ`cJ^l@| z3wQ}Jo`rtrdX8hk`r;hNfweiT&*y4acmopWC^s(F-wixOT+})quFX5@-JS?`SgXDX z$R3Ebip}GF*d6RTo`Ls5ZqJ=>taBoEGC1yze{Ye8Z!d7<@oz8kxI24;y*tj)7i($H zuH~9h-#*|W=0tz@g=>quqfcAJPXTL-eoTdH^UUnWG_b>d=-Us`W{u*gqZu4E`Zpta z_@;y7`Dg*lqsAFvYjlo2Ycy7S?C}7w_bT>yAY5CV*+F23d(<}*IT&$|VsrZKnTOan zqx>C|Jq=K4bj*qxP7895{0O zcWim|@_6t9=11%aaOeAbXess+!SM|-_Ev21u4#oE@32R+(I>&36YmXF1M4SB;N1e;T))~L^R=}e!eeT!a+SGgQ;BuG`2D>St}UL=)4&e*t#1{g&Aj3`kJG`Pndj;m%H2c6sw4Ic zaGdjMusnKuCOF^Qv#@EuBmFnR8nCw5<2hi5d(?L}qRl;uBlmgWeD3qHX^(fu1z>Id zUUJ?h^o8(Vw7#g}BCvMrJr^wx-^F0-(zg~Z4__PDczu_k#zpj7gxcJkN1V#ID2cK?T_Zs`z kJ+(*79&qPoV}B3gRNNf^&${3XlhDj3j&e2u-Ih$?rMdX}Ato7lQ!BD$}mbwi=I+}gRRT6+^*wpMM&?pkY;!?9PD3*{o( zZfZN)y9(>n_54rH9v0TE>n*Mu7;8T_^cLHzYt|)`u`Br&^|u3=-B97gmEzh;{XOMk zZ`D_gH+NrJgZrMMlVQJSqhnM0=3vMEn!#N*B#m8+Ue;4x*u7{wI(ybsigEsuDuuTu zS)mVcoz-a8DRigyMFZ@bqywA%6I+**!1TJWt+!Ba!(py|xX5)QxK!?_Y|`8}VP|vS zoU`x7j((ee8+LJ1TUUQy=SI(^jxKITca=)(Th|oY)>lqgzAkwHo~1f_cX74J!*JAB zN$L8JV(-phpX}To>^hQ z{Mw{#fL)WcV^`0+Cg}mk9ybh#t4(@y_GQ?o^_MG`)H|x}DD{^sm(;p8r(6ql`9AYK zpk7LRJi`;P9dbUOjyz`?Ip1Ls=kqEz1Tp_4bUorb%b3Y%M>dySHdhNZ%GDs|@OeEH zafCe{>{{B3(5_LfqjuMGc>tViE=D_IP4mk|exKQFer@wR%)gQ1cIV&S(`-Kb9Qo{T zHlOx*a-k~ut_Rrj1Cf2$h5H!KZ0C#jWlh&>L>%MrG@z?oZKb!`%6RrSG{yE@asFQD zM#MFzr1l7~qY--~I&$uhZ7t@v2K@)2ohv^#wY?*e&)Bhu!}x>I`s7Emi$=8kVQ9}q z-aOh9Q_g)grM_|LJu^;DIT*k9=Wt{kbG;M#<;;Bq+J4EKTid%Lr+;dSonsHw(~&_L zdj_~JV^@9zY}p-uBtG-S{G-4_ND%h1*etfoo`-2~$FklDXnQPYJquG?-n}hCyHESD z7_DvGiD=i9b6?utNjd#3DRyq0Z98~S#_j}r-chIDFcVmp?-J{hGoL?fJR^D6(=Mf) zet)azcTVKH670PT`xfvk6cF`n#kStCzXo2IvA+&BXUx9`>^`T$y3W0b_iHk}X+%GS z#QpL6lb!nr_M+5pti23?EGPd*mU6A$Pn6q!n z**9lw=lk0rJAYfo*1tVxKajKi_Gjns%-K)n?5A_K-~8-+zxP?&?|s(xd!Mxr%-MeP zvwpw%S=(=Z*7lp9wf)X#ZNKwb+i!Z-_Pd_73pv~Gdid?9-}SIPPrvP1+wXhU_WQ2g z%q{aBI}Y(Z?6WbSJ6DfB1d+dum`3z?M9!G!xTO=&ju(m7cPJv~-4&bLcjYw1_owfT zXIRzTvk>_i$V|lNU;*(H(fTGKz9U~CSL9UhU@dvYcjkPuneS-Cx#rThfcku9`+Um9+9!eKnpwF8y%br3 z*b{xmHKCUw`mD!uJh_6aai^ziodRc^!#K}>Iby8&TwmKbvF`%=YcEeldp_!CS#u@& zbi|?m477fA)V>OAj=NZ^5q%~i=l(r=?b8r*IaeHW&jLF)?6Xt*x^yq+V9Qxw#GDJZ zS7DzAHb?aOe6TU{(WeV4I7!UsJn!~}h`c>mo!Zg=R_y4%`J7|_FGlo5uL~)M+SThd z;ERyxwYG7`akj48h73ns)3~T{Z3QQZJ+UTjW9_?VVDHqi?;<$6?{(O6?MR%fT=rZ& zgIMp93al! z;i~_}G^cfLM7vM5_iqz=Gvd&HIa%jT-Uk@iA z>%Rej`Ynk3OeD_Z#*~jUeJi${cj6|rcKF@~ z_Pq7ojMk2O{&ukU+&TKNsqMLa1gY-1q4K@B6?+O|t;VY39raPL@eXU9ivAdKD>4GHN7}|s zMrI(f?#IEd8})wzEa&rT{hvZRtY6a$l7`&FfFkP1v|{G z?>~r~xy4cA^Wf~cynr3&qR%@V=kg+0fBszl3#ZQ!=kh=Bb4Z+vwsFz>ms7ueeF>}` zKL79(=kf}?cAU!~u;=0&eQ_@GooH+Y#@;zj~JQH(2 zmfGq#n>yx=!5)jmdi8MfaW)NLxj37lV25X;&%gi4xt2Kc37Qo2LQ){$SVf47@km_T2f#IwxWe0vltjf3MLF-&k|dv~0pFV>RJuH~Aszj5Gfe-FWyi@T#wF5(;0_~^%YY&p-&eoO#6?1#QX5jpoL zj(to7#~%HglXmzH2gmc#1lEo{P6E3}=jd~f#>z*HlfmArsPPDFxj3__V23s8n}SS3 ztWg}_Grup{@0s(AF*eqk4$eM%GqB~NfBNLC(_UJGXA!y7+R-Dmc04cY*ZOwO#P=}y z^50JRsBaedQDUMOM`Fvxw{tewVK4L@g~+*2v9-jzUtT-j8*1$b(Z(AW?~OUc=HDB~ zz{y7~$AabJy>UF)VJ-TOL*&dYj{7km9QVV&KWc~X1aJpA?6-f5)Q-Ee5bWJ?jy_}j zd!~HMUj(k^uKs_N+HnV>r;FkIEf{l7#5Rw=om$|Rfa4ou>{hsV2c3j%yu)5iMYmu( zr}_@6tJua(#y118$I%CM^mr*a-{WO)^3mgy!E({#iLunOFaWSoeKA-*yYJQ5^0Ch~VE5;}FQB!< zw-#)?zBaUWJU2zK&y92R8Dk!Kdm4LF>uX2b3w`Ur?py7@e>%~ZAP)WOQ@=XyUKiLL z@prLqu-rW6*i-Ee#9ZbU$J`Rwxv{TH!LjEJU~T(2gj{_UOfS(9Uk1y&k6yHP^r0VY zAIziAm>#rztbaLp4l!Y0fo(nU`{hcoG4iqIRbcbP@0qK?n-QNe1OItw-B9duyNVk2>nP7M!i;I_wIU&etcGJv;mSQZk5d{2=yLSNT0-9~?uGV_4UI z?vcU6z^=U~+FD1VMoYt*xM%16vT*v`$yjK-Fa9_|Zvp8QJgLj!t0ME+&g6kF^6 E0pGmNRR910 literal 0 HcmV?d00001 diff --git a/tests/shader/prefix.comp b/tests/shader/prefix.comp index 3ca1509..a6a0d57 100644 --- a/tests/shader/prefix.comp +++ b/tests/shader/prefix.comp @@ -1,9 +1,26 @@ // SPDX-License-Identifier: Apache-2.0 OR MIT OR Unlicense // A prefix sum. +// +// This test builds in three configurations. The default is a +// compatibility mode, essentially plain GLSL. With ATOMIC set, the +// flag loads and stores are atomic operations, but uses barriers. +// With both ATOMIC and VKMM set, it uses acquire/release semantics +// instead of barriers. #version 450 +#extension GL_KHR_memory_scope_semantics : enable + +#ifdef VKMM +#pragma use_vulkan_memory_model +#define ACQUIRE gl_StorageSemanticsBuffer, gl_SemanticsAcquire +#define RELEASE gl_StorageSemanticsBuffer, gl_SemanticsRelease +#else +#define ACQUIRE 0, 0 +#define RELEASE 0, 0 +#endif + #define N_ROWS 16 #define LG_WG_SIZE 9 #define WG_SIZE (1 << LG_WG_SIZE) @@ -24,9 +41,9 @@ layout(set = 0, binding = 1) buffer OutBuf { }; // These correspond to X, A, P respectively in the prefix sum paper. -#define FLAG_NOT_READY 0 -#define FLAG_AGGREGATE_READY 1 -#define FLAG_PREFIX_READY 2 +#define FLAG_NOT_READY 0u +#define FLAG_AGGREGATE_READY 1u +#define FLAG_PREFIX_READY 2u struct State { uint flag; @@ -34,6 +51,7 @@ struct State { Monoid prefix; }; +// Perhaps this should be "nonprivate" with VKMM layout(set = 0, binding = 2) volatile buffer StateBuf { uint part_counter; State[] state; @@ -87,13 +105,19 @@ void main() { } } // Write flag with release semantics; this is done portably with a barrier. +#ifndef VKMM memoryBarrierBuffer(); +#endif if (gl_LocalInvocationID.x == WG_SIZE - 1) { uint flag = FLAG_AGGREGATE_READY; if (part_ix == 0) { flag = FLAG_PREFIX_READY; } +#ifdef ATOMIC + atomicStore(state[part_ix].flag, flag, gl_ScopeDevice, RELEASE); +#else state[part_ix].flag = flag; +#endif } Monoid exclusive = Monoid(0); @@ -106,13 +130,19 @@ void main() { while (true) { // Read flag with acquire semantics. if (gl_LocalInvocationID.x == WG_SIZE - 1) { +#ifdef ATOMIC + sh_flag = atomicLoad(state[look_back_ix].flag, gl_ScopeDevice, ACQUIRE); +#else sh_flag = state[look_back_ix].flag; +#endif } // The flag load is done only in the last thread. However, because the // translation of memoryBarrierBuffer to Metal requires uniform control // flow, we broadcast it to all threads. barrier(); +#ifndef VKMM memoryBarrierBuffer(); +#endif uint flag = sh_flag; if (flag == FLAG_PREFIX_READY) { @@ -165,9 +195,15 @@ void main() { sh_prefix = exclusive; state[part_ix].prefix = inclusive_prefix; } +#ifndef VKMM memoryBarrierBuffer(); +#endif if (gl_LocalInvocationID.x == WG_SIZE - 1) { +#ifdef ATOMIC + atomicStore(state[part_ix].flag, FLAG_PREFIX_READY, gl_ScopeDevice, RELEASE); +#else state[part_ix].flag = FLAG_PREFIX_READY; +#endif } } barrier(); diff --git a/tests/src/main.rs b/tests/src/main.rs index adefa7f..186ce25 100644 --- a/tests/src/main.rs +++ b/tests/src/main.rs @@ -86,7 +86,23 @@ fn main() { } report(&clear::run_clear_test(&mut runner, &config)); if config.groups.matches("prefix") { - report(&prefix::run_prefix_test(&mut runner, &config)); + report(&prefix::run_prefix_test( + &mut runner, + &config, + prefix::Variant::Compatibility, + )); + report(&prefix::run_prefix_test( + &mut runner, + &config, + prefix::Variant::Atomic, + )); + if runner.session.gpu_info().has_memory_model { + report(&prefix::run_prefix_test( + &mut runner, + &config, + prefix::Variant::Vkmm, + )); + } report(&prefix_tree::run_prefix_test(&mut runner, &config)); } } diff --git a/tests/src/prefix.rs b/tests/src/prefix.rs index a2e52c3..1391c36 100644 --- a/tests/src/prefix.rs +++ b/tests/src/prefix.rs @@ -14,7 +14,7 @@ // // Also licensed under MIT license, at your choice. -use piet_gpu_hal::{include_shader, BackendType, BindType, BufferUsage, DescriptorSet}; +use piet_gpu_hal::{include_shader, BackendType, BindType, BufferUsage, DescriptorSet, ShaderCode}; use piet_gpu_hal::{Buffer, Pipeline}; use crate::clear::{ClearBinding, ClearCode, ClearStage}; @@ -51,8 +51,19 @@ struct PrefixBinding { descriptor_set: DescriptorSet, } -pub unsafe fn run_prefix_test(runner: &mut Runner, config: &Config) -> TestResult { - let mut result = TestResult::new("prefix sum, decoupled look-back"); +#[derive(Debug)] +pub enum Variant { + Compatibility, + Atomic, + Vkmm, +} + +pub unsafe fn run_prefix_test( + runner: &mut Runner, + config: &Config, + variant: Variant, +) -> TestResult { + let mut result = TestResult::new(format!("prefix sum, decoupled look-back, {:?}", variant)); /* // We're good if we're using DXC. if runner.backend_type() == BackendType::Dx12 { @@ -67,7 +78,7 @@ pub unsafe fn run_prefix_test(runner: &mut Runner, config: &Config) -> TestResul .create_buffer_init(&data, BufferUsage::STORAGE) .unwrap(); let out_buf = runner.buf_down(data_buf.size()); - let code = PrefixCode::new(runner); + let code = PrefixCode::new(runner, variant); let stage = PrefixStage::new(runner, &code, n_elements); let binding = stage.bind(runner, &code, &data_buf, &out_buf.dev_buf); let n_iter = config.n_iter; @@ -95,8 +106,12 @@ pub unsafe fn run_prefix_test(runner: &mut Runner, config: &Config) -> TestResul } impl PrefixCode { - unsafe fn new(runner: &mut Runner) -> PrefixCode { - let code = include_shader!(&runner.session, "../shader/gen/prefix"); + unsafe fn new(runner: &mut Runner, variant: Variant) -> PrefixCode { + let code = match variant { + Variant::Compatibility => include_shader!(&runner.session, "../shader/gen/prefix"), + Variant::Atomic => include_shader!(&runner.session, "../shader/gen/prefix_atomic"), + Variant::Vkmm => ShaderCode::Spv(include_bytes!("../shader/gen/prefix_vkmm.spv")), + }; let pipeline = runner .session .create_compute_pipeline( diff --git a/tests/src/test_result.rs b/tests/src/test_result.rs index e582c63..05ad9b3 100644 --- a/tests/src/test_result.rs +++ b/tests/src/test_result.rs @@ -38,9 +38,9 @@ pub enum ReportStyle { } impl TestResult { - pub fn new(name: &str) -> TestResult { + pub fn new(name: impl Into) -> TestResult { TestResult { - name: name.to_string(), + name: name.into(), total_time: 0.0, n_elements: 0, status: Status::Pass,