make interrupt system interrupt safe(r)

FnMut is not something that can be used in a interrupt safe manner.
Instead use Fn with a Mutex that disables interrupts with a lock.
This commit is contained in:
Corwin Kuiper 2021-06-24 00:58:25 +01:00
parent a5488fab56
commit a9aad11dd7

View file

@ -1,6 +1,7 @@
use core::{ use core::{
cell::Cell, cell::{Cell, UnsafeCell},
marker::{PhantomData, PhantomPinned}, marker::{PhantomData, PhantomPinned},
ops::{Deref, DerefMut},
pin::Pin, pin::Pin,
}; };
@ -132,7 +133,7 @@ pub struct InterruptClosureBounded<'a> {
} }
struct InterruptClosure { struct InterruptClosure {
closure: *mut (dyn FnMut()), closure: *const (dyn Fn()),
next: Cell<*const InterruptClosure>, next: Cell<*const InterruptClosure>,
root: *const InterruptRoot, root: *const InterruptRoot,
} }
@ -142,7 +143,7 @@ impl InterruptRoot {
let mut c = self.next.get(); let mut c = self.next.get();
while !c.is_null() { while !c.is_null() {
let closure_ptr = unsafe { &*c }.closure; let closure_ptr = unsafe { &*c }.closure;
let closure_ref = unsafe { &mut *closure_ptr }; let closure_ref = unsafe { &*closure_ptr };
closure_ref(); closure_ref();
c = unsafe { &*c }.next.get(); c = unsafe { &*c }.next.get();
} }
@ -188,12 +189,12 @@ fn interrupt_to_root(interrupt: Interrupt) -> &'static InterruptRoot {
} }
fn get_interrupt_handle_root<'a>( fn get_interrupt_handle_root<'a>(
f: &'a mut dyn FnMut(), f: &'a dyn Fn(),
root: &InterruptRoot, root: &InterruptRoot,
) -> InterruptClosureBounded<'a> { ) -> InterruptClosureBounded<'a> {
InterruptClosureBounded { InterruptClosureBounded {
c: InterruptClosure { c: InterruptClosure {
closure: unsafe { core::mem::transmute(f as *mut _) }, closure: unsafe { core::mem::transmute(f as *const _) },
next: Cell::new(core::ptr::null()), next: Cell::new(core::ptr::null()),
root: root as *const _, root: root as *const _,
}, },
@ -202,7 +203,10 @@ fn get_interrupt_handle_root<'a>(
} }
} }
pub fn get_interrupt_handle(f: &mut dyn FnMut(), interrupt: Interrupt) -> InterruptClosureBounded { pub fn get_interrupt_handle(
f: &(dyn Fn() + Send + Sync),
interrupt: Interrupt,
) -> InterruptClosureBounded {
let root = interrupt_to_root(interrupt); let root = interrupt_to_root(interrupt);
get_interrupt_handle_root(f, root) get_interrupt_handle_root(f, root)
@ -229,11 +233,11 @@ pub fn add_interrupt<'a>(interrupt: Pin<&'a InterruptClosureBounded<'a>>) {
#[test_case] #[test_case]
fn test_vblank_interrupt_handler(gba: &mut crate::Gba) { fn test_vblank_interrupt_handler(gba: &mut crate::Gba) {
{ {
let mut counter = 0; let counter = Mutex::new(0);
let mut counter_2 = 0; let counter_2 = Mutex::new(0);
let mut vblank_interrupt = || counter += 1; let mut vblank_interrupt = || *counter.lock() += 1;
let mut vblank_interrupt_2 = || counter_2 += 1; let mut vblank_interrupt_2 = || *counter_2.lock() += 1;
let interrupt_closure = get_interrupt_handle(&mut vblank_interrupt, Interrupt::VBlank); let interrupt_closure = get_interrupt_handle(&mut vblank_interrupt, Interrupt::VBlank);
let interrupt_closure = unsafe { Pin::new_unchecked(&interrupt_closure) }; let interrupt_closure = unsafe { Pin::new_unchecked(&interrupt_closure) };
@ -245,7 +249,7 @@ fn test_vblank_interrupt_handler(gba: &mut crate::Gba) {
let vblank = gba.display.vblank.get(); let vblank = gba.display.vblank.get();
while counter < 100 || counter_2 < 100 { while *counter.lock() < 100 || *counter_2.lock() < 100 {
vblank.wait_for_VBlank(); vblank.wait_for_VBlank();
} }
} }
@ -256,3 +260,67 @@ fn test_vblank_interrupt_handler(gba: &mut crate::Gba) {
"expected the interrupt table for vblank to be empty" "expected the interrupt table for vblank to be empty"
); );
} }
#[derive(Clone, Copy)]
enum MutexState {
Locked,
Unlocked(bool),
}
pub struct Mutex<T> {
internal: UnsafeCell<T>,
state: UnsafeCell<MutexState>,
}
unsafe impl<T> Send for Mutex<T> {}
unsafe impl<T> Sync for Mutex<T> {}
impl<T> Mutex<T> {
pub fn lock(&self) -> MutexRef<T> {
let state = INTERRUPTS_ENABLED.get();
INTERRUPTS_ENABLED.set(0);
unsafe { *self.state.get() = MutexState::Unlocked(state != 0) };
MutexRef {
internal: &self.internal,
state: &self.state,
}
}
pub fn new(val: T) -> Self {
Mutex {
internal: UnsafeCell::new(val),
state: UnsafeCell::new(MutexState::Locked),
}
}
}
pub struct MutexRef<'a, T> {
internal: &'a UnsafeCell<T>,
state: &'a UnsafeCell<MutexState>,
}
impl<'a, T> Drop for MutexRef<'a, T> {
fn drop(&mut self) {
unsafe {
let state = &mut *self.state.get();
let prev_state = *state;
*state = MutexState::Locked;
match prev_state {
MutexState::Unlocked(b) => INTERRUPTS_ENABLED.set(b as u16),
MutexState::Locked => {}
}
}
}
}
impl<'a, T> Deref for MutexRef<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { &*self.internal.get() }
}
}
impl<'a, T> DerefMut for MutexRef<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *self.internal.get() }
}
}