Get Collatz example working

Rework Collatz example to use new traits.
This commit is contained in:
Raph Levien 2020-04-06 12:11:37 -07:00
parent 1b0248fbbf
commit 1e1b9ff319
6 changed files with 154 additions and 56 deletions

View file

@ -0,0 +1,28 @@
use piet_gpu_hal::vulkan::VkInstance;
use piet_gpu_hal::{CmdBuf, Device, MemFlags};
fn main() {
let instance = VkInstance::new().unwrap();
unsafe {
let device = instance.device().unwrap();
let mem_flags = MemFlags::host_coherent();
let src = (0..256).map(|x| x + 1).collect::<Vec<u32>>();
let buffer = device
.create_buffer(std::mem::size_of_val(&src[..]) as u64, mem_flags)
.unwrap();
device.write_buffer(&buffer, &src).unwrap();
let code = include_bytes!("./shader/collatz.spv");
let pipeline = device.create_simple_compute_pipeline(code, 1).unwrap();
let descriptor_set = device.create_descriptor_set(&pipeline, &[&buffer]).unwrap();
let mut cmd_buf = device.create_cmd_buf().unwrap();
cmd_buf.begin();
cmd_buf.dispatch(&pipeline, &descriptor_set);
cmd_buf.finish();
device.run_cmd_buf(&cmd_buf).unwrap();
let mut dst: Vec<u32> = Default::default();
device.read_buffer(&buffer, &mut dst).unwrap();
for (i, val) in dst.iter().enumerate().take(16) {
println!("{}: {}", i, val);
}
}
}

View file

@ -0,0 +1,10 @@
# Build file for shaders.
# You must have glslangValidator in your path, or patch here.
glslang_validator = glslangValidator
rule glsl
command = $glslang_validator -V -o $out $in
build collatz.spv: glsl collatz.comp

View file

@ -0,0 +1,35 @@
// Copied from wgpu hello-compute example
// TODO: delete or clean up attribution before releasing
#version 450
layout(local_size_x = 1) in;
layout(set = 0, binding = 0) buffer PrimeIndices {
uint[] indices;
}; // this is used as both input and output for convenience
// The Collatz Conjecture states that for any integer n:
// If n is even, n = n/2
// If n is odd, n = 3n+1
// And repeat this process for each new n, you will always eventually reach 1.
// Though the conjecture has not been proven, no counterexample has ever been found.
// This function returns how many times this recurrence needs to be applied to reach 1.
uint collatz_iterations(uint n) {
uint i = 0;
while(n != 1) {
if (mod(n, 2) == 0) {
n = n / 2;
}
else {
n = (3 * n) + 1;
}
i++;
}
return i;
}
void main() {
uint index = gl_GlobalInvocationID.x;
indices[index] = collatz_iterations(indices[index]);
}

Binary file not shown.

View file

