From a9aad11dd72a1c3c7e99ef40cf29cf8d1936b324 Mon Sep 17 00:00:00 2001 From: Corwin Kuiper Date: Thu, 24 Jun 2021 00:58:25 +0100 Subject: [PATCH] make interrupt system interrupt safe(r) FnMut is not something that can be used in a interrupt safe manner. Instead use Fn with a Mutex that disables interrupts with a lock. --- agb/src/interrupt.rs | 90 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 79 insertions(+), 11 deletions(-) diff --git a/agb/src/interrupt.rs b/agb/src/interrupt.rs index c8ffa02a..9f73970e 100644 --- a/agb/src/interrupt.rs +++ b/agb/src/interrupt.rs @@ -1,6 +1,7 @@ use core::{ - cell::Cell, + cell::{Cell, UnsafeCell}, marker::{PhantomData, PhantomPinned}, + ops::{Deref, DerefMut}, pin::Pin, }; @@ -132,7 +133,7 @@ pub struct InterruptClosureBounded<'a> { } struct InterruptClosure { - closure: *mut (dyn FnMut()), + closure: *const (dyn Fn()), next: Cell<*const InterruptClosure>, root: *const InterruptRoot, } @@ -142,7 +143,7 @@ impl InterruptRoot { let mut c = self.next.get(); while !c.is_null() { let closure_ptr = unsafe { &*c }.closure; - let closure_ref = unsafe { &mut *closure_ptr }; + let closure_ref = unsafe { &*closure_ptr }; closure_ref(); c = unsafe { &*c }.next.get(); } @@ -188,12 +189,12 @@ fn interrupt_to_root(interrupt: Interrupt) -> &'static InterruptRoot { } fn get_interrupt_handle_root<'a>( - f: &'a mut dyn FnMut(), + f: &'a dyn Fn(), root: &InterruptRoot, ) -> InterruptClosureBounded<'a> { InterruptClosureBounded { c: InterruptClosure { - closure: unsafe { core::mem::transmute(f as *mut _) }, + closure: unsafe { core::mem::transmute(f as *const _) }, next: Cell::new(core::ptr::null()), root: root as *const _, }, @@ -202,7 +203,10 @@ fn get_interrupt_handle_root<'a>( } } -pub fn get_interrupt_handle(f: &mut dyn FnMut(), interrupt: Interrupt) -> InterruptClosureBounded { +pub fn get_interrupt_handle( + f: &(dyn Fn() + Send + Sync), + interrupt: Interrupt, +) -> InterruptClosureBounded { let root = interrupt_to_root(interrupt); get_interrupt_handle_root(f, root) @@ -229,11 +233,11 @@ pub fn add_interrupt<'a>(interrupt: Pin<&'a InterruptClosureBounded<'a>>) { #[test_case] fn test_vblank_interrupt_handler(gba: &mut crate::Gba) { { - let mut counter = 0; - let mut counter_2 = 0; + let counter = Mutex::new(0); + let counter_2 = Mutex::new(0); - let mut vblank_interrupt = || counter += 1; - let mut vblank_interrupt_2 = || counter_2 += 1; + let mut vblank_interrupt = || *counter.lock() += 1; + let mut vblank_interrupt_2 = || *counter_2.lock() += 1; let interrupt_closure = get_interrupt_handle(&mut vblank_interrupt, Interrupt::VBlank); let interrupt_closure = unsafe { Pin::new_unchecked(&interrupt_closure) }; @@ -245,7 +249,7 @@ fn test_vblank_interrupt_handler(gba: &mut crate::Gba) { let vblank = gba.display.vblank.get(); - while counter < 100 || counter_2 < 100 { + while *counter.lock() < 100 || *counter_2.lock() < 100 { vblank.wait_for_VBlank(); } } @@ -256,3 +260,67 @@ fn test_vblank_interrupt_handler(gba: &mut crate::Gba) { "expected the interrupt table for vblank to be empty" ); } + +#[derive(Clone, Copy)] +enum MutexState { + Locked, + Unlocked(bool), +} + +pub struct Mutex { + internal: UnsafeCell, + state: UnsafeCell, +} + +unsafe impl Send for Mutex {} +unsafe impl Sync for Mutex {} + +impl Mutex { + pub fn lock(&self) -> MutexRef { + let state = INTERRUPTS_ENABLED.get(); + INTERRUPTS_ENABLED.set(0); + unsafe { *self.state.get() = MutexState::Unlocked(state != 0) }; + MutexRef { + internal: &self.internal, + state: &self.state, + } + } + pub fn new(val: T) -> Self { + Mutex { + internal: UnsafeCell::new(val), + state: UnsafeCell::new(MutexState::Locked), + } + } +} + +pub struct MutexRef<'a, T> { + internal: &'a UnsafeCell, + state: &'a UnsafeCell, +} + +impl<'a, T> Drop for MutexRef<'a, T> { + fn drop(&mut self) { + unsafe { + let state = &mut *self.state.get(); + let prev_state = *state; + *state = MutexState::Locked; + match prev_state { + MutexState::Unlocked(b) => INTERRUPTS_ENABLED.set(b as u16), + MutexState::Locked => {} + } + } + } +} + +impl<'a, T> Deref for MutexRef<'a, T> { + type Target = T; + fn deref(&self) -> &Self::Target { + unsafe { &*self.internal.get() } + } +} + +impl<'a, T> DerefMut for MutexRef<'a, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { &mut *self.internal.get() } + } +}