Improved get_acceleration_structure_handle and keep local device handle in rt extension
This commit is contained in:
parent
e10bbb6298
commit
9b6fa860c9
1 changed files with 19 additions and 22 deletions
|
@ -8,6 +8,7 @@ use RawPtr;
|
|||
|
||||
#[derive(Clone)]
|
||||
pub struct RayTracing {
|
||||
handle: vk::Device,
|
||||
ray_tracing_fn: vk::NvRayTracingFn,
|
||||
}
|
||||
|
||||
|
@ -16,18 +17,20 @@ impl RayTracing {
|
|||
let ray_tracing_fn = vk::NvRayTracingFn::load(|name| unsafe {
|
||||
mem::transmute(instance.get_device_proc_addr(device.handle(), name.as_ptr()))
|
||||
});
|
||||
RayTracing { ray_tracing_fn }
|
||||
RayTracing {
|
||||
handle: device.handle(),
|
||||
ray_tracing_fn,
|
||||
}
|
||||
}
|
||||
|
||||
pub unsafe fn create_acceleration_structure(
|
||||
&self,
|
||||
device: vk::Device,
|
||||
create_info: &vk::AccelerationStructureCreateInfoNV,
|
||||
allocation_callbacks: Option<&vk::AllocationCallbacks>,
|
||||
) -> VkResult<vk::AccelerationStructureNV> {
|
||||
let mut accel_struct = mem::uninitialized();
|
||||
let err_code = self.ray_tracing_fn.create_acceleration_structure_nv(
|
||||
device,
|
||||
self.handle,
|
||||
create_info,
|
||||
allocation_callbacks.as_raw_ptr(),
|
||||
&mut accel_struct,
|
||||
|
@ -40,12 +43,11 @@ impl RayTracing {
|
|||
|
||||
pub unsafe fn destroy_acceleration_structure(
|
||||
&self,
|
||||
device: vk::Device,
|
||||
accel_struct: vk::AccelerationStructureNV,
|
||||
allocation_callbacks: Option<&vk::AllocationCallbacks>,
|
||||
) {
|
||||
self.ray_tracing_fn.destroy_acceleration_structure_nv(
|
||||
device,
|
||||
self.handle,
|
||||
accel_struct,
|
||||
allocation_callbacks.as_raw_ptr(),
|
||||
);
|
||||
|
@ -53,22 +55,20 @@ impl RayTracing {
|
|||
|
||||
pub unsafe fn get_acceleration_structure_memory_requirements(
|
||||
&self,
|
||||
device: vk::Device,
|
||||
info: &vk::AccelerationStructureMemoryRequirementsInfoNV,
|
||||
) -> vk::MemoryRequirements2KHR {
|
||||
let mut requirements = mem::uninitialized();
|
||||
self.ray_tracing_fn
|
||||
.get_acceleration_structure_memory_requirements_nv(device, info, &mut requirements);
|
||||
.get_acceleration_structure_memory_requirements_nv(self.handle, info, &mut requirements);
|
||||
requirements
|
||||
}
|
||||
|
||||
pub unsafe fn bind_acceleration_structure_memory(
|
||||
&self,
|
||||
device: vk::Device,
|
||||
bind_info: &[vk::BindAccelerationStructureMemoryInfoNV],
|
||||
) -> VkResult<()> {
|
||||
let err_code = self.ray_tracing_fn.bind_acceleration_structure_memory_nv(
|
||||
device,
|
||||
self.handle,
|
||||
bind_info.len() as u32,
|
||||
bind_info.as_ptr(),
|
||||
);
|
||||
|
@ -153,14 +153,13 @@ impl RayTracing {
|
|||
|
||||
pub unsafe fn create_ray_tracing_pipelines(
|
||||
&self,
|
||||
device: vk::Device,
|
||||
pipeline_cache: vk::PipelineCache,
|
||||
create_info: &[vk::RayTracingPipelineCreateInfoNV],
|
||||
allocation_callbacks: Option<&vk::AllocationCallbacks>,
|
||||
) -> VkResult<Vec<vk::Pipeline>> {
|
||||
let mut pipelines = vec![mem::uninitialized(); create_info.len()];
|
||||
let err_code = self.ray_tracing_fn.create_ray_tracing_pipelines_nv(
|
||||
device,
|
||||
self.handle,
|
||||
pipeline_cache,
|
||||
create_info.len() as u32,
|
||||
create_info.as_ptr(),
|
||||
|
@ -175,14 +174,13 @@ impl RayTracing {
|
|||
|
||||
pub unsafe fn get_ray_tracing_shader_group_handles(
|
||||
&self,
|
||||
device: vk::Device,
|
||||
pipeline: vk::Pipeline,
|
||||
first_group: u32,
|
||||
group_count: u32,
|
||||
data: &mut [u8],
|
||||
) -> VkResult<()> {
|
||||
let err_code = self.ray_tracing_fn.get_ray_tracing_shader_group_handles_nv(
|
||||
device,
|
||||
self.handle,
|
||||
pipeline,
|
||||
first_group,
|
||||
group_count,
|
||||
|
@ -197,18 +195,18 @@ impl RayTracing {
|
|||
|
||||
pub unsafe fn get_acceleration_structure_handle(
|
||||
&self,
|
||||
device: vk::Device,
|
||||
accel_struct: vk::AccelerationStructureNV,
|
||||
data: &mut [u8],
|
||||
) -> VkResult<()> {
|
||||
) -> VkResult<u64> {
|
||||
let mut handle: u64 = 0;
|
||||
let handle_ptr: *mut u64 = &mut handle;
|
||||
let err_code = self.ray_tracing_fn.get_acceleration_structure_handle_nv(
|
||||
device,
|
||||
self.handle,
|
||||
accel_struct,
|
||||
data.len(),
|
||||
data.as_mut_ptr() as *mut std::ffi::c_void,
|
||||
8, // sizeof(u64)
|
||||
handle_ptr as *mut std::ffi::c_void,
|
||||
);
|
||||
match err_code {
|
||||
vk::Result::SUCCESS => Ok(()),
|
||||
vk::Result::SUCCESS => Ok(handle),
|
||||
_ => Err(err_code),
|
||||
}
|
||||
}
|
||||
|
@ -234,13 +232,12 @@ impl RayTracing {
|
|||
|
||||
pub unsafe fn compile_deferred(
|
||||
&self,
|
||||
device: vk::Device,
|
||||
pipeline: vk::Pipeline,
|
||||
shader: u32,
|
||||
) -> VkResult<()> {
|
||||
let err_code = self
|
||||
.ray_tracing_fn
|
||||
.compile_deferred_nv(device, pipeline, shader);
|
||||
.compile_deferred_nv(self.handle, pipeline, shader);
|
||||
match err_code {
|
||||
vk::Result::SUCCESS => Ok(()),
|
||||
_ => Err(err_code),
|
||||
|
|
Loading…
Add table
Reference in a new issue