Merge pull request #128 from linebender/prefix_war

Fix write-after-read in prefix test
This commit is contained in:
Raph Levien 2021-11-16 08:10:20 -08:00 committed by GitHub
commit 95d356c08f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 13 additions and 1 deletions

View file

@ -248,9 +248,10 @@ void main() {
// The flag load is done only in the last thread. However, because the // The flag load is done only in the last thread. However, because the
// translation of memoryBarrierBuffer to Metal requires uniform control // translation of memoryBarrierBuffer to Metal requires uniform control
// flow, we broadcast it to all threads. // flow, we broadcast it to all threads.
barrier();
memoryBarrierBuffer(); memoryBarrierBuffer();
barrier();
uint flag = sh_flag; uint flag = sh_flag;
barrier();
if (flag == FLAG_PREFIX_READY) { if (flag == FLAG_PREFIX_READY) {
if (gl_LocalInvocationID.x == WG_SIZE - 1) { if (gl_LocalInvocationID.x == WG_SIZE - 1) {
@ -293,6 +294,7 @@ void main() {
} }
barrier(); barrier();
flag = sh_flag; flag = sh_flag;
barrier();
if (flag == FLAG_PREFIX_READY) { if (flag == FLAG_PREFIX_READY) {
break; break;
} }

Binary file not shown.

Binary file not shown.

Binary file not shown.

View file

@ -109,6 +109,7 @@ void comp_main()
GroupMemoryBarrierWithGroupSync(); GroupMemoryBarrierWithGroupSync();
DeviceMemoryBarrier(); DeviceMemoryBarrier();
uint flag_1 = sh_flag; uint flag_1 = sh_flag;
GroupMemoryBarrierWithGroupSync();
if (flag_1 == 2u) if (flag_1 == 2u)
{ {
if (gl_LocalInvocationID.x == 511u) if (gl_LocalInvocationID.x == 511u)
@ -174,6 +175,7 @@ void comp_main()
} }
GroupMemoryBarrierWithGroupSync(); GroupMemoryBarrierWithGroupSync();
flag_1 = sh_flag; flag_1 = sh_flag;
GroupMemoryBarrierWithGroupSync();
if (flag_1 == 2u) if (flag_1 == 2u)
{ {
break; break;

View file

@ -160,6 +160,7 @@ kernel void main0(const device InBuf& _67 [[buffer(0)]], device OutBuf& _372 [[b
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup_barrier(mem_flags::mem_device); threadgroup_barrier(mem_flags::mem_device);
uint flag_1 = sh_flag; uint flag_1 = sh_flag;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (flag_1 == 2u) if (flag_1 == 2u)
{ {
if (gl_LocalInvocationID.x == 511u) if (gl_LocalInvocationID.x == 511u)
@ -219,6 +220,7 @@ kernel void main0(const device InBuf& _67 [[buffer(0)]], device OutBuf& _372 [[b
} }
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
flag_1 = sh_flag; flag_1 = sh_flag;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (flag_1 == 2u) if (flag_1 == 2u)
{ {
break; break;

Binary file not shown.

Binary file not shown.

View file

@ -112,6 +112,7 @@ void comp_main()
GroupMemoryBarrierWithGroupSync(); GroupMemoryBarrierWithGroupSync();
DeviceMemoryBarrier(); DeviceMemoryBarrier();
uint flag_1 = sh_flag; uint flag_1 = sh_flag;
GroupMemoryBarrierWithGroupSync();
if (flag_1 == 2u) if (flag_1 == 2u)
{ {
if (gl_LocalInvocationID.x == 511u) if (gl_LocalInvocationID.x == 511u)
@ -177,6 +178,7 @@ void comp_main()
} }
GroupMemoryBarrierWithGroupSync(); GroupMemoryBarrierWithGroupSync();
flag_1 = sh_flag; flag_1 = sh_flag;
GroupMemoryBarrierWithGroupSync();
if (flag_1 == 2u) if (flag_1 == 2u)
{ {
break; break;

View file

@ -161,6 +161,7 @@ kernel void main0(const device InBuf& _67 [[buffer(0)]], device OutBuf& _372 [[b
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup_barrier(mem_flags::mem_device); threadgroup_barrier(mem_flags::mem_device);
uint flag_1 = sh_flag; uint flag_1 = sh_flag;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (flag_1 == 2u) if (flag_1 == 2u)
{ {
if (gl_LocalInvocationID.x == 511u) if (gl_LocalInvocationID.x == 511u)
@ -220,6 +221,7 @@ kernel void main0(const device InBuf& _67 [[buffer(0)]], device OutBuf& _372 [[b
} }
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
flag_1 = sh_flag; flag_1 = sh_flag;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (flag_1 == 2u) if (flag_1 == 2u)
{ {
break; break;

Binary file not shown.

Binary file not shown.

View file

@ -144,6 +144,7 @@ void main() {
memoryBarrierBuffer(); memoryBarrierBuffer();
#endif #endif
uint flag = sh_flag; uint flag = sh_flag;
barrier();
if (flag == FLAG_PREFIX_READY) { if (flag == FLAG_PREFIX_READY) {
if (gl_LocalInvocationID.x == WG_SIZE - 1) { if (gl_LocalInvocationID.x == WG_SIZE - 1) {
@ -185,6 +186,7 @@ void main() {
} }
barrier(); barrier();
flag = sh_flag; flag = sh_flag;
barrier();
if (flag == FLAG_PREFIX_READY) { if (flag == FLAG_PREFIX_READY) {
break; break;
} }