Improved get_acceleration_structure_handle and keep local device handle in rt extension

This commit is contained in:
Graham Wihlidal 2019-02-10 13:37:53 +01:00
parent e10bbb6298
commit 9b6fa860c9

View file

@ -8,6 +8,7 @@ use RawPtr;
#[derive(Clone)]
pub struct RayTracing {
handle: vk::Device,
ray_tracing_fn: vk::NvRayTracingFn,
}
@ -16,18 +17,20 @@ impl RayTracing {
let ray_tracing_fn = vk::NvRayTracingFn::load(|name| unsafe {
mem::transmute(instance.get_device_proc_addr(device.handle(), name.as_ptr()))
});
RayTracing { ray_tracing_fn }
RayTracing {
handle: device.handle(),
ray_tracing_fn,
}
}
pub unsafe fn create_acceleration_structure(
&self,
device: vk::Device,
create_info: &vk::AccelerationStructureCreateInfoNV,
allocation_callbacks: Option<&vk::AllocationCallbacks>,
) -> VkResult<vk::AccelerationStructureNV> {
let mut accel_struct = mem::uninitialized();
let err_code = self.ray_tracing_fn.create_acceleration_structure_nv(
device,
self.handle,
create_info,
allocation_callbacks.as_raw_ptr(),
&mut accel_struct,
@ -40,12 +43,11 @@ impl RayTracing {
pub unsafe fn destroy_acceleration_structure(
&self,
device: vk::Device,
accel_struct: vk::AccelerationStructureNV,
allocation_callbacks: Option<&vk::AllocationCallbacks>,
) {
self.ray_tracing_fn.destroy_acceleration_structure_nv(
device,
self.handle,
accel_struct,
allocation_callbacks.as_raw_ptr(),
);
@ -53,22 +55,20 @@ impl RayTracing {
pub unsafe fn get_acceleration_structure_memory_requirements(
&self,
device: vk::Device,
info: &vk::AccelerationStructureMemoryRequirementsInfoNV,
) -> vk::MemoryRequirements2KHR {
let mut requirements = mem::uninitialized();
self.ray_tracing_fn
.get_acceleration_structure_memory_requirements_nv(device, info, &mut requirements);
.get_acceleration_structure_memory_requirements_nv(self.handle, info, &mut requirements);
requirements
}
pub unsafe fn bind_acceleration_structure_memory(
&self,
device: vk::Device,
bind_info: &[vk::BindAccelerationStructureMemoryInfoNV],
) -> VkResult<()> {
let err_code = self.ray_tracing_fn.bind_acceleration_structure_memory_nv(
device,
self.handle,
bind_info.len() as u32,
bind_info.as_ptr(),
);
@ -153,14 +153,13 @@ impl RayTracing {
pub unsafe fn create_ray_tracing_pipelines(
&self,
device: vk::Device,
pipeline_cache: vk::PipelineCache,
create_info: &[vk::RayTracingPipelineCreateInfoNV],
allocation_callbacks: Option<&vk::AllocationCallbacks>,
) -> VkResult<Vec<vk::Pipeline>> {
let mut pipelines = vec![mem::uninitialized(); create_info.len()];
let err_code = self.ray_tracing_fn.create_ray_tracing_pipelines_nv(
device,
self.handle,
pipeline_cache,
create_info.len() as u32,
create_info.as_ptr(),
@ -175,14 +174,13 @@ impl RayTracing {
pub unsafe fn get_ray_tracing_shader_group_handles(
&self,
device: vk::Device,
pipeline: vk::Pipeline,
first_group: u32,
group_count: u32,
data: &mut [u8],
) -> VkResult<()> {
let err_code = self.ray_tracing_fn.get_ray_tracing_shader_group_handles_nv(
device,
self.handle,
pipeline,
first_group,
group_count,
@ -197,18 +195,18 @@ impl RayTracing {
pub unsafe fn get_acceleration_structure_handle(
&self,
device: vk::Device,
accel_struct: vk::AccelerationStructureNV,
data: &mut [u8],
) -> VkResult<()> {
) -> VkResult<u64> {
let mut handle: u64 = 0;
let handle_ptr: *mut u64 = &mut handle;
let err_code = self.ray_tracing_fn.get_acceleration_structure_handle_nv(
device,
self.handle,
accel_struct,
data.len(),
data.as_mut_ptr() as *mut std::ffi::c_void,
8, // sizeof(u64)
handle_ptr as *mut std::ffi::c_void,
);
match err_code {
vk::Result::SUCCESS => Ok(()),
vk::Result::SUCCESS => Ok(handle),
_ => Err(err_code),
}
}
@ -234,13 +232,12 @@ impl RayTracing {
pub unsafe fn compile_deferred(
&self,
device: vk::Device,
pipeline: vk::Pipeline,
shader: u32,
) -> VkResult<()> {
let err_code = self
.ray_tracing_fn
.compile_deferred_nv(device, pipeline, shader);
.compile_deferred_nv(self.handle, pipeline, shader);
match err_code {
vk::Result::SUCCESS => Ok(()),
_ => Err(err_code),