Fix up merge

Update signatures to match changes to backend trait, and add new types
and stub methods to complete mux abstraction.
This commit is contained in:
Raph Levien 2021-05-26 19:08:56 -07:00
parent 0d5ff515ec
commit ebefd025f7
3 changed files with 230 additions and 69 deletions

View file

@ -25,6 +25,10 @@ macro_rules! mux_cfg {
( #[cfg(dx12)] $($tokens:tt)* ) => { ( #[cfg(dx12)] $($tokens:tt)* ) => {
#[cfg(target_os="windows")] $( $tokens )* #[cfg(target_os="windows")] $( $tokens )*
}; };
( #[cfg(mtl)] $($tokens:tt)* ) => {
#[cfg(target_os="macos")] $( $tokens )*
};
} }
#[macro_export] #[macro_export]
@ -32,12 +36,15 @@ macro_rules! mux_enum {
( $(#[$outer:meta])* $v:vis enum $name:ident { ( $(#[$outer:meta])* $v:vis enum $name:ident {
Vk($vk:ty), Vk($vk:ty),
Dx12($dx12:ty), Dx12($dx12:ty),
Mtl($mtl:ty),
} ) => { } ) => {
$(#[$outer])* $v enum $name { $(#[$outer])* $v enum $name {
#[cfg(not(target_os="macos"))] #[cfg(not(target_os="macos"))]
Vk($vk), Vk($vk),
#[cfg(target_os="windows")] #[cfg(target_os="windows")]
Dx12($dx12), Dx12($dx12),
#[cfg(target_os="macos")]
Mtl($mtl),
} }
impl $name { impl $name {
@ -62,6 +69,17 @@ macro_rules! mux_enum {
} }
} }
} }
$crate::mux_cfg! {
#[cfg(mtl)]
#[allow(unused)]
fn mtl(&self) -> &$mtl {
match self {
$name::Mtl(x) => x,
_ => panic!("downcast error")
}
}
}
} }
}; };
} }
@ -73,6 +91,7 @@ macro_rules! mux_device_enum {
pub enum $assoc_type { pub enum $assoc_type {
Vk(<$crate::vulkan::VkDevice as $crate::Device>::$assoc_type), Vk(<$crate::vulkan::VkDevice as $crate::Device>::$assoc_type),
Dx12(<$crate::dx12::Dx12Device as $crate::Device>::$assoc_type), Dx12(<$crate::dx12::Dx12Device as $crate::Device>::$assoc_type),
Mtl(<$crate::metal::MtlDevice as $crate::Device>::$assoc_type),
} }
} }
} }
@ -83,22 +102,27 @@ macro_rules! mux_match {
( $e:expr ; ( $e:expr ;
$vkname:ident::Vk($vkvar:ident) => $vkblock: block $vkname:ident::Vk($vkvar:ident) => $vkblock: block
$dx12name:ident::Dx12($dx12var:ident) => $dx12block: block $dx12name:ident::Dx12($dx12var:ident) => $dx12block: block
$mtlname:ident::Mtl($mtlvar:ident) => $mtlblock: block
) => { ) => {
match $e { match $e {
#[cfg(not(target_os="macos"))] #[cfg(not(target_os="macos"))]
$vkname::Vk($vkvar) => $vkblock $vkname::Vk($vkvar) => $vkblock
#[cfg(target_os="windows")] #[cfg(target_os="windows")]
$dx12name::Dx12($dx12var) => $dx12block $dx12name::Dx12($dx12var) => $dx12block
#[cfg(target_os="macos")]
$mtlname::Mtl($mtlvar) => $mtlblock
} }
}; };
( $e:expr ; ( $e:expr ;
$vkname:ident::Vk($vkvar:ident) => $vkblock: expr, $vkname:ident::Vk($vkvar:ident) => $vkblock: expr,
$dx12name:ident::Dx12($dx12var:ident) => $dx12block: expr, $dx12name:ident::Dx12($dx12var:ident) => $dx12block: expr,
$mtlname:ident::Mtl($mtlvar:ident) => $mtlblock: expr,
) => { ) => {
$crate::mux_match! { $e; $crate::mux_match! { $e;
$vkname::Vk($vkvar) => { $vkblock } $vkname::Vk($vkvar) => { $vkblock }
$dx12name::Dx12($dx12var) => { $dx12block } $dx12name::Dx12($dx12var) => { $dx12block }
$mtlname::Mtl($mtlvar) => { $mtlblock }
} }
}; };
} }

View file

