Rework of compute encoder abstraction

The current plan is to more or less follow the wgpu/wgpu-hal approach. In the mux/backend layer (which corresponds fairly strongly to wgpu-hal), there isn't explicit construction of a compute encoder, but there are new methods for beginning and ending a compute pass. At the hub layer (which corresponds to wgpu) there will be a ComputeEncoder object.

That said, there will be some differences. The WebGPU "end" method on a compute encoder is implemented in wgpu as Drop, and that is not ideal. Also, the wgpu-hal approach to timer queries (still based on write_timestamp) is not up to the task of Metal timer queries, where the query offsets have to be specified at compute encoder creation. That's why there are different projects :)

WIP: current state is that stage-style queries work on Apple Silicon, but non-Metal backends are broken, and piet-gpu is not yet updated to use new API.
This commit is contained in:
Raph Levien 2022-04-13 10:31:38 -07:00
parent 290d5d2e13
commit ba2b27cc3c
7 changed files with 196 additions and 111 deletions

View file

@ -1,4 +1,4 @@
use piet_gpu_hal::{include_shader, BindType};
use piet_gpu_hal::{include_shader, BindType, ComputePassDescriptor};
use piet_gpu_hal::{BufferUsage, Instance, InstanceFlags, Session};
fn main() {
@ -20,9 +20,9 @@ fn main() {
let mut cmd_buf = session.cmd_buf().unwrap();
cmd_buf.begin();
cmd_buf.reset_query_pool(&query_pool);
cmd_buf.write_timestamp(&query_pool, 0);
cmd_buf.dispatch(&pipeline, &descriptor_set, (256, 1, 1), (1, 1, 1));
cmd_buf.write_timestamp(&query_pool, 1);
let mut pass = cmd_buf.begin_compute_pass(&ComputePassDescriptor::timer(&query_pool, 0, 1));
pass.dispatch(&pipeline, &descriptor_set, (256, 1, 1), (1, 1, 1));
pass.end();
cmd_buf.finish_timestamps(&query_pool);
cmd_buf.host_barrier();
cmd_buf.finish();

View file

@ -17,7 +17,8 @@
//! The generic trait for backends to implement.
use crate::{
BindType, BufferUsage, Error, GpuInfo, ImageFormat, ImageLayout, MapMode, SamplerParams,
BindType, BufferUsage, ComputePassDescriptor, Error, GpuInfo, ImageFormat, ImageLayout,
MapMode, SamplerParams,
};
pub trait Device: Sized {
@ -159,16 +160,32 @@ pub trait Device: Sized {
unsafe fn create_sampler(&self, params: SamplerParams) -> Result<Self::Sampler, Error>;
}
/// The trait implemented by backend command buffer implementations.
///
/// Valid encoding is represented by a state machine (currently not validated
/// but it is easy to imagine there might be at least debug validation). Most
/// methods are only valid in a particular state, and some move it to another
/// state.
pub trait CmdBuf<D: Device> {
type ComputeEncoder;
/// Begin encoding.
///
/// State: init -> ready
unsafe fn begin(&mut self);
/// State: ready -> finished
unsafe fn finish(&mut self);
/// Return true if the command buffer is suitable for reuse.
unsafe fn reset(&mut self) -> bool;
/// Begin a compute pass.
///
/// State: ready -> in_compute_pass
unsafe fn begin_compute_pass(&mut self, desc: &ComputePassDescriptor);
/// Dispatch
///
/// State: in_compute_pass
unsafe fn dispatch(
&mut self,
pipeline: &D::Pipeline,
@ -177,6 +194,9 @@ pub trait CmdBuf<D: Device> {
workgroup_size: (u32, u32, u32),
);
/// State: in_compute_pass -> ready
unsafe fn end_compute_pass(&mut self);
/// Insert an execution and memory barrier.
///
/// Compute kernels (and other actions) after this barrier may read from buffers
@ -229,12 +249,10 @@ pub trait CmdBuf<D: Device> {
unsafe fn finish_timestamps(&mut self, _pool: &D::QueryPool) {}
/// Begin a labeled section for debugging and profiling purposes.
unsafe fn begin_debug_label(&mut self, label: &str) {}
unsafe fn begin_debug_label(&mut self, _label: &str) {}
/// End a section opened by `begin_debug_label`.
unsafe fn end_debug_label(&mut self) {}
unsafe fn new_compute_encoder(&mut self) -> Self::ComputeEncoder;
}
/// A builder for descriptor sets with more complex layouts.
@ -256,16 +274,3 @@ pub trait DescriptorSetBuilder<D: Device> {
fn add_textures(&mut self, images: &[&D::Image]);
unsafe fn build(self, device: &D, pipeline: &D::Pipeline) -> Result<D::DescriptorSet, Error>;
}
pub trait ComputeEncoder<D: Device> {
unsafe fn dispatch(
&mut self,
pipeline: &D::Pipeline,
descriptor_set: &D::DescriptorSet,
workgroup_count: (u32, u32, u32),
workgroup_size: (u32, u32, u32),
);
// Question: should be self?
unsafe fn finish(&mut self);
}

View file

@ -13,7 +13,7 @@ use std::sync::{Arc, Mutex, Weak};
use bytemuck::Pod;
use smallvec::SmallVec;
use crate::{mux, BackendType, BufWrite, ImageFormat, MapMode};
use crate::{mux, BackendType, BufWrite, ComputePassDescriptor, ImageFormat, MapMode};
use crate::{BindType, BufferUsage, Error, GpuInfo, ImageLayout, SamplerParams};
@ -135,6 +135,11 @@ pub struct BufReadGuard<'a> {
size: u64,
}
/// A sub-object of a command buffer for a sequence of compute dispatches.
pub struct ComputePass<'a> {
cmd_buf: &'a mut CmdBuf,
}
impl Session {
/// Create a new session, choosing the best backend.
pub fn new(device: mux::Device) -> Session {
@ -471,6 +476,12 @@ impl CmdBuf {
self.cmd_buf().finish();
}
/// Begin a compute pass.
pub unsafe fn begin_compute_pass(&mut self, desc: &ComputePassDescriptor) -> ComputePass {
self.cmd_buf().begin_compute_pass(desc);
ComputePass { cmd_buf: self }
}
/// Dispatch a compute shader.
///
/// Request a compute shader to be run, using the pipeline to specify the
@ -479,6 +490,11 @@ impl CmdBuf {
/// Both the workgroup count (number of workgroups) and the workgroup size
/// (number of threads in a workgroup) must be specified here, though not
/// all back-ends require the latter info.
///
/// This version is deprecated because (a) you do not get timer queries and
/// (b) it doesn't aggregate multiple dispatches into a single compute
/// pass, which is a performance concern.
#[deprecated(note = "moving to ComputePass")]
pub unsafe fn dispatch(
&mut self,
pipeline: &Pipeline,
@ -486,8 +502,9 @@ impl CmdBuf {
workgroup_count: (u32, u32, u32),
workgroup_size: (u32, u32, u32),
) {
self.cmd_buf()
.dispatch(pipeline, descriptor_set, workgroup_count, workgroup_size);
let mut pass = self.begin_compute_pass(&Default::default());
pass.dispatch(pipeline, descriptor_set, workgroup_count, workgroup_size);
pass.end();
}
/// Insert an execution and memory barrier.
@ -692,6 +709,32 @@ impl Drop for SubmittedCmdBuf {
}
}
impl<'a> ComputePass<'a> {
/// Dispatch a compute shader.
///
/// Request a compute shader to be run, using the pipeline to specify the
/// code, and the descriptor set to address the resources read and written.
///
/// Both the workgroup count (number of workgroups) and the workgroup size
/// (number of threads in a workgroup) must be specified here, though not
/// all back-ends require the latter info.
pub unsafe fn dispatch(
&mut self,
pipeline: &Pipeline,
descriptor_set: &DescriptorSet,
workgroup_count: (u32, u32, u32),
workgroup_size: (u32, u32, u32),
) {
self.cmd_buf
.cmd_buf()
.dispatch(pipeline, descriptor_set, workgroup_count, workgroup_size);
}
pub unsafe fn end(&mut self) {
self.cmd_buf.cmd_buf().end_compute_pass();
}
}
impl Drop for BufferInner {
fn drop(&mut self) {
if let Some(session) = Weak::upgrade(&self.session) {

View file

@ -189,3 +189,17 @@ pub struct WorkgroupLimits {
/// dimension.
pub max_invocations: u32,
}
#[derive(Default)]
pub struct ComputePassDescriptor<'a> {
// Maybe label should go here? It does in wgpu and wgpu_hal.
timer_queries: Option<(&'a QueryPool, u32, u32)>,
}
impl<'a> ComputePassDescriptor<'a> {
pub fn timer(pool: &'a QueryPool, start_query: u32, end_query: u32) -> ComputePassDescriptor {
ComputePassDescriptor {
timer_queries: Some((pool, start_query, end_query)),
}
}
}

View file

@ -33,11 +33,13 @@ use metal::{CGFloat, CommandBufferRef, MTLFeatureSet};
use raw_window_handle::{HasRawWindowHandle, RawWindowHandle};
use crate::{BufferUsage, Error, GpuInfo, ImageFormat, MapMode, WorkgroupLimits};
use crate::{
BufferUsage, ComputePassDescriptor, Error, GpuInfo, ImageFormat, MapMode, WorkgroupLimits,
};
use util::*;
use self::timer::{CounterSampleBuffer, CounterSet};
use self::timer::{CounterSampleBuffer, CounterSet, TimeCalibration};
pub struct MtlInstance;
@ -110,15 +112,11 @@ enum Encoder {
}
#[derive(Default)]
struct TimeCalibration {
cpu_start_ts: u64,
gpu_start_ts: u64,
cpu_end_ts: u64,
gpu_end_ts: u64,
pub struct QueryPool {
counter_sample_buf: Option<CounterSampleBuffer>,
calibration: Arc<Mutex<Option<Arc<Mutex<TimeCalibration>>>>>,
}
pub struct QueryPool(Option<CounterSampleBuffer>);
pub struct Pipeline(metal::ComputePipelineState);
#[derive(Default)]
@ -134,10 +132,6 @@ struct Helpers {
clear_pipeline: metal::ComputePipelineState,
}
pub struct ComputeEncoder {
raw: metal::ComputeCommandEncoder,
}
impl MtlInstance {
pub fn new(
window_handle: Option<&dyn HasRawWindowHandle>,
@ -263,7 +257,7 @@ impl MtlDevice {
helpers,
timer_set,
counter_style,
}
}
}
pub fn cmd_buf_from_raw_mtl(&self, raw_cmd_buf: metal::CommandBuffer) -> CmdBuf {
@ -409,16 +403,28 @@ impl crate::backend::Device for MtlDevice {
if let Some(timer_set) = &self.timer_set {
let pool = CounterSampleBuffer::new(&self.device, n_queries as u64, timer_set)
.ok_or("error creating timer query pool")?;
return Ok(QueryPool(Some(pool)));
return Ok(QueryPool {
counter_sample_buf: Some(pool),
calibration: Default::default(),
});
}
Ok(QueryPool(None))
Ok(QueryPool::default())
}
unsafe fn fetch_query_pool(&self, pool: &Self::QueryPool) -> Result<Vec<f64>, Error> {
if let Some(raw) = &pool.0 {
if let Some(raw) = &pool.counter_sample_buf {
let resolved = raw.resolve();
println!("resolved = {:?}", resolved);
let calibration = pool.calibration.lock().unwrap();
if let Some(calibration) = &*calibration {
let calibration = calibration.lock().unwrap();
let result = resolved
.iter()
.map(|time_ns| calibration.correlate(*time_ns))
.collect();
return Ok(result);
}
}
// Maybe should return None indicating it wasn't successful? But that might break.
Ok(Vec::new())
}
@ -444,10 +450,6 @@ impl crate::backend::Device for MtlDevice {
let gpu_ts_ptr = &mut time_calibration.gpu_start_ts as *mut _;
// TODO: only do this if supported.
let () = msg_send![device, sampleTimestamps: cpu_ts_ptr gpuTimestamp: gpu_ts_ptr];
println!(
"scheduled, {}, {}",
time_calibration.cpu_start_ts, time_calibration.gpu_start_ts
);
})
.copy();
add_scheduled_handler(&cmd_buf.cmd_buf, &start_block);
@ -461,10 +463,6 @@ impl crate::backend::Device for MtlDevice {
// TODO: only do this if supported.
let () =
msg_send![device, sampleTimestamps: cpu_ts_ptr gpuTimestamp: gpu_ts_ptr];
println!(
"completed, {}, {}",
time_calibration.cpu_end_ts, time_calibration.gpu_end_ts
);
})
.copy();
cmd_buf.cmd_buf.add_completed_handler(&completed_block);
@ -546,8 +544,6 @@ impl crate::backend::Device for MtlDevice {
}
impl crate::backend::CmdBuf<MtlDevice> for CmdBuf {
type ComputeEncoder = ComputeEncoder;
unsafe fn begin(&mut self) {}
unsafe fn finish(&mut self) {
@ -558,6 +554,35 @@ impl crate::backend::CmdBuf<MtlDevice> for CmdBuf {
false
}
unsafe fn begin_compute_pass(&mut self, desc: &ComputePassDescriptor) {
debug_assert!(matches!(self.cur_encoder, Encoder::None));
let encoder = if let Some(queries) = &desc.timer_queries {
let descriptor: id = msg_send![class!(MTLComputePassDescriptor), computePassDescriptor];
let attachments: id = msg_send![descriptor, sampleBufferAttachments];
let index: NSUInteger = 0;
let attachment: id = msg_send![attachments, objectAtIndexedSubscript: index];
// Here we break the hub/mux separation a bit, for expedience
#[allow(irrefutable_let_patterns)]
if let crate::hub::QueryPool::Mtl(query_pool) = queries.0 {
if let Some(sample_buf) = &query_pool.counter_sample_buf {
let () = msg_send![attachment, setSampleBuffer: sample_buf.id()];
}
}
let start_index = queries.1 as NSUInteger;
let end_index = queries.2 as NSInteger;
let () = msg_send![attachment, setStartOfEncoderSampleIndex: start_index];
let () = msg_send![attachment, setEndOfEncoderSampleIndex: end_index];
let encoder = msg_send![
self.cmd_buf,
computeCommandEncoderWithDescriptor: descriptor
];
encoder
} else {
self.cmd_buf.new_compute_command_encoder()
};
self.cur_encoder = Encoder::Compute(encoder.to_owned());
}
unsafe fn dispatch(
&mut self,
pipeline: &Pipeline,
@ -590,6 +615,11 @@ impl crate::backend::CmdBuf<MtlDevice> for CmdBuf {
encoder.dispatch_thread_groups(workgroup_count, workgroup_size);
}
unsafe fn end_compute_pass(&mut self) {
// TODO: might validate that we are in a compute encoder state
self.flush_encoder();
}
unsafe fn memory_barrier(&mut self) {
// We'll probably move to explicit barriers, but for now rely on
// Metal's own tracking.
@ -690,10 +720,13 @@ impl crate::backend::CmdBuf<MtlDevice> for CmdBuf {
);
}
unsafe fn reset_query_pool(&mut self, _pool: &QueryPool) {}
unsafe fn reset_query_pool(&mut self, pool: &QueryPool) {
let mut calibration = pool.calibration.lock().unwrap();
*calibration = Some(self.time_calibration.clone());
}
unsafe fn write_timestamp(&mut self, pool: &QueryPool, query: u32) {
if let Some(buf) = &pool.0 {
if let Some(buf) = &pool.counter_sample_buf {
if matches!(self.cur_encoder, Encoder::None) {
self.cur_encoder =
Encoder::Compute(self.cmd_buf.new_compute_command_encoder().to_owned());
@ -709,21 +742,14 @@ impl crate::backend::CmdBuf<MtlDevice> for CmdBuf {
}
} else if self.counter_style == CounterStyle::Stage {
match &self.cur_encoder {
Encoder::Compute(e) => {
println!("here we are");
Encoder::Compute(_e) => {
println!("write_timestamp is not supported for stage-style encoders");
}
_ => (),
}
}
}
}
unsafe fn new_compute_encoder(&mut self) -> Self::ComputeEncoder {
let raw = self.cmd_buf.new_compute_command_encoder().to_owned();
ComputeEncoder {
raw
}
}
}
impl CmdBuf {
@ -761,43 +787,6 @@ impl CmdBuf {
}
}
impl crate::backend::ComputeEncoder<MtlDevice> for ComputeEncoder {
unsafe fn dispatch(
&mut self,
pipeline: &Pipeline,
descriptor_set: &DescriptorSet,
workgroup_count: (u32, u32, u32),
workgroup_size: (u32, u32, u32),
) {
self.raw.set_compute_pipeline_state(&pipeline.0);
let mut buf_ix = 0;
for buffer in &descriptor_set.buffers {
self.raw.set_buffer(buf_ix, Some(&buffer.buffer), 0);
buf_ix += 1;
}
let mut img_ix = buf_ix;
for image in &descriptor_set.images {
self.raw.set_texture(img_ix, Some(&image.texture));
img_ix += 1;
}
let workgroup_count = metal::MTLSize {
width: workgroup_count.0 as u64,
height: workgroup_count.1 as u64,
depth: workgroup_count.2 as u64,
};
let workgroup_size = metal::MTLSize {
width: workgroup_size.0 as u64,
height: workgroup_size.1 as u64,
depth: workgroup_size.2 as u64,
};
self.raw.dispatch_thread_groups(workgroup_count, workgroup_size);
}
unsafe fn finish(&mut self) {
self.raw.end_encoding();
}
}
impl crate::backend::DescriptorSetBuilder<MtlDevice> for DescriptorSetBuilder {
fn add_buffers(&mut self, buffers: &[&Buffer]) {
self.0.buffers.extend(buffers.iter().copied().cloned());

View file

@ -36,6 +36,14 @@ pub struct CounterSet {
id: id,
}
#[derive(Default)]
pub struct TimeCalibration {
pub cpu_start_ts: u64,
pub gpu_start_ts: u64,
pub cpu_end_ts: u64,
pub gpu_end_ts: u64,
}
impl Drop for CounterSampleBuffer {
fn drop(&mut self) {
unsafe { msg_send![self.id, release] }
@ -87,7 +95,6 @@ impl CounterSampleBuffer {
unsafe {
let desc_cls = class!(MTLCounterSampleBufferDescriptor);
let descriptor: id = msg_send![desc_cls, alloc];
println!("descriptor = {:?}", descriptor);
let _: id = msg_send![descriptor, init];
let count = count as NSUInteger;
let () = msg_send![descriptor, setSampleCount: count];
@ -121,3 +128,21 @@ impl CounterSampleBuffer {
}
}
}
impl TimeCalibration {
/// Convert GPU timestamp into CPU time base.
///
/// See https://developer.apple.com/documentation/metal/performance_tuning/correlating_cpu_and_gpu_timestamps
pub fn correlate(&self, raw_ts: u64) -> f64 {
let delta_cpu = self.cpu_end_ts - self.cpu_start_ts;
let delta_gpu = self.gpu_end_ts - self.gpu_start_ts;
let adj_ts = if delta_gpu > 0 {
let scale = delta_cpu as f64 / delta_gpu as f64;
self.cpu_start_ts as f64 + (raw_ts - self.gpu_start_ts) as f64 * scale
} else {
// Default is ns on Apple Silicon; on other hardware this will be wrong
raw_ts as f64
};
adj_ts * 1e-9
}
}

View file

@ -35,6 +35,7 @@ use crate::backend::DescriptorSetBuilder as DescriptorSetBuilderTrait;
use crate::backend::Device as DeviceTrait;
use crate::BackendType;
use crate::BindType;
use crate::ComputePassDescriptor;
use crate::ImageFormat;
use crate::MapMode;
use crate::{BufferUsage, Error, GpuInfo, ImageLayout, InstanceFlags};
@ -100,14 +101,6 @@ mux_device_enum! {
QueryPool }
mux_device_enum! { Sampler }
mux_enum! {
pub enum ComputeEncoder {
Vk(<crate::vulkan::CmdBuf as crate::backend::CmdBuf<vulkan::VkDevice>>::ComputeEncoder),
Dx12(<crate::dx12::Dx12Device as crate::backend::CmdBuf<dx12::Dx12Device>>::ComputeEncoder),
Mtl(<crate::metal::CmdBuf as crate::backend::CmdBuf<metal::MtlDevice>>::ComputeEncoder),
}
}
/// The code for a shader, either as source or intermediate representation.
pub enum ShaderCode<'a> {
/// SPIR-V (binary intermediate representation)
@ -666,6 +659,14 @@ impl CmdBuf {
}
}
pub unsafe fn begin_compute_pass(&mut self, desc: &ComputePassDescriptor) {
mux_match! { self;
CmdBuf::Vk(c) => c.begin_compute_pass(desc),
CmdBuf::Dx12(c) => c.begin_compute_pass(desc),
CmdBuf::Mtl(c) => c.begin_compute_pass(desc),
}
}
/// Dispatch a compute shader.
///
/// Note that both the number of workgroups (`workgroup_count`) and the number of
@ -688,6 +689,14 @@ impl CmdBuf {
}
}
pub unsafe fn end_compute_pass(&mut self) {
mux_match! { self;
CmdBuf::Vk(c) => c.end_compute_pass(),
CmdBuf::Dx12(c) => c.end_compute_pass(),
CmdBuf::Mtl(c) => c.end_compute_pass(),
}
}
pub unsafe fn memory_barrier(&mut self) {
mux_match! { self;
CmdBuf::Vk(c) => c.memory_barrier(),