2021-11-07 10:08:43 +11:00
|
|
|
// Copyright 2021 The piet-gpu authors.
|
|
|
|
//
|
|
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
// you may not use this file except in compliance with the License.
|
|
|
|
// You may obtain a copy of the License at
|
|
|
|
//
|
|
|
|
// https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
//
|
|
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
// limitations under the License.
|
|
|
|
//
|
|
|
|
// Also licensed under MIT license, at your choice.
|
|
|
|
|
2021-11-10 15:28:06 +11:00
|
|
|
use piet_gpu_hal::{include_shader, BindType, BufferUsage, DescriptorSet};
|
2021-11-07 10:08:43 +11:00
|
|
|
use piet_gpu_hal::{Buffer, Pipeline};
|
|
|
|
|
2021-11-10 09:04:58 +11:00
|
|
|
use crate::config::Config;
|
2021-11-07 10:08:43 +11:00
|
|
|
use crate::runner::{Commands, Runner};
|
2021-11-10 09:04:58 +11:00
|
|
|
use crate::test_result::TestResult;
|
2021-11-07 10:08:43 +11:00
|
|
|
|
|
|
|
const WG_SIZE: u64 = 512;
|
|
|
|
const N_ROWS: u64 = 8;
|
|
|
|
const ELEMENTS_PER_WG: u64 = WG_SIZE * N_ROWS;
|
|
|
|
|
|
|
|
struct PrefixTreeCode {
|
|
|
|
reduce_pipeline: Pipeline,
|
|
|
|
scan_pipeline: Pipeline,
|
|
|
|
root_pipeline: Pipeline,
|
|
|
|
}
|
|
|
|
|
|
|
|
struct PrefixTreeStage {
|
|
|
|
sizes: Vec<u64>,
|
|
|
|
tmp_bufs: Vec<Buffer>,
|
|
|
|
}
|
|
|
|
|
|
|
|
struct PrefixTreeBinding {
|
|
|
|
// All but the first and last can be moved to stage.
|
|
|
|
descriptor_sets: Vec<DescriptorSet>,
|
|
|
|
}
|
|
|
|
|
2021-11-10 09:04:58 +11:00
|
|
|
pub unsafe fn run_prefix_test(runner: &mut Runner, config: &Config) -> TestResult {
|
|
|
|
let mut result = TestResult::new("prefix sum, tree reduction");
|
2021-11-07 10:08:43 +11:00
|
|
|
// This will be configurable. Note though that the current code is
|
|
|
|
// prone to reading and writing past the end of buffers if this is
|
|
|
|
// not a power of the number of elements processed in a workgroup.
|
2021-11-10 09:04:58 +11:00
|
|
|
let n_elements: u64 = config.size.choose(1 << 12, 1 << 24, 1 << 24);
|
2021-11-07 10:08:43 +11:00
|
|
|
let data_buf = runner
|
|
|
|
.session
|
2021-11-26 08:12:25 +11:00
|
|
|
.create_buffer_with(
|
|
|
|
n_elements * 4,
|
2021-11-26 17:02:04 +11:00
|
|
|
|b| b.extend(0..n_elements as u32),
|
2021-11-26 08:12:25 +11:00
|
|
|
BufferUsage::STORAGE,
|
|
|
|
)
|
2021-11-07 10:08:43 +11:00
|
|
|
.unwrap();
|
2021-11-21 02:14:23 +11:00
|
|
|
let out_buf = runner.buf_down(data_buf.size(), BufferUsage::empty());
|
2021-11-07 10:08:43 +11:00
|
|
|
let code = PrefixTreeCode::new(runner);
|
|
|
|
let stage = PrefixTreeStage::new(runner, n_elements);
|
|
|
|
let binding = stage.bind(runner, &code, &out_buf.dev_buf);
|
|
|
|
// Also will be configurable of course.
|
2021-11-12 02:26:32 +11:00
|
|
|
let n_iter = config.n_iter;
|
2021-11-07 10:08:43 +11:00
|
|
|
let mut total_elapsed = 0.0;
|
|
|
|
for i in 0..n_iter {
|
|
|
|
let mut commands = runner.commands();
|
|
|
|
commands.cmd_buf.copy_buffer(&data_buf, &out_buf.dev_buf);
|
|
|
|
commands.cmd_buf.memory_barrier();
|
|
|
|
stage.record(&mut commands, &code, &binding);
|
2021-11-24 02:28:50 +11:00
|
|
|
if i == 0 || config.verify_all {
|
2021-11-07 10:08:43 +11:00
|
|
|
commands.cmd_buf.memory_barrier();
|
|
|
|
commands.download(&out_buf);
|
|
|
|
}
|
|
|
|
total_elapsed += runner.submit(commands);
|
2021-11-24 02:28:50 +11:00
|
|
|
if i == 0 || config.verify_all {
|
2021-11-26 08:12:25 +11:00
|
|
|
let dst = out_buf.map_read(..);
|
|
|
|
if let Some(failure) = verify(dst.cast_slice()) {
|
2021-11-10 09:04:58 +11:00
|
|
|
result.fail(format!("failure at {}", failure));
|
|
|
|
}
|
2021-11-07 10:08:43 +11:00
|
|
|
}
|
|
|
|
}
|
2021-11-10 09:04:58 +11:00
|
|
|
result.timing(total_elapsed, n_elements * n_iter);
|
|
|
|
result
|
2021-11-07 10:08:43 +11:00
|
|
|
}
|
|
|
|
|
|
|
|
impl PrefixTreeCode {
|
|
|
|
unsafe fn new(runner: &mut Runner) -> PrefixTreeCode {
|
|
|
|
let reduce_code = include_shader!(&runner.session, "../shader/gen/prefix_reduce");
|
|
|
|
let reduce_pipeline = runner
|
|
|
|
.session
|
2021-11-10 15:28:06 +11:00
|
|
|
.create_compute_pipeline(reduce_code, &[BindType::BufReadOnly, BindType::Buffer])
|
2021-11-07 10:08:43 +11:00
|
|
|
.unwrap();
|
|
|
|
let scan_code = include_shader!(&runner.session, "../shader/gen/prefix_scan");
|
|
|
|
let scan_pipeline = runner
|
|
|
|
.session
|
2021-11-10 15:28:06 +11:00
|
|
|
.create_compute_pipeline(scan_code, &[BindType::Buffer, BindType::BufReadOnly])
|
2021-11-07 10:08:43 +11:00
|
|
|
.unwrap();
|
|
|
|
let root_code = include_shader!(&runner.session, "../shader/gen/prefix_root");
|
|
|
|
let root_pipeline = runner
|
|
|
|
.session
|
2021-11-10 15:28:06 +11:00
|
|
|
.create_compute_pipeline(root_code, &[BindType::Buffer])
|
2021-11-07 10:08:43 +11:00
|
|
|
.unwrap();
|
|
|
|
PrefixTreeCode {
|
|
|
|
reduce_pipeline,
|
|
|
|
scan_pipeline,
|
|
|
|
root_pipeline,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
impl PrefixTreeStage {
|
|
|
|
unsafe fn new(runner: &mut Runner, n_elements: u64) -> PrefixTreeStage {
|
|
|
|
let mut size = n_elements;
|
|
|
|
let mut sizes = vec![size];
|
|
|
|
let mut tmp_bufs = Vec::new();
|
|
|
|
while size > ELEMENTS_PER_WG {
|
|
|
|
size = (size + ELEMENTS_PER_WG - 1) / ELEMENTS_PER_WG;
|
|
|
|
sizes.push(size);
|
|
|
|
let buf = runner
|
|
|
|
.session
|
|
|
|
.create_buffer(4 * size, BufferUsage::STORAGE)
|
|
|
|
.unwrap();
|
|
|
|
tmp_bufs.push(buf);
|
|
|
|
}
|
|
|
|
PrefixTreeStage { sizes, tmp_bufs }
|
|
|
|
}
|
|
|
|
|
|
|
|
unsafe fn bind(
|
|
|
|
&self,
|
|
|
|
runner: &mut Runner,
|
|
|
|
code: &PrefixTreeCode,
|
|
|
|
data_buf: &Buffer,
|
|
|
|
) -> PrefixTreeBinding {
|
|
|
|
let mut descriptor_sets = Vec::with_capacity(2 * self.tmp_bufs.len() + 1);
|
|
|
|
for i in 0..self.tmp_bufs.len() {
|
|
|
|
let buf0 = if i == 0 {
|
|
|
|
data_buf
|
|
|
|
} else {
|
|
|
|
&self.tmp_bufs[i - 1]
|
|
|
|
};
|
|
|
|
let buf1 = &self.tmp_bufs[i];
|
|
|
|
let descriptor_set = runner
|
|
|
|
.session
|
|
|
|
.create_simple_descriptor_set(&code.reduce_pipeline, &[buf0, buf1])
|
|
|
|
.unwrap();
|
|
|
|
descriptor_sets.push(descriptor_set);
|
|
|
|
}
|
|
|
|
let buf0 = self.tmp_bufs.last().unwrap_or(data_buf);
|
|
|
|
let descriptor_set = runner
|
|
|
|
.session
|
|
|
|
.create_simple_descriptor_set(&code.root_pipeline, &[buf0])
|
|
|
|
.unwrap();
|
|
|
|
descriptor_sets.push(descriptor_set);
|
|
|
|
for i in (0..self.tmp_bufs.len()).rev() {
|
|
|
|
let buf0 = if i == 0 {
|
|
|
|
data_buf
|
|
|
|
} else {
|
|
|
|
&self.tmp_bufs[i - 1]
|
|
|
|
};
|
|
|
|
let buf1 = &self.tmp_bufs[i];
|
|
|
|
let descriptor_set = runner
|
|
|
|
.session
|
|
|
|
.create_simple_descriptor_set(&code.scan_pipeline, &[buf0, buf1])
|
|
|
|
.unwrap();
|
|
|
|
descriptor_sets.push(descriptor_set);
|
|
|
|
}
|
|
|
|
PrefixTreeBinding { descriptor_sets }
|
|
|
|
}
|
|
|
|
|
|
|
|
unsafe fn record(
|
|
|
|
&self,
|
|
|
|
commands: &mut Commands,
|
|
|
|
code: &PrefixTreeCode,
|
|
|
|
bindings: &PrefixTreeBinding,
|
|
|
|
) {
|
2022-04-21 18:20:54 +10:00
|
|
|
let mut pass = commands.compute_pass(0, 1);
|
2021-11-07 10:08:43 +11:00
|
|
|
let n = self.tmp_bufs.len();
|
|
|
|
for i in 0..n {
|
|
|
|
let n_workgroups = self.sizes[i + 1];
|
2022-04-21 18:20:54 +10:00
|
|
|
pass.dispatch(
|
2021-11-07 10:08:43 +11:00
|
|
|
&code.reduce_pipeline,
|
|
|
|
&bindings.descriptor_sets[i],
|
|
|
|
(n_workgroups as u32, 1, 1),
|
|
|
|
(WG_SIZE as u32, 1, 1),
|
|
|
|
);
|
2022-04-21 18:20:54 +10:00
|
|
|
pass.memory_barrier();
|
2021-11-07 10:08:43 +11:00
|
|
|
}
|
2022-04-21 18:20:54 +10:00
|
|
|
pass.dispatch(
|
2021-11-07 10:08:43 +11:00
|
|
|
&code.root_pipeline,
|
|
|
|
&bindings.descriptor_sets[n],
|
|
|
|
(1, 1, 1),
|
|
|
|
(WG_SIZE as u32, 1, 1),
|
|
|
|
);
|
|
|
|
for i in (0..n).rev() {
|
2022-04-21 18:20:54 +10:00
|
|
|
pass.memory_barrier();
|
2021-11-07 10:08:43 +11:00
|
|
|
let n_workgroups = self.sizes[i + 1];
|
2022-04-21 18:20:54 +10:00
|
|
|
pass.dispatch(
|
2021-11-07 10:08:43 +11:00
|
|
|
&code.scan_pipeline,
|
|
|
|
&bindings.descriptor_sets[2 * n - i],
|
|
|
|
(n_workgroups as u32, 1, 1),
|
|
|
|
(WG_SIZE as u32, 1, 1),
|
|
|
|
);
|
|
|
|
}
|
2022-04-21 18:20:54 +10:00
|
|
|
pass.end();
|
2021-11-07 10:08:43 +11:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Verify that the data is OEIS A000217
|
|
|
|
fn verify(data: &[u32]) -> Option<usize> {
|
|
|
|
data.iter()
|
|
|
|
.enumerate()
|
|
|
|
.position(|(i, val)| ((i * (i + 1)) / 2) as u32 != *val)
|
|
|
|
}
|