resolve interrupts unsoundness

This commit is contained in:
Corwin 2023-04-06 21:16:13 +01:00
parent d3834ed2ff
commit dffda47985
No known key found for this signature in database
5 changed files with 111 additions and 126 deletions

View file

@ -3,13 +3,16 @@
use agb::sync::Static; use agb::sync::Static;
static COUNT: Static<u32> = Static::new(0);
#[agb::entry] #[agb::entry]
fn main(_gba: agb::Gba) -> ! { fn main(_gba: agb::Gba) -> ! {
let count = Static::new(0); let _a = unsafe {
let _a = agb::interrupt::add_interrupt_handler(agb::interrupt::Interrupt::VBlank, |_| { agb::interrupt::add_interrupt_handler(agb::interrupt::Interrupt::VBlank, |_| {
let cur_count = count.read(); let cur_count = COUNT.read();
agb::println!("Hello, world, frame = {}", cur_count); agb::println!("Hello, world, frame = {}", cur_count);
count.write(cur_count + 1); COUNT.write(cur_count + 1);
}); })
};
loop {} loop {}
} }

View file

@ -18,6 +18,11 @@ struct BackCosines {
row: usize, row: usize,
} }
static BACK: Mutex<RefCell<BackCosines>> = Mutex::new(RefCell::new(BackCosines {
cosines: [0; 32],
row: 0,
}));
#[agb::entry] #[agb::entry]
fn main(mut gba: agb::Gba) -> ! { fn main(mut gba: agb::Gba) -> ! {
let (gfx, mut vram) = gba.display.video.tiled0(); let (gfx, mut vram) = gba.display.video.tiled0();
@ -30,24 +35,22 @@ fn main(mut gba: agb::Gba) -> ! {
example_logo::display_logo(&mut background, &mut vram); example_logo::display_logo(&mut background, &mut vram);
let mut time = 0; let _a = unsafe {
let cosines = [0_u16; 32]; agb::interrupt::add_interrupt_handler(Interrupt::HBlank, |key: CriticalSection| {
let mut back = BACK.borrow(key).borrow_mut();
let back = Mutex::new(RefCell::new(BackCosines { cosines, row: 0 }));
let _a = agb::interrupt::add_interrupt_handler(Interrupt::HBlank, |key: CriticalSection| {
let mut back = back.borrow(key).borrow_mut();
let deflection = back.cosines[back.row % 32]; let deflection = back.cosines[back.row % 32];
unsafe { ((0x0400_0010) as *mut u16).write_volatile(deflection) } ((0x0400_0010) as *mut u16).write_volatile(deflection);
back.row += 1; back.row += 1;
}); })
};
let vblank = agb::interrupt::VBlank::get(); let vblank = agb::interrupt::VBlank::get();
let mut time = 0;
loop { loop {
vblank.wait_for_vblank(); vblank.wait_for_vblank();
free(|key| { free(|key| {
let mut back = back.borrow(key).borrow_mut(); let mut back = BACK.borrow(key).borrow_mut();
back.row = 0; back.row = 0;
time += 1; time += 1;
for (r, a) in back.cosines.iter_mut().enumerate() { for (r, a) in back.cosines.iter_mut().enumerate() {

View file

@ -1,8 +1,4 @@
use core::{ use core::{cell::Cell, marker::PhantomPinned, pin::Pin};
cell::Cell,
marker::{PhantomData, PhantomPinned},
pin::Pin,
};
use alloc::boxed::Box; use alloc::boxed::Box;
use bare_metal::CriticalSection; use bare_metal::CriticalSection;
@ -206,9 +202,8 @@ impl Drop for InterruptInner {
} }
} }
pub struct InterruptHandler<'a> { pub struct InterruptHandler {
_inner: Pin<Box<InterruptInner>>, _inner: Pin<Box<InterruptInner>>,
_lifetime: PhantomData<&'a ()>,
} }
impl InterruptRoot { impl InterruptRoot {
@ -231,6 +226,13 @@ fn interrupt_to_root(interrupt: Interrupt) -> &'static InterruptRoot {
/// Adds an interrupt handler as long as the returned value is alive. The /// Adds an interrupt handler as long as the returned value is alive. The
/// closure takes a [`CriticalSection`] which can be used for mutexes. /// closure takes a [`CriticalSection`] which can be used for mutexes.
/// ///
/// SAFETY:
/// * You *must not* allocate in an interrupt.
///
/// STATICNESS:
/// * The closure must be static because forgetting the interrupt handler will
/// cause a use after free.
///
/// [`CriticalSection`]: bare_metal::CriticalSection /// [`CriticalSection`]: bare_metal::CriticalSection
/// ///
/// # Examples /// # Examples
@ -247,14 +249,11 @@ fn interrupt_to_root(interrupt: Interrupt) -> &'static InterruptRoot {
/// }); /// });
/// # } /// # }
/// ``` /// ```
pub fn add_interrupt_handler<'a>( pub unsafe fn add_interrupt_handler(
interrupt: Interrupt, interrupt: Interrupt,
handler: impl Fn(CriticalSection) + Send + Sync + 'a, handler: impl Fn(CriticalSection) + Send + Sync + 'static,
) -> InterruptHandler<'a> { ) -> InterruptHandler {
fn do_with_inner<'a>( fn do_with_inner(interrupt: Interrupt, inner: Pin<Box<InterruptInner>>) -> InterruptHandler {
interrupt: Interrupt,
inner: Pin<Box<InterruptInner>>,
) -> InterruptHandler<'a> {
free(|_| { free(|_| {
let root = interrupt_to_root(interrupt); let root = interrupt_to_root(interrupt);
root.add(); root.add();
@ -274,10 +273,7 @@ pub fn add_interrupt_handler<'a>(
} }
}); });
InterruptHandler { InterruptHandler { _inner: inner }
_inner: inner,
_lifetime: PhantomData,
}
} }
let root = interrupt_to_root(interrupt) as *const _; let root = interrupt_to_root(interrupt) as *const _;
let inner = unsafe { create_interrupt_inner(handler, root) }; let inner = unsafe { create_interrupt_inner(handler, root) };
@ -322,9 +318,12 @@ impl VBlank {
#[must_use] #[must_use]
pub fn get() -> Self { pub fn get() -> Self {
if !HAS_CREATED_INTERRUPT.read() { if !HAS_CREATED_INTERRUPT.read() {
let handler = add_interrupt_handler(Interrupt::VBlank, |_| { // safety: we don't allocate in the interrupt
let handler = unsafe {
add_interrupt_handler(Interrupt::VBlank, |_| {
NUM_VBLANKS.write(NUM_VBLANKS.read() + 1); NUM_VBLANKS.write(NUM_VBLANKS.read() + 1);
}); })
};
core::mem::forget(handler); core::mem::forget(handler);
HAS_CREATED_INTERRUPT.write(true); HAS_CREATED_INTERRUPT.write(true);
@ -351,36 +350,6 @@ impl VBlank {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use bare_metal::Mutex;
use core::cell::RefCell;
#[test_case]
fn test_can_create_and_destroy_interrupt_handlers(_gba: &mut crate::Gba) {
let mut counter = Mutex::new(RefCell::new(0));
let counter_2 = Mutex::new(RefCell::new(0));
let vblank = VBlank::get();
{
let _a = add_interrupt_handler(Interrupt::VBlank, |key: CriticalSection| {
*counter.borrow(key).borrow_mut() += 1;
});
let _b = add_interrupt_handler(Interrupt::VBlank, |key: CriticalSection| {
*counter_2.borrow(key).borrow_mut() += 1;
});
while free(|key| {
*counter.borrow(key).borrow() < 100 || *counter_2.borrow(key).borrow() < 100
}) {
vblank.wait_for_vblank();
}
}
vblank.wait_for_vblank();
vblank.wait_for_vblank();
assert_eq!(*counter.get_mut().get_mut(), 100);
}
#[test_case] #[test_case]
fn test_interrupt_table_length(_gba: &mut crate::Gba) { fn test_interrupt_table_length(_gba: &mut crate::Gba) {
@ -406,7 +375,9 @@ pub fn profiler(timer: &mut crate::timer::Timer, period: u16) -> InterruptHandle
timer.set_overflow_amount(period); timer.set_overflow_amount(period);
timer.set_enabled(true); timer.set_enabled(true);
unsafe {
add_interrupt_handler(timer.interrupt(), |_key: CriticalSection| { add_interrupt_handler(timer.interrupt(), |_key: CriticalSection| {
crate::println!("{:#010x}", crate::program_counter_before_interrupt()); crate::println!("{:#010x}", crate::program_counter_before_interrupt());
}) })
}
} }

View file

@ -82,7 +82,7 @@ extern "C" {
pub struct Mixer<'gba> { pub struct Mixer<'gba> {
interrupt_timer: Timer, interrupt_timer: Timer,
// SAFETY: Has to go before buffer because it holds a reference to it // SAFETY: Has to go before buffer because it holds a reference to it
_interrupt_handler: InterruptHandler<'static>, _interrupt_handler: InterruptHandler,
buffer: Pin<Box<MixerBuffer, InternalAllocator>>, buffer: Pin<Box<MixerBuffer, InternalAllocator>>,
channels: [Option<SoundChannel>; 8], channels: [Option<SoundChannel>; 8],
@ -140,9 +140,11 @@ impl Mixer<'_> {
// In the case of the mixer being forgotten, both stay alive so okay // In the case of the mixer being forgotten, both stay alive so okay
let buffer_pointer_for_interrupt_handler: &MixerBuffer = let buffer_pointer_for_interrupt_handler: &MixerBuffer =
unsafe { core::mem::transmute(buffer_pointer_for_interrupt_handler) }; unsafe { core::mem::transmute(buffer_pointer_for_interrupt_handler) };
let interrupt_handler = add_interrupt_handler(interrupt_timer.interrupt(), |cs| { let interrupt_handler = unsafe {
add_interrupt_handler(interrupt_timer.interrupt(), |cs| {
buffer_pointer_for_interrupt_handler.swap(cs); buffer_pointer_for_interrupt_handler.swap(cs);
}); })
};
set_asm_buffer_size(frequency); set_asm_buffer_size(frequency);

View file

@ -267,9 +267,11 @@ mod test {
use crate::timer::Divider; use crate::timer::Divider;
use crate::Gba; use crate::Gba;
fn write_read_concurrency_test_impl<const COUNT: usize>(gba: &mut Gba) { macro_rules! generate_concurrency_test {
let sentinel = [0x12345678; COUNT]; ($count:literal, $gba:ident) => {{
let value: Static<[u32; COUNT]> = Static::new(sentinel); (|gba: &mut Gba| {
const SENTINEL: [u32; $count] = [0x12345678; $count];
static VALUE: Static<[u32; $count]> = Static::new(SENTINEL);
// set up a timer and an interrupt that uses the timer // set up a timer and an interrupt that uses the timer
let mut timer = gba.timers.timers().timer2; let mut timer = gba.timers.timers().timer2;
@ -279,23 +281,25 @@ mod test {
timer.set_interrupt(true); timer.set_interrupt(true);
timer.set_enabled(true); timer.set_enabled(true);
let _int = crate::interrupt::add_interrupt_handler(Interrupt::Timer2, |_| { let _int = unsafe {
value.write(sentinel); crate::interrupt::add_interrupt_handler(Interrupt::Timer2, |_| {
}); VALUE.write(SENTINEL);
})
};
// the actual main test loop // the actual main test loop
let mut interrupt_seen = false; let mut interrupt_seen = false;
let mut no_interrupt_seen = false; let mut no_interrupt_seen = false;
for i in 0..250000 { for i in 0..250000 {
// write to the static // write to the static
let new_value = [i; COUNT]; let new_value = [i; $count];
value.write(new_value); VALUE.write(new_value);
// check the current value // check the current value
let current = value.read(); let current = VALUE.read();
if current == new_value { if current == new_value {
no_interrupt_seen = true; no_interrupt_seen = true;
} else if current == sentinel { } else if current == SENTINEL {
interrupt_seen = true; interrupt_seen = true;
} else { } else {
panic!("Unexpected value found in `Static`."); panic!("Unexpected value found in `Static`.");
@ -312,20 +316,22 @@ mod test {
timer.set_overflow_amount(1049 + (i / 64) as u16); timer.set_overflow_amount(1049 + (i / 64) as u16);
} }
} }
panic!("Concurrency test timed out: {}", COUNT) panic!("Concurrency test timed out: {}", $count)
})($gba);
}};
} }
#[test_case] #[test_case]
fn write_read_concurrency_test(gba: &mut Gba) { fn write_read_concurrency_test(gba: &mut Gba) {
write_read_concurrency_test_impl::<1>(gba); generate_concurrency_test!(1, gba);
write_read_concurrency_test_impl::<2>(gba); generate_concurrency_test!(2, gba);
write_read_concurrency_test_impl::<3>(gba); generate_concurrency_test!(3, gba);
write_read_concurrency_test_impl::<4>(gba); generate_concurrency_test!(4, gba);
write_read_concurrency_test_impl::<5>(gba); generate_concurrency_test!(5, gba);
write_read_concurrency_test_impl::<6>(gba); generate_concurrency_test!(6, gba);
write_read_concurrency_test_impl::<7>(gba); generate_concurrency_test!(7, gba);
write_read_concurrency_test_impl::<8>(gba); generate_concurrency_test!(8, gba);
write_read_concurrency_test_impl::<9>(gba); generate_concurrency_test!(9, gba);
write_read_concurrency_test_impl::<10>(gba); generate_concurrency_test!(10, gba);
} }
} }