@ -3,9 +3,6 @@
/// This abstraction is inspired by gfx-hal, but is specialized to the needs of piet-gpu. /// This abstraction is inspired by gfx-hal, but is specialized to the needs of piet-gpu.
/// In time, it may go away and be replaced by either gfx-hal or wgpu. /// In time, it may go away and be replaced by either gfx-hal or wgpu.
#[macro_use]
extern crate ash;
pub mod vulkan; pub mod vulkan;
/// This isn't great but is expedient. /// This isn't great but is expedient.
@ -13,7 +10,7 @@ type Error = Box<dyn std::error::Error>;
pub trait Device: Sized { pub trait Device: Sized {
type Buffer; type Buffer;
type MemFlags; type MemFlags: MemFlags;
type Pipeline; type Pipeline;
type DescriptorSet; type DescriptorSet;
type CmdBuf: CmdBuf<Self>; type CmdBuf: CmdBuf<Self>;
@ -58,3 +55,7 @@ pub trait CmdBuf<D: Device> {
unsafe fn memory_barrier(&mut self); unsafe fn memory_barrier(&mut self);
} }
pub trait MemFlags: Sized {
fn host_coherent() -> Self;
}

View file

@ -8,18 +8,15 @@ use ash::{vk, Device, Entry, Instance};
use crate::Error; use crate::Error;
/// A base for allocating resources and dispatching work. pub struct VkInstance {
///
/// This is quite similar to "device" in most GPU API's, but I didn't want to overload
/// that term further.
pub struct Base {
/// Retain the dynamic lib. /// Retain the dynamic lib.
#[allow(unused)] #[allow(unused)]
entry: Entry, entry: Entry,
#[allow(unused)]
instance: Instance, instance: Instance,
}
pub struct VkDevice {
device: Arc<RawDevice>, device: Arc<RawDevice>,
device_mem_props: vk::PhysicalDeviceMemoryProperties, device_mem_props: vk::PhysicalDeviceMemoryProperties,
queue: vk::Queue, queue: vk::Queue,
@ -55,12 +52,14 @@ pub struct CmdBuf {
device: Arc<RawDevice>, device: Arc<RawDevice>,
} }
impl Base { pub struct MemFlags(vk::MemoryPropertyFlags);
impl VkInstance {
/// Create a new instance. /// Create a new instance.
/// ///
/// There's more to be done to make this suitable for integration with other /// There's more to be done to make this suitable for integration with other
/// systems, but for now the goal is to make things simple. /// systems, but for now the goal is to make things simple.
pub fn new() -> Result<Base, Error> { pub fn new() -> Result<VkInstance, Error> {
unsafe { unsafe {
let app_name = CString::new("VkToy").unwrap(); let app_name = CString::new("VkToy").unwrap();
let entry = Entry::new()?; let entry = Entry::new()?;
@ -75,43 +74,63 @@ impl Base {
None, None,
)?; )?;
let devices = instance.enumerate_physical_devices()?; Ok(VkInstance {
let (pdevice, qfi) =
choose_compute_device(&instance, &devices).ok_or("no suitable device")?;
let device = instance.create_device(
pdevice,
&vk::DeviceCreateInfo::builder().queue_create_infos(&[
vk::DeviceQueueCreateInfo::builder()
.queue_family_index(qfi)
.queue_priorities(&[1.0])
.build(),
]),
None,
)?;
let device_mem_props = instance.get_physical_device_memory_properties(pdevice);
let queue_index = 0;
let queue = device.get_device_queue(qfi, queue_index);
let device = Arc::new(RawDevice { device });
Ok(Base {
entry, entry,
instance, instance,
device,
device_mem_props,
qfi,
queue,
}) })
} }
} }
pub fn create_buffer( /// Create a device from the instance, suitable for compute.
///
/// # Safety
///
/// The caller is responsible for making sure that the instance outlives the device.
/// We could enforce that, for example having an `Arc` of the raw instance, but for
/// now keep things simple.
pub unsafe fn device(&self) -> Result<VkDevice, Error> {
let devices = self.instance.enumerate_physical_devices()?;
let (pdevice, qfi) =
choose_compute_device(&self.instance, &devices).ok_or("no suitable device")?;
let device = self.instance.create_device(
pdevice,
&vk::DeviceCreateInfo::builder().queue_create_infos(&[
vk::DeviceQueueCreateInfo::builder()
.queue_family_index(qfi)
.queue_priorities(&[1.0])
.build(),
]),
None,
)?;
let device_mem_props = self.instance.get_physical_device_memory_properties(pdevice);
let queue_index = 0;
let queue = device.get_device_queue(qfi, queue_index);
let device = Arc::new(RawDevice { device });
Ok(VkDevice {
device,
device_mem_props,
qfi,
queue,
})
}
}
impl crate::Device for VkDevice {
type Buffer = Buffer;
type CmdBuf = CmdBuf;
type DescriptorSet = DescriptorSet;
type Pipeline = Pipeline;
type MemFlags = MemFlags;
fn create_buffer(
&self, &self,
size: u64, size: u64,
mem_flags: vk::MemoryPropertyFlags, mem_flags: MemFlags,
) -> Result<Buffer, Error> { ) -> Result<Buffer, Error> {
unsafe { unsafe {
let device = &self.device.device; let device = &self.device.device;
@ -125,7 +144,7 @@ impl Base {
let mem_requirements = device.get_buffer_memory_requirements(buffer); let mem_requirements = device.get_buffer_memory_requirements(buffer);
let mem_type = find_memory_type( let mem_type = find_memory_type(
mem_requirements.memory_type_bits, mem_requirements.memory_type_bits,
mem_flags, mem_flags.0,
&self.device_mem_props, &self.device_mem_props,
) )
.unwrap(); // TODO: proper error .unwrap(); // TODO: proper error
@ -148,7 +167,7 @@ impl Base {
/// ///
/// The code is included from "../comp.spv", and the descriptor set layout is just some /// The code is included from "../comp.spv", and the descriptor set layout is just some
/// number of buffers. /// number of buffers.
pub unsafe fn create_simple_compute_pipeline( unsafe fn create_simple_compute_pipeline(
&self, &self,
code: &[u8], code: &[u8],
n_buffers: u32, n_buffers: u32,
@ -199,7 +218,7 @@ impl Base {
}) })
} }
pub unsafe fn create_descriptor_set( unsafe fn create_descriptor_set(
&self, &self,
pipeline: &Pipeline, pipeline: &Pipeline,
bufs: &[&Buffer], bufs: &[&Buffer],
@ -247,7 +266,7 @@ impl Base {
}) })
} }
pub fn create_cmd_buf(&self) -> Result<CmdBuf, Error> { fn create_cmd_buf(&self) -> Result<CmdBuf, Error> {
unsafe { unsafe {
let device = &self.device.device; let device = &self.device.device;
let command_pool = device.create_command_pool( let command_pool = device.create_command_pool(
@ -272,7 +291,7 @@ impl Base {
/// Run the command buffer. /// Run the command buffer.
/// ///
/// This version simply blocks until it's complete. /// This version simply blocks until it's complete.
pub unsafe fn run_cmd_buf(&self, cmd_buf: &CmdBuf) -> Result<(), Error> { unsafe fn run_cmd_buf(&self, cmd_buf: &CmdBuf) -> Result<(), Error> {
let device = &self.device.device; let device = &self.device.device;
// Run the command buffer. // Run the command buffer.
@ -287,12 +306,12 @@ impl Base {
.build()], .build()],
fence, fence,
)?; )?;
device.wait_for_fences(&[fence], true, 1_000_000)?; device.wait_for_fences(&[fence], true, 100_000_000)?;
device.destroy_fence(fence, None); // TODO: handle errors better (currently leaks fence and can lead to other problems)
Ok(()) Ok(())
} }
pub unsafe fn read_buffer<T: Sized>( unsafe fn read_buffer<T: Sized>(
&self, &self,
buffer: &Buffer, buffer: &Buffer,
result: &mut Vec<T>, result: &mut Vec<T>,
@ -314,7 +333,7 @@ impl Base {
Ok(()) Ok(())
} }
pub unsafe fn write_buffer<T: Sized>( unsafe fn write_buffer<T: Sized>(
&self, &self,
buffer: &Buffer, buffer: &Buffer,
contents: &[T], contents: &[T],
@ -332,8 +351,8 @@ impl Base {
} }
} }
impl CmdBuf { impl crate::CmdBuf<VkDevice> for CmdBuf {
pub unsafe fn begin(&mut self) { unsafe fn begin(&mut self) {
self.device self.device
.device .device
.begin_command_buffer( .begin_command_buffer(
@ -344,11 +363,11 @@ impl CmdBuf {
.unwrap(); .unwrap();
} }
pub unsafe fn finish(&mut self) { unsafe fn finish(&mut self) {
self.device.device.end_command_buffer(self.cmd_buf).unwrap(); self.device.device.end_command_buffer(self.cmd_buf).unwrap();
} }
pub unsafe fn dispatch(&mut self, pipeline: &Pipeline, descriptor_set: &DescriptorSet) { unsafe fn dispatch(&mut self, pipeline: &Pipeline, descriptor_set: &DescriptorSet) {
let device = &self.device.device; let device = &self.device.device;
device.cmd_bind_pipeline( device.cmd_bind_pipeline(
self.cmd_buf, self.cmd_buf,
@ -367,8 +386,7 @@ impl CmdBuf {
} }
/// Insert a pipeline barrier for all memory accesses. /// Insert a pipeline barrier for all memory accesses.
#[allow(unused)] unsafe fn memory_barrier(&mut self) {
pub unsafe fn memory_barrier(&mut self) {
let device = &self.device.device; let device = &self.device.device;
device.cmd_pipeline_barrier( device.cmd_pipeline_barrier(
self.cmd_buf, self.cmd_buf,
@ -385,6 +403,12 @@ impl CmdBuf {
} }
} }
impl crate::MemFlags for MemFlags {
fn host_coherent() -> Self {
MemFlags(vk::MemoryPropertyFlags::HOST_VISIBLE | vk::MemoryPropertyFlags::HOST_COHERENT)
}
}
unsafe fn choose_compute_device( unsafe fn choose_compute_device(
instance: &Instance, instance: &Instance,
devices: &[vk::PhysicalDevice], devices: &[vk::PhysicalDevice],