diff --git a/agb/examples/output.rs b/agb/examples/output.rs index 93b8595a..8aabc69d 100644 --- a/agb/examples/output.rs +++ b/agb/examples/output.rs @@ -3,13 +3,16 @@ use agb::sync::Static; +static COUNT: Static = Static::new(0); + #[agb::entry] fn main(_gba: agb::Gba) -> ! { - let count = Static::new(0); - let _a = agb::interrupt::add_interrupt_handler(agb::interrupt::Interrupt::VBlank, |_| { - let cur_count = count.read(); - agb::println!("Hello, world, frame = {}", cur_count); - count.write(cur_count + 1); - }); + let _a = unsafe { + agb::interrupt::add_interrupt_handler(agb::interrupt::Interrupt::VBlank, |_| { + let cur_count = COUNT.read(); + agb::println!("Hello, world, frame = {}", cur_count); + COUNT.write(cur_count + 1); + }) + }; loop {} } diff --git a/agb/examples/wave.rs b/agb/examples/wave.rs index 491c0e0b..1e9cf223 100644 --- a/agb/examples/wave.rs +++ b/agb/examples/wave.rs @@ -18,6 +18,11 @@ struct BackCosines { row: usize, } +static BACK: Mutex> = Mutex::new(RefCell::new(BackCosines { + cosines: [0; 32], + row: 0, +})); + #[agb::entry] fn main(mut gba: agb::Gba) -> ! { let (gfx, mut vram) = gba.display.video.tiled0(); @@ -30,24 +35,22 @@ fn main(mut gba: agb::Gba) -> ! { example_logo::display_logo(&mut background, &mut vram); - let mut time = 0; - let cosines = [0_u16; 32]; - - let back = Mutex::new(RefCell::new(BackCosines { cosines, row: 0 })); - - let _a = agb::interrupt::add_interrupt_handler(Interrupt::HBlank, |key: CriticalSection| { - let mut back = back.borrow(key).borrow_mut(); - let deflection = back.cosines[back.row % 32]; - unsafe { ((0x0400_0010) as *mut u16).write_volatile(deflection) } - back.row += 1; - }); + let _a = unsafe { + agb::interrupt::add_interrupt_handler(Interrupt::HBlank, |key: CriticalSection| { + let mut back = BACK.borrow(key).borrow_mut(); + let deflection = back.cosines[back.row % 32]; + ((0x0400_0010) as *mut u16).write_volatile(deflection); + back.row += 1; + }) + }; let vblank = agb::interrupt::VBlank::get(); + let mut time = 0; loop { vblank.wait_for_vblank(); free(|key| { - let mut back = back.borrow(key).borrow_mut(); + let mut back = BACK.borrow(key).borrow_mut(); back.row = 0; time += 1; for (r, a) in back.cosines.iter_mut().enumerate() { diff --git a/agb/src/interrupt.rs b/agb/src/interrupt.rs index 68aa0789..83c30357 100644 --- a/agb/src/interrupt.rs +++ b/agb/src/interrupt.rs @@ -1,8 +1,4 @@ -use core::{ - cell::Cell, - marker::{PhantomData, PhantomPinned}, - pin::Pin, -}; +use core::{cell::Cell, marker::PhantomPinned, pin::Pin}; use alloc::boxed::Box; use bare_metal::CriticalSection; @@ -206,9 +202,8 @@ impl Drop for InterruptInner { } } -pub struct InterruptHandler<'a> { +pub struct InterruptHandler { _inner: Pin>, - _lifetime: PhantomData<&'a ()>, } impl InterruptRoot { @@ -231,6 +226,13 @@ fn interrupt_to_root(interrupt: Interrupt) -> &'static InterruptRoot { /// Adds an interrupt handler as long as the returned value is alive. The /// closure takes a [`CriticalSection`] which can be used for mutexes. /// +/// SAFETY: +/// * You *must not* allocate in an interrupt. +/// +/// STATICNESS: +/// * The closure must be static because forgetting the interrupt handler will +/// cause a use after free. +/// /// [`CriticalSection`]: bare_metal::CriticalSection /// /// # Examples @@ -247,14 +249,11 @@ fn interrupt_to_root(interrupt: Interrupt) -> &'static InterruptRoot { /// }); /// # } /// ``` -pub fn add_interrupt_handler<'a>( +pub unsafe fn add_interrupt_handler( interrupt: Interrupt, - handler: impl Fn(CriticalSection) + Send + Sync + 'a, -) -> InterruptHandler<'a> { - fn do_with_inner<'a>( - interrupt: Interrupt, - inner: Pin>, - ) -> InterruptHandler<'a> { + handler: impl Fn(CriticalSection) + Send + Sync + 'static, +) -> InterruptHandler { + fn do_with_inner(interrupt: Interrupt, inner: Pin>) -> InterruptHandler { free(|_| { let root = interrupt_to_root(interrupt); root.add(); @@ -274,10 +273,7 @@ pub fn add_interrupt_handler<'a>( } }); - InterruptHandler { - _inner: inner, - _lifetime: PhantomData, - } + InterruptHandler { _inner: inner } } let root = interrupt_to_root(interrupt) as *const _; let inner = unsafe { create_interrupt_inner(handler, root) }; @@ -322,9 +318,12 @@ impl VBlank { #[must_use] pub fn get() -> Self { if !HAS_CREATED_INTERRUPT.read() { - let handler = add_interrupt_handler(Interrupt::VBlank, |_| { - NUM_VBLANKS.write(NUM_VBLANKS.read() + 1); - }); + // safety: we don't allocate in the interrupt + let handler = unsafe { + add_interrupt_handler(Interrupt::VBlank, |_| { + NUM_VBLANKS.write(NUM_VBLANKS.read() + 1); + }) + }; core::mem::forget(handler); HAS_CREATED_INTERRUPT.write(true); @@ -351,36 +350,6 @@ impl VBlank { #[cfg(test)] mod tests { use super::*; - use bare_metal::Mutex; - use core::cell::RefCell; - - #[test_case] - fn test_can_create_and_destroy_interrupt_handlers(_gba: &mut crate::Gba) { - let mut counter = Mutex::new(RefCell::new(0)); - let counter_2 = Mutex::new(RefCell::new(0)); - - let vblank = VBlank::get(); - - { - let _a = add_interrupt_handler(Interrupt::VBlank, |key: CriticalSection| { - *counter.borrow(key).borrow_mut() += 1; - }); - let _b = add_interrupt_handler(Interrupt::VBlank, |key: CriticalSection| { - *counter_2.borrow(key).borrow_mut() += 1; - }); - - while free(|key| { - *counter.borrow(key).borrow() < 100 || *counter_2.borrow(key).borrow() < 100 - }) { - vblank.wait_for_vblank(); - } - } - - vblank.wait_for_vblank(); - vblank.wait_for_vblank(); - - assert_eq!(*counter.get_mut().get_mut(), 100); - } #[test_case] fn test_interrupt_table_length(_gba: &mut crate::Gba) { @@ -406,7 +375,9 @@ pub fn profiler(timer: &mut crate::timer::Timer, period: u16) -> InterruptHandle timer.set_overflow_amount(period); timer.set_enabled(true); - add_interrupt_handler(timer.interrupt(), |_key: CriticalSection| { - crate::println!("{:#010x}", crate::program_counter_before_interrupt()); - }) + unsafe { + add_interrupt_handler(timer.interrupt(), |_key: CriticalSection| { + crate::println!("{:#010x}", crate::program_counter_before_interrupt()); + }) + } } diff --git a/agb/src/sound/mixer/sw_mixer.rs b/agb/src/sound/mixer/sw_mixer.rs index 4f32e786..8984f842 100644 --- a/agb/src/sound/mixer/sw_mixer.rs +++ b/agb/src/sound/mixer/sw_mixer.rs @@ -82,7 +82,7 @@ extern "C" { pub struct Mixer<'gba> { interrupt_timer: Timer, // SAFETY: Has to go before buffer because it holds a reference to it - _interrupt_handler: InterruptHandler<'static>, + _interrupt_handler: InterruptHandler, buffer: Pin>, channels: [Option; 8], @@ -140,9 +140,11 @@ impl Mixer<'_> { // In the case of the mixer being forgotten, both stay alive so okay let buffer_pointer_for_interrupt_handler: &MixerBuffer = unsafe { core::mem::transmute(buffer_pointer_for_interrupt_handler) }; - let interrupt_handler = add_interrupt_handler(interrupt_timer.interrupt(), |cs| { - buffer_pointer_for_interrupt_handler.swap(cs); - }); + let interrupt_handler = unsafe { + add_interrupt_handler(interrupt_timer.interrupt(), |cs| { + buffer_pointer_for_interrupt_handler.swap(cs); + }) + }; set_asm_buffer_size(frequency); diff --git a/agb/src/sync/statics.rs b/agb/src/sync/statics.rs index 69536c6c..68eb5030 100644 --- a/agb/src/sync/statics.rs +++ b/agb/src/sync/statics.rs @@ -267,65 +267,71 @@ mod test { use crate::timer::Divider; use crate::Gba; - fn write_read_concurrency_test_impl(gba: &mut Gba) { - let sentinel = [0x12345678; COUNT]; - let value: Static<[u32; COUNT]> = Static::new(sentinel); + macro_rules! generate_concurrency_test { + ($count:literal, $gba:ident) => {{ + (|gba: &mut Gba| { + const SENTINEL: [u32; $count] = [0x12345678; $count]; + static VALUE: Static<[u32; $count]> = Static::new(SENTINEL); - // set up a timer and an interrupt that uses the timer - let mut timer = gba.timers.timers().timer2; - timer.set_cascade(false); - timer.set_divider(Divider::Divider1); - timer.set_overflow_amount(1049); - timer.set_interrupt(true); - timer.set_enabled(true); + // set up a timer and an interrupt that uses the timer + let mut timer = gba.timers.timers().timer2; + timer.set_cascade(false); + timer.set_divider(Divider::Divider1); + timer.set_overflow_amount(1049); + timer.set_interrupt(true); + timer.set_enabled(true); - let _int = crate::interrupt::add_interrupt_handler(Interrupt::Timer2, |_| { - value.write(sentinel); - }); + let _int = unsafe { + crate::interrupt::add_interrupt_handler(Interrupt::Timer2, |_| { + VALUE.write(SENTINEL); + }) + }; - // the actual main test loop - let mut interrupt_seen = false; - let mut no_interrupt_seen = false; - for i in 0..250000 { - // write to the static - let new_value = [i; COUNT]; - value.write(new_value); + // the actual main test loop + let mut interrupt_seen = false; + let mut no_interrupt_seen = false; + for i in 0..250000 { + // write to the static + let new_value = [i; $count]; + VALUE.write(new_value); - // check the current value - let current = value.read(); - if current == new_value { - no_interrupt_seen = true; - } else if current == sentinel { - interrupt_seen = true; - } else { - panic!("Unexpected value found in `Static`."); - } + // check the current value + let current = VALUE.read(); + if current == new_value { + no_interrupt_seen = true; + } else if current == SENTINEL { + interrupt_seen = true; + } else { + panic!("Unexpected value found in `Static`."); + } - // we return as soon as we've seen both the value written by the main thread - // and interrupt - if interrupt_seen && no_interrupt_seen { - timer.set_enabled(false); - return; - } + // we return as soon as we've seen both the value written by the main thread + // and interrupt + if interrupt_seen && no_interrupt_seen { + timer.set_enabled(false); + return; + } - if i % 8192 == 0 && i != 0 { - timer.set_overflow_amount(1049 + (i / 64) as u16); - } - } - panic!("Concurrency test timed out: {}", COUNT) + if i % 8192 == 0 && i != 0 { + timer.set_overflow_amount(1049 + (i / 64) as u16); + } + } + panic!("Concurrency test timed out: {}", $count) + })($gba); + }}; } #[test_case] fn write_read_concurrency_test(gba: &mut Gba) { - write_read_concurrency_test_impl::<1>(gba); - write_read_concurrency_test_impl::<2>(gba); - write_read_concurrency_test_impl::<3>(gba); - write_read_concurrency_test_impl::<4>(gba); - write_read_concurrency_test_impl::<5>(gba); - write_read_concurrency_test_impl::<6>(gba); - write_read_concurrency_test_impl::<7>(gba); - write_read_concurrency_test_impl::<8>(gba); - write_read_concurrency_test_impl::<9>(gba); - write_read_concurrency_test_impl::<10>(gba); + generate_concurrency_test!(1, gba); + generate_concurrency_test!(2, gba); + generate_concurrency_test!(3, gba); + generate_concurrency_test!(4, gba); + generate_concurrency_test!(5, gba); + generate_concurrency_test!(6, gba); + generate_concurrency_test!(7, gba); + generate_concurrency_test!(8, gba); + generate_concurrency_test!(9, gba); + generate_concurrency_test!(10, gba); } }