diff --git a/piet-gpu-hal/examples/collatz.rs b/piet-gpu-hal/examples/collatz.rs new file mode 100644 index 0000000..533906d --- /dev/null +++ b/piet-gpu-hal/examples/collatz.rs @@ -0,0 +1,28 @@ +use piet_gpu_hal::vulkan::VkInstance; +use piet_gpu_hal::{CmdBuf, Device, MemFlags}; + +fn main() { + let instance = VkInstance::new().unwrap(); + unsafe { + let device = instance.device().unwrap(); + let mem_flags = MemFlags::host_coherent(); + let src = (0..256).map(|x| x + 1).collect::>(); + let buffer = device + .create_buffer(std::mem::size_of_val(&src[..]) as u64, mem_flags) + .unwrap(); + device.write_buffer(&buffer, &src).unwrap(); + let code = include_bytes!("./shader/collatz.spv"); + let pipeline = device.create_simple_compute_pipeline(code, 1).unwrap(); + let descriptor_set = device.create_descriptor_set(&pipeline, &[&buffer]).unwrap(); + let mut cmd_buf = device.create_cmd_buf().unwrap(); + cmd_buf.begin(); + cmd_buf.dispatch(&pipeline, &descriptor_set); + cmd_buf.finish(); + device.run_cmd_buf(&cmd_buf).unwrap(); + let mut dst: Vec = Default::default(); + device.read_buffer(&buffer, &mut dst).unwrap(); + for (i, val) in dst.iter().enumerate().take(16) { + println!("{}: {}", i, val); + } + } +} diff --git a/piet-gpu-hal/examples/shader/build.ninja b/piet-gpu-hal/examples/shader/build.ninja new file mode 100644 index 0000000..848637a --- /dev/null +++ b/piet-gpu-hal/examples/shader/build.ninja @@ -0,0 +1,10 @@ +# Build file for shaders. + +# You must have glslangValidator in your path, or patch here. + +glslang_validator = glslangValidator + +rule glsl + command = $glslang_validator -V -o $out $in + +build collatz.spv: glsl collatz.comp diff --git a/piet-gpu-hal/examples/shader/collatz.comp b/piet-gpu-hal/examples/shader/collatz.comp new file mode 100644 index 0000000..7c0e2ab --- /dev/null +++ b/piet-gpu-hal/examples/shader/collatz.comp @@ -0,0 +1,35 @@ +// Copied from wgpu hello-compute example + +// TODO: delete or clean up attribution before releasing + +#version 450 +layout(local_size_x = 1) in; + +layout(set = 0, binding = 0) buffer PrimeIndices { + uint[] indices; +}; // this is used as both input and output for convenience + +// The Collatz Conjecture states that for any integer n: +// If n is even, n = n/2 +// If n is odd, n = 3n+1 +// And repeat this process for each new n, you will always eventually reach 1. +// Though the conjecture has not been proven, no counterexample has ever been found. +// This function returns how many times this recurrence needs to be applied to reach 1. +uint collatz_iterations(uint n) { + uint i = 0; + while(n != 1) { + if (mod(n, 2) == 0) { + n = n / 2; + } + else { + n = (3 * n) + 1; + } + i++; + } + return i; +} + +void main() { + uint index = gl_GlobalInvocationID.x; + indices[index] = collatz_iterations(indices[index]); +} diff --git a/piet-gpu-hal/examples/shader/collatz.spv b/piet-gpu-hal/examples/shader/collatz.spv new file mode 100644 index 0000000..21e4e92 Binary files /dev/null and b/piet-gpu-hal/examples/shader/collatz.spv differ diff --git a/piet-gpu-hal/src/lib.rs b/piet-gpu-hal/src/lib.rs index 79c1e01..07ba686 100644 --- a/piet-gpu-hal/src/lib.rs +++ b/piet-gpu-hal/src/lib.rs @@ -3,9 +3,6 @@ /// This abstraction is inspired by gfx-hal, but is specialized to the needs of piet-gpu. /// In time, it may go away and be replaced by either gfx-hal or wgpu. -#[macro_use] -extern crate ash; - pub mod vulkan; /// This isn't great but is expedient. @@ -13,7 +10,7 @@ type Error = Box; pub trait Device: Sized { type Buffer; - type MemFlags; + type MemFlags: MemFlags; type Pipeline; type DescriptorSet; type CmdBuf: CmdBuf; @@ -58,3 +55,7 @@ pub trait CmdBuf { unsafe fn memory_barrier(&mut self); } + +pub trait MemFlags: Sized { + fn host_coherent() -> Self; +} diff --git a/piet-gpu-hal/src/vulkan.rs b/piet-gpu-hal/src/vulkan.rs index a2b5631..6d6f4d0 100644 --- a/piet-gpu-hal/src/vulkan.rs +++ b/piet-gpu-hal/src/vulkan.rs @@ -8,18 +8,15 @@ use ash::{vk, Device, Entry, Instance}; use crate::Error; -/// A base for allocating resources and dispatching work. -/// -/// This is quite similar to "device" in most GPU API's, but I didn't want to overload -/// that term further. -pub struct Base { +pub struct VkInstance { /// Retain the dynamic lib. #[allow(unused)] entry: Entry, - #[allow(unused)] instance: Instance, +} +pub struct VkDevice { device: Arc, device_mem_props: vk::PhysicalDeviceMemoryProperties, queue: vk::Queue, @@ -55,12 +52,14 @@ pub struct CmdBuf { device: Arc, } -impl Base { +pub struct MemFlags(vk::MemoryPropertyFlags); + +impl VkInstance { /// Create a new instance. /// /// There's more to be done to make this suitable for integration with other /// systems, but for now the goal is to make things simple. - pub fn new() -> Result { + pub fn new() -> Result { unsafe { let app_name = CString::new("VkToy").unwrap(); let entry = Entry::new()?; @@ -75,43 +74,63 @@ impl Base { None, )?; - let devices = instance.enumerate_physical_devices()?; - let (pdevice, qfi) = - choose_compute_device(&instance, &devices).ok_or("no suitable device")?; - - let device = instance.create_device( - pdevice, - &vk::DeviceCreateInfo::builder().queue_create_infos(&[ - vk::DeviceQueueCreateInfo::builder() - .queue_family_index(qfi) - .queue_priorities(&[1.0]) - .build(), - ]), - None, - )?; - - let device_mem_props = instance.get_physical_device_memory_properties(pdevice); - - let queue_index = 0; - let queue = device.get_device_queue(qfi, queue_index); - - let device = Arc::new(RawDevice { device }); - - Ok(Base { + Ok(VkInstance { entry, instance, - device, - device_mem_props, - qfi, - queue, }) } } - pub fn create_buffer( + /// Create a device from the instance, suitable for compute. + /// + /// # Safety + /// + /// The caller is responsible for making sure that the instance outlives the device. + /// We could enforce that, for example having an `Arc` of the raw instance, but for + /// now keep things simple. + pub unsafe fn device(&self) -> Result { + let devices = self.instance.enumerate_physical_devices()?; + let (pdevice, qfi) = + choose_compute_device(&self.instance, &devices).ok_or("no suitable device")?; + + let device = self.instance.create_device( + pdevice, + &vk::DeviceCreateInfo::builder().queue_create_infos(&[ + vk::DeviceQueueCreateInfo::builder() + .queue_family_index(qfi) + .queue_priorities(&[1.0]) + .build(), + ]), + None, + )?; + + let device_mem_props = self.instance.get_physical_device_memory_properties(pdevice); + + let queue_index = 0; + let queue = device.get_device_queue(qfi, queue_index); + + let device = Arc::new(RawDevice { device }); + + Ok(VkDevice { + device, + device_mem_props, + qfi, + queue, + }) + } +} + +impl crate::Device for VkDevice { + type Buffer = Buffer; + type CmdBuf = CmdBuf; + type DescriptorSet = DescriptorSet; + type Pipeline = Pipeline; + type MemFlags = MemFlags; + + fn create_buffer( &self, size: u64, - mem_flags: vk::MemoryPropertyFlags, + mem_flags: MemFlags, ) -> Result { unsafe { let device = &self.device.device; @@ -125,7 +144,7 @@ impl Base { let mem_requirements = device.get_buffer_memory_requirements(buffer); let mem_type = find_memory_type( mem_requirements.memory_type_bits, - mem_flags, + mem_flags.0, &self.device_mem_props, ) .unwrap(); // TODO: proper error @@ -148,7 +167,7 @@ impl Base { /// /// The code is included from "../comp.spv", and the descriptor set layout is just some /// number of buffers. - pub unsafe fn create_simple_compute_pipeline( + unsafe fn create_simple_compute_pipeline( &self, code: &[u8], n_buffers: u32, @@ -199,7 +218,7 @@ impl Base { }) } - pub unsafe fn create_descriptor_set( + unsafe fn create_descriptor_set( &self, pipeline: &Pipeline, bufs: &[&Buffer], @@ -247,7 +266,7 @@ impl Base { }) } - pub fn create_cmd_buf(&self) -> Result { + fn create_cmd_buf(&self) -> Result { unsafe { let device = &self.device.device; let command_pool = device.create_command_pool( @@ -272,7 +291,7 @@ impl Base { /// Run the command buffer. /// /// This version simply blocks until it's complete. - pub unsafe fn run_cmd_buf(&self, cmd_buf: &CmdBuf) -> Result<(), Error> { + unsafe fn run_cmd_buf(&self, cmd_buf: &CmdBuf) -> Result<(), Error> { let device = &self.device.device; // Run the command buffer. @@ -287,12 +306,12 @@ impl Base { .build()], fence, )?; - device.wait_for_fences(&[fence], true, 1_000_000)?; - device.destroy_fence(fence, None); + device.wait_for_fences(&[fence], true, 100_000_000)?; + // TODO: handle errors better (currently leaks fence and can lead to other problems) Ok(()) } - pub unsafe fn read_buffer( + unsafe fn read_buffer( &self, buffer: &Buffer, result: &mut Vec, @@ -314,7 +333,7 @@ impl Base { Ok(()) } - pub unsafe fn write_buffer( + unsafe fn write_buffer( &self, buffer: &Buffer, contents: &[T], @@ -332,8 +351,8 @@ impl Base { } } -impl CmdBuf { - pub unsafe fn begin(&mut self) { +impl crate::CmdBuf for CmdBuf { + unsafe fn begin(&mut self) { self.device .device .begin_command_buffer( @@ -344,11 +363,11 @@ impl CmdBuf { .unwrap(); } - pub unsafe fn finish(&mut self) { + unsafe fn finish(&mut self) { self.device.device.end_command_buffer(self.cmd_buf).unwrap(); } - pub unsafe fn dispatch(&mut self, pipeline: &Pipeline, descriptor_set: &DescriptorSet) { + unsafe fn dispatch(&mut self, pipeline: &Pipeline, descriptor_set: &DescriptorSet) { let device = &self.device.device; device.cmd_bind_pipeline( self.cmd_buf, @@ -367,8 +386,7 @@ impl CmdBuf { } /// Insert a pipeline barrier for all memory accesses. - #[allow(unused)] - pub unsafe fn memory_barrier(&mut self) { + unsafe fn memory_barrier(&mut self) { let device = &self.device.device; device.cmd_pipeline_barrier( self.cmd_buf, @@ -385,6 +403,12 @@ impl CmdBuf { } } +impl crate::MemFlags for MemFlags { + fn host_coherent() -> Self { + MemFlags(vk::MemoryPropertyFlags::HOST_VISIBLE | vk::MemoryPropertyFlags::HOST_COHERENT) + } +} + unsafe fn choose_compute_device( instance: &Instance, devices: &[vk::PhysicalDevice],