mirror of
https://github.com/italicsjenga/vello.git
synced 2025-01-10 12:41:30 +11:00
a0648a2153
The MSL translation of the prefix example had its bindings permuted; a flag prevents this (but, as is typical for shader translation, potentially creates other problems). Also use explicit unsigned literal to avoid DXC warnings.
78 lines
1.9 KiB
GLSL
78 lines
1.9 KiB
GLSL
// SPDX-License-Identifier: Apache-2.0 OR MIT OR Unlicense
|
|
|
|
// A scan for a tree reduction prefix scan (either root or not, by ifdef).
|
|
|
|
#version 450
|
|
|
|
#define N_ROWS 8
|
|
#define LG_WG_SIZE 9
|
|
#define WG_SIZE (1 << LG_WG_SIZE)
|
|
#define PARTITION_SIZE (WG_SIZE * N_ROWS)
|
|
|
|
layout(local_size_x = WG_SIZE, local_size_y = 1) in;
|
|
|
|
struct Monoid {
|
|
uint element;
|
|
};
|
|
|
|
layout(set = 0, binding = 0) buffer DataBuf {
|
|
Monoid[] data;
|
|
};
|
|
|
|
#ifndef ROOT
|
|
layout(set = 0, binding = 1) readonly buffer ParentBuf {
|
|
Monoid[] parent;
|
|
};
|
|
#endif
|
|
|
|
shared Monoid sh_scratch[WG_SIZE];
|
|
|
|
Monoid combine_monoid(Monoid a, Monoid b) {
|
|
return Monoid(a.element + b.element);
|
|
}
|
|
|
|
void main() {
|
|
Monoid local[N_ROWS];
|
|
|
|
uint ix = gl_GlobalInvocationID.x * N_ROWS;
|
|
|
|
// TODO: gate buffer read
|
|
local[0] = data[ix];
|
|
for (uint i = 1; i < N_ROWS; i++) {
|
|
local[i] = combine_monoid(local[i - 1], data[ix + i]);
|
|
}
|
|
Monoid agg = local[N_ROWS - 1];
|
|
sh_scratch[gl_LocalInvocationID.x] = agg;
|
|
for (uint i = 0; i < LG_WG_SIZE; i++) {
|
|
barrier();
|
|
if (gl_LocalInvocationID.x >= (1u << i)) {
|
|
Monoid other = sh_scratch[gl_LocalInvocationID.x - (1u << i)];
|
|
agg = combine_monoid(other, agg);
|
|
}
|
|
barrier();
|
|
sh_scratch[gl_LocalInvocationID.x] = agg;
|
|
}
|
|
|
|
barrier();
|
|
// This could be a semigroup instead of a monoid if we reworked the
|
|
// conditional logic, but that might impact performance.
|
|
Monoid row = Monoid(0);
|
|
#ifdef ROOT
|
|
if (gl_LocalInvocationID.x > 0) {
|
|
row = sh_scratch[gl_LocalInvocationID.x - 1];
|
|
}
|
|
#else
|
|
if (gl_WorkGroupID.x > 0) {
|
|
row = parent[gl_WorkGroupID.x - 1];
|
|
}
|
|
if (gl_LocalInvocationID.x > 0) {
|
|
row = combine_monoid(row, sh_scratch[gl_LocalInvocationID.x - 1]);
|
|
}
|
|
#endif
|
|
for (uint i = 0; i < N_ROWS; i++) {
|
|
Monoid m = combine_monoid(row, local[i]);
|
|
// TODO: gate buffer write
|
|
data[ix + i] = m;
|
|
}
|
|
}
|