Rework of compute encoder abstraction

The current plan is to more or less follow the wgpu/wgpu-hal approach. In the mux/backend layer (which corresponds fairly strongly to wgpu-hal), there isn't explicit construction of a compute encoder, but there are new methods for beginning and ending a compute pass. At the hub layer (which corresponds to wgpu) there will be a ComputeEncoder object.

That said, there will be some differences. The WebGPU "end" method on a compute encoder is implemented in wgpu as Drop, and that is not ideal. Also, the wgpu-hal approach to timer queries (still based on write_timestamp) is not up to the task of Metal timer queries, where the query offsets have to be specified at compute encoder creation. That's why there are different projects :)

WIP: current state is that stage-style queries work on Apple Silicon, but non-Metal backends are broken, and piet-gpu is not yet updated to use new API.
This commit is contained in:
Raph Levien 2022-04-13 10:31:38 -07:00
parent 290d5d2e13
commit ba2b27cc3c
7 changed files with 196 additions and 111 deletions

View file

@ -1,4 +1,4 @@
use piet_gpu_hal::{include_shader, BindType}; use piet_gpu_hal::{include_shader, BindType, ComputePassDescriptor};
use piet_gpu_hal::{BufferUsage, Instance, InstanceFlags, Session}; use piet_gpu_hal::{BufferUsage, Instance, InstanceFlags, Session};
fn main() { fn main() {
@ -20,9 +20,9 @@ fn main() {
let mut cmd_buf = session.cmd_buf().unwrap(); let mut cmd_buf = session.cmd_buf().unwrap();
cmd_buf.begin(); cmd_buf.begin();
cmd_buf.reset_query_pool(&query_pool); cmd_buf.reset_query_pool(&query_pool);
cmd_buf.write_timestamp(&query_pool, 0); let mut pass = cmd_buf.begin_compute_pass(&ComputePassDescriptor::timer(&query_pool, 0, 1));
cmd_buf.dispatch(&pipeline, &descriptor_set, (256, 1, 1), (1, 1, 1)); pass.dispatch(&pipeline, &descriptor_set, (256, 1, 1), (1, 1, 1));
cmd_buf.write_timestamp(&query_pool, 1); pass.end();
cmd_buf.finish_timestamps(&query_pool); cmd_buf.finish_timestamps(&query_pool);
cmd_buf.host_barrier(); cmd_buf.host_barrier();
cmd_buf.finish(); cmd_buf.finish();

View file

@ -17,7 +17,8 @@
//! The generic trait for backends to implement. //! The generic trait for backends to implement.
use crate::{ use crate::{
BindType, BufferUsage, Error, GpuInfo, ImageFormat, ImageLayout, MapMode, SamplerParams, BindType, BufferUsage, ComputePassDescriptor, Error, GpuInfo, ImageFormat, ImageLayout,
MapMode, SamplerParams,
}; };
pub trait Device: Sized { pub trait Device: Sized {
@ -159,16 +160,32 @@ pub trait Device: Sized {
unsafe fn create_sampler(&self, params: SamplerParams) -> Result<Self::Sampler, Error>; unsafe fn create_sampler(&self, params: SamplerParams) -> Result<Self::Sampler, Error>;
} }
/// The trait implemented by backend command buffer implementations.
///
/// Valid encoding is represented by a state machine (currently not validated
/// but it is easy to imagine there might be at least debug validation). Most
/// methods are only valid in a particular state, and some move it to another
/// state.
pub trait CmdBuf<D: Device> { pub trait CmdBuf<D: Device> {
type ComputeEncoder; /// Begin encoding.
///
/// State: init -> ready
unsafe fn begin(&mut self); unsafe fn begin(&mut self);
/// State: ready -> finished
unsafe fn finish(&mut self); unsafe fn finish(&mut self);
/// Return true if the command buffer is suitable for reuse. /// Return true if the command buffer is suitable for reuse.
unsafe fn reset(&mut self) -> bool; unsafe fn reset(&mut self) -> bool;
/// Begin a compute pass.
///
/// State: ready -> in_compute_pass
unsafe fn begin_compute_pass(&mut self, desc: &ComputePassDescriptor);
/// Dispatch
///
/// State: in_compute_pass
unsafe fn dispatch( unsafe fn dispatch(
&mut self, &mut self,
pipeline: &D::Pipeline, pipeline: &D::Pipeline,
@ -177,6 +194,9 @@ pub trait CmdBuf<D: Device> {
workgroup_size: (u32, u32, u32), workgroup_size: (u32, u32, u32),
); );
/// State: in_compute_pass -> ready
unsafe fn end_compute_pass(&mut self);
/// Insert an execution and memory barrier. /// Insert an execution and memory barrier.
/// ///
/// Compute kernels (and other actions) after this barrier may read from buffers /// Compute kernels (and other actions) after this barrier may read from buffers
@ -229,12 +249,10 @@ pub trait CmdBuf<D: Device> {
unsafe fn finish_timestamps(&mut self, _pool: &D::QueryPool) {} unsafe fn finish_timestamps(&mut self, _pool: &D::QueryPool) {}
/// Begin a labeled section for debugging and profiling purposes. /// Begin a labeled section for debugging and profiling purposes.
unsafe fn begin_debug_label(&mut self, label: &str) {} unsafe fn begin_debug_label(&mut self, _label: &str) {}
/// End a section opened by `begin_debug_label`. /// End a section opened by `begin_debug_label`.
unsafe fn end_debug_label(&mut self) {} unsafe fn end_debug_label(&mut self) {}
unsafe fn new_compute_encoder(&mut self) -> Self::ComputeEncoder;
} }
/// A builder for descriptor sets with more complex layouts. /// A builder for descriptor sets with more complex layouts.
@ -256,16 +274,3 @@ pub trait DescriptorSetBuilder<D: Device> {
fn add_textures(&mut self, images: &[&D::Image]); fn add_textures(&mut self, images: &[&D::Image]);
unsafe fn build(self, device: &D, pipeline: &D::Pipeline) -> Result<D::DescriptorSet, Error>; unsafe fn build(self, device: &D, pipeline: &D::Pipeline) -> Result<D::DescriptorSet, Error>;
} }
pub trait ComputeEncoder<D: Device> {
unsafe fn dispatch(
&mut self,
pipeline: &D::Pipeline,
descriptor_set: &D::DescriptorSet,
workgroup_count: (u32, u32, u32),
workgroup_size: (u32, u32, u32),
);
// Question: should be self?
unsafe fn finish(&mut self);
}

View file

@ -13,7 +13,7 @@ use std::sync::{Arc, Mutex, Weak};
use bytemuck::Pod; use bytemuck::Pod;
use smallvec::SmallVec; use smallvec::SmallVec;
use crate::{mux, BackendType, BufWrite, ImageFormat, MapMode}; use crate::{mux, BackendType, BufWrite, ComputePassDescriptor, ImageFormat, MapMode};
use crate::{BindType, BufferUsage, Error, GpuInfo, ImageLayout, SamplerParams}; use crate::{BindType, BufferUsage, Error, GpuInfo, ImageLayout, SamplerParams};
@ -135,6 +135,11 @@ pub struct BufReadGuard<'a> {
size: u64, size: u64,
} }
/// A sub-object of a command buffer for a sequence of compute dispatches.
pub struct ComputePass<'a> {
cmd_buf: &'a mut CmdBuf,
}
impl Session { impl Session {
/// Create a new session, choosing the best backend. /// Create a new session, choosing the best backend.
pub fn new(device: mux::Device) -> Session { pub fn new(device: mux::Device) -> Session {
@ -471,6 +476,12 @@ impl CmdBuf {
self.cmd_buf().finish(); self.cmd_buf().finish();
} }
/// Begin a compute pass.
pub unsafe fn begin_compute_pass(&mut self, desc: &ComputePassDescriptor) -> ComputePass {
self.cmd_buf().begin_compute_pass(desc);
ComputePass { cmd_buf: self }
}
/// Dispatch a compute shader. /// Dispatch a compute shader.
/// ///
/// Request a compute shader to be run, using the pipeline to specify the /// Request a compute shader to be run, using the pipeline to specify the
@ -479,6 +490,11 @@ impl CmdBuf {
/// Both the workgroup count (number of workgroups) and the workgroup size /// Both the workgroup count (number of workgroups) and the workgroup size
/// (number of threads in a workgroup) must be specified here, though not /// (number of threads in a workgroup) must be specified here, though not
/// all back-ends require the latter info. /// all back-ends require the latter info.
///
/// This version is deprecated because (a) you do not get timer queries and
/// (b) it doesn't aggregate multiple dispatches into a single compute
/// pass, which is a performance concern.
#[deprecated(note = "moving to ComputePass")]
pub unsafe fn dispatch( pub unsafe fn dispatch(
&mut self, &mut self,
pipeline: &Pipeline, pipeline: &Pipeline,
@ -486,8 +502,9 @@ impl CmdBuf {
workgroup_count: (u32, u32, u32), workgroup_count: (u32, u32, u32),
workgroup_size: (u32, u32, u32), workgroup_size: (u32, u32, u32),
) { ) {
self.cmd_buf() let mut pass = self.begin_compute_pass(&Default::default());
.dispatch(pipeline, descriptor_set, workgroup_count, workgroup_size); pass.dispatch(pipeline, descriptor_set, workgroup_count, workgroup_size);
pass.end();
} }
/// Insert an execution and memory barrier. /// Insert an execution and memory barrier.
@ -692,6 +709,32 @@ impl Drop for SubmittedCmdBuf {
} }
} }
impl<'a> ComputePass<'a> {
/// Dispatch a compute shader.
///
/// Request a compute shader to be run, using the pipeline to specify the
/// code, and the descriptor set to address the resources read and written.
///
/// Both the workgroup count (number of workgroups) and the workgroup size
/// (number of threads in a workgroup) must be specified here, though not
/// all back-ends require the latter info.
pub unsafe fn dispatch(
&mut self,
pipeline: &Pipeline,
descriptor_set: &DescriptorSet,
workgroup_count: (u32, u32, u32),
workgroup_size: (u32, u32, u32),
) {
self.cmd_buf
.cmd_buf()
.dispatch(pipeline, descriptor_set, workgroup_count, workgroup_size);
}
pub unsafe fn end(&mut self) {
self.cmd_buf.cmd_buf().end_compute_pass();
}
}
impl Drop for BufferInner { impl Drop for BufferInner {
fn drop(&mut self) { fn drop(&mut self) {
if let Some(session) = Weak::upgrade(&self.session) { if let Some(session) = Weak::upgrade(&self.session) {

View file

@ -189,3 +189,17 @@ pub struct WorkgroupLimits {
/// dimension. /// dimension.
pub max_invocations: u32, pub max_invocations: u32,
} }
#[derive(Default)]
pub struct ComputePassDescriptor<'a> {
// Maybe label should go here? It does in wgpu and wgpu_hal.
timer_queries: Option<(&'a QueryPool, u32, u32)>,
}
impl<'a> ComputePassDescriptor<'a> {
pub fn timer(pool: &'a QueryPool, start_query: u32, end_query: u32) -> ComputePassDescriptor {
ComputePassDescriptor {
timer_queries: Some((pool, start_query, end_query)),
}
}
}

View file

@ -33,11 +33,13 @@ use metal::{CGFloat, CommandBufferRef, MTLFeatureSet};
use raw_window_handle::{HasRawWindowHandle, RawWindowHandle}; use raw_window_handle::{HasRawWindowHandle, RawWindowHandle};
use crate::{BufferUsage, Error, GpuInfo, ImageFormat, MapMode, WorkgroupLimits}; use crate::{
BufferUsage, ComputePassDescriptor, Error, GpuInfo, ImageFormat, MapMode, WorkgroupLimits,
};
use util::*; use util::*;
use self::timer::{CounterSampleBuffer, CounterSet}; use self::timer::{CounterSampleBuffer, CounterSet, TimeCalibration};
pub struct MtlInstance; pub struct MtlInstance;
@ -110,15 +112,11 @@ enum Encoder {
} }
#[derive(Default)] #[derive(Default)]
struct TimeCalibration { pub struct QueryPool {
cpu_start_ts: u64, counter_sample_buf: Option<CounterSampleBuffer>,
gpu_start_ts: u64, calibration: Arc<Mutex<Option<Arc<Mutex<TimeCalibration>>>>>,
cpu_end_ts: u64,
gpu_end_ts: u64,
} }
pub struct QueryPool(Option<CounterSampleBuffer>);
pub struct Pipeline(metal::ComputePipelineState); pub struct Pipeline(metal::ComputePipelineState);
#[derive(Default)] #[derive(Default)]
@ -134,10 +132,6 @@ struct Helpers {
clear_pipeline: metal::ComputePipelineState, clear_pipeline: metal::ComputePipelineState,
} }
pub struct ComputeEncoder {
raw: metal::ComputeCommandEncoder,
}
impl MtlInstance { impl MtlInstance {
pub fn new( pub fn new(
window_handle: Option<&dyn HasRawWindowHandle>, window_handle: Option<&dyn HasRawWindowHandle>,
@ -263,7 +257,7 @@ impl MtlDevice {
helpers, helpers,
timer_set, timer_set,
counter_style, counter_style,
} }
} }
pub fn cmd_buf_from_raw_mtl(&self, raw_cmd_buf: metal::CommandBuffer) -> CmdBuf { pub fn cmd_buf_from_raw_mtl(&self, raw_cmd_buf: metal::CommandBuffer) -> CmdBuf {
@ -409,16 +403,28 @@ impl crate::backend::Device for MtlDevice {
if let Some(timer_set) = &self.timer_set { if let Some(timer_set) = &self.timer_set {
let pool = CounterSampleBuffer::new(&self.device, n_queries as u64, timer_set) let pool = CounterSampleBuffer::new(&self.device, n_queries as u64, timer_set)
.ok_or("error creating timer query pool")?; .ok_or("error creating timer query pool")?;
return Ok(QueryPool(Some(pool))); return Ok(QueryPool {
counter_sample_buf: Some(pool),
calibration: Default::default(),
});
} }
Ok(QueryPool(None)) Ok(QueryPool::default())
} }
unsafe fn fetch_query_pool(&self, pool: &Self::QueryPool) -> Result<Vec<f64>, Error> { unsafe fn fetch_query_pool(&self, pool: &Self::QueryPool) -> Result<Vec<f64>, Error> {
if let Some(raw) = &pool.0 { if let Some(raw) = &pool.counter_sample_buf {
let resolved = raw.resolve(); let resolved = raw.resolve();
println!("resolved = {:?}", resolved); let calibration = pool.calibration.lock().unwrap();
if let Some(calibration) = &*calibration {
let calibration = calibration.lock().unwrap();
let result = resolved
.iter()
.map(|time_ns| calibration.correlate(*time_ns))
.collect();
return Ok(result);
}
} }
// Maybe should return None indicating it wasn't successful? But that might break.
Ok(Vec::new()) Ok(Vec::new())
} }
@ -444,10 +450,6 @@ impl crate::backend::Device for MtlDevice {
let gpu_ts_ptr = &mut time_calibration.gpu_start_ts as *mut _; let gpu_ts_ptr = &mut time_calibration.gpu_start_ts as *mut _;
// TODO: only do this if supported. // TODO: only do this if supported.
let () = msg_send![device, sampleTimestamps: cpu_ts_ptr gpuTimestamp: gpu_ts_ptr]; let () = msg_send![device, sampleTimestamps: cpu_ts_ptr gpuTimestamp: gpu_ts_ptr];
println!(
"scheduled, {}, {}",
time_calibration.cpu_start_ts, time_calibration.gpu_start_ts
);
}) })
.copy(); .copy();
add_scheduled_handler(&cmd_buf.cmd_buf, &start_block); add_scheduled_handler(&cmd_buf.cmd_buf, &start_block);
@ -461,10 +463,6 @@ impl crate::backend::Device for MtlDevice {
// TODO: only do this if supported. // TODO: only do this if supported.
let () = let () =
msg_send![device, sampleTimestamps: cpu_ts_ptr gpuTimestamp: gpu_ts_ptr]; msg_send![device, sampleTimestamps: cpu_ts_ptr gpuTimestamp: gpu_ts_ptr];
println!(
"completed, {}, {}",
time_calibration.cpu_end_ts, time_calibration.gpu_end_ts
);
}) })
.copy(); .copy();
cmd_buf.cmd_buf.add_completed_handler(&completed_block); cmd_buf.cmd_buf.add_completed_handler(&completed_block);
@ -546,8 +544,6 @@ impl crate::backend::Device for MtlDevice {
} }
impl crate::backend::CmdBuf<MtlDevice> for CmdBuf { impl crate::backend::CmdBuf<MtlDevice> for CmdBuf {
type ComputeEncoder = ComputeEncoder;
unsafe fn begin(&mut self) {} unsafe fn begin(&mut self) {}
unsafe fn finish(&mut self) { unsafe fn finish(&mut self) {
@ -558,6 +554,35 @@ impl crate::backend::CmdBuf<MtlDevice> for CmdBuf {
false false
} }
unsafe fn begin_compute_pass(&mut self, desc: &ComputePassDescriptor) {
debug_assert!(matches!(self.cur_encoder, Encoder::None));
let encoder = if let Some(queries) = &desc.timer_queries {
let descriptor: id = msg_send![class!(MTLComputePassDescriptor), computePassDescriptor];
let attachments: id = msg_send![descriptor, sampleBufferAttachments];
let index: NSUInteger = 0;
let attachment: id = msg_send![attachments, objectAtIndexedSubscript: index];
// Here we break the hub/mux separation a bit, for expedience
#[allow(irrefutable_let_patterns)]
if let crate::hub::QueryPool::Mtl(query_pool) = queries.0 {
if let Some(sample_buf) = &query_pool.counter_sample_buf {
let () = msg_send![attachment, setSampleBuffer: sample_buf.id()];
}
}
let start_index = queries.1 as NSUInteger;
let end_index = queries.2 as NSInteger;
let () = msg_send![attachment, setStartOfEncoderSampleIndex: start_index];
let () = msg_send![attachment, setEndOfEncoderSampleIndex: end_index];
let encoder = msg_send![
self.cmd_buf,
computeCommandEncoderWithDescriptor: descriptor
];
encoder
} else {
self.cmd_buf.new_compute_command_encoder()
};
self.cur_encoder = Encoder::Compute(encoder.to_owned());
}
unsafe fn dispatch( unsafe fn dispatch(
&mut self, &mut self,
pipeline: &Pipeline, pipeline: &Pipeline,
@ -590,6 +615,11 @@ impl crate::backend::CmdBuf<MtlDevice> for CmdBuf {
encoder.dispatch_thread_groups(workgroup_count, workgroup_size); encoder.dispatch_thread_groups(workgroup_count, workgroup_size);
} }
unsafe fn end_compute_pass(&mut self) {
// TODO: might validate that we are in a compute encoder state
self.flush_encoder();
}
unsafe fn memory_barrier(&mut self) { unsafe fn memory_barrier(&mut self) {
// We'll probably move to explicit barriers, but for now rely on // We'll probably move to explicit barriers, but for now rely on
// Metal's own tracking. // Metal's own tracking.
@ -690,10 +720,13 @@ impl crate::backend::CmdBuf<MtlDevice> for CmdBuf {
); );
} }
unsafe fn reset_query_pool(&mut self, _pool: &QueryPool) {} unsafe fn reset_query_pool(&mut self, pool: &QueryPool) {
let mut calibration = pool.calibration.lock().unwrap();
*calibration = Some(self.time_calibration.clone());
}
unsafe fn write_timestamp(&mut self, pool: &QueryPool, query: u32) { unsafe fn write_timestamp(&mut self, pool: &QueryPool, query: u32) {
if let Some(buf) = &pool.0 { if let Some(buf) = &pool.counter_sample_buf {
if matches!(self.cur_encoder, Encoder::None) { if matches!(self.cur_encoder, Encoder::None) {
self.cur_encoder = self.cur_encoder =
Encoder::Compute(self.cmd_buf.new_compute_command_encoder().to_owned()); Encoder::Compute(self.cmd_buf.new_compute_command_encoder().to_owned());
@ -709,21 +742,14 @@ impl crate::backend::CmdBuf<MtlDevice> for CmdBuf {
} }
} else if self.counter_style == CounterStyle::Stage { } else if self.counter_style == CounterStyle::Stage {
match &self.cur_encoder { match &self.cur_encoder {
Encoder::Compute(e) => { Encoder::Compute(_e) => {
println!("here we are"); println!("write_timestamp is not supported for stage-style encoders");
} }
_ => (), _ => (),
} }
} }
} }
} }
unsafe fn new_compute_encoder(&mut self) -> Self::ComputeEncoder {
let raw = self.cmd_buf.new_compute_command_encoder().to_owned();
ComputeEncoder {
raw
}
}
} }
impl CmdBuf { impl CmdBuf {
@ -761,43 +787,6 @@ impl CmdBuf {
} }
} }
impl crate::backend::ComputeEncoder<MtlDevice> for ComputeEncoder {
unsafe fn dispatch(
&mut self,
pipeline: &Pipeline,
descriptor_set: &DescriptorSet,
workgroup_count: (u32, u32, u32),
workgroup_size: (u32, u32, u32),
) {
self.raw.set_compute_pipeline_state(&pipeline.0);
let mut buf_ix = 0;
for buffer in &descriptor_set.buffers {
self.raw.set_buffer(buf_ix, Some(&buffer.buffer), 0);
buf_ix += 1;
}
let mut img_ix = buf_ix;
for image in &descriptor_set.images {
self.raw.set_texture(img_ix, Some(&image.texture));
img_ix += 1;
}
let workgroup_count = metal::MTLSize {
width: workgroup_count.0 as u64,
height: workgroup_count.1 as u64,
depth: workgroup_count.2 as u64,
};
let workgroup_size = metal::MTLSize {
width: workgroup_size.0 as u64,
height: workgroup_size.1 as u64,
depth: workgroup_size.2 as u64,
};
self.raw.dispatch_thread_groups(workgroup_count, workgroup_size);
}
unsafe fn finish(&mut self) {
self.raw.end_encoding();
}
}
impl crate::backend::DescriptorSetBuilder<MtlDevice> for DescriptorSetBuilder { impl crate::backend::DescriptorSetBuilder<MtlDevice> for DescriptorSetBuilder {
fn add_buffers(&mut self, buffers: &[&Buffer]) { fn add_buffers(&mut self, buffers: &[&Buffer]) {
self.0.buffers.extend(buffers.iter().copied().cloned()); self.0.buffers.extend(buffers.iter().copied().cloned());

View file

@ -36,6 +36,14 @@ pub struct CounterSet {
id: id, id: id,
} }
#[derive(Default)]
pub struct TimeCalibration {
pub cpu_start_ts: u64,
pub gpu_start_ts: u64,
pub cpu_end_ts: u64,
pub gpu_end_ts: u64,
}
impl Drop for CounterSampleBuffer { impl Drop for CounterSampleBuffer {
fn drop(&mut self) { fn drop(&mut self) {
unsafe { msg_send![self.id, release] } unsafe { msg_send![self.id, release] }
@ -87,7 +95,6 @@ impl CounterSampleBuffer {
unsafe { unsafe {
let desc_cls = class!(MTLCounterSampleBufferDescriptor); let desc_cls = class!(MTLCounterSampleBufferDescriptor);
let descriptor: id = msg_send![desc_cls, alloc]; let descriptor: id = msg_send![desc_cls, alloc];
println!("descriptor = {:?}", descriptor);
let _: id = msg_send![descriptor, init]; let _: id = msg_send![descriptor, init];
let count = count as NSUInteger; let count = count as NSUInteger;
let () = msg_send![descriptor, setSampleCount: count]; let () = msg_send![descriptor, setSampleCount: count];
@ -121,3 +128,21 @@ impl CounterSampleBuffer {
} }
} }
} }
impl TimeCalibration {
/// Convert GPU timestamp into CPU time base.
///
/// See https://developer.apple.com/documentation/metal/performance_tuning/correlating_cpu_and_gpu_timestamps
pub fn correlate(&self, raw_ts: u64) -> f64 {
let delta_cpu = self.cpu_end_ts - self.cpu_start_ts;
let delta_gpu = self.gpu_end_ts - self.gpu_start_ts;
let adj_ts = if delta_gpu > 0 {
let scale = delta_cpu as f64 / delta_gpu as f64;
self.cpu_start_ts as f64 + (raw_ts - self.gpu_start_ts) as f64 * scale
} else {
// Default is ns on Apple Silicon; on other hardware this will be wrong
raw_ts as f64
};
adj_ts * 1e-9
}
}

View file

@ -35,6 +35,7 @@ use crate::backend::DescriptorSetBuilder as DescriptorSetBuilderTrait;
use crate::backend::Device as DeviceTrait; use crate::backend::Device as DeviceTrait;
use crate::BackendType; use crate::BackendType;
use crate::BindType; use crate::BindType;
use crate::ComputePassDescriptor;
use crate::ImageFormat; use crate::ImageFormat;
use crate::MapMode; use crate::MapMode;
use crate::{BufferUsage, Error, GpuInfo, ImageLayout, InstanceFlags}; use crate::{BufferUsage, Error, GpuInfo, ImageLayout, InstanceFlags};
@ -100,14 +101,6 @@ mux_device_enum! {
QueryPool } QueryPool }
mux_device_enum! { Sampler } mux_device_enum! { Sampler }
mux_enum! {
pub enum ComputeEncoder {
Vk(<crate::vulkan::CmdBuf as crate::backend::CmdBuf<vulkan::VkDevice>>::ComputeEncoder),
Dx12(<crate::dx12::Dx12Device as crate::backend::CmdBuf<dx12::Dx12Device>>::ComputeEncoder),
Mtl(<crate::metal::CmdBuf as crate::backend::CmdBuf<metal::MtlDevice>>::ComputeEncoder),
}
}
/// The code for a shader, either as source or intermediate representation. /// The code for a shader, either as source or intermediate representation.
pub enum ShaderCode<'a> { pub enum ShaderCode<'a> {
/// SPIR-V (binary intermediate representation) /// SPIR-V (binary intermediate representation)
@ -666,6 +659,14 @@ impl CmdBuf {
} }
} }
pub unsafe fn begin_compute_pass(&mut self, desc: &ComputePassDescriptor) {
mux_match! { self;
CmdBuf::Vk(c) => c.begin_compute_pass(desc),
CmdBuf::Dx12(c) => c.begin_compute_pass(desc),
CmdBuf::Mtl(c) => c.begin_compute_pass(desc),
}
}
/// Dispatch a compute shader. /// Dispatch a compute shader.
/// ///
/// Note that both the number of workgroups (`workgroup_count`) and the number of /// Note that both the number of workgroups (`workgroup_count`) and the number of
@ -688,6 +689,14 @@ impl CmdBuf {
} }
} }
pub unsafe fn end_compute_pass(&mut self) {
mux_match! { self;
CmdBuf::Vk(c) => c.end_compute_pass(),
CmdBuf::Dx12(c) => c.end_compute_pass(),
CmdBuf::Mtl(c) => c.end_compute_pass(),
}
}
pub unsafe fn memory_barrier(&mut self) { pub unsafe fn memory_barrier(&mut self) {
mux_match! { self; mux_match! { self;
CmdBuf::Vk(c) => c.memory_barrier(), CmdBuf::Vk(c) => c.memory_barrier(),