From 3c20c83bc98e4c12442c670d8bfbd05dde08c4e7 Mon Sep 17 00:00:00 2001
From: chyyran <ronny@ronnychan.ca>
Date: Sun, 29 Sep 2024 00:30:39 -0400
Subject: [PATCH] rt(d3d12): use InterfaceRef for D3D12InputImage to avoid
 refcount for input image

---
 .../src/runtime/d3d12/filter_chain.rs         | 19 ++----
 librashader-cli/src/render/d3d12/mod.rs       | 18 +++---
 librashader-runtime-d3d12/src/framebuffer.rs  |  2 +-
 librashader-runtime-d3d12/src/luts.rs         |  2 +-
 librashader-runtime-d3d12/src/texture.rs      | 62 ++++++-------------
 .../tests/hello_triangle/mod.rs               |  2 +-
 6 files changed, 35 insertions(+), 70 deletions(-)

diff --git a/librashader-capi/src/runtime/d3d12/filter_chain.rs b/librashader-capi/src/runtime/d3d12/filter_chain.rs
index 80ab7c8..e3b3389 100644
--- a/librashader-capi/src/runtime/d3d12/filter_chain.rs
+++ b/librashader-capi/src/runtime/d3d12/filter_chain.rs
@@ -8,6 +8,7 @@ use std::ffi::CStr;
 use std::mem::{ManuallyDrop, MaybeUninit};
 use std::ptr::NonNull;
 use std::slice;
+use windows::core::Interface;
 use windows::Win32::Graphics::Direct3D12::{
     ID3D12Device, ID3D12GraphicsCommandList, ID3D12Resource, D3D12_CPU_DESCRIPTOR_HANDLE,
 };
@@ -93,19 +94,6 @@ config_struct! {
     }
 }
 
-impl TryFrom<libra_source_image_d3d12_t> for D3D12InputImage {
-    type Error = LibrashaderError;
-
-    fn try_from(value: libra_source_image_d3d12_t) -> Result<Self, Self::Error> {
-        let resource = value.resource.clone();
-
-        Ok(D3D12InputImage {
-            resource: ManuallyDrop::into_inner(resource),
-            descriptor: value.descriptor,
-        })
-    }
-}
-
 extern_fn! {
     /// Create the filter chain given the shader preset.
     ///
@@ -295,7 +283,10 @@ extern_fn! {
             }
         };
 
-        let image = image.try_into()?;
+        let image = D3D12InputImage {
+            resource: image.resource.to_ref(),
+            descriptor: image.descriptor,
+        };
         unsafe {
             chain.frame(&command_list, image, &viewport, frame_count, options.as_ref())?;
         }
