Use bytemuck

Get rid of `PlainData` trait and use `Pod` from bytemuck instead.
This commit is contained in:
Raph Levien 2021-11-23 08:24:16 -08:00
parent ecdd7fd817
commit 2ebdd942cf
5 changed files with 27 additions and 30 deletions

1
Cargo.lock generated
View file

@ -904,6 +904,7 @@ dependencies = [
"ash-window", "ash-window",
"bitflags", "bitflags",
"block", "block",
"bytemuck",
"cocoa-foundation", "cocoa-foundation",
"metal", "metal",
"objc", "objc",

View file

@ -12,6 +12,7 @@ ash-window = "0.7"
raw-window-handle = "0.3" raw-window-handle = "0.3"
bitflags = "1.2.1" bitflags = "1.2.1"
smallvec = "1.6.1" smallvec = "1.6.1"
bytemuck = "1.7.2"
[target.'cfg(target_os="windows")'.dependencies] [target.'cfg(target_os="windows")'.dependencies]
winapi = { version = "0.3.9", features = [ winapi = { version = "0.3.9", features = [

View file

@ -9,6 +9,7 @@
use std::convert::TryInto; use std::convert::TryInto;
use std::sync::{Arc, Mutex, Weak}; use std::sync::{Arc, Mutex, Weak};
use bytemuck::Pod;
use smallvec::SmallVec; use smallvec::SmallVec;
use crate::{mux, BackendType}; use crate::{mux, BackendType};
@ -105,20 +106,6 @@ struct BufferInner {
/// Add bindings to the descriptor set before dispatching a shader. /// Add bindings to the descriptor set before dispatching a shader.
pub struct DescriptorSetBuilder(mux::DescriptorSetBuilder); pub struct DescriptorSetBuilder(mux::DescriptorSetBuilder);
/// Data types that can be stored in a GPU buffer.
pub unsafe trait PlainData {}
unsafe impl PlainData for u8 {}
unsafe impl PlainData for u16 {}
unsafe impl PlainData for u32 {}
unsafe impl PlainData for u64 {}
unsafe impl PlainData for i8 {}
unsafe impl PlainData for i16 {}
unsafe impl PlainData for i32 {}
unsafe impl PlainData for i64 {}
unsafe impl PlainData for f32 {}
unsafe impl PlainData for f64 {}
/// A resource to retain during the lifetime of a command submission. /// A resource to retain during the lifetime of a command submission.
pub enum RetainResource { pub enum RetainResource {
Buffer(Buffer), Buffer(Buffer),
@ -242,15 +229,12 @@ impl Session {
/// the buffer will subsequently be written by the host. /// the buffer will subsequently be written by the host.
pub fn create_buffer_init( pub fn create_buffer_init(
&self, &self,
contents: &[impl PlainData], contents: &[impl Pod],
usage: BufferUsage, usage: BufferUsage,
) -> Result<Buffer, Error> { ) -> Result<Buffer, Error> {
unsafe { unsafe {
self.create_buffer_init_raw( let bytes = bytemuck::cast_slice(contents);
contents.as_ptr() as *const u8, self.create_buffer_init_raw(bytes.as_ptr(), bytes.len().try_into()?, usage)
std::mem::size_of_val(contents).try_into()?,
usage,
)
} }
} }
@ -682,13 +666,14 @@ impl Buffer {
/// ///
/// The buffer must have been created with `MAP_WRITE` usage, and with /// The buffer must have been created with `MAP_WRITE` usage, and with
/// a size large enough to accommodate the given slice. /// a size large enough to accommodate the given slice.
pub unsafe fn write<T: PlainData>(&mut self, contents: &[T]) -> Result<(), Error> { pub unsafe fn write(&mut self, contents: &[impl Pod]) -> Result<(), Error> {
let bytes = bytemuck::cast_slice(contents);
if let Some(session) = Weak::upgrade(&self.0.session) { if let Some(session) = Weak::upgrade(&self.0.session) {
session.device.write_buffer( session.device.write_buffer(
&self.0.buffer, &self.0.buffer,
contents.as_ptr() as *const u8, bytes.as_ptr(),
0, 0,
std::mem::size_of_val(contents).try_into()?, bytes.len().try_into()?,
)?; )?;
} }
// else session lost error? // else session lost error?
@ -700,8 +685,10 @@ impl Buffer {
/// The buffer must have been created with `MAP_READ` usage. The caller /// The buffer must have been created with `MAP_READ` usage. The caller
/// is also responsible for ensuring that this does not read uninitialized /// is also responsible for ensuring that this does not read uninitialized
/// memory. /// memory.
pub unsafe fn read<T: PlainData>(&self, result: &mut Vec<T>) -> Result<(), Error> { pub unsafe fn read<T: Pod>(&self, result: &mut Vec<T>) -> Result<(), Error> {
let size = self.mux_buffer().size(); let size = self.mux_buffer().size();
// TODO: can bytemuck grow a method to do this more safely?
// It's similar to pod_collect_to_vec.
let len = size as usize / std::mem::size_of::<T>(); let len = size as usize / std::mem::size_of::<T>();
if len > result.len() { if len > result.len() {
result.reserve(len - result.len()); result.reserve(len - result.len());

View file

@ -19,8 +19,7 @@ pub use crate::mux::{
Swapchain, Swapchain,
}; };
pub use hub::{ pub use hub::{
Buffer, CmdBuf, DescriptorSetBuilder, Image, PlainData, RetainResource, Session, Buffer, CmdBuf, DescriptorSetBuilder, Image, RetainResource, Session, SubmittedCmdBuf,
SubmittedCmdBuf,
}; };
// TODO: because these are conditionally included, "cargo fmt" does not // TODO: because these are conditionally included, "cargo fmt" does not

View file

@ -42,16 +42,25 @@ pub fn make_clear_pipeline(device: &Device) -> ComputePipelineState {
let library = device.new_library_with_source(CLEAR_MSL, &options).unwrap(); let library = device.new_library_with_source(CLEAR_MSL, &options).unwrap();
let function = library.get_function("main0", None).unwrap(); let function = library.get_function("main0", None).unwrap();
device device
.new_compute_pipeline_state_with_function(&function).unwrap() .new_compute_pipeline_state_with_function(&function)
.unwrap()
} }
pub fn encode_clear(encoder: &metal::ComputeCommandEncoderRef, clear_pipeline: &ComputePipelineState, buffer: &metal::Buffer, size: u64) { pub fn encode_clear(
encoder: &metal::ComputeCommandEncoderRef,
clear_pipeline: &ComputePipelineState,
buffer: &metal::Buffer,
size: u64,
) {
// TODO: should be more careful with overflow // TODO: should be more careful with overflow
let size_in_u32s = (size / 4) as u32; let size_in_u32s = (size / 4) as u32;
encoder.set_compute_pipeline_state(&clear_pipeline); encoder.set_compute_pipeline_state(&clear_pipeline);
let config = [size_in_u32s, 0]; let config = [size_in_u32s, 0];
encoder.set_bytes(0, std::mem::size_of_val(&config) as u64, config.as_ptr() as *const _); encoder.set_bytes(
0,
std::mem::size_of_val(&config) as u64,
config.as_ptr() as *const _,
);
encoder.set_buffer(1, Some(buffer), 0); encoder.set_buffer(1, Some(buffer), 0);
let n_wg = (size_in_u32s + 255) / 256; let n_wg = (size_in_u32s + 255) / 256;
let workgroup_count = metal::MTLSize { let workgroup_count = metal::MTLSize {