Use bytemuck

Get rid of `PlainData` trait and use `Pod` from bytemuck instead.
This commit is contained in:
Raph Levien 2021-11-23 08:24:16 -08:00
parent ecdd7fd817
commit 2ebdd942cf
5 changed files with 27 additions and 30 deletions

1
Cargo.lock generated
View file

@ -904,6 +904,7 @@ dependencies = [
"ash-window",
"bitflags",
"block",
"bytemuck",
"cocoa-foundation",
"metal",
"objc",

View file

@ -12,6 +12,7 @@ ash-window = "0.7"
raw-window-handle = "0.3"
bitflags = "1.2.1"
smallvec = "1.6.1"
bytemuck = "1.7.2"
[target.'cfg(target_os="windows")'.dependencies]
winapi = { version = "0.3.9", features = [

View file

@ -9,6 +9,7 @@
use std::convert::TryInto;
use std::sync::{Arc, Mutex, Weak};
use bytemuck::Pod;
use smallvec::SmallVec;
use crate::{mux, BackendType};
@ -105,20 +106,6 @@ struct BufferInner {
/// Add bindings to the descriptor set before dispatching a shader.
pub struct DescriptorSetBuilder(mux::DescriptorSetBuilder);
/// Data types that can be stored in a GPU buffer.
pub unsafe trait PlainData {}
unsafe impl PlainData for u8 {}
unsafe impl PlainData for u16 {}
unsafe impl PlainData for u32 {}
unsafe impl PlainData for u64 {}
unsafe impl PlainData for i8 {}
unsafe impl PlainData for i16 {}
unsafe impl PlainData for i32 {}
unsafe impl PlainData for i64 {}
unsafe impl PlainData for f32 {}
unsafe impl PlainData for f64 {}
/// A resource to retain during the lifetime of a command submission.
pub enum RetainResource {
Buffer(Buffer),
@ -242,15 +229,12 @@ impl Session {
/// the buffer will subsequently be written by the host.
pub fn create_buffer_init(
&self,
contents: &[impl PlainData],
contents: &[impl Pod],
usage: BufferUsage,
) -> Result<Buffer, Error> {
unsafe {
self.create_buffer_init_raw(
contents.as_ptr() as *const u8,
std::mem::size_of_val(contents).try_into()?,
usage,
)
let bytes = bytemuck::cast_slice(contents);
self.create_buffer_init_raw(bytes.as_ptr(), bytes.len().try_into()?, usage)
}
}
@ -682,13 +666,14 @@ impl Buffer {
///
/// The buffer must have been created with `MAP_WRITE` usage, and with
/// a size large enough to accommodate the given slice.
pub unsafe fn write<T: PlainData>(&mut self, contents: &[T]) -> Result<(), Error> {
pub unsafe fn write(&mut self, contents: &[impl Pod]) -> Result<(), Error> {
let bytes = bytemuck::cast_slice(contents);
if let Some(session) = Weak::upgrade(&self.0.session) {
session.device.write_buffer(
&self.0.buffer,
contents.as_ptr() as *const u8,
bytes.as_ptr(),
0,
std::mem::size_of_val(contents).try_into()?,
bytes.len().try_into()?,
)?;
}
// else session lost error?
@ -700,8 +685,10 @@ impl Buffer {
/// The buffer must have been created with `MAP_READ` usage. The caller
/// is also responsible for ensuring that this does not read uninitialized
/// memory.
pub unsafe fn read<T: PlainData>(&self, result: &mut Vec<T>) -> Result<(), Error> {
pub unsafe fn read<T: Pod>(&self, result: &mut Vec<T>) -> Result<(), Error> {
let size = self.mux_buffer().size();
// TODO: can bytemuck grow a method to do this more safely?
// It's similar to pod_collect_to_vec.
let len = size as usize / std::mem::size_of::<T>();
if len > result.len() {
result.reserve(len - result.len());

View file

@ -19,8 +19,7 @@ pub use crate::mux::{
Swapchain,
};
pub use hub::{
Buffer, CmdBuf, DescriptorSetBuilder, Image, PlainData, RetainResource, Session,
SubmittedCmdBuf,
Buffer, CmdBuf, DescriptorSetBuilder, Image, RetainResource, Session, SubmittedCmdBuf,
};
// TODO: because these are conditionally included, "cargo fmt" does not

View file

@ -42,16 +42,25 @@ pub fn make_clear_pipeline(device: &Device) -> ComputePipelineState {
let library = device.new_library_with_source(CLEAR_MSL, &options).unwrap();
let function = library.get_function("main0", None).unwrap();
device
.new_compute_pipeline_state_with_function(&function).unwrap()
.new_compute_pipeline_state_with_function(&function)
.unwrap()
}
pub fn encode_clear(encoder: &metal::ComputeCommandEncoderRef, clear_pipeline: &ComputePipelineState, buffer: &metal::Buffer, size: u64) {
pub fn encode_clear(
encoder: &metal::ComputeCommandEncoderRef,
clear_pipeline: &ComputePipelineState,
buffer: &metal::Buffer,
size: u64,
) {
// TODO: should be more careful with overflow
let size_in_u32s = (size / 4) as u32;
encoder.set_compute_pipeline_state(&clear_pipeline);
let config = [size_in_u32s, 0];
encoder.set_bytes(0, std::mem::size_of_val(&config) as u64, config.as_ptr() as *const _);
encoder.set_bytes(
0,
std::mem::size_of_val(&config) as u64,
config.as_ptr() as *const _,
);
encoder.set_buffer(1, Some(buffer), 0);
let n_wg = (size_in_u32s + 255) / 256;
let workgroup_count = metal::MTLSize {