@ -14,17 +14,22 @@
// //
// Also licensed under MIT license, at your choice. // Also licensed under MIT license, at your choice.
use crate::Error; use crate::{BufferUsage, Error};
use bitflags::bitflags; pub struct MtlInstance;
pub struct MetalInstance; pub struct MtlDevice {
pub struct MetalDevice {
device: metal::Device, device: metal::Device,
} }
pub struct Buffer(metal::Buffer); pub struct MtlSurface;
pub struct MtlSwapchain;
pub struct Buffer {
buffer: metal::Buffer,
pub(crate) size: u64,
}
pub struct Image; pub struct Image;
@ -36,19 +41,6 @@ pub struct Fence;
pub struct Semaphore; pub struct Semaphore;
// This is the new direction of how I want this to go, and will
// move it to crate level. It's very similar to wgpu's BufferUsage.
bitflags! {
pub struct MemFlags: u32 {
const MAP_READ = 1;
const MAP_WRITE = 2;
const COPY_SRC = 4;
const COPY_DST = 8;
const STORAGE = 128;
}
}
pub struct CmdBuf; pub struct CmdBuf;
pub struct QueryPool; pub struct QueryPool;
@ -57,29 +49,39 @@ pub struct PipelineBuilder;
pub struct DescriptorSetBuilder; pub struct DescriptorSetBuilder;
impl MetalInstance { impl MtlInstance {
pub fn new() -> MetalInstance { pub fn new(
MetalInstance window_handle: Option<&dyn raw_window_handle::HasRawWindowHandle>,
) -> Result<(MtlInstance, Option<MtlSurface>), Error> {
Ok((MtlInstance, None))
} }
// TODO might do some enumeration of devices // TODO might do some enumeration of devices
pub fn device(&self) -> Result<MetalDevice, Error> { pub fn device(&self, surface: Option<&MtlSurface>) -> Result<MtlDevice, Error> {
if let Some(device) = metal::Device::system_default() { if let Some(device) = metal::Device::system_default() {
Ok(MetalDevice { device }) Ok(MtlDevice { device })
} else { } else {
Err("can't create system default Metal device".into()) Err("can't create system default Metal device".into())
} }
} }
pub unsafe fn swapchain(
&self,
width: usize,
height: usize,
device: &MtlDevice,
surface: &MtlSurface,
) -> Result<MtlSwapchain, Error> {
todo!()
}
} }
impl crate::Device for MetalDevice { impl crate::Device for MtlDevice {
type Buffer = Buffer; type Buffer = Buffer;
type Image = Image; type Image = Image;
type MemFlags = MemFlags;
type Pipeline = Pipeline; type Pipeline = Pipeline;
type DescriptorSet = DescriptorSet; type DescriptorSet = DescriptorSet;
@ -98,20 +100,22 @@ impl crate::Device for MetalDevice {
type Sampler = (); type Sampler = ();
type ShaderSource = str;
fn query_gpu_info(&self) -> crate::GpuInfo { fn query_gpu_info(&self) -> crate::GpuInfo {
todo!() todo!()
} }
fn create_buffer(&self, size: u64, mem_flags: Self::MemFlags) -> Result<Self::Buffer, Error> { fn create_buffer(&self, size: u64, usage: BufferUsage) -> Result<Self::Buffer, Error> {
let options = if mem_flags.contains(MemFlags::MAP_READ) { let options = if usage.contains(BufferUsage::MAP_READ) {
metal::MTLResourceOptions::StorageModeShared | metal::MTLResourceOptions::CPUCacheModeDefaultCache metal::MTLResourceOptions::StorageModeShared | metal::MTLResourceOptions::CPUCacheModeDefaultCache
} else if mem_flags.contains(MemFlags::MAP_WRITE) { } else if usage.contains(BufferUsage::MAP_WRITE) {
metal::MTLResourceOptions::StorageModeShared | metal::MTLResourceOptions::CPUCacheModeWriteCombined metal::MTLResourceOptions::StorageModeShared | metal::MTLResourceOptions::CPUCacheModeWriteCombined
} else { } else {
metal::MTLResourceOptions::StorageModePrivate metal::MTLResourceOptions::StorageModePrivate
}; };
let buffer = self.device.new_buffer(size, options); let buffer = self.device.new_buffer(size, options);
Ok(Buffer(buffer)) Ok(Buffer { buffer, size })
} }
unsafe fn destroy_buffer(&self, buffer: &Self::Buffer) -> Result<(), Error> { unsafe fn destroy_buffer(&self, buffer: &Self::Buffer) -> Result<(), Error> {
@ -122,7 +126,6 @@ impl crate::Device for MetalDevice {
&self, &self,
width: u32, width: u32,
height: u32, height: u32,
mem_flags: Self::MemFlags,
) -> Result<Self::Image, Error> { ) -> Result<Self::Image, Error> {
todo!() todo!()
} }
@ -151,46 +154,43 @@ impl crate::Device for MetalDevice {
todo!() todo!()
} }
unsafe fn run_cmd_buf( unsafe fn run_cmd_bufs(
&self, &self,
cmd_buf: &Self::CmdBuf, cmd_bufs: &[&Self::CmdBuf],
wait_semaphores: &[Self::Semaphore], wait_semaphores: &[&Self::Semaphore],
signal_semaphores: &[Self::Semaphore], signal_semaphores: &[&Self::Semaphore],
fence: Option<&Self::Fence>, fence: Option<&Self::Fence>,
) -> Result<(), Error> { ) -> Result<(), Error> {
todo!() todo!()
} }
unsafe fn read_buffer<T: Sized>( unsafe fn read_buffer(
&self, &self,
buffer: &Self::Buffer, buffer: &Self::Buffer,
result: &mut Vec<T>, dst: *mut u8,
offset: u64,
size: u64,
) -> Result<(), Error> { ) -> Result<(), Error> {
let contents_ptr = buffer.0.contents(); let contents_ptr = buffer.buffer.contents();
if contents_ptr.is_null() { if contents_ptr.is_null() {
return Err("probably trying to read from private buffer".into()); return Err("probably trying to read from private buffer".into());
} }
let len = buffer.0.length() as usize / std::mem::size_of::<T>(); std::ptr::copy_nonoverlapping((contents_ptr as *const u8).add(offset as usize), dst, size as usize);
if len > result.len() {
result.reserve(len - result.len());
}
std::ptr::copy_nonoverlapping(contents_ptr as *const T, result.as_mut_ptr(), len);
result.set_len(len);
Ok(()) Ok(())
} }
unsafe fn write_buffer<T: Sized>( unsafe fn write_buffer(
&self, &self,
buffer: &Self::Buffer, buffer: &Buffer,
contents: &[T], contents: *const u8,
offset: u64,
size: u64,
) -> Result<(), Error> { ) -> Result<(), Error> {
let contents_ptr = buffer.0.contents(); let contents_ptr = buffer.buffer.contents();
if contents_ptr.is_null() { if contents_ptr.is_null() {
return Err("probably trying to write to private buffer".into()); return Err("probably trying to write to private buffer".into());
} }
let len = buffer.0.length() as usize / std::mem::size_of::<T>(); std::ptr::copy_nonoverlapping(contents, (contents_ptr as *mut u8).add(offset as usize), size as usize);
assert!(len >= contents.len());
std::ptr::copy_nonoverlapping(contents.as_ptr(), contents_ptr as *mut T, len);
Ok(()) Ok(())
} }
@ -202,11 +202,11 @@ impl crate::Device for MetalDevice {
todo!() todo!()
} }
unsafe fn wait_and_reset(&self, fences: &[Self::Fence]) -> Result<(), Error> { unsafe fn wait_and_reset(&self, fences: &[&Self::Fence]) -> Result<(), Error> {
todo!() todo!()
} }
unsafe fn get_fence_status(&self, fence: Self::Fence) -> Result<bool, Error> { unsafe fn get_fence_status(&self, fence: &Self::Fence) -> Result<bool, Error> {
todo!() todo!()
} }
@ -215,17 +215,7 @@ impl crate::Device for MetalDevice {
} }
} }
impl crate::MemFlags for MemFlags { impl crate::CmdBuf<MtlDevice> for CmdBuf {
fn device_local() -> Self {
MemFlags::COPY_SRC | MemFlags::COPY_DST | MemFlags::STORAGE
}
fn host_coherent() -> Self {
MemFlags::device_local() | MemFlags::MAP_READ | MemFlags::MAP_WRITE
}
}
impl crate::CmdBuf<MetalDevice> for CmdBuf {
unsafe fn begin(&mut self) { unsafe fn begin(&mut self) {
todo!() todo!()
} }
@ -289,7 +279,7 @@ impl crate::CmdBuf<MetalDevice> for CmdBuf {
} }
} }
impl crate::PipelineBuilder<MetalDevice> for PipelineBuilder { impl crate::PipelineBuilder<MtlDevice> for PipelineBuilder {
fn add_buffers(&mut self, n_buffers: u32) { fn add_buffers(&mut self, n_buffers: u32) {
todo!() todo!()
} }
@ -302,12 +292,12 @@ impl crate::PipelineBuilder<MetalDevice> for PipelineBuilder {
todo!() todo!()
} }
unsafe fn create_compute_pipeline(self, device: &MetalDevice, code: &[u8]) -> Result<Pipeline, Error> { unsafe fn create_compute_pipeline(self, device: &MtlDevice, code: &str) -> Result<Pipeline, Error> {
todo!() todo!()
} }
} }
impl crate::DescriptorSetBuilder<MetalDevice> for DescriptorSetBuilder { impl crate::DescriptorSetBuilder<MtlDevice> for DescriptorSetBuilder {
fn add_buffers(&mut self, buffers: &[&Buffer]) { fn add_buffers(&mut self, buffers: &[&Buffer]) {
todo!() todo!()
} }
@ -320,7 +310,25 @@ impl crate::DescriptorSetBuilder<MetalDevice> for DescriptorSetBuilder {
todo!() todo!()
} }
unsafe fn build(self, device: &MetalDevice, pipeline: &Pipeline) -> Result<DescriptorSet, Error> { unsafe fn build(self, device: &MtlDevice, pipeline: &Pipeline) -> Result<DescriptorSet, Error> {
todo!()
}
}
impl MtlSwapchain {
pub unsafe fn next(&mut self) -> Result<(usize, Semaphore), Error> {
todo!()
}
pub unsafe fn image(&self, idx: usize) -> Image {
todo!()
}
pub unsafe fn present(
&self,
image_idx: usize,
semaphores: &[&Semaphore],
) -> Result<bool, Error> {
todo!() todo!()
} }
} }

