diff --git a/ash/src/extensions/nv/ray_tracing.rs b/ash/src/extensions/nv/ray_tracing.rs index 8645708..95e9bf1 100644 --- a/ash/src/extensions/nv/ray_tracing.rs +++ b/ash/src/extensions/nv/ray_tracing.rs @@ -8,6 +8,7 @@ use RawPtr; #[derive(Clone)] pub struct RayTracing { + handle: vk::Device, ray_tracing_fn: vk::NvRayTracingFn, } @@ -16,18 +17,20 @@ impl RayTracing { let ray_tracing_fn = vk::NvRayTracingFn::load(|name| unsafe { mem::transmute(instance.get_device_proc_addr(device.handle(), name.as_ptr())) }); - RayTracing { ray_tracing_fn } + RayTracing { + handle: device.handle(), + ray_tracing_fn, + } } pub unsafe fn create_acceleration_structure( &self, - device: vk::Device, create_info: &vk::AccelerationStructureCreateInfoNV, allocation_callbacks: Option<&vk::AllocationCallbacks>, ) -> VkResult { let mut accel_struct = mem::uninitialized(); let err_code = self.ray_tracing_fn.create_acceleration_structure_nv( - device, + self.handle, create_info, allocation_callbacks.as_raw_ptr(), &mut accel_struct, @@ -40,12 +43,11 @@ impl RayTracing { pub unsafe fn destroy_acceleration_structure( &self, - device: vk::Device, accel_struct: vk::AccelerationStructureNV, allocation_callbacks: Option<&vk::AllocationCallbacks>, ) { self.ray_tracing_fn.destroy_acceleration_structure_nv( - device, + self.handle, accel_struct, allocation_callbacks.as_raw_ptr(), ); @@ -53,22 +55,20 @@ impl RayTracing { pub unsafe fn get_acceleration_structure_memory_requirements( &self, - device: vk::Device, info: &vk::AccelerationStructureMemoryRequirementsInfoNV, ) -> vk::MemoryRequirements2KHR { let mut requirements = mem::uninitialized(); self.ray_tracing_fn - .get_acceleration_structure_memory_requirements_nv(device, info, &mut requirements); + .get_acceleration_structure_memory_requirements_nv(self.handle, info, &mut requirements); requirements } pub unsafe fn bind_acceleration_structure_memory( &self, - device: vk::Device, bind_info: &[vk::BindAccelerationStructureMemoryInfoNV], ) -> VkResult<()> { let err_code = self.ray_tracing_fn.bind_acceleration_structure_memory_nv( - device, + self.handle, bind_info.len() as u32, bind_info.as_ptr(), ); @@ -153,14 +153,13 @@ impl RayTracing { pub unsafe fn create_ray_tracing_pipelines( &self, - device: vk::Device, pipeline_cache: vk::PipelineCache, create_info: &[vk::RayTracingPipelineCreateInfoNV], allocation_callbacks: Option<&vk::AllocationCallbacks>, ) -> VkResult> { let mut pipelines = vec![mem::uninitialized(); create_info.len()]; let err_code = self.ray_tracing_fn.create_ray_tracing_pipelines_nv( - device, + self.handle, pipeline_cache, create_info.len() as u32, create_info.as_ptr(), @@ -175,14 +174,13 @@ impl RayTracing { pub unsafe fn get_ray_tracing_shader_group_handles( &self, - device: vk::Device, pipeline: vk::Pipeline, first_group: u32, group_count: u32, data: &mut [u8], ) -> VkResult<()> { let err_code = self.ray_tracing_fn.get_ray_tracing_shader_group_handles_nv( - device, + self.handle, pipeline, first_group, group_count, @@ -197,18 +195,18 @@ impl RayTracing { pub unsafe fn get_acceleration_structure_handle( &self, - device: vk::Device, accel_struct: vk::AccelerationStructureNV, - data: &mut [u8], - ) -> VkResult<()> { + ) -> VkResult { + let mut handle: u64 = 0; + let handle_ptr: *mut u64 = &mut handle; let err_code = self.ray_tracing_fn.get_acceleration_structure_handle_nv( - device, + self.handle, accel_struct, - data.len(), - data.as_mut_ptr() as *mut std::ffi::c_void, + 8, // sizeof(u64) + handle_ptr as *mut std::ffi::c_void, ); match err_code { - vk::Result::SUCCESS => Ok(()), + vk::Result::SUCCESS => Ok(handle), _ => Err(err_code), } } @@ -234,13 +232,12 @@ impl RayTracing { pub unsafe fn compile_deferred( &self, - device: vk::Device, pipeline: vk::Pipeline, shader: u32, ) -> VkResult<()> { let err_code = self .ray_tracing_fn - .compile_deferred_nv(device, pipeline, shader); + .compile_deferred_nv(self.handle, pipeline, shader); match err_code { vk::Result::SUCCESS => Ok(()), _ => Err(err_code),