diff --git a/librashader-cli/src/render/d3d12/mod.rs b/librashader-cli/src/render/d3d12/mod.rs
index cc151b8..b905cdc 100644
--- a/librashader-cli/src/render/d3d12/mod.rs
+++ b/librashader-cli/src/render/d3d12/mod.rs
@@ -45,7 +45,7 @@ pub struct Direct3D12 {
     _cpu_heap: D3D12DescriptorHeap<CpuStagingHeap>,
     rtv_heap: D3D12DescriptorHeap<RenderTargetHeap>,
 
-    texture: D3D12InputImage,
+    texture: ID3D12Resource,
     _heap_slot: D3D12DescriptorHeapSlot<CpuStagingHeap>,
     command_pool: ID3D12CommandAllocator,
     queue: ID3D12CommandQueue,
@@ -154,7 +154,10 @@ impl RenderTest for Direct3D12 {
             for frame in 0..=frame_count {
                 filter_chain.frame(
                     &cmd,
-                    self.texture.clone(),
+                    D3D12InputImage {
+                        resource: self.texture.to_ref(),
+                        descriptor: *self._heap_slot.as_ref(),
+                    },
                     &viewport,
                     frame,
                     options.as_ref(),
@@ -249,7 +252,7 @@ impl Direct3D12 {
         path: &Path,
     ) -> anyhow::Result<(
         Image<BGRA8>,
-        D3D12InputImage,
+        ID3D12Resource,
         D3D12DescriptorHeapSlot<CpuStagingHeap>,
     )> {
         // 1 time queue infrastructure for lut uploads
@@ -392,14 +395,7 @@ impl Direct3D12 {
                 CloseHandle(fence_event)?;
             }
 
-            Ok((
-                image,
-                D3D12InputImage {
-                    resource,
-                    descriptor: descriptor.as_ref().clone(),
-                },
-                descriptor,
-            ))
+            Ok((image, resource, descriptor))
         }
     }
 
diff --git a/librashader-runtime-d3d12/src/framebuffer.rs b/librashader-runtime-d3d12/src/framebuffer.rs
index 26e0a33..f2fe905 100644
--- a/librashader-runtime-d3d12/src/framebuffer.rs
+++ b/librashader-runtime-d3d12/src/framebuffer.rs
@@ -295,7 +295,7 @@ impl OwnedImage {
             );
         }
 
-        Ok(InputTexture::new::<OutlivesFrame, _>(
+        Ok(InputTexture::new(
             &self.resource,
             descriptor,
             self.size,
diff --git a/librashader-runtime-d3d12/src/luts.rs b/librashader-runtime-d3d12/src/luts.rs
index fdcfdc8..4b24d6e 100644
--- a/librashader-runtime-d3d12/src/luts.rs
+++ b/librashader-runtime-d3d12/src/luts.rs
@@ -201,7 +201,7 @@ impl LutTexture {
             D3D12_RESOURCE_STATE_PIXEL_SHADER_RESOURCE,
         );
 
-        let view = InputTexture::new::<OutlivesFrame, _>(
+        let view = InputTexture::new(
             &resource,
             descriptor,
             source.size,
diff --git a/librashader-runtime-d3d12/src/texture.rs b/librashader-runtime-d3d12/src/texture.rs
index 5b38e30..610556e 100644
--- a/librashader-runtime-d3d12/src/texture.rs
+++ b/librashader-runtime-d3d12/src/texture.rs
@@ -1,15 +1,16 @@
 use crate::descriptor_heap::{CpuStagingHeap, RenderTargetHeap};
-use crate::resource::ResourceHandleStrategy;
-use d3d12_descriptor_heap::D3D12DescriptorHeapSlot;
+use crate::resource::{OutlivesFrame, ResourceHandleStrategy};
+use d3d12_descriptor_heap::{D3D12DescriptorHeap, D3D12DescriptorHeapSlot};
 use librashader_common::{FilterMode, GetSize, Size, WrapMode};
 use std::mem::ManuallyDrop;
+use windows::core::InterfaceRef;
 use windows::Win32::Graphics::Direct3D12::{ID3D12Resource, D3D12_CPU_DESCRIPTOR_HANDLE};
 use windows::Win32::Graphics::Dxgi::Common::DXGI_FORMAT;
 
 /// An image for use as shader resource view.
 #[derive(Clone)]
-pub struct D3D12InputImage {
-    pub resource: ID3D12Resource,
+pub struct D3D12InputImage<'a> {
+    pub resource: InterfaceRef<'a, ID3D12Resource>,
     pub descriptor: D3D12_CPU_DESCRIPTOR_HANDLE,
 }
 
@@ -95,12 +96,13 @@ pub struct InputTexture {
     pub(crate) format: DXGI_FORMAT,
     pub(crate) wrap_mode: WrapMode,
     pub(crate) filter: FilterMode,
-    drop_flag: bool,
 }
 
 impl InputTexture {
-    pub fn new<S: ResourceHandleStrategy<T>, T>(
-        resource: &T,
+    // Create a new input texture, with runtime lifetime tracking.
+    // The source owned framebuffer must outlive this input.
+    pub fn new(
+        resource: &ManuallyDrop<ID3D12Resource>,
         handle: D3D12DescriptorHeapSlot<CpuStagingHeap>,
         size: Size<u32>,
         format: DXGI_FORMAT,
@@ -114,13 +116,12 @@ impl InputTexture {
             // as valid for the lifetime of handle.
             // Also, resource is non-null by construction.
             // Option<T> and <T> have the same layout.
-            resource: unsafe { std::mem::transmute(S::obtain(resource)) },
+            resource: unsafe { std::mem::transmute(OutlivesFrame::obtain(resource)) },
             descriptor: srv,
             size,
             format,
             wrap_mode,
             filter,
-            drop_flag: S::NEEDS_CLEANUP,
         }
     }
 
@@ -132,42 +133,27 @@ impl InputTexture {
     ) -> InputTexture {
         let desc = unsafe { image.resource.GetDesc() };
         InputTexture {
-            resource: ManuallyDrop::new(image.resource.clone()),
+            resource: unsafe { std::mem::transmute(image.resource) },
             descriptor: InputDescriptor::Raw(image.descriptor),
             size: Size::new(desc.Width as u32, desc.Height),
             format: desc.Format,
             wrap_mode,
             filter,
-            drop_flag: true,
         }
     }
 }
 
 impl Clone for InputTexture {
     fn clone(&self) -> Self {
-        // ensure lifetime for raw resources or if there is a drop flag
-        if self.descriptor.is_raw() || self.drop_flag {
-            InputTexture {
-                resource: ManuallyDrop::clone(&self.resource),
-                descriptor: self.descriptor.clone(),
-                size: self.size,
-                format: self.format,
-                wrap_mode: self.wrap_mode,
-                filter: self.filter,
-                drop_flag: true,
-            }
-        } else {
-            // SAFETY: the parent doesn't have drop flag, so that means
-            // we don't need to handle drop.
-            InputTexture {
-                resource: unsafe { std::mem::transmute_copy(&self.resource) },
-                descriptor: self.descriptor.clone(),
-                size: self.size,
-                format: self.format,
-                wrap_mode: self.wrap_mode,
-                filter: self.filter,
-                drop_flag: false,
-            }
+        // SAFETY: the parent doesn't have drop flag, so that means
+        // we don't need to handle drop.
+        InputTexture {
+            resource: unsafe { std::mem::transmute_copy(&self.resource) },
+            descriptor: self.descriptor.clone(),
+            size: self.size,
+            format: self.format,
+            wrap_mode: self.wrap_mode,
+            filter: self.filter,
         }
     }
 }
@@ -185,11 +171,3 @@ impl GetSize<u32> for D3D12OutputView {
         Ok(self.size)
     }
 }
-
-impl Drop for InputTexture {
-    fn drop(&mut self) {
-        if self.drop_flag {
-            unsafe { ManuallyDrop::drop(&mut self.resource) }
-        }
-    }
-}
diff --git a/librashader-runtime-d3d12/tests/hello_triangle/mod.rs b/librashader-runtime-d3d12/tests/hello_triangle/mod.rs
index 354fdfa..fa98380 100644
--- a/librashader-runtime-d3d12/tests/hello_triangle/mod.rs
+++ b/librashader-runtime-d3d12/tests/hello_triangle/mod.rs
@@ -620,7 +620,7 @@ pub mod d3d12_hello_triangle {
                 .frame(
                     command_list,
                     D3D12InputImage {
-                        resource: ID3D12Resource::clone(&*resources.framebuffer),
+                        resource: resources.framebuffer.to_ref(),
                         descriptor: framebuffer,
                     },
                     &Viewport {