View file

@ -26,6 +26,10 @@ mux_cfg! {
#[cfg(dx12)] #[cfg(dx12)]
use crate::dx12; use crate::dx12;
} }
mux_cfg! {
#[cfg(mtl)]
use crate::metal;
}
use crate::CmdBuf as CmdBufTrait; use crate::CmdBuf as CmdBufTrait;
use crate::DescriptorSetBuilder as DescriptorSetBuilderTrait; use crate::DescriptorSetBuilder as DescriptorSetBuilderTrait;
use crate::Device as DeviceTrait; use crate::Device as DeviceTrait;
@ -37,6 +41,7 @@ mux_enum! {
pub enum Instance { pub enum Instance {
Vk(vulkan::VkInstance), Vk(vulkan::VkInstance),
Dx12(dx12::Dx12Instance), Dx12(dx12::Dx12Instance),
Mtl(metal::MtlInstance),
} }
} }
@ -45,6 +50,7 @@ mux_enum! {
pub enum Device { pub enum Device {
Vk(vulkan::VkDevice), Vk(vulkan::VkDevice),
Dx12(dx12::Dx12Device), Dx12(dx12::Dx12Device),
Mtl(metal::MtlDevice),
} }
} }
@ -53,6 +59,7 @@ mux_enum! {
pub enum Surface { pub enum Surface {
Vk(vulkan::VkSurface), Vk(vulkan::VkSurface),
Dx12(dx12::Dx12Surface), Dx12(dx12::Dx12Surface),
Mtl(metal::MtlSurface),
} }
} }
@ -61,6 +68,7 @@ mux_enum! {
pub enum Swapchain { pub enum Swapchain {
Vk(vulkan::VkSwapchain), Vk(vulkan::VkSwapchain),
Dx12(dx12::Dx12Swapchain), Dx12(dx12::Dx12Swapchain),
Mtl(metal::MtlSwapchain),
} }
} }
@ -78,8 +86,12 @@ mux_device_enum! { Sampler }
/// The code for a shader, either as source or intermediate representation. /// The code for a shader, either as source or intermediate representation.
pub enum ShaderCode<'a> { pub enum ShaderCode<'a> {
/// SPIR-V (binary intermediate representation)
Spv(&'a [u8]), Spv(&'a [u8]),
/// HLSL (source)
Hlsl(&'a str), Hlsl(&'a str),
/// Metal Shading Language (source)
Msl(&'a str),
} }
impl Instance { impl Instance {
@ -111,6 +123,15 @@ impl Instance {
} }
} }
} }
mux_cfg! {
#[cfg(mtl)]
{
let result = metal::MtlInstance::new(window_handle);
if let Ok((instance, surface)) = result {
return Ok((Instance::Mtl(instance), surface.map(Surface::Mtl)));
}
}
}
// TODO plumb creation errors through. // TODO plumb creation errors through.
Err("No suitable instances found".into()) Err("No suitable instances found".into())
} }
@ -124,6 +145,7 @@ impl Instance {
mux_match! { self; mux_match! { self;
Instance::Vk(i) => i.device(surface.map(Surface::vk)).map(Device::Vk), Instance::Vk(i) => i.device(surface.map(Surface::vk)).map(Device::Vk),
Instance::Dx12(i) => i.device(surface.map(Surface::dx12)).map(Device::Dx12), Instance::Dx12(i) => i.device(surface.map(Surface::dx12)).map(Device::Dx12),
Instance::Mtl(i) => i.device(surface.map(Surface::mtl)).map(Device::Mtl),
} }
} }
@ -147,6 +169,9 @@ impl Instance {
Instance::Dx12(i) => i Instance::Dx12(i) => i
.swapchain(width, height, device.dx12(), surface.dx12()) .swapchain(width, height, device.dx12(), surface.dx12())
.map(Swapchain::Dx12), .map(Swapchain::Dx12),
Instance::Mtl(i) => i
.swapchain(width, height, device.mtl(), surface.mtl())
.map(Swapchain::Mtl),
} }
} }
} }
@ -159,6 +184,7 @@ impl Device {
mux_match! { self; mux_match! { self;
Device::Vk(d) => d.query_gpu_info(), Device::Vk(d) => d.query_gpu_info(),
Device::Dx12(d) => d.query_gpu_info(), Device::Dx12(d) => d.query_gpu_info(),
Device::Mtl(d) => d.query_gpu_info(),
} }
} }
@ -166,6 +192,7 @@ impl Device {
mux_match! { self; mux_match! { self;
Device::Vk(d) => d.create_buffer(size, usage).map(Buffer::Vk), Device::Vk(d) => d.create_buffer(size, usage).map(Buffer::Vk),
Device::Dx12(d) => d.create_buffer(size, usage).map(Buffer::Dx12), Device::Dx12(d) => d.create_buffer(size, usage).map(Buffer::Dx12),
Device::Mtl(d) => d.create_buffer(size, usage).map(Buffer::Mtl),
} }
} }
@ -173,6 +200,7 @@ impl Device {
mux_match! { self; mux_match! { self;
Device::Vk(d) => d.destroy_buffer(buffer.vk()), Device::Vk(d) => d.destroy_buffer(buffer.vk()),
Device::Dx12(d) => d.destroy_buffer(buffer.dx12()), Device::Dx12(d) => d.destroy_buffer(buffer.dx12()),
Device::Mtl(d) => d.destroy_buffer(buffer.mtl()),
} }
} }
@ -180,6 +208,7 @@ impl Device {
mux_match! { self; mux_match! { self;
Device::Vk(d) => d.create_image2d(width, height).map(Image::Vk), Device::Vk(d) => d.create_image2d(width, height).map(Image::Vk),
Device::Dx12(d) => d.create_image2d(width, height).map(Image::Dx12), Device::Dx12(d) => d.create_image2d(width, height).map(Image::Dx12),
Device::Mtl(d) => d.create_image2d(width, height).map(Image::Mtl),
} }
} }
@ -187,6 +216,7 @@ impl Device {
mux_match! { self; mux_match! { self;
Device::Vk(d) => d.destroy_image(image.vk()), Device::Vk(d) => d.destroy_image(image.vk()),
Device::Dx12(d) => d.destroy_image(image.dx12()), Device::Dx12(d) => d.destroy_image(image.dx12()),
Device::Mtl(d) => d.destroy_image(image.mtl()),
} }
} }
@ -194,6 +224,7 @@ impl Device {
mux_match! { self; mux_match! { self;
Device::Vk(d) => d.create_fence(signaled).map(Fence::Vk), Device::Vk(d) => d.create_fence(signaled).map(Fence::Vk),
Device::Dx12(d) => d.create_fence(signaled).map(Fence::Dx12), Device::Dx12(d) => d.create_fence(signaled).map(Fence::Dx12),
Device::Mtl(d) => d.create_fence(signaled).map(Fence::Mtl),
} }
} }
@ -215,6 +246,14 @@ impl Device {
.collect::<SmallVec<[_; 4]>>(); .collect::<SmallVec<[_; 4]>>();
d.wait_and_reset(&*fences) d.wait_and_reset(&*fences)
} }
Device::Mtl(d) => {
let fences = fences
.iter()
.copied()
.map(Fence::mtl)
.collect::<SmallVec<[_; 4]>>();
d.wait_and_reset(&*fences)
}
} }
} }
@ -222,6 +261,7 @@ impl Device {
mux_match! { self; mux_match! { self;
Device::Vk(d) => d.get_fence_status(fence.vk()), Device::Vk(d) => d.get_fence_status(fence.vk()),
Device::Dx12(d) => d.get_fence_status(fence.dx12()), Device::Dx12(d) => d.get_fence_status(fence.dx12()),
Device::Mtl(d) => d.get_fence_status(fence.mtl()),
} }
} }
@ -229,6 +269,7 @@ impl Device {
mux_match! { self; mux_match! { self;
Device::Vk(d) => d.create_semaphore().map(Semaphore::Vk), Device::Vk(d) => d.create_semaphore().map(Semaphore::Vk),
Device::Dx12(d) => d.create_semaphore().map(Semaphore::Dx12), Device::Dx12(d) => d.create_semaphore().map(Semaphore::Dx12),
Device::Mtl(d) => d.create_semaphore().map(Semaphore::Mtl),
} }
} }
@ -236,6 +277,7 @@ impl Device {
mux_match! { self; mux_match! { self;
Device::Vk(d) => PipelineBuilder::Vk(d.pipeline_builder()), Device::Vk(d) => PipelineBuilder::Vk(d.pipeline_builder()),
Device::Dx12(d) => PipelineBuilder::Dx12(d.pipeline_builder()), Device::Dx12(d) => PipelineBuilder::Dx12(d.pipeline_builder()),
Device::Mtl(d) => PipelineBuilder::Mtl(d.pipeline_builder()),
} }
} }
@ -243,6 +285,7 @@ impl Device {
mux_match! { self; mux_match! { self;
Device::Vk(d) => DescriptorSetBuilder::Vk(d.descriptor_set_builder()), Device::Vk(d) => DescriptorSetBuilder::Vk(d.descriptor_set_builder()),
Device::Dx12(d) => DescriptorSetBuilder::Dx12(d.descriptor_set_builder()), Device::Dx12(d) => DescriptorSetBuilder::Dx12(d.descriptor_set_builder()),
Device::Mtl(d) => DescriptorSetBuilder::Mtl(d.descriptor_set_builder()),
} }
} }
@ -250,6 +293,7 @@ impl Device {
mux_match! { self; mux_match! { self;
Device::Vk(d) => d.create_cmd_buf().map(CmdBuf::Vk), Device::Vk(d) => d.create_cmd_buf().map(CmdBuf::Vk),
Device::Dx12(d) => d.create_cmd_buf().map(CmdBuf::Dx12), Device::Dx12(d) => d.create_cmd_buf().map(CmdBuf::Dx12),
Device::Mtl(d) => d.create_cmd_buf().map(CmdBuf::Mtl),
} }
} }
@ -257,6 +301,7 @@ impl Device {
mux_match! { self; mux_match! { self;
Device::Vk(d) => d.create_query_pool(n_queries).map(QueryPool::Vk), Device::Vk(d) => d.create_query_pool(n_queries).map(QueryPool::Vk),
Device::Dx12(d) => d.create_query_pool(n_queries).map(QueryPool::Dx12), Device::Dx12(d) => d.create_query_pool(n_queries).map(QueryPool::Dx12),
Device::Mtl(d) => d.create_query_pool(n_queries).map(QueryPool::Mtl),
} }
} }
@ -264,6 +309,7 @@ impl Device {
mux_match! { self; mux_match! { self;
Device::Vk(d) => d.fetch_query_pool(pool.vk()), Device::Vk(d) => d.fetch_query_pool(pool.vk()),
Device::Dx12(d) => d.fetch_query_pool(pool.dx12()), Device::Dx12(d) => d.fetch_query_pool(pool.dx12()),
Device::Mtl(d) => d.fetch_query_pool(pool.mtl()),
} }
} }
@ -309,6 +355,23 @@ impl Device {
.collect::<SmallVec<[_; 4]>>(), .collect::<SmallVec<[_; 4]>>(),
fence.map(Fence::dx12), fence.map(Fence::dx12),
), ),
Device::Mtl(d) => d.run_cmd_bufs(
&cmd_bufs
.iter()
.map(|c| c.mtl())
.collect::<SmallVec<[_; 4]>>(),
&wait_semaphores
.iter()
.copied()
.map(Semaphore::mtl)
.collect::<SmallVec<[_; 4]>>(),
&signal_semaphores
.iter()
.copied()
.map(Semaphore::mtl)
.collect::<SmallVec<[_; 4]>>(),
fence.map(Fence::mtl),
),
} }
} }
@ -322,6 +385,7 @@ impl Device {
mux_match! { self; mux_match! { self;
Device::Vk(d) => d.read_buffer(buffer.vk(), dst, offset, size), Device::Vk(d) => d.read_buffer(buffer.vk(), dst, offset, size),
Device::Dx12(d) => d.read_buffer(buffer.dx12(), dst, offset, size), Device::Dx12(d) => d.read_buffer(buffer.dx12(), dst, offset, size),
Device::Mtl(d) => d.read_buffer(buffer.mtl(), dst, offset, size),
} }
} }
@ -335,6 +399,7 @@ impl Device {
mux_match! { self; mux_match! { self;
Device::Vk(d) => d.write_buffer(buffer.vk(), contents, offset, size), Device::Vk(d) => d.write_buffer(buffer.vk(), contents, offset, size),
Device::Dx12(d) => d.write_buffer(buffer.dx12(), contents, offset, size), Device::Dx12(d) => d.write_buffer(buffer.dx12(), contents, offset, size),
Device::Mtl(d) => d.write_buffer(buffer.mtl(), contents, offset, size),
} }
} }
} }
@ -344,6 +409,7 @@ impl PipelineBuilder {
mux_match! { self; mux_match! { self;
PipelineBuilder::Vk(x) => x.add_buffers(n_buffers), PipelineBuilder::Vk(x) => x.add_buffers(n_buffers),
PipelineBuilder::Dx12(x) => x.add_buffers(n_buffers), PipelineBuilder::Dx12(x) => x.add_buffers(n_buffers),
PipelineBuilder::Mtl(x) => x.add_buffers(n_buffers),
} }
} }
@ -351,6 +417,7 @@ impl PipelineBuilder {
mux_match! { self; mux_match! { self;
PipelineBuilder::Vk(x) => x.add_images(n_buffers), PipelineBuilder::Vk(x) => x.add_images(n_buffers),
PipelineBuilder::Dx12(x) => x.add_images(n_buffers), PipelineBuilder::Dx12(x) => x.add_images(n_buffers),
PipelineBuilder::Mtl(x) => x.add_images(n_buffers),
} }
} }
@ -358,6 +425,7 @@ impl PipelineBuilder {
mux_match! { self; mux_match! { self;
PipelineBuilder::Vk(x) => x.add_textures(n_buffers), PipelineBuilder::Vk(x) => x.add_textures(n_buffers),
PipelineBuilder::Dx12(x) => x.add_textures(n_buffers), PipelineBuilder::Dx12(x) => x.add_textures(n_buffers),
PipelineBuilder::Mtl(x) => x.add_textures(n_buffers),
} }
} }
@ -385,6 +453,15 @@ impl PipelineBuilder {
x.create_compute_pipeline(device.dx12(), shader_code) x.create_compute_pipeline(device.dx12(), shader_code)
.map(Pipeline::Dx12) .map(Pipeline::Dx12)
} }
PipelineBuilder::Mtl(x) => {
let shader_code = match code {
ShaderCode::Msl(msl) => msl,
// Panic or return "incompatible shader" error here?
_ => panic!("Metal backend requires shader code in MSL format"),
};
x.create_compute_pipeline(device.mtl(), shader_code)
.map(Pipeline::Mtl)
}
} }
} }
} }
@ -406,6 +483,13 @@ impl DescriptorSetBuilder {
.map(Buffer::dx12) .map(Buffer::dx12)
.collect::<SmallVec<[_; 8]>>(), .collect::<SmallVec<[_; 8]>>(),
), ),
DescriptorSetBuilder::Mtl(x) => x.add_buffers(
&buffers
.iter()
.copied()
.map(Buffer::mtl)
.collect::<SmallVec<[_; 8]>>(),
),
} }
} }
@ -425,6 +509,13 @@ impl DescriptorSetBuilder {
.map(Image::dx12) .map(Image::dx12)
.collect::<SmallVec<[_; 8]>>(), .collect::<SmallVec<[_; 8]>>(),
), ),
DescriptorSetBuilder::Mtl(x) => x.add_images(
&images
.iter()
.copied()
.map(Image::mtl)
.collect::<SmallVec<[_; 8]>>(),
),
} }
} }
@ -444,6 +535,13 @@ impl DescriptorSetBuilder {
.map(Image::dx12) .map(Image::dx12)
.collect::<SmallVec<[_; 8]>>(), .collect::<SmallVec<[_; 8]>>(),
), ),
DescriptorSetBuilder::Mtl(x) => x.add_textures(
&images
.iter()
.copied()
.map(Image::mtl)
.collect::<SmallVec<[_; 8]>>(),
),
} }
} }
@ -458,6 +556,9 @@ impl DescriptorSetBuilder {
DescriptorSetBuilder::Dx12(x) => x DescriptorSetBuilder::Dx12(x) => x
.build(device.dx12(), pipeline.dx12()) .build(device.dx12(), pipeline.dx12())
.map(DescriptorSet::Dx12), .map(DescriptorSet::Dx12),
DescriptorSetBuilder::Mtl(x) => x
.build(device.mtl(), pipeline.mtl())
.map(DescriptorSet::Mtl),
} }
} }
} }
@ -467,6 +568,7 @@ impl CmdBuf {
mux_match! { self; mux_match! { self;
CmdBuf::Vk(c) => c.begin(), CmdBuf::Vk(c) => c.begin(),
CmdBuf::Dx12(c) => c.begin(), CmdBuf::Dx12(c) => c.begin(),
CmdBuf::Mtl(c) => c.begin(),
} }
} }
@ -474,6 +576,7 @@ impl CmdBuf {
mux_match! { self; mux_match! { self;
CmdBuf::Vk(c) => c.finish(), CmdBuf::Vk(c) => c.finish(),
CmdBuf::Dx12(c) => c.finish(), CmdBuf::Dx12(c) => c.finish(),
CmdBuf::Mtl(c) => c.finish(),
} }
} }
@ -486,6 +589,7 @@ impl CmdBuf {
mux_match! { self; mux_match! { self;
CmdBuf::Vk(c) => c.dispatch(pipeline.vk(), descriptor_set.vk(), size), CmdBuf::Vk(c) => c.dispatch(pipeline.vk(), descriptor_set.vk(), size),
CmdBuf::Dx12(c) => c.dispatch(pipeline.dx12(), descriptor_set.dx12(), size), CmdBuf::Dx12(c) => c.dispatch(pipeline.dx12(), descriptor_set.dx12(), size),
CmdBuf::Mtl(c) => c.dispatch(pipeline.mtl(), descriptor_set.mtl(), size),
} }
} }
@ -493,6 +597,7 @@ impl CmdBuf {
mux_match! { self; mux_match! { self;
CmdBuf::Vk(c) => c.memory_barrier(), CmdBuf::Vk(c) => c.memory_barrier(),
CmdBuf::Dx12(c) => c.memory_barrier(), CmdBuf::Dx12(c) => c.memory_barrier(),
CmdBuf::Mtl(c) => c.memory_barrier(),
} }
} }
@ -500,6 +605,7 @@ impl CmdBuf {
mux_match! { self; mux_match! { self;
CmdBuf::Vk(c) => c.host_barrier(), CmdBuf::Vk(c) => c.host_barrier(),
CmdBuf::Dx12(c) => c.host_barrier(), CmdBuf::Dx12(c) => c.host_barrier(),
CmdBuf::Mtl(c) => c.host_barrier(),
} }
} }
@ -512,6 +618,7 @@ impl CmdBuf {
mux_match! { self; mux_match! { self;
CmdBuf::Vk(c) => c.image_barrier(image.vk(), src_layout, dst_layout), CmdBuf::Vk(c) => c.image_barrier(image.vk(), src_layout, dst_layout),
CmdBuf::Dx12(c) => c.image_barrier(image.dx12(), src_layout, dst_layout), CmdBuf::Dx12(c) => c.image_barrier(image.dx12(), src_layout, dst_layout),
CmdBuf::Mtl(c) => c.image_barrier(image.mtl(), src_layout, dst_layout),
} }
} }
@ -519,6 +626,7 @@ impl CmdBuf {
mux_match! { self; mux_match! { self;
CmdBuf::Vk(c) => c.clear_buffer(buffer.vk(), size), CmdBuf::Vk(c) => c.clear_buffer(buffer.vk(), size),
CmdBuf::Dx12(c) => c.clear_buffer(buffer.dx12(), size), CmdBuf::Dx12(c) => c.clear_buffer(buffer.dx12(), size),
CmdBuf::Mtl(c) => c.clear_buffer(buffer.mtl(), size),
} }
} }
@ -526,6 +634,7 @@ impl CmdBuf {
mux_match! { self; mux_match! { self;
CmdBuf::Vk(c) => c.copy_buffer(src.vk(), dst.vk()), CmdBuf::Vk(c) => c.copy_buffer(src.vk(), dst.vk()),
CmdBuf::Dx12(c) => c.copy_buffer(src.dx12(), dst.dx12()), CmdBuf::Dx12(c) => c.copy_buffer(src.dx12(), dst.dx12()),
CmdBuf::Mtl(c) => c.copy_buffer(src.mtl(), dst.mtl()),
} }
} }
@ -533,6 +642,7 @@ impl CmdBuf {
mux_match! { self; mux_match! { self;
CmdBuf::Vk(c) => c.copy_image_to_buffer(src.vk(), dst.vk()), CmdBuf::Vk(c) => c.copy_image_to_buffer(src.vk(), dst.vk()),
CmdBuf::Dx12(c) => c.copy_image_to_buffer(src.dx12(), dst.dx12()), CmdBuf::Dx12(c) => c.copy_image_to_buffer(src.dx12(), dst.dx12()),
CmdBuf::Mtl(c) => c.copy_image_to_buffer(src.mtl(), dst.mtl()),
} }
} }
@ -540,6 +650,7 @@ impl CmdBuf {
mux_match! { self; mux_match! { self;
CmdBuf::Vk(c) => c.copy_buffer_to_image(src.vk(), dst.vk()), CmdBuf::Vk(c) => c.copy_buffer_to_image(src.vk(), dst.vk()),
CmdBuf::Dx12(c) => c.copy_buffer_to_image(src.dx12(), dst.dx12()), CmdBuf::Dx12(c) => c.copy_buffer_to_image(src.dx12(), dst.dx12()),
CmdBuf::Mtl(c) => c.copy_buffer_to_image(src.mtl(), dst.mtl()),
} }
} }
@ -547,6 +658,7 @@ impl CmdBuf {
mux_match! { self; mux_match! { self;
CmdBuf::Vk(c) => c.blit_image(src.vk(), dst.vk()), CmdBuf::Vk(c) => c.blit_image(src.vk(), dst.vk()),
CmdBuf::Dx12(c) => c.blit_image(src.dx12(), dst.dx12()), CmdBuf::Dx12(c) => c.blit_image(src.dx12(), dst.dx12()),
CmdBuf::Mtl(c) => c.blit_image(src.mtl(), dst.mtl()),
} }
} }
@ -554,6 +666,7 @@ impl CmdBuf {
mux_match! { self; mux_match! { self;
CmdBuf::Vk(c) => c.reset_query_pool(pool.vk()), CmdBuf::Vk(c) => c.reset_query_pool(pool.vk()),
CmdBuf::Dx12(c) => c.reset_query_pool(pool.dx12()), CmdBuf::Dx12(c) => c.reset_query_pool(pool.dx12()),
CmdBuf::Mtl(c) => c.reset_query_pool(pool.mtl()),
} }
} }
@ -561,6 +674,7 @@ impl CmdBuf {
mux_match! { self; mux_match! { self;
CmdBuf::Vk(c) => c.write_timestamp(pool.vk(), query), CmdBuf::Vk(c) => c.write_timestamp(pool.vk(), query),
CmdBuf::Dx12(c) => c.write_timestamp(pool.dx12(), query), CmdBuf::Dx12(c) => c.write_timestamp(pool.dx12(), query),
CmdBuf::Mtl(c) => c.write_timestamp(pool.mtl(), query),
} }
} }
@ -568,6 +682,7 @@ impl CmdBuf {
mux_match! { self; mux_match! { self;
CmdBuf::Vk(c) => c.finish_timestamps(pool.vk()), CmdBuf::Vk(c) => c.finish_timestamps(pool.vk()),
CmdBuf::Dx12(c) => c.finish_timestamps(pool.dx12()), CmdBuf::Dx12(c) => c.finish_timestamps(pool.dx12()),
CmdBuf::Mtl(c) => c.finish_timestamps(pool.mtl()),
} }
} }
} }
@ -577,6 +692,7 @@ impl Buffer {
mux_match! { self; mux_match! { self;
Buffer::Vk(b) => b.size, Buffer::Vk(b) => b.size,
Buffer::Dx12(b) => b.size, Buffer::Dx12(b) => b.size,
Buffer::Mtl(b) => b.size,
} }
} }
} }
@ -592,6 +708,10 @@ impl Swapchain {
let (idx, sem) = s.next()?; let (idx, sem) = s.next()?;
Ok((idx, Semaphore::Dx12(sem))) Ok((idx, Semaphore::Dx12(sem)))
} }
Swapchain::Mtl(s) => {
let (idx, sem) = s.next()?;
Ok((idx, Semaphore::Mtl(sem)))
}
} }
} }
@ -599,6 +719,7 @@ impl Swapchain {
mux_match! { self; mux_match! { self;
Swapchain::Vk(s) => Image::Vk(s.image(idx)), Swapchain::Vk(s) => Image::Vk(s.image(idx)),
Swapchain::Dx12(s) => Image::Dx12(s.image(idx)), Swapchain::Dx12(s) => Image::Dx12(s.image(idx)),
Swapchain::Mtl(s) => Image::Mtl(s.image(idx)),
} }
} }
@ -624,6 +745,14 @@ impl Swapchain {
.map(Semaphore::dx12) .map(Semaphore::dx12)
.collect::<SmallVec<[_; 4]>>(), .collect::<SmallVec<[_; 4]>>(),
), ),
Swapchain::Mtl(s) => s.present(
image_idx,
&semaphores
.iter()
.copied()
.map(Semaphore::mtl)
.collect::<SmallVec<[_; 4]>>(),
),
} }
} }
} }