Portability fixes

The MSL translation of the prefix example had its bindings permuted; a
flag prevents this (but, as is typical for shader translation,
potentially creates other problems).

Also use explicit unsigned literal to avoid DXC warnings.
This commit is contained in:
Raph Levien 2021-11-11 06:59:27 -08:00
parent fbfd4ee81b
commit a0648a2153
16 changed files with 56 additions and 52 deletions

View file

@ -5,6 +5,10 @@
glslang_validator = glslangValidator glslang_validator = glslangValidator
spirv_cross = spirv-cross spirv_cross = spirv-cross
# See https://github.com/KhronosGroup/SPIRV-Cross/issues/1248 for
# why we set this.
msl_flags = --msl-decoration-binding
rule glsl rule glsl
command = $glslang_validator $flags -V -o $out $in command = $glslang_validator $flags -V -o $out $in
@ -12,7 +16,7 @@ rule hlsl
command = $spirv_cross --hlsl $in --output $out command = $spirv_cross --hlsl $in --output $out
rule msl rule msl
command = $spirv_cross --msl $in --output $out command = $spirv_cross --msl $in --output $out $msl_flags
build gen/clear.spv: glsl clear.comp build gen/clear.spv: glsl clear.comp
build gen/clear.hlsl: hlsl gen/clear.spv build gen/clear.hlsl: hlsl gen/clear.spv

View file

@ -12,11 +12,11 @@ struct State
static const uint3 gl_WorkGroupSize = uint3(512u, 1u, 1u); static const uint3 gl_WorkGroupSize = uint3(512u, 1u, 1u);
static const Monoid _187 = { 0u }; static const Monoid _185 = { 0u };
globallycoherent RWByteAddressBuffer _43 : register(u2); globallycoherent RWByteAddressBuffer _43 : register(u2);
ByteAddressBuffer _67 : register(t0); ByteAddressBuffer _67 : register(t0);
RWByteAddressBuffer _374 : register(u1); RWByteAddressBuffer _372 : register(u1);
static uint3 gl_LocalInvocationID; static uint3 gl_LocalInvocationID;
struct SPIRV_Cross_Input struct SPIRV_Cross_Input
@ -64,9 +64,9 @@ void comp_main()
for (uint i_1 = 0u; i_1 < 9u; i_1++) for (uint i_1 = 0u; i_1 < 9u; i_1++)
{ {
GroupMemoryBarrierWithGroupSync(); GroupMemoryBarrierWithGroupSync();
if (gl_LocalInvocationID.x >= uint(1 << int(i_1))) if (gl_LocalInvocationID.x >= (1u << i_1))
{ {
Monoid other = sh_scratch[gl_LocalInvocationID.x - uint(1 << int(i_1))]; Monoid other = sh_scratch[gl_LocalInvocationID.x - (1u << i_1)];
Monoid param_2 = other; Monoid param_2 = other;
Monoid param_3 = agg; Monoid param_3 = agg;
agg = combine_monoid(param_2, param_3); agg = combine_monoid(param_2, param_3);
@ -92,7 +92,7 @@ void comp_main()
} }
_43.Store(part_ix * 12 + 4, flag); _43.Store(part_ix * 12 + 4, flag);
} }
Monoid exclusive = _187; Monoid exclusive = _185;
if (part_ix != 0u) if (part_ix != 0u)
{ {
uint look_back_ix = part_ix - 1u; uint look_back_ix = part_ix - 1u;
@ -113,9 +113,9 @@ void comp_main()
{ {
if (gl_LocalInvocationID.x == 511u) if (gl_LocalInvocationID.x == 511u)
{ {
Monoid _225; Monoid _223;
_225.element = _43.Load(look_back_ix * 12 + 12); _223.element = _43.Load(look_back_ix * 12 + 12);
their_prefix.element = _225.element; their_prefix.element = _223.element;
Monoid param_4 = their_prefix; Monoid param_4 = their_prefix;
Monoid param_5 = exclusive; Monoid param_5 = exclusive;
exclusive = combine_monoid(param_4, param_5); exclusive = combine_monoid(param_4, param_5);
@ -128,9 +128,9 @@ void comp_main()
{ {
if (gl_LocalInvocationID.x == 511u) if (gl_LocalInvocationID.x == 511u)
{ {
Monoid _247; Monoid _245;
_247.element = _43.Load(look_back_ix * 12 + 8); _245.element = _43.Load(look_back_ix * 12 + 8);
their_agg.element = _247.element; their_agg.element = _245.element;
Monoid param_6 = their_agg; Monoid param_6 = their_agg;
Monoid param_7 = exclusive; Monoid param_7 = exclusive;
exclusive = combine_monoid(param_6, param_7); exclusive = combine_monoid(param_6, param_7);
@ -142,9 +142,9 @@ void comp_main()
} }
if (gl_LocalInvocationID.x == 511u) if (gl_LocalInvocationID.x == 511u)
{ {
Monoid _269; Monoid _267;
_269.element = _67.Load(((look_back_ix * 8192u) + their_ix) * 4 + 0); _267.element = _67.Load(((look_back_ix * 8192u) + their_ix) * 4 + 0);
m.element = _269.element; m.element = _267.element;
if (their_ix == 0u) if (their_ix == 0u)
{ {
their_agg = m; their_agg = m;
@ -211,7 +211,7 @@ void comp_main()
Monoid param_16 = row; Monoid param_16 = row;
Monoid param_17 = local[i_2]; Monoid param_17 = local[i_2];
Monoid m_1 = combine_monoid(param_16, param_17); Monoid m_1 = combine_monoid(param_16, param_17);
_374.Store((ix + i_2) * 4 + 0, m_1.element); _372.Store((ix + i_2) * 4 + 0, m_1.element);
} }
} }

View file

@ -87,7 +87,7 @@ Monoid combine_monoid(thread const Monoid& a, thread const Monoid& b)
return Monoid{ a.element + b.element }; return Monoid{ a.element + b.element };
} }
kernel void main0(volatile device StateBuf& _43 [[buffer(0)]], const device InBuf& _67 [[buffer(1)]], device OutBuf& _374 [[buffer(2)]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]]) kernel void main0(const device InBuf& _67 [[buffer(0)]], device OutBuf& _372 [[buffer(1)]], volatile device StateBuf& _43 [[buffer(2)]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]])
{ {
threadgroup uint sh_part_ix; threadgroup uint sh_part_ix;
threadgroup Monoid sh_scratch[512]; threadgroup Monoid sh_scratch[512];
@ -115,9 +115,9 @@ kernel void main0(volatile device StateBuf& _43 [[buffer(0)]], const device InBu
for (uint i_1 = 0u; i_1 < 9u; i_1++) for (uint i_1 = 0u; i_1 < 9u; i_1++)
{ {
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
if (gl_LocalInvocationID.x >= uint(1 << int(i_1))) if (gl_LocalInvocationID.x >= (1u << i_1))
{ {
Monoid other = sh_scratch[gl_LocalInvocationID.x - uint(1 << int(i_1))]; Monoid other = sh_scratch[gl_LocalInvocationID.x - (1u << i_1)];
Monoid param_2 = other; Monoid param_2 = other;
Monoid param_3 = agg; Monoid param_3 = agg;
agg = combine_monoid(param_2, param_3); agg = combine_monoid(param_2, param_3);
@ -256,7 +256,7 @@ kernel void main0(volatile device StateBuf& _43 [[buffer(0)]], const device InBu
Monoid param_16 = row; Monoid param_16 = row;
Monoid param_17 = local[i_2]; Monoid param_17 = local[i_2];
Monoid m_1 = combine_monoid(param_16, param_17); Monoid m_1 = combine_monoid(param_16, param_17);
_374.outbuf[ix + i_2].element = m_1.element; _372.outbuf[ix + i_2].element = m_1.element;
} }
} }

Binary file not shown.

View file

@ -6,7 +6,7 @@ struct Monoid
static const uint3 gl_WorkGroupSize = uint3(512u, 1u, 1u); static const uint3 gl_WorkGroupSize = uint3(512u, 1u, 1u);
ByteAddressBuffer _40 : register(t0); ByteAddressBuffer _40 : register(t0);
RWByteAddressBuffer _129 : register(u1); RWByteAddressBuffer _127 : register(u1);
static uint3 gl_WorkGroupID; static uint3 gl_WorkGroupID;
static uint3 gl_LocalInvocationID; static uint3 gl_LocalInvocationID;
@ -46,9 +46,9 @@ void comp_main()
for (uint i_1 = 0u; i_1 < 9u; i_1++) for (uint i_1 = 0u; i_1 < 9u; i_1++)
{ {
GroupMemoryBarrierWithGroupSync(); GroupMemoryBarrierWithGroupSync();
if ((gl_LocalInvocationID.x + uint(1 << int(i_1))) < 512u) if ((gl_LocalInvocationID.x + (1u << i_1)) < 512u)
{ {
Monoid other = sh_scratch[gl_LocalInvocationID.x + uint(1 << int(i_1))]; Monoid other = sh_scratch[gl_LocalInvocationID.x + (1u << i_1)];
Monoid param_2 = agg; Monoid param_2 = agg;
Monoid param_3 = other; Monoid param_3 = other;
agg = combine_monoid(param_2, param_3); agg = combine_monoid(param_2, param_3);
@ -58,7 +58,7 @@ void comp_main()
} }
if (gl_LocalInvocationID.x == 0u) if (gl_LocalInvocationID.x == 0u)
{ {
_129.Store(gl_WorkGroupID.x * 4 + 0, agg.element); _127.Store(gl_WorkGroupID.x * 4 + 0, agg.element);
} }
} }

View file

@ -33,7 +33,7 @@ Monoid combine_monoid(thread const Monoid& a, thread const Monoid& b)
return Monoid{ a.element + b.element }; return Monoid{ a.element + b.element };
} }
kernel void main0(const device InBuf& _40 [[buffer(0)]], device OutBuf& _129 [[buffer(1)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]]) kernel void main0(const device InBuf& _40 [[buffer(0)]], device OutBuf& _127 [[buffer(1)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]])
{ {
threadgroup Monoid sh_scratch[512]; threadgroup Monoid sh_scratch[512];
uint ix = gl_GlobalInvocationID.x * 8u; uint ix = gl_GlobalInvocationID.x * 8u;
@ -50,9 +50,9 @@ kernel void main0(const device InBuf& _40 [[buffer(0)]], device OutBuf& _129 [[b
for (uint i_1 = 0u; i_1 < 9u; i_1++) for (uint i_1 = 0u; i_1 < 9u; i_1++)
{ {
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
if ((gl_LocalInvocationID.x + uint(1 << int(i_1))) < 512u) if ((gl_LocalInvocationID.x + (1u << i_1)) < 512u)
{ {
Monoid other = sh_scratch[gl_LocalInvocationID.x + uint(1 << int(i_1))]; Monoid other = sh_scratch[gl_LocalInvocationID.x + (1u << i_1)];
Monoid param_2 = agg; Monoid param_2 = agg;
Monoid param_3 = other; Monoid param_3 = other;
agg = combine_monoid(param_2, param_3); agg = combine_monoid(param_2, param_3);
@ -62,7 +62,7 @@ kernel void main0(const device InBuf& _40 [[buffer(0)]], device OutBuf& _129 [[b
} }
if (gl_LocalInvocationID.x == 0u) if (gl_LocalInvocationID.x == 0u)
{ {
_129.outbuf[gl_WorkGroupID.x].element = agg.element; _127.outbuf[gl_WorkGroupID.x].element = agg.element;
} }
} }

Binary file not shown.

View file

@ -5,7 +5,7 @@ struct Monoid
static const uint3 gl_WorkGroupSize = uint3(512u, 1u, 1u); static const uint3 gl_WorkGroupSize = uint3(512u, 1u, 1u);
static const Monoid _133 = { 0u }; static const Monoid _131 = { 0u };
RWByteAddressBuffer _42 : register(u0); RWByteAddressBuffer _42 : register(u0);
@ -46,9 +46,9 @@ void comp_main()
for (uint i_1 = 0u; i_1 < 9u; i_1++) for (uint i_1 = 0u; i_1 < 9u; i_1++)
{ {
GroupMemoryBarrierWithGroupSync(); GroupMemoryBarrierWithGroupSync();
if (gl_LocalInvocationID.x >= uint(1 << int(i_1))) if (gl_LocalInvocationID.x >= (1u << i_1))
{ {
Monoid other = sh_scratch[gl_LocalInvocationID.x - uint(1 << int(i_1))]; Monoid other = sh_scratch[gl_LocalInvocationID.x - (1u << i_1)];
Monoid param_2 = other; Monoid param_2 = other;
Monoid param_3 = agg; Monoid param_3 = agg;
agg = combine_monoid(param_2, param_3); agg = combine_monoid(param_2, param_3);
@ -57,7 +57,7 @@ void comp_main()
sh_scratch[gl_LocalInvocationID.x] = agg; sh_scratch[gl_LocalInvocationID.x] = agg;
} }
GroupMemoryBarrierWithGroupSync(); GroupMemoryBarrierWithGroupSync();
Monoid row = _133; Monoid row = _131;
if (gl_LocalInvocationID.x > 0u) if (gl_LocalInvocationID.x > 0u)
{ {
row = sh_scratch[gl_LocalInvocationID.x - 1u]; row = sh_scratch[gl_LocalInvocationID.x - 1u];

View file

@ -85,9 +85,9 @@ kernel void main0(device DataBuf& _42 [[buffer(0)]], uint3 gl_GlobalInvocationID
for (uint i_1 = 0u; i_1 < 9u; i_1++) for (uint i_1 = 0u; i_1 < 9u; i_1++)
{ {
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
if (gl_LocalInvocationID.x >= uint(1 << int(i_1))) if (gl_LocalInvocationID.x >= (1u << i_1))
{ {
Monoid other = sh_scratch[gl_LocalInvocationID.x - uint(1 << int(i_1))]; Monoid other = sh_scratch[gl_LocalInvocationID.x - (1u << i_1)];
Monoid param_2 = other; Monoid param_2 = other;
Monoid param_3 = agg; Monoid param_3 = agg;
agg = combine_monoid(param_2, param_3); agg = combine_monoid(param_2, param_3);

Binary file not shown.

View file

@ -5,10 +5,10 @@ struct Monoid
static const uint3 gl_WorkGroupSize = uint3(512u, 1u, 1u); static const uint3 gl_WorkGroupSize = uint3(512u, 1u, 1u);
static const Monoid _133 = { 0u }; static const Monoid _131 = { 0u };
RWByteAddressBuffer _42 : register(u0); RWByteAddressBuffer _42 : register(u0);
ByteAddressBuffer _143 : register(t1); ByteAddressBuffer _141 : register(t1);
static uint3 gl_WorkGroupID; static uint3 gl_WorkGroupID;
static uint3 gl_LocalInvocationID; static uint3 gl_LocalInvocationID;
@ -49,9 +49,9 @@ void comp_main()
for (uint i_1 = 0u; i_1 < 9u; i_1++) for (uint i_1 = 0u; i_1 < 9u; i_1++)
{ {
GroupMemoryBarrierWithGroupSync(); GroupMemoryBarrierWithGroupSync();
if (gl_LocalInvocationID.x >= uint(1 << int(i_1))) if (gl_LocalInvocationID.x >= (1u << i_1))
{ {
Monoid other = sh_scratch[gl_LocalInvocationID.x - uint(1 << int(i_1))]; Monoid other = sh_scratch[gl_LocalInvocationID.x - (1u << i_1)];
Monoid param_2 = other; Monoid param_2 = other;
Monoid param_3 = agg; Monoid param_3 = agg;
agg = combine_monoid(param_2, param_3); agg = combine_monoid(param_2, param_3);
@ -60,12 +60,12 @@ void comp_main()
sh_scratch[gl_LocalInvocationID.x] = agg; sh_scratch[gl_LocalInvocationID.x] = agg;
} }
GroupMemoryBarrierWithGroupSync(); GroupMemoryBarrierWithGroupSync();
Monoid row = _133; Monoid row = _131;
if (gl_WorkGroupID.x > 0u) if (gl_WorkGroupID.x > 0u)
{ {
Monoid _148; Monoid _146;
_148.element = _143.Load((gl_WorkGroupID.x - 1u) * 4 + 0); _146.element = _141.Load((gl_WorkGroupID.x - 1u) * 4 + 0);
row.element = _148.element; row.element = _146.element;
} }
if (gl_LocalInvocationID.x > 0u) if (gl_LocalInvocationID.x > 0u)
{ {

View file

@ -72,7 +72,7 @@ Monoid combine_monoid(thread const Monoid& a, thread const Monoid& b)
return Monoid{ a.element + b.element }; return Monoid{ a.element + b.element };
} }
kernel void main0(device DataBuf& _42 [[buffer(0)]], const device ParentBuf& _143 [[buffer(1)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]]) kernel void main0(device DataBuf& _42 [[buffer(0)]], const device ParentBuf& _141 [[buffer(1)]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_LocalInvocationID [[thread_position_in_threadgroup]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]])
{ {
threadgroup Monoid sh_scratch[512]; threadgroup Monoid sh_scratch[512];
uint ix = gl_GlobalInvocationID.x * 8u; uint ix = gl_GlobalInvocationID.x * 8u;
@ -90,9 +90,9 @@ kernel void main0(device DataBuf& _42 [[buffer(0)]], const device ParentBuf& _14
for (uint i_1 = 0u; i_1 < 9u; i_1++) for (uint i_1 = 0u; i_1 < 9u; i_1++)
{ {
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
if (gl_LocalInvocationID.x >= uint(1 << int(i_1))) if (gl_LocalInvocationID.x >= (1u << i_1))
{ {
Monoid other = sh_scratch[gl_LocalInvocationID.x - uint(1 << int(i_1))]; Monoid other = sh_scratch[gl_LocalInvocationID.x - (1u << i_1)];
Monoid param_2 = other; Monoid param_2 = other;
Monoid param_3 = agg; Monoid param_3 = agg;
agg = combine_monoid(param_2, param_3); agg = combine_monoid(param_2, param_3);
@ -104,7 +104,7 @@ kernel void main0(device DataBuf& _42 [[buffer(0)]], const device ParentBuf& _14
Monoid row = Monoid{ 0u }; Monoid row = Monoid{ 0u };
if (gl_WorkGroupID.x > 0u) if (gl_WorkGroupID.x > 0u)
{ {
row.element = _143.parent[gl_WorkGroupID.x - 1u].element; row.element = _141.parent[gl_WorkGroupID.x - 1u].element;
} }
if (gl_LocalInvocationID.x > 0u) if (gl_LocalInvocationID.x > 0u)
{ {

Binary file not shown.

View file

@ -71,8 +71,8 @@ void main() {
sh_scratch[gl_LocalInvocationID.x] = agg; sh_scratch[gl_LocalInvocationID.x] = agg;
for (uint i = 0; i < LG_WG_SIZE; i++) { for (uint i = 0; i < LG_WG_SIZE; i++) {
barrier(); barrier();
if (gl_LocalInvocationID.x >= (1 << i)) { if (gl_LocalInvocationID.x >= (1u << i)) {
Monoid other = sh_scratch[gl_LocalInvocationID.x - (1 << i)]; Monoid other = sh_scratch[gl_LocalInvocationID.x - (1u << i)];
agg = combine_monoid(other, agg); agg = combine_monoid(other, agg);
} }
barrier(); barrier();

View file

@ -40,8 +40,8 @@ void main() {
for (uint i = 0; i < LG_WG_SIZE; i++) { for (uint i = 0; i < LG_WG_SIZE; i++) {
barrier(); barrier();
// We could make this predicate tighter, but would it help? // We could make this predicate tighter, but would it help?
if (gl_LocalInvocationID.x + (1 << i) < WG_SIZE) { if (gl_LocalInvocationID.x + (1u << i) < WG_SIZE) {
Monoid other = sh_scratch[gl_LocalInvocationID.x + (1 << i)]; Monoid other = sh_scratch[gl_LocalInvocationID.x + (1u << i)];
agg = combine_monoid(agg, other); agg = combine_monoid(agg, other);
} }
barrier(); barrier();

View file

@ -45,8 +45,8 @@ void main() {
sh_scratch[gl_LocalInvocationID.x] = agg; sh_scratch[gl_LocalInvocationID.x] = agg;
for (uint i = 0; i < LG_WG_SIZE; i++) { for (uint i = 0; i < LG_WG_SIZE; i++) {
barrier(); barrier();
if (gl_LocalInvocationID.x >= (1 << i)) { if (gl_LocalInvocationID.x >= (1u << i)) {
Monoid other = sh_scratch[gl_LocalInvocationID.x - (1 << i)]; Monoid other = sh_scratch[gl_LocalInvocationID.x - (1u << i)];
agg = combine_monoid(other, agg); agg = combine_monoid(other, agg);
} }
barrier(); barrier();