diff --git a/piet-gpu-hal/examples/collatz.rs b/piet-gpu-hal/examples/collatz.rs index dae5b31..7aff938 100644 --- a/piet-gpu-hal/examples/collatz.rs +++ b/piet-gpu-hal/examples/collatz.rs @@ -1,4 +1,4 @@ -use piet_gpu_hal::{include_shader, BindType}; +use piet_gpu_hal::{include_shader, BindType, ComputePassDescriptor}; use piet_gpu_hal::{BufferUsage, Instance, InstanceFlags, Session}; fn main() { @@ -20,9 +20,9 @@ fn main() { let mut cmd_buf = session.cmd_buf().unwrap(); cmd_buf.begin(); cmd_buf.reset_query_pool(&query_pool); - cmd_buf.write_timestamp(&query_pool, 0); - cmd_buf.dispatch(&pipeline, &descriptor_set, (256, 1, 1), (1, 1, 1)); - cmd_buf.write_timestamp(&query_pool, 1); + let mut pass = cmd_buf.begin_compute_pass(&ComputePassDescriptor::timer(&query_pool, 0, 1)); + pass.dispatch(&pipeline, &descriptor_set, (256, 1, 1), (1, 1, 1)); + pass.end(); cmd_buf.finish_timestamps(&query_pool); cmd_buf.host_barrier(); cmd_buf.finish(); diff --git a/piet-gpu-hal/src/backend.rs b/piet-gpu-hal/src/backend.rs index c1b2132..f2c67a1 100644 --- a/piet-gpu-hal/src/backend.rs +++ b/piet-gpu-hal/src/backend.rs @@ -17,7 +17,8 @@ //! The generic trait for backends to implement. use crate::{ - BindType, BufferUsage, Error, GpuInfo, ImageFormat, ImageLayout, MapMode, SamplerParams, + BindType, BufferUsage, ComputePassDescriptor, Error, GpuInfo, ImageFormat, ImageLayout, + MapMode, SamplerParams, }; pub trait Device: Sized { @@ -159,16 +160,32 @@ pub trait Device: Sized { unsafe fn create_sampler(&self, params: SamplerParams) -> Result; } +/// The trait implemented by backend command buffer implementations. +/// +/// Valid encoding is represented by a state machine (currently not validated +/// but it is easy to imagine there might be at least debug validation). Most +/// methods are only valid in a particular state, and some move it to another +/// state. pub trait CmdBuf { - type ComputeEncoder; - + /// Begin encoding. + /// + /// State: init -> ready unsafe fn begin(&mut self); + /// State: ready -> finished unsafe fn finish(&mut self); /// Return true if the command buffer is suitable for reuse. unsafe fn reset(&mut self) -> bool; + /// Begin a compute pass. + /// + /// State: ready -> in_compute_pass + unsafe fn begin_compute_pass(&mut self, desc: &ComputePassDescriptor); + + /// Dispatch + /// + /// State: in_compute_pass unsafe fn dispatch( &mut self, pipeline: &D::Pipeline, @@ -177,6 +194,9 @@ pub trait CmdBuf { workgroup_size: (u32, u32, u32), ); + /// State: in_compute_pass -> ready + unsafe fn end_compute_pass(&mut self); + /// Insert an execution and memory barrier. /// /// Compute kernels (and other actions) after this barrier may read from buffers @@ -229,12 +249,10 @@ pub trait CmdBuf { unsafe fn finish_timestamps(&mut self, _pool: &D::QueryPool) {} /// Begin a labeled section for debugging and profiling purposes. - unsafe fn begin_debug_label(&mut self, label: &str) {} + unsafe fn begin_debug_label(&mut self, _label: &str) {} /// End a section opened by `begin_debug_label`. unsafe fn end_debug_label(&mut self) {} - - unsafe fn new_compute_encoder(&mut self) -> Self::ComputeEncoder; } /// A builder for descriptor sets with more complex layouts. @@ -256,16 +274,3 @@ pub trait DescriptorSetBuilder { fn add_textures(&mut self, images: &[&D::Image]); unsafe fn build(self, device: &D, pipeline: &D::Pipeline) -> Result; } - -pub trait ComputeEncoder { - unsafe fn dispatch( - &mut self, - pipeline: &D::Pipeline, - descriptor_set: &D::DescriptorSet, - workgroup_count: (u32, u32, u32), - workgroup_size: (u32, u32, u32), - ); - - // Question: should be self? - unsafe fn finish(&mut self); -} diff --git a/piet-gpu-hal/src/hub.rs b/piet-gpu-hal/src/hub.rs index cc09832..37c59df 100644 --- a/piet-gpu-hal/src/hub.rs +++ b/piet-gpu-hal/src/hub.rs @@ -13,7 +13,7 @@ use std::sync::{Arc, Mutex, Weak}; use bytemuck::Pod; use smallvec::SmallVec; -use crate::{mux, BackendType, BufWrite, ImageFormat, MapMode}; +use crate::{mux, BackendType, BufWrite, ComputePassDescriptor, ImageFormat, MapMode}; use crate::{BindType, BufferUsage, Error, GpuInfo, ImageLayout, SamplerParams}; @@ -135,6 +135,11 @@ pub struct BufReadGuard<'a> { size: u64, } +/// A sub-object of a command buffer for a sequence of compute dispatches. +pub struct ComputePass<'a> { + cmd_buf: &'a mut CmdBuf, +} + impl Session { /// Create a new session, choosing the best backend. pub fn new(device: mux::Device) -> Session { @@ -471,6 +476,12 @@ impl CmdBuf { self.cmd_buf().finish(); } + /// Begin a compute pass. + pub unsafe fn begin_compute_pass(&mut self, desc: &ComputePassDescriptor) -> ComputePass { + self.cmd_buf().begin_compute_pass(desc); + ComputePass { cmd_buf: self } + } + /// Dispatch a compute shader. /// /// Request a compute shader to be run, using the pipeline to specify the @@ -479,6 +490,11 @@ impl CmdBuf { /// Both the workgroup count (number of workgroups) and the workgroup size /// (number of threads in a workgroup) must be specified here, though not /// all back-ends require the latter info. + /// + /// This version is deprecated because (a) you do not get timer queries and + /// (b) it doesn't aggregate multiple dispatches into a single compute + /// pass, which is a performance concern. + #[deprecated(note = "moving to ComputePass")] pub unsafe fn dispatch( &mut self, pipeline: &Pipeline, @@ -486,8 +502,9 @@ impl CmdBuf { workgroup_count: (u32, u32, u32), workgroup_size: (u32, u32, u32), ) { - self.cmd_buf() - .dispatch(pipeline, descriptor_set, workgroup_count, workgroup_size); + let mut pass = self.begin_compute_pass(&Default::default()); + pass.dispatch(pipeline, descriptor_set, workgroup_count, workgroup_size); + pass.end(); } /// Insert an execution and memory barrier. @@ -692,6 +709,32 @@ impl Drop for SubmittedCmdBuf { } } +impl<'a> ComputePass<'a> { + /// Dispatch a compute shader. + /// + /// Request a compute shader to be run, using the pipeline to specify the + /// code, and the descriptor set to address the resources read and written. + /// + /// Both the workgroup count (number of workgroups) and the workgroup size + /// (number of threads in a workgroup) must be specified here, though not + /// all back-ends require the latter info. + pub unsafe fn dispatch( + &mut self, + pipeline: &Pipeline, + descriptor_set: &DescriptorSet, + workgroup_count: (u32, u32, u32), + workgroup_size: (u32, u32, u32), + ) { + self.cmd_buf + .cmd_buf() + .dispatch(pipeline, descriptor_set, workgroup_count, workgroup_size); + } + + pub unsafe fn end(&mut self) { + self.cmd_buf.cmd_buf().end_compute_pass(); + } +} + impl Drop for BufferInner { fn drop(&mut self) { if let Some(session) = Weak::upgrade(&self.session) { diff --git a/piet-gpu-hal/src/lib.rs b/piet-gpu-hal/src/lib.rs index fab7d65..241cdfd 100644 --- a/piet-gpu-hal/src/lib.rs +++ b/piet-gpu-hal/src/lib.rs @@ -189,3 +189,17 @@ pub struct WorkgroupLimits { /// dimension. pub max_invocations: u32, } + +#[derive(Default)] +pub struct ComputePassDescriptor<'a> { + // Maybe label should go here? It does in wgpu and wgpu_hal. + timer_queries: Option<(&'a QueryPool, u32, u32)>, +} + +impl<'a> ComputePassDescriptor<'a> { + pub fn timer(pool: &'a QueryPool, start_query: u32, end_query: u32) -> ComputePassDescriptor { + ComputePassDescriptor { + timer_queries: Some((pool, start_query, end_query)), + } + } +} diff --git a/piet-gpu-hal/src/metal.rs b/piet-gpu-hal/src/metal.rs index 23cc256..c907d77 100644 --- a/piet-gpu-hal/src/metal.rs +++ b/piet-gpu-hal/src/metal.rs @@ -33,11 +33,13 @@ use metal::{CGFloat, CommandBufferRef, MTLFeatureSet}; use raw_window_handle::{HasRawWindowHandle, RawWindowHandle}; -use crate::{BufferUsage, Error, GpuInfo, ImageFormat, MapMode, WorkgroupLimits}; +use crate::{ + BufferUsage, ComputePassDescriptor, Error, GpuInfo, ImageFormat, MapMode, WorkgroupLimits, +}; use util::*; -use self::timer::{CounterSampleBuffer, CounterSet}; +use self::timer::{CounterSampleBuffer, CounterSet, TimeCalibration}; pub struct MtlInstance; @@ -110,15 +112,11 @@ enum Encoder { } #[derive(Default)] -struct TimeCalibration { - cpu_start_ts: u64, - gpu_start_ts: u64, - cpu_end_ts: u64, - gpu_end_ts: u64, +pub struct QueryPool { + counter_sample_buf: Option, + calibration: Arc>>>>, } -pub struct QueryPool(Option); - pub struct Pipeline(metal::ComputePipelineState); #[derive(Default)] @@ -134,10 +132,6 @@ struct Helpers { clear_pipeline: metal::ComputePipelineState, } -pub struct ComputeEncoder { - raw: metal::ComputeCommandEncoder, -} - impl MtlInstance { pub fn new( window_handle: Option<&dyn HasRawWindowHandle>, @@ -263,7 +257,7 @@ impl MtlDevice { helpers, timer_set, counter_style, - } + } } pub fn cmd_buf_from_raw_mtl(&self, raw_cmd_buf: metal::CommandBuffer) -> CmdBuf { @@ -409,16 +403,28 @@ impl crate::backend::Device for MtlDevice { if let Some(timer_set) = &self.timer_set { let pool = CounterSampleBuffer::new(&self.device, n_queries as u64, timer_set) .ok_or("error creating timer query pool")?; - return Ok(QueryPool(Some(pool))); + return Ok(QueryPool { + counter_sample_buf: Some(pool), + calibration: Default::default(), + }); } - Ok(QueryPool(None)) + Ok(QueryPool::default()) } unsafe fn fetch_query_pool(&self, pool: &Self::QueryPool) -> Result, Error> { - if let Some(raw) = &pool.0 { + if let Some(raw) = &pool.counter_sample_buf { let resolved = raw.resolve(); - println!("resolved = {:?}", resolved); + let calibration = pool.calibration.lock().unwrap(); + if let Some(calibration) = &*calibration { + let calibration = calibration.lock().unwrap(); + let result = resolved + .iter() + .map(|time_ns| calibration.correlate(*time_ns)) + .collect(); + return Ok(result); + } } + // Maybe should return None indicating it wasn't successful? But that might break. Ok(Vec::new()) } @@ -444,10 +450,6 @@ impl crate::backend::Device for MtlDevice { let gpu_ts_ptr = &mut time_calibration.gpu_start_ts as *mut _; // TODO: only do this if supported. let () = msg_send![device, sampleTimestamps: cpu_ts_ptr gpuTimestamp: gpu_ts_ptr]; - println!( - "scheduled, {}, {}", - time_calibration.cpu_start_ts, time_calibration.gpu_start_ts - ); }) .copy(); add_scheduled_handler(&cmd_buf.cmd_buf, &start_block); @@ -461,10 +463,6 @@ impl crate::backend::Device for MtlDevice { // TODO: only do this if supported. let () = msg_send![device, sampleTimestamps: cpu_ts_ptr gpuTimestamp: gpu_ts_ptr]; - println!( - "completed, {}, {}", - time_calibration.cpu_end_ts, time_calibration.gpu_end_ts - ); }) .copy(); cmd_buf.cmd_buf.add_completed_handler(&completed_block); @@ -546,8 +544,6 @@ impl crate::backend::Device for MtlDevice { } impl crate::backend::CmdBuf for CmdBuf { - type ComputeEncoder = ComputeEncoder; - unsafe fn begin(&mut self) {} unsafe fn finish(&mut self) { @@ -558,6 +554,35 @@ impl crate::backend::CmdBuf for CmdBuf { false } + unsafe fn begin_compute_pass(&mut self, desc: &ComputePassDescriptor) { + debug_assert!(matches!(self.cur_encoder, Encoder::None)); + let encoder = if let Some(queries) = &desc.timer_queries { + let descriptor: id = msg_send![class!(MTLComputePassDescriptor), computePassDescriptor]; + let attachments: id = msg_send![descriptor, sampleBufferAttachments]; + let index: NSUInteger = 0; + let attachment: id = msg_send![attachments, objectAtIndexedSubscript: index]; + // Here we break the hub/mux separation a bit, for expedience + #[allow(irrefutable_let_patterns)] + if let crate::hub::QueryPool::Mtl(query_pool) = queries.0 { + if let Some(sample_buf) = &query_pool.counter_sample_buf { + let () = msg_send![attachment, setSampleBuffer: sample_buf.id()]; + } + } + let start_index = queries.1 as NSUInteger; + let end_index = queries.2 as NSInteger; + let () = msg_send![attachment, setStartOfEncoderSampleIndex: start_index]; + let () = msg_send![attachment, setEndOfEncoderSampleIndex: end_index]; + let encoder = msg_send![ + self.cmd_buf, + computeCommandEncoderWithDescriptor: descriptor + ]; + encoder + } else { + self.cmd_buf.new_compute_command_encoder() + }; + self.cur_encoder = Encoder::Compute(encoder.to_owned()); + } + unsafe fn dispatch( &mut self, pipeline: &Pipeline, @@ -590,6 +615,11 @@ impl crate::backend::CmdBuf for CmdBuf { encoder.dispatch_thread_groups(workgroup_count, workgroup_size); } + unsafe fn end_compute_pass(&mut self) { + // TODO: might validate that we are in a compute encoder state + self.flush_encoder(); + } + unsafe fn memory_barrier(&mut self) { // We'll probably move to explicit barriers, but for now rely on // Metal's own tracking. @@ -690,10 +720,13 @@ impl crate::backend::CmdBuf for CmdBuf { ); } - unsafe fn reset_query_pool(&mut self, _pool: &QueryPool) {} + unsafe fn reset_query_pool(&mut self, pool: &QueryPool) { + let mut calibration = pool.calibration.lock().unwrap(); + *calibration = Some(self.time_calibration.clone()); + } unsafe fn write_timestamp(&mut self, pool: &QueryPool, query: u32) { - if let Some(buf) = &pool.0 { + if let Some(buf) = &pool.counter_sample_buf { if matches!(self.cur_encoder, Encoder::None) { self.cur_encoder = Encoder::Compute(self.cmd_buf.new_compute_command_encoder().to_owned()); @@ -709,21 +742,14 @@ impl crate::backend::CmdBuf for CmdBuf { } } else if self.counter_style == CounterStyle::Stage { match &self.cur_encoder { - Encoder::Compute(e) => { - println!("here we are"); + Encoder::Compute(_e) => { + println!("write_timestamp is not supported for stage-style encoders"); } _ => (), } } } } - - unsafe fn new_compute_encoder(&mut self) -> Self::ComputeEncoder { - let raw = self.cmd_buf.new_compute_command_encoder().to_owned(); - ComputeEncoder { - raw - } - } } impl CmdBuf { @@ -761,43 +787,6 @@ impl CmdBuf { } } -impl crate::backend::ComputeEncoder for ComputeEncoder { - unsafe fn dispatch( - &mut self, - pipeline: &Pipeline, - descriptor_set: &DescriptorSet, - workgroup_count: (u32, u32, u32), - workgroup_size: (u32, u32, u32), - ) { - self.raw.set_compute_pipeline_state(&pipeline.0); - let mut buf_ix = 0; - for buffer in &descriptor_set.buffers { - self.raw.set_buffer(buf_ix, Some(&buffer.buffer), 0); - buf_ix += 1; - } - let mut img_ix = buf_ix; - for image in &descriptor_set.images { - self.raw.set_texture(img_ix, Some(&image.texture)); - img_ix += 1; - } - let workgroup_count = metal::MTLSize { - width: workgroup_count.0 as u64, - height: workgroup_count.1 as u64, - depth: workgroup_count.2 as u64, - }; - let workgroup_size = metal::MTLSize { - width: workgroup_size.0 as u64, - height: workgroup_size.1 as u64, - depth: workgroup_size.2 as u64, - }; - self.raw.dispatch_thread_groups(workgroup_count, workgroup_size); - } - - unsafe fn finish(&mut self) { - self.raw.end_encoding(); - } -} - impl crate::backend::DescriptorSetBuilder for DescriptorSetBuilder { fn add_buffers(&mut self, buffers: &[&Buffer]) { self.0.buffers.extend(buffers.iter().copied().cloned()); diff --git a/piet-gpu-hal/src/metal/timer.rs b/piet-gpu-hal/src/metal/timer.rs index a51bc6d..a8b80d6 100644 --- a/piet-gpu-hal/src/metal/timer.rs +++ b/piet-gpu-hal/src/metal/timer.rs @@ -36,6 +36,14 @@ pub struct CounterSet { id: id, } +#[derive(Default)] +pub struct TimeCalibration { + pub cpu_start_ts: u64, + pub gpu_start_ts: u64, + pub cpu_end_ts: u64, + pub gpu_end_ts: u64, +} + impl Drop for CounterSampleBuffer { fn drop(&mut self) { unsafe { msg_send![self.id, release] } @@ -87,7 +95,6 @@ impl CounterSampleBuffer { unsafe { let desc_cls = class!(MTLCounterSampleBufferDescriptor); let descriptor: id = msg_send![desc_cls, alloc]; - println!("descriptor = {:?}", descriptor); let _: id = msg_send![descriptor, init]; let count = count as NSUInteger; let () = msg_send![descriptor, setSampleCount: count]; @@ -121,3 +128,21 @@ impl CounterSampleBuffer { } } } + +impl TimeCalibration { + /// Convert GPU timestamp into CPU time base. + /// + /// See https://developer.apple.com/documentation/metal/performance_tuning/correlating_cpu_and_gpu_timestamps + pub fn correlate(&self, raw_ts: u64) -> f64 { + let delta_cpu = self.cpu_end_ts - self.cpu_start_ts; + let delta_gpu = self.gpu_end_ts - self.gpu_start_ts; + let adj_ts = if delta_gpu > 0 { + let scale = delta_cpu as f64 / delta_gpu as f64; + self.cpu_start_ts as f64 + (raw_ts - self.gpu_start_ts) as f64 * scale + } else { + // Default is ns on Apple Silicon; on other hardware this will be wrong + raw_ts as f64 + }; + adj_ts * 1e-9 + } +} diff --git a/piet-gpu-hal/src/mux.rs b/piet-gpu-hal/src/mux.rs index 7853c2b..9795193 100644 --- a/piet-gpu-hal/src/mux.rs +++ b/piet-gpu-hal/src/mux.rs @@ -35,6 +35,7 @@ use crate::backend::DescriptorSetBuilder as DescriptorSetBuilderTrait; use crate::backend::Device as DeviceTrait; use crate::BackendType; use crate::BindType; +use crate::ComputePassDescriptor; use crate::ImageFormat; use crate::MapMode; use crate::{BufferUsage, Error, GpuInfo, ImageLayout, InstanceFlags}; @@ -100,14 +101,6 @@ mux_device_enum! { QueryPool } mux_device_enum! { Sampler } -mux_enum! { - pub enum ComputeEncoder { - Vk(>::ComputeEncoder), - Dx12(>::ComputeEncoder), - Mtl(>::ComputeEncoder), - } -} - /// The code for a shader, either as source or intermediate representation. pub enum ShaderCode<'a> { /// SPIR-V (binary intermediate representation) @@ -666,6 +659,14 @@ impl CmdBuf { } } + pub unsafe fn begin_compute_pass(&mut self, desc: &ComputePassDescriptor) { + mux_match! { self; + CmdBuf::Vk(c) => c.begin_compute_pass(desc), + CmdBuf::Dx12(c) => c.begin_compute_pass(desc), + CmdBuf::Mtl(c) => c.begin_compute_pass(desc), + } + } + /// Dispatch a compute shader. /// /// Note that both the number of workgroups (`workgroup_count`) and the number of @@ -688,6 +689,14 @@ impl CmdBuf { } } + pub unsafe fn end_compute_pass(&mut self) { + mux_match! { self; + CmdBuf::Vk(c) => c.end_compute_pass(), + CmdBuf::Dx12(c) => c.end_compute_pass(), + CmdBuf::Mtl(c) => c.end_compute_pass(), + } + } + pub unsafe fn memory_barrier(&mut self) { mux_match! { self; CmdBuf::Vk(c) => c.memory_barrier(),