Improved get_acceleration_structure_handle and keep local device handle in rt extension

This commit is contained in:
Graham Wihlidal 2019-02-10 13:37:53 +01:00
parent e10bbb6298
commit 9b6fa860c9

View file

@ -8,6 +8,7 @@ use RawPtr;
#[derive(Clone)] #[derive(Clone)]
pub struct RayTracing { pub struct RayTracing {
handle: vk::Device,
ray_tracing_fn: vk::NvRayTracingFn, ray_tracing_fn: vk::NvRayTracingFn,
} }
@ -16,18 +17,20 @@ impl RayTracing {
let ray_tracing_fn = vk::NvRayTracingFn::load(|name| unsafe { let ray_tracing_fn = vk::NvRayTracingFn::load(|name| unsafe {
mem::transmute(instance.get_device_proc_addr(device.handle(), name.as_ptr())) mem::transmute(instance.get_device_proc_addr(device.handle(), name.as_ptr()))
}); });
RayTracing { ray_tracing_fn } RayTracing {
handle: device.handle(),
ray_tracing_fn,
}
} }
pub unsafe fn create_acceleration_structure( pub unsafe fn create_acceleration_structure(
&self, &self,
device: vk::Device,
create_info: &vk::AccelerationStructureCreateInfoNV, create_info: &vk::AccelerationStructureCreateInfoNV,
allocation_callbacks: Option<&vk::AllocationCallbacks>, allocation_callbacks: Option<&vk::AllocationCallbacks>,
) -> VkResult<vk::AccelerationStructureNV> { ) -> VkResult<vk::AccelerationStructureNV> {
let mut accel_struct = mem::uninitialized(); let mut accel_struct = mem::uninitialized();
let err_code = self.ray_tracing_fn.create_acceleration_structure_nv( let err_code = self.ray_tracing_fn.create_acceleration_structure_nv(
device, self.handle,
create_info, create_info,
allocation_callbacks.as_raw_ptr(), allocation_callbacks.as_raw_ptr(),
&mut accel_struct, &mut accel_struct,
@ -40,12 +43,11 @@ impl RayTracing {
pub unsafe fn destroy_acceleration_structure( pub unsafe fn destroy_acceleration_structure(
&self, &self,
device: vk::Device,
accel_struct: vk::AccelerationStructureNV, accel_struct: vk::AccelerationStructureNV,
allocation_callbacks: Option<&vk::AllocationCallbacks>, allocation_callbacks: Option<&vk::AllocationCallbacks>,
) { ) {
self.ray_tracing_fn.destroy_acceleration_structure_nv( self.ray_tracing_fn.destroy_acceleration_structure_nv(
device, self.handle,
accel_struct, accel_struct,
allocation_callbacks.as_raw_ptr(), allocation_callbacks.as_raw_ptr(),
); );
@ -53,22 +55,20 @@ impl RayTracing {
pub unsafe fn get_acceleration_structure_memory_requirements( pub unsafe fn get_acceleration_structure_memory_requirements(
&self, &self,
device: vk::Device,
info: &vk::AccelerationStructureMemoryRequirementsInfoNV, info: &vk::AccelerationStructureMemoryRequirementsInfoNV,
) -> vk::MemoryRequirements2KHR { ) -> vk::MemoryRequirements2KHR {
let mut requirements = mem::uninitialized(); let mut requirements = mem::uninitialized();
self.ray_tracing_fn self.ray_tracing_fn
.get_acceleration_structure_memory_requirements_nv(device, info, &mut requirements); .get_acceleration_structure_memory_requirements_nv(self.handle, info, &mut requirements);
requirements requirements
} }
pub unsafe fn bind_acceleration_structure_memory( pub unsafe fn bind_acceleration_structure_memory(
&self, &self,
device: vk::Device,
bind_info: &[vk::BindAccelerationStructureMemoryInfoNV], bind_info: &[vk::BindAccelerationStructureMemoryInfoNV],
) -> VkResult<()> { ) -> VkResult<()> {
let err_code = self.ray_tracing_fn.bind_acceleration_structure_memory_nv( let err_code = self.ray_tracing_fn.bind_acceleration_structure_memory_nv(
device, self.handle,
bind_info.len() as u32, bind_info.len() as u32,
bind_info.as_ptr(), bind_info.as_ptr(),
); );
@ -153,14 +153,13 @@ impl RayTracing {
pub unsafe fn create_ray_tracing_pipelines( pub unsafe fn create_ray_tracing_pipelines(
&self, &self,
device: vk::Device,
pipeline_cache: vk::PipelineCache, pipeline_cache: vk::PipelineCache,
create_info: &[vk::RayTracingPipelineCreateInfoNV], create_info: &[vk::RayTracingPipelineCreateInfoNV],
allocation_callbacks: Option<&vk::AllocationCallbacks>, allocation_callbacks: Option<&vk::AllocationCallbacks>,
) -> VkResult<Vec<vk::Pipeline>> { ) -> VkResult<Vec<vk::Pipeline>> {
let mut pipelines = vec![mem::uninitialized(); create_info.len()]; let mut pipelines = vec![mem::uninitialized(); create_info.len()];
let err_code = self.ray_tracing_fn.create_ray_tracing_pipelines_nv( let err_code = self.ray_tracing_fn.create_ray_tracing_pipelines_nv(
device, self.handle,
pipeline_cache, pipeline_cache,
create_info.len() as u32, create_info.len() as u32,
create_info.as_ptr(), create_info.as_ptr(),
@ -175,14 +174,13 @@ impl RayTracing {
pub unsafe fn get_ray_tracing_shader_group_handles( pub unsafe fn get_ray_tracing_shader_group_handles(
&self, &self,
device: vk::Device,
pipeline: vk::Pipeline, pipeline: vk::Pipeline,
first_group: u32, first_group: u32,
group_count: u32, group_count: u32,
data: &mut [u8], data: &mut [u8],
) -> VkResult<()> { ) -> VkResult<()> {
let err_code = self.ray_tracing_fn.get_ray_tracing_shader_group_handles_nv( let err_code = self.ray_tracing_fn.get_ray_tracing_shader_group_handles_nv(
device, self.handle,
pipeline, pipeline,
first_group, first_group,
group_count, group_count,
@ -197,18 +195,18 @@ impl RayTracing {
pub unsafe fn get_acceleration_structure_handle( pub unsafe fn get_acceleration_structure_handle(
&self, &self,
device: vk::Device,
accel_struct: vk::AccelerationStructureNV, accel_struct: vk::AccelerationStructureNV,
data: &mut [u8], ) -> VkResult<u64> {
) -> VkResult<()> { let mut handle: u64 = 0;
let handle_ptr: *mut u64 = &mut handle;
let err_code = self.ray_tracing_fn.get_acceleration_structure_handle_nv( let err_code = self.ray_tracing_fn.get_acceleration_structure_handle_nv(
device, self.handle,
accel_struct, accel_struct,
data.len(), 8, // sizeof(u64)
data.as_mut_ptr() as *mut std::ffi::c_void, handle_ptr as *mut std::ffi::c_void,
); );
match err_code { match err_code {
vk::Result::SUCCESS => Ok(()), vk::Result::SUCCESS => Ok(handle),
_ => Err(err_code), _ => Err(err_code),
} }
} }
@ -234,13 +232,12 @@ impl RayTracing {
pub unsafe fn compile_deferred( pub unsafe fn compile_deferred(
&self, &self,
device: vk::Device,
pipeline: vk::Pipeline, pipeline: vk::Pipeline,
shader: u32, shader: u32,
) -> VkResult<()> { ) -> VkResult<()> {
let err_code = self let err_code = self
.ray_tracing_fn .ray_tracing_fn
.compile_deferred_nv(device, pipeline, shader); .compile_deferred_nv(self.handle, pipeline, shader);
match err_code { match err_code {
vk::Result::SUCCESS => Ok(()), vk::Result::SUCCESS => Ok(()),
_ => Err(err_code), _ => Err(err_code),