Parallel merge

The fancy stuff :)
This commit is contained in:
Raph Levien 2020-05-30 15:37:34 -07:00
parent 121f29fef6
commit 192ddc5eab
2 changed files with 65 additions and 17 deletions

View file

@ -30,9 +30,16 @@ layout(set = 0, binding = 3) buffer PtclBuf {
#define N_RINGBUF 512 #define N_RINGBUF 512
#define LG_N_PART_READ 8
#define N_PART_READ (1 << LG_N_PART_READ)
shared uint sh_elements[N_RINGBUF]; shared uint sh_elements[N_RINGBUF];
shared float sh_right_edge[N_RINGBUF]; shared float sh_right_edge[N_RINGBUF];
// Number of elements in the partition; prefix sum.
shared uint sh_part_count[N_PART_READ];
shared uint sh_part_elements[N_PART_READ];
shared uint sh_bitmaps[N_SLICE][N_TILE]; shared uint sh_bitmaps[N_SLICE][N_TILE];
shared uint sh_backdrop[N_SLICE][N_TILE]; shared uint sh_backdrop[N_SLICE][N_TILE];
shared uint sh_bd_sign[N_SLICE]; shared uint sh_bd_sign[N_SLICE];
@ -89,7 +96,7 @@ void main() {
// invocations within the workgroup. We'll use variables to abstract. // invocations within the workgroup. We'll use variables to abstract.
uint bin_ix = N_TILE_X * gl_WorkGroupID.y + gl_WorkGroupID.x; uint bin_ix = N_TILE_X * gl_WorkGroupID.y + gl_WorkGroupID.x;
uint partition_ix = 0; uint partition_ix = 0;
uint my_n_elements = n_elements; uint n_partitions = (n_elements + N_TILE - 1) / N_TILE;
// Top left coordinates of this bin. // Top left coordinates of this bin.
vec2 xy0 = vec2(N_TILE_X * TILE_WIDTH_PX * gl_WorkGroupID.x, N_TILE_Y * TILE_HEIGHT_PX * gl_WorkGroupID.y); vec2 xy0 = vec2(N_TILE_X * TILE_WIDTH_PX * gl_WorkGroupID.x, N_TILE_Y * TILE_HEIGHT_PX * gl_WorkGroupID.y);
uint th_ix = gl_LocalInvocationID.x; uint th_ix = gl_LocalInvocationID.x;
@ -107,8 +114,14 @@ void main() {
SegmentRef last_chunk_segs = SegmentRef(0); SegmentRef last_chunk_segs = SegmentRef(0);
alloc_chunk_remaining = 0; alloc_chunk_remaining = 0;
uint wr_ix = 0; // I'm sure we can figure out how to do this with at least one fewer register...
// Items up to rd_ix have been read from sh_elements
uint rd_ix = 0; uint rd_ix = 0;
// Items up to wr_ix have been written into sh_elements
uint wr_ix = 0;
// Items between part_start_ix and ready_ix are ready to be transferred from sh_part_elements
uint part_start_ix = 0;
uint ready_ix = 0;
if (th_ix < N_SLICE) { if (th_ix < N_SLICE) {
sh_bd_sign[th_ix] = 0; sh_bd_sign[th_ix] = 0;
} }
@ -122,21 +135,58 @@ void main() {
sh_is_segment[th_ix] = 0; sh_is_segment[th_ix] = 0;
} }
while (wr_ix - rd_ix <= N_TILE && partition_ix * N_TILE < my_n_elements) { // parallel read of input partitions
uint in_ix = (partition_ix * N_TILE + bin_ix) * 2; do {
uint chunk_n = bins[in_ix]; if (ready_ix == wr_ix && partition_ix < n_partitions) {
uint elements_ref = bins[in_ix + 1]; part_start_ix = ready_ix;
BinInstanceRef inst_ref = BinInstanceRef(elements_ref); uint count = 0;
if (th_ix < chunk_n) { if (th_ix < N_PART_READ && partition_ix + th_ix < n_partitions) {
BinInstance inst = BinInstance_read(BinInstance_index(inst_ref, th_ix)); uint in_ix = ((partition_ix + th_ix) * N_TILE + bin_ix) * 2;
uint wr_el_ix = (wr_ix + th_ix) % N_RINGBUF; count = bins[in_ix];
sh_part_elements[th_ix] = bins[in_ix + 1];
}
// prefix sum of counts
for (uint i = 0; i < LG_N_PART_READ; i++) {
if (th_ix < N_PART_READ) {
sh_part_count[th_ix] = count;
}
barrier();
if (th_ix < N_PART_READ) {
if (th_ix >= (1 << i)) {
count += sh_part_count[th_ix - (1 << i)];
}
}
barrier();
}
if (th_ix < N_PART_READ) {
sh_part_count[th_ix] = part_start_ix + count;
}
barrier();
ready_ix = sh_part_count[N_PART_READ - 1];
partition_ix += N_PART_READ;
}
// use binary search to find element to read
uint ix = rd_ix + th_ix;
if (ix >= wr_ix && ix < ready_ix) {
uint part_ix = 0;
for (uint i = 0; i < LG_N_PART_READ; i++) {
uint probe = part_ix + ((N_PART_READ / 2) >> i);
if (ix >= sh_part_count[probe - 1]) {
part_ix = probe;
}
}
ix -= part_ix > 0 ? sh_part_count[part_ix - 1] : part_start_ix;
BinInstanceRef inst_ref = BinInstanceRef(sh_part_elements[part_ix]);
BinInstance inst = BinInstance_read(BinInstance_index(inst_ref, ix));
uint wr_el_ix = (rd_ix + th_ix) % N_RINGBUF;
sh_elements[wr_el_ix] = inst.element_ix; sh_elements[wr_el_ix] = inst.element_ix;
sh_right_edge[wr_el_ix] = inst.right_edge; sh_right_edge[wr_el_ix] = inst.right_edge;
} }
wr_ix += chunk_n; barrier();
partition_ix++;
} wr_ix = min(rd_ix + N_TILE, ready_ix);
barrier(); } while (wr_ix - rd_ix < N_TILE && (wr_ix < ready_ix || partition_ix < n_partitions));
// We've done the merge and filled the buffer. // We've done the merge and filled the buffer.
@ -475,9 +525,7 @@ void main() {
barrier(); barrier();
rd_ix += N_TILE; rd_ix += N_TILE;
// The second disjunct is there as a strange workaround on Nvidia. If it is if (rd_ix >= ready_ix && partition_ix >= n_partitions) break;
// removed, then the kernel fails with ERROR_DEVICE_LOST.
if (rd_ix >= wr_ix || bin_ix == ~0) break;
} }
Cmd_End_write(cmd_ref); Cmd_End_write(cmd_ref);
} }

Binary file not shown.