command style metal timer queries + compute pass

This commit adds timestamps to compute pass boundaries for command style timer queries on metal.

It also updates the code in piet-gpu/stages, piet-gpu/lib.rs and tests/ to use the new ComputePass type.
This commit is contained in:
chad 2022-04-21 04:20:54 -04:00
parent ea0fbab8aa
commit 02cc867950
16 changed files with 166 additions and 122 deletions

View file

@ -107,7 +107,7 @@ pub struct CmdBuf {
enum Encoder {
None,
Compute(metal::ComputeCommandEncoder),
Compute(metal::ComputeCommandEncoder, Option<(id, u32)>),
Blit(metal::BlitCommandEncoder),
}
@ -578,7 +578,8 @@ impl crate::backend::CmdBuf<MtlDevice> for CmdBuf {
//debug_assert!(matches!(self.cur_encoder, Encoder::None));
self.flush_encoder();
autoreleasepool(|| {
let encoder = if let Some(queries) = &desc.timer_queries {
let (encoder, end_query) = match (&desc.timer_queries, self.counter_style) {
(Some(queries), CounterStyle::Stage) => {
let descriptor: id =
msg_send![class!(MTLComputePassDescriptor), computePassDescriptor];
let attachments: id = msg_send![descriptor, sampleBufferAttachments];
@ -595,14 +596,34 @@ impl crate::backend::CmdBuf<MtlDevice> for CmdBuf {
let end_index = queries.2 as NSInteger;
let () = msg_send![attachment, setStartOfEncoderSampleIndex: start_index];
let () = msg_send![attachment, setEndOfEncoderSampleIndex: end_index];
(
msg_send![
self.cmd_buf,
computeCommandEncoderWithDescriptor: descriptor
]
],
None,
)
}
(Some(queries), CounterStyle::Command) => {
let encoder = self.cmd_buf.new_compute_command_encoder();
#[allow(irrefutable_let_patterns)]
let end_query = if let crate::hub::QueryPool::Mtl(query_pool) = queries.0 {
if let Some(sample_buf) = &query_pool.counter_sample_buf {
let sample_index = queries.1 as NSUInteger;
let sample_buf = sample_buf.id();
let () = msg_send![encoder, sampleCountersInBuffer: sample_buf atSampleIndex: sample_index withBarrier: true];
Some((sample_buf, queries.2))
} else {
self.cmd_buf.new_compute_command_encoder()
None
}
} else {
None
};
self.cur_encoder = Encoder::Compute(encoder.to_owned());
(encoder, end_query)
}
_ => (self.cmd_buf.new_compute_command_encoder(), None),
};
self.cur_encoder = Encoder::Compute(encoder.to_owned(), end_query);
});
}
@ -663,7 +684,7 @@ impl crate::backend::CmdBuf<MtlDevice> for CmdBuf {
let size = size.unwrap_or(buffer.size);
let _ = self.compute_command_encoder();
// Getting this directly is a workaround for a borrow checker issue.
if let Encoder::Compute(e) = &self.cur_encoder {
if let Encoder::Compute(e, _) = &self.cur_encoder {
clear::encode_clear(e, &self.helpers.clear_pipeline, &buffer.buffer, size);
}
}
@ -752,12 +773,12 @@ impl crate::backend::CmdBuf<MtlDevice> for CmdBuf {
if let Some(buf) = &pool.counter_sample_buf {
if matches!(self.cur_encoder, Encoder::None) {
self.cur_encoder =
Encoder::Compute(self.cmd_buf.new_compute_command_encoder().to_owned());
Encoder::Compute(self.cmd_buf.new_compute_command_encoder().to_owned(), None);
}
let sample_index = query as NSUInteger;
if self.counter_style == CounterStyle::Command {
match &self.cur_encoder {
Encoder::Compute(e) => {
Encoder::Compute(e, _) => {
let () = msg_send![e.as_ptr(), sampleCountersInBuffer: buf.id() atSampleIndex: sample_index withBarrier: true];
}
Encoder::None => unreachable!(),
@ -765,7 +786,7 @@ impl crate::backend::CmdBuf<MtlDevice> for CmdBuf {
}
} else if self.counter_style == CounterStyle::Stage {
match &self.cur_encoder {
Encoder::Compute(_e) => {
Encoder::Compute(_e, _) => {
println!("write_timestamp is not supported for stage-style encoders");
}
_ => (),
@ -777,12 +798,12 @@ impl crate::backend::CmdBuf<MtlDevice> for CmdBuf {
impl CmdBuf {
fn compute_command_encoder(&mut self) -> &metal::ComputeCommandEncoder {
if !matches!(self.cur_encoder, Encoder::Compute(_)) {
if !matches!(self.cur_encoder, Encoder::Compute(..)) {
self.flush_encoder();
self.cur_encoder =
Encoder::Compute(self.cmd_buf.new_compute_command_encoder().to_owned());
Encoder::Compute(self.cmd_buf.new_compute_command_encoder().to_owned(), None);
}
if let Encoder::Compute(e) = &self.cur_encoder {
if let Encoder::Compute(e, _) = &self.cur_encoder {
e
} else {
unreachable!()
@ -803,7 +824,14 @@ impl CmdBuf {
fn flush_encoder(&mut self) {
match std::mem::replace(&mut self.cur_encoder, Encoder::None) {
Encoder::Compute(e) => e.end_encoding(),
Encoder::Compute(e, Some((sample_buf, end_query))) => {
let sample_index = end_query as NSUInteger;
unsafe {
let () = msg_send![e.as_ptr(), sampleCountersInBuffer: sample_buf atSampleIndex: sample_index withBarrier: true];
}
e.end_encoding();
}
Encoder::Compute(e, None) => e.end_encoding(),
Encoder::Blit(e) => e.end_encoding(),
Encoder::None => (),
}

View file

@ -70,7 +70,7 @@ fn main() -> Result<(), Error> {
.map(|_| session.create_semaphore())
.collect::<Result<Vec<_>, Error>>()?;
let query_pools = (0..NUM_FRAMES)
.map(|_| session.create_query_pool(8))
.map(|_| session.create_query_pool(12))
.collect::<Result<Vec<_>, Error>>()?;
let mut cmd_bufs: [Option<CmdBuf>; NUM_FRAMES] = Default::default();
let mut submitted: [Option<SubmittedCmdBuf>; NUM_FRAMES] = Default::default();
@ -112,22 +112,23 @@ fn main() -> Result<(), Error> {
if !ts.is_empty() {
info_string = format!(
"{:.3}ms :: e:{:.3}ms|alloc:{:.3}ms|cp:{:.3}ms|bd:{:.3}ms|bin:{:.3}ms|cr:{:.3}ms|r:{:.3}ms",
ts[6] * 1e3,
ts[10] * 1e3,
ts[0] * 1e3,
(ts[1] - ts[0]) * 1e3,
(ts[2] - ts[1]) * 1e3,
(ts[3] - ts[2]) * 1e3,
(ts[4] - ts[3]) * 1e3,
(ts[5] - ts[4]) * 1e3,
(ts[6] - ts[5]) * 1e3,
(ts[8] - ts[7]) * 1e3,
(ts[10] - ts[9]) * 1e3,
);
}
}
let mut ctx = PietGpuRenderContext::new();
let test_blend = false;
if let Some(svg) = &svg {
test_scenes::render_svg(&mut ctx, svg);
} else {
} else if test_blend {
use piet_gpu::{Blend, BlendMode::*, CompositionMode::*};
let blends = [
Blend::new(Normal, SrcOver),
@ -163,6 +164,8 @@ fn main() -> Result<(), Error> {
let blend = blends[mode % blends.len()];
test_scenes::render_blend_test(&mut ctx, current_frame, blend);
info_string = format!("{:?}", blend);
} else {
test_scenes::render_anim_frame(&mut ctx, current_frame);
}
render_info_string(&mut ctx, &info_string);
if let Err(e) = renderer.upload_render_ctx(&mut ctx, frame_idx) {

View file

@ -17,8 +17,8 @@ use piet::kurbo::Vec2;
use piet::{ImageFormat, RenderContext};
use piet_gpu_hal::{
include_shader, BindType, Buffer, BufferUsage, CmdBuf, DescriptorSet, Error, Image,
ImageLayout, Pipeline, QueryPool, Session,
include_shader, BindType, Buffer, BufferUsage, CmdBuf, ComputePassDescriptor, DescriptorSet,
Error, Image, ImageLayout, Pipeline, QueryPool, Session,
};
pub use pico_svg::PicoSvg;
@ -423,10 +423,11 @@ impl Renderer {
cmd_buf.copy_buffer_to_image(&self.gradient_bufs[buf_ix], &self.gradients);
cmd_buf.image_barrier(&self.gradients, ImageLayout::BlitDst, ImageLayout::General);
cmd_buf.reset_query_pool(&query_pool);
cmd_buf.write_timestamp(&query_pool, 0);
cmd_buf.begin_debug_label("Element bounding box calculation");
let mut pass = cmd_buf.begin_compute_pass(&ComputePassDescriptor::timer(&query_pool, 0, 1));
// cmd_buf.write_timestamp(&query_pool, 0);
self.element_stage.record(
cmd_buf,
&mut pass,
&self.element_code,
&self.element_bindings[buf_ix],
self.n_transform as u64,
@ -434,56 +435,64 @@ impl Renderer {
self.n_pathtag as u32,
self.n_drawobj as u64,
);
pass.end();
cmd_buf.end_debug_label();
cmd_buf.write_timestamp(&query_pool, 1);
// cmd_buf.write_timestamp(&query_pool, 1);
cmd_buf.memory_barrier();
cmd_buf.begin_debug_label("Clip bounding box calculation");
let mut pass = cmd_buf.begin_compute_pass(&ComputePassDescriptor::timer(&query_pool, 2, 3));
self.clip_binding
.record(cmd_buf, &self.clip_code, self.n_clip as u32);
cmd_buf.end_debug_label();
cmd_buf.begin_debug_label("Element binning");
cmd_buf.dispatch(
.record(&mut pass, &self.clip_code, self.n_clip as u32);
// cmd_buf.end_debug_label();
// cmd_buf.begin_debug_label("Element binning");
pass.dispatch(
&self.bin_pipeline,
&self.bin_ds,
(((self.n_paths + 255) / 256) as u32, 1, 1),
(256, 1, 1),
);
cmd_buf.end_debug_label();
cmd_buf.memory_barrier();
cmd_buf.begin_debug_label("Tile allocation");
cmd_buf.dispatch(
// cmd_buf.end_debug_label();
pass.memory_barrier();
// cmd_buf.begin_debug_label("Tile allocation");
pass.dispatch(
&self.tile_pipeline,
&self.tile_ds[buf_ix],
(((self.n_paths + 255) / 256) as u32, 1, 1),
(256, 1, 1),
);
cmd_buf.end_debug_label();
cmd_buf.write_timestamp(&query_pool, 2);
cmd_buf.memory_barrier();
// cmd_buf.end_debug_label();
pass.end();
// cmd_buf.write_timestamp(&query_pool, 2);
cmd_buf.begin_debug_label("Path flattening");
cmd_buf.dispatch(
cmd_buf.memory_barrier();
let mut pass = cmd_buf.begin_compute_pass(&ComputePassDescriptor::timer(&query_pool, 4, 5));
pass.dispatch(
&self.path_pipeline,
&self.path_ds,
(((self.n_pathseg + 31) / 32) as u32, 1, 1),
(32, 1, 1),
);
pass.end();
// cmd_buf.write_timestamp(&query_pool, 3);
cmd_buf.end_debug_label();
cmd_buf.write_timestamp(&query_pool, 3);
cmd_buf.memory_barrier();
cmd_buf.begin_debug_label("Backdrop propagation");
cmd_buf.dispatch(
let mut pass = cmd_buf.begin_compute_pass(&ComputePassDescriptor::timer(&query_pool, 6, 7));
pass.dispatch(
&self.backdrop_pipeline,
&self.backdrop_ds,
(((self.n_paths + 255) / 256) as u32, 1, 1),
(256, self.backdrop_y, 1),
);
pass.end();
// cmd_buf.write_timestamp(&query_pool, 4);
cmd_buf.end_debug_label();
cmd_buf.write_timestamp(&query_pool, 4);
// TODO: redo query accounting
cmd_buf.write_timestamp(&query_pool, 5);
// cmd_buf.write_timestamp(&query_pool, 5);
cmd_buf.memory_barrier();
cmd_buf.begin_debug_label("Coarse raster");
cmd_buf.dispatch(
let mut pass = cmd_buf.begin_compute_pass(&ComputePassDescriptor::timer(&query_pool, 8, 9));
pass.dispatch(
&self.coarse_pipeline,
&self.coarse_ds[buf_ix],
(
@ -493,11 +502,14 @@ impl Renderer {
),
(256, 1, 1),
);
pass.end();
cmd_buf.end_debug_label();
cmd_buf.write_timestamp(&query_pool, 6);
// cmd_buf.write_timestamp(&query_pool, 6);
cmd_buf.memory_barrier();
cmd_buf.begin_debug_label("Fine raster");
cmd_buf.dispatch(
let mut pass =
cmd_buf.begin_compute_pass(&ComputePassDescriptor::timer(&query_pool, 10, 11));
pass.dispatch(
&self.k4_pipeline,
&self.k4_ds,
(
@ -507,8 +519,9 @@ impl Renderer {
),
(8, 4, 1),
);
pass.end();
cmd_buf.end_debug_label();
cmd_buf.write_timestamp(&query_pool, 7);
// cmd_buf.write_timestamp(&query_pool, 7);
cmd_buf.memory_barrier();
cmd_buf.image_barrier(&self.image_dev, ImageLayout::General, ImageLayout::BlitSrc);
}

View file

@ -26,7 +26,7 @@ use bytemuck::{Pod, Zeroable};
pub use clip::{ClipBinding, ClipCode, CLIP_PART_SIZE};
pub use draw::{DrawBinding, DrawCode, DrawMonoid, DrawStage, DRAW_PART_SIZE};
pub use path::{PathBinding, PathCode, PathEncoder, PathStage, PATHSEG_PART_SIZE};
use piet_gpu_hal::{Buffer, CmdBuf, Session};
use piet_gpu_hal::{Buffer, ComputePass, Session};
pub use transform::{
Transform, TransformBinding, TransformCode, TransformStage, TRANSFORM_PART_SIZE,
};
@ -140,7 +140,7 @@ impl ElementStage {
pub unsafe fn record(
&self,
cmd_buf: &mut CmdBuf,
pass: &mut ComputePass,
code: &ElementCode,
binding: &ElementBinding,
n_transform: u64,
@ -149,14 +149,14 @@ impl ElementStage {
n_drawobj: u64,
) {
self.transform_stage.record(
cmd_buf,
pass,
&code.transform_code,
&binding.transform_binding,
n_transform,
);
// No memory barrier needed here; path has at least one before pathseg
self.path_stage.record(
cmd_buf,
pass,
&code.path_code,
&binding.path_binding,
n_paths,
@ -164,6 +164,6 @@ impl ElementStage {
);
// No memory barrier needed here; draw has at least one before draw_leaf
self.draw_stage
.record(cmd_buf, &code.draw_code, &binding.draw_binding, n_drawobj);
.record(pass, &code.draw_code, &binding.draw_binding, n_drawobj);
}
}

View file

@ -16,7 +16,7 @@
//! The clip processing stage (includes substages).
use piet_gpu_hal::{include_shader, BindType, Buffer, CmdBuf, DescriptorSet, Pipeline, Session};
use piet_gpu_hal::{include_shader, BindType, Buffer, ComputePass, DescriptorSet, Pipeline, Session};
// Note that this isn't the code/stage/binding pattern of most of the other stages
// in the new element processing pipeline. We want to move those temporary buffers
@ -69,26 +69,26 @@ impl ClipBinding {
/// Record the clip dispatches.
///
/// Assumes memory barrier on entry. Provides memory barrier on exit.
pub unsafe fn record(&self, cmd_buf: &mut CmdBuf, code: &ClipCode, n_clip: u32) {
pub unsafe fn record(&self, pass: &mut ComputePass, code: &ClipCode, n_clip: u32) {
let n_wg_reduce = n_clip.saturating_sub(1) / CLIP_PART_SIZE;
if n_wg_reduce > 0 {
cmd_buf.dispatch(
pass.dispatch(
&code.reduce_pipeline,
&self.reduce_ds,
(n_wg_reduce, 1, 1),
(CLIP_PART_SIZE, 1, 1),
);
cmd_buf.memory_barrier();
pass.memory_barrier();
}
let n_wg = (n_clip + CLIP_PART_SIZE - 1) / CLIP_PART_SIZE;
if n_wg > 0 {
cmd_buf.dispatch(
pass.dispatch(
&code.leaf_pipeline,
&self.leaf_ds,
(n_wg, 1, 1),
(CLIP_PART_SIZE, 1, 1),
);
cmd_buf.memory_barrier();
pass.memory_barrier();
}
}
}

View file

@ -19,7 +19,7 @@
use bytemuck::{Pod, Zeroable};
use piet_gpu_hal::{
include_shader, BindType, Buffer, BufferUsage, CmdBuf, DescriptorSet, Pipeline, Session,
include_shader, BindType, Buffer, BufferUsage, ComputePass, DescriptorSet, Pipeline, Session,
};
/// The output element of the draw object stage.
@ -130,7 +130,7 @@ impl DrawStage {
pub unsafe fn record(
&self,
cmd_buf: &mut CmdBuf,
pass: &mut ComputePass,
code: &DrawCode,
binding: &DrawBinding,
size: u64,
@ -140,22 +140,22 @@ impl DrawStage {
}
let n_workgroups = (size + DRAW_PART_SIZE - 1) / DRAW_PART_SIZE;
if n_workgroups > 1 {
cmd_buf.dispatch(
pass.dispatch(
&code.reduce_pipeline,
&binding.reduce_ds,
(n_workgroups as u32, 1, 1),
(DRAW_WG as u32, 1, 1),
);
cmd_buf.memory_barrier();
cmd_buf.dispatch(
pass.memory_barrier();
pass.dispatch(
&code.root_pipeline,
&self.root_ds,
(1, 1, 1),
(DRAW_WG as u32, 1, 1),
);
}
cmd_buf.memory_barrier();
cmd_buf.dispatch(
pass.memory_barrier();
pass.dispatch(
&code.leaf_pipeline,
&binding.leaf_ds,
(n_workgroups as u32, 1, 1),

View file

@ -17,7 +17,7 @@
//! The path stage (includes substages).
use piet_gpu_hal::{
include_shader, BindType, Buffer, BufferUsage, CmdBuf, DescriptorSet, Pipeline, Session,
include_shader, BindType, Buffer, BufferUsage, ComputePass, DescriptorSet, Pipeline, Session,
};
pub struct PathCode {
@ -148,7 +148,7 @@ impl PathStage {
/// those are consumed. Result is written without barrier.
pub unsafe fn record(
&self,
cmd_buf: &mut CmdBuf,
pass: &mut ComputePass,
code: &PathCode,
binding: &PathBinding,
n_paths: u32,
@ -166,15 +166,15 @@ impl PathStage {
let reduce_part_tags = REDUCE_PART_SIZE * 4;
let n_wg_tag_reduce = (n_tags + reduce_part_tags - 1) / reduce_part_tags;
if n_wg_tag_reduce > 1 {
cmd_buf.dispatch(
pass.dispatch(
&code.reduce_pipeline,
&binding.reduce_ds,
(n_wg_tag_reduce, 1, 1),
(REDUCE_WG, 1, 1),
);
// I think we can skip root if n_wg_tag_reduce == 2
cmd_buf.memory_barrier();
cmd_buf.dispatch(
pass.memory_barrier();
pass.dispatch(
&code.tag_root_pipeline,
&self.tag_root_ds,
(1, 1, 1),
@ -183,15 +183,15 @@ impl PathStage {
// No barrier needed here; clear doesn't depend on path tags
}
let n_wg_clear = (n_paths + CLEAR_WG - 1) / CLEAR_WG;
cmd_buf.dispatch(
pass.dispatch(
&code.clear_pipeline,
&binding.clear_ds,
(n_wg_clear, 1, 1),
(CLEAR_WG, 1, 1),
);
cmd_buf.memory_barrier();
pass.memory_barrier();
let n_wg_pathseg = (n_tags + SCAN_PART_SIZE - 1) / SCAN_PART_SIZE;
cmd_buf.dispatch(
pass.dispatch(
&code.pathseg_pipeline,
&binding.path_ds,
(n_wg_pathseg, 1, 1),

View file

@ -20,7 +20,7 @@ use bytemuck::{Pod, Zeroable};
use piet::kurbo::Affine;
use piet_gpu_hal::{
include_shader, BindType, Buffer, BufferUsage, CmdBuf, DescriptorSet, Pipeline, Session,
include_shader, BindType, Buffer, BufferUsage, ComputePass, DescriptorSet, Pipeline, Session,
};
/// An affine transform.
@ -132,7 +132,7 @@ impl TransformStage {
pub unsafe fn record(
&self,
cmd_buf: &mut CmdBuf,
pass: &mut ComputePass,
code: &TransformCode,
binding: &TransformBinding,
size: u64,
@ -142,22 +142,22 @@ impl TransformStage {
}
let n_workgroups = (size + TRANSFORM_PART_SIZE - 1) / TRANSFORM_PART_SIZE;
if n_workgroups > 1 {
cmd_buf.dispatch(
pass.dispatch(
&code.reduce_pipeline,
&binding.reduce_ds,
(n_workgroups as u32, 1, 1),
(TRANSFORM_WG as u32, 1, 1),
);
cmd_buf.memory_barrier();
cmd_buf.dispatch(
pass.memory_barrier();
pass.dispatch(
&code.root_pipeline,
&self.root_ds,
(1, 1, 1),
(TRANSFORM_WG as u32, 1, 1),
);
cmd_buf.memory_barrier();
pass.memory_barrier();
}
cmd_buf.dispatch(
pass.dispatch(
&code.leaf_pipeline,
&binding.leaf_ds,
(n_workgroups as u32, 1, 1),

View file

@ -58,11 +58,11 @@ pub unsafe fn clip_test(runner: &mut Runner, config: &Config) -> TestResult {
let binding = ClipBinding::new(&runner.session, &code, &config_buf, &memory.dev_buf);
let mut commands = runner.commands();
commands.write_timestamp(0);
commands.upload(&memory);
binding.record(&mut commands.cmd_buf, &code, n_clip as u32);
let mut pass = commands.compute_pass(0, 1);
binding.record(&mut pass, &code, n_clip as u32);
pass.end();
commands.download(&memory);
commands.write_timestamp(1);
runner.submit(commands);
let dst = memory.map_read(..);
if let Some(failure) = data.verify(&dst) {

View file

@ -77,9 +77,9 @@ pub unsafe fn draw_test(runner: &mut Runner, config: &Config) -> TestResult {
let n_iter = config.n_iter;
for i in 0..n_iter {
let mut commands = runner.commands();
commands.write_timestamp(0);
stage.record(&mut commands.cmd_buf, &code, &binding, n_tag);
commands.write_timestamp(1);
let mut pass = commands.compute_pass(0, 1);
stage.record(&mut pass, &code, &binding, n_tag);
pass.end();
if i == 0 || config.verify_all {
commands.cmd_buf.memory_barrier();
commands.download(&memory);

View file

@ -45,9 +45,7 @@ pub unsafe fn run_linkedlist_test(runner: &mut Runner, config: &Config) -> TestR
for i in 0..n_iter {
let mut commands = runner.commands();
// Might clear only buckets to save time.
commands.write_timestamp(0);
stage.record(&mut commands, &code, &binding, &mem_buf.dev_buf);
commands.write_timestamp(1);
if i == 0 || config.verify_all {
commands.cmd_buf.memory_barrier();
commands.download(&mem_buf);
@ -107,12 +105,14 @@ impl LinkedListStage {
commands.cmd_buf.clear_buffer(out_buf, None);
commands.cmd_buf.memory_barrier();
let n_workgroups = N_BUCKETS / WG_SIZE;
commands.cmd_buf.dispatch(
let mut pass = commands.compute_pass(0, 1);
pass.dispatch(
&code.pipeline,
&bindings.descriptor_set,
(n_workgroups as u32, 1, 1),
(WG_SIZE as u32, 1, 1),
);
pass.end();
}
}

View file

@ -59,9 +59,7 @@ pub unsafe fn run_message_passing_test(
let mut failures = 0;
for _ in 0..n_iter {
let mut commands = runner.commands();
commands.write_timestamp(0);
stage.record(&mut commands, &code, &binding, &out_buf.dev_buf);
commands.write_timestamp(1);
commands.cmd_buf.memory_barrier();
commands.download(&out_buf);
total_elapsed += runner.submit(commands);
@ -128,11 +126,13 @@ impl MessagePassingStage {
commands.cmd_buf.clear_buffer(&self.data_buf, None);
commands.cmd_buf.clear_buffer(out_buf, None);
commands.cmd_buf.memory_barrier();
commands.cmd_buf.dispatch(
let mut pass = commands.compute_pass(0, 1);
pass.dispatch(
&code.pipeline,
&bindings.descriptor_set,
(256, 1, 1),
(256, 1, 1),
);
pass.end();
}
}

View file

@ -105,15 +105,15 @@ pub unsafe fn path_test(runner: &mut Runner, config: &Config) -> TestResult {
let mut commands = runner.commands();
commands.cmd_buf.copy_buffer(&memory_init, &memory.dev_buf);
commands.cmd_buf.memory_barrier();
commands.write_timestamp(0);
let mut pass = commands.compute_pass(0, 1);
stage.record(
&mut commands.cmd_buf,
&mut pass,
&code,
&binding,
path_data.n_path,
path_data.tags.len() as u32,
);
commands.write_timestamp(1);
pass.end();
if i == 0 || config.verify_all {
commands.cmd_buf.memory_barrier();
commands.download(&memory);

View file

@ -85,9 +85,7 @@ pub unsafe fn run_prefix_test(
let mut total_elapsed = 0.0;
for i in 0..n_iter {
let mut commands = runner.commands();
commands.write_timestamp(0);
stage.record(&mut commands, &code, &binding);
commands.write_timestamp(1);
if i == 0 || config.verify_all {
commands.cmd_buf.memory_barrier();
commands.download(&out_buf);
@ -159,12 +157,14 @@ impl PrefixStage {
let n_workgroups = (self.n_elements + ELEMENTS_PER_WG - 1) / ELEMENTS_PER_WG;
commands.cmd_buf.clear_buffer(&self.state_buf, None);
commands.cmd_buf.memory_barrier();
commands.cmd_buf.dispatch(
let mut pass = commands.compute_pass(0, 1);
pass.dispatch(
&code.pipeline,
&bindings.descriptor_set,
(n_workgroups as u32, 1, 1),
(WG_SIZE as u32, 1, 1),
);
pass.end();
// One thing that's missing here is registering the buffers so
// they can be safely dropped by Rust code before the execution
// of the command buffer completes.

View file

@ -66,9 +66,7 @@ pub unsafe fn run_prefix_test(runner: &mut Runner, config: &Config) -> TestResul
let mut commands = runner.commands();
commands.cmd_buf.copy_buffer(&data_buf, &out_buf.dev_buf);
commands.cmd_buf.memory_barrier();
commands.write_timestamp(0);
stage.record(&mut commands, &code, &binding);
commands.write_timestamp(1);
if i == 0 || config.verify_all {
commands.cmd_buf.memory_barrier();
commands.download(&out_buf);
@ -175,33 +173,35 @@ impl PrefixTreeStage {
code: &PrefixTreeCode,
bindings: &PrefixTreeBinding,
) {
let mut pass = commands.compute_pass(0, 1);
let n = self.tmp_bufs.len();
for i in 0..n {
let n_workgroups = self.sizes[i + 1];
commands.cmd_buf.dispatch(
pass.dispatch(
&code.reduce_pipeline,
&bindings.descriptor_sets[i],
(n_workgroups as u32, 1, 1),
(WG_SIZE as u32, 1, 1),
);
commands.cmd_buf.memory_barrier();
pass.memory_barrier();
}
commands.cmd_buf.dispatch(
pass.dispatch(
&code.root_pipeline,
&bindings.descriptor_sets[n],
(1, 1, 1),
(WG_SIZE as u32, 1, 1),
);
for i in (0..n).rev() {
commands.cmd_buf.memory_barrier();
pass.memory_barrier();
let n_workgroups = self.sizes[i + 1];
commands.cmd_buf.dispatch(
pass.dispatch(
&code.scan_pipeline,
&bindings.descriptor_sets[2 * n - i],
(n_workgroups as u32, 1, 1),
(WG_SIZE as u32, 1, 1),
);
}
pass.end();
}
}

View file

@ -61,9 +61,9 @@ pub unsafe fn transform_test(runner: &mut Runner, config: &Config) -> TestResult
let n_iter = config.n_iter;
for i in 0..n_iter {
let mut commands = runner.commands();
commands.write_timestamp(0);
stage.record(&mut commands.cmd_buf, &code, &binding, n_elements);
commands.write_timestamp(1);
let mut pass = commands.compute_pass(0, 1);
stage.record(&mut pass, &code, &binding, n_elements);
pass.end();
if i == 0 || config.verify_all {
commands.cmd_buf.memory_barrier();
commands.download(&memory);