diff --git a/CHANGELOG.md b/CHANGELOG.md index c15cd3ed..841c6936 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - The template now uses rust 2021 edition by default. - All objects which should only be created once now have the correct lifetimes to only allow one to exist. - Template now uses codegen-units=1 to workaround bug in nightly. +- Allocator is no longer interrupt safe. +- Soundness issues with interrupts resolved which makes them unsafe and require the closure to be static (breaking change). ### Fixed - Alpha channel is now considered by `include_gfx!()` even when `transparent_colour` is absent. diff --git a/agb/examples/output.rs b/agb/examples/output.rs index 93b8595a..8aabc69d 100644 --- a/agb/examples/output.rs +++ b/agb/examples/output.rs @@ -3,13 +3,16 @@ use agb::sync::Static; +static COUNT: Static = Static::new(0); + #[agb::entry] fn main(_gba: agb::Gba) -> ! { - let count = Static::new(0); - let _a = agb::interrupt::add_interrupt_handler(agb::interrupt::Interrupt::VBlank, |_| { - let cur_count = count.read(); - agb::println!("Hello, world, frame = {}", cur_count); - count.write(cur_count + 1); - }); + let _a = unsafe { + agb::interrupt::add_interrupt_handler(agb::interrupt::Interrupt::VBlank, |_| { + let cur_count = COUNT.read(); + agb::println!("Hello, world, frame = {}", cur_count); + COUNT.write(cur_count + 1); + }) + }; loop {} } diff --git a/agb/examples/wave.rs b/agb/examples/wave.rs index 491c0e0b..1e9cf223 100644 --- a/agb/examples/wave.rs +++ b/agb/examples/wave.rs @@ -18,6 +18,11 @@ struct BackCosines { row: usize, } +static BACK: Mutex> = Mutex::new(RefCell::new(BackCosines { + cosines: [0; 32], + row: 0, +})); + #[agb::entry] fn main(mut gba: agb::Gba) -> ! { let (gfx, mut vram) = gba.display.video.tiled0(); @@ -30,24 +35,22 @@ fn main(mut gba: agb::Gba) -> ! { example_logo::display_logo(&mut background, &mut vram); - let mut time = 0; - let cosines = [0_u16; 32]; - - let back = Mutex::new(RefCell::new(BackCosines { cosines, row: 0 })); - - let _a = agb::interrupt::add_interrupt_handler(Interrupt::HBlank, |key: CriticalSection| { - let mut back = back.borrow(key).borrow_mut(); - let deflection = back.cosines[back.row % 32]; - unsafe { ((0x0400_0010) as *mut u16).write_volatile(deflection) } - back.row += 1; - }); + let _a = unsafe { + agb::interrupt::add_interrupt_handler(Interrupt::HBlank, |key: CriticalSection| { + let mut back = BACK.borrow(key).borrow_mut(); + let deflection = back.cosines[back.row % 32]; + ((0x0400_0010) as *mut u16).write_volatile(deflection); + back.row += 1; + }) + }; let vblank = agb::interrupt::VBlank::get(); + let mut time = 0; loop { vblank.wait_for_vblank(); free(|key| { - let mut back = back.borrow(key).borrow_mut(); + let mut back = BACK.borrow(key).borrow_mut(); back.row = 0; time += 1; for (r, a) in back.cosines.iter_mut().enumerate() { diff --git a/agb/src/agb_alloc/block_allocator.rs b/agb/src/agb_alloc/block_allocator.rs index 14eb4638..921a4308 100644 --- a/agb/src/agb_alloc/block_allocator.rs +++ b/agb/src/agb_alloc/block_allocator.rs @@ -5,13 +5,10 @@ use core::alloc::{Allocator, GlobalAlloc, Layout}; -use core::cell::RefCell; +use core::cell::UnsafeCell; use core::convert::TryInto; use core::ptr::NonNull; -use crate::interrupt::free; -use bare_metal::Mutex; - use super::bump_allocator::{BumpAllocatorInner, StartEnd}; use super::SendNonNull; @@ -53,36 +50,45 @@ struct BlockAllocatorInner { } pub struct BlockAllocator { - inner: Mutex>, + inner: UnsafeCell, } +unsafe impl Sync for BlockAllocator {} + impl BlockAllocator { pub(crate) const unsafe fn new(start: StartEnd) -> Self { Self { - inner: Mutex::new(RefCell::new(BlockAllocatorInner::new(start))), + inner: UnsafeCell::new(BlockAllocatorInner::new(start)), } } + #[inline(always)] + unsafe fn with_inner(&self, f: F) -> T + where + F: Fn(&mut BlockAllocatorInner) -> T, + { + let inner = &mut *self.inner.get(); + + f(inner) + } + #[doc(hidden)] #[cfg(any(test, feature = "testing"))] pub unsafe fn number_of_blocks(&self) -> u32 { - free(|key| self.inner.borrow(key).borrow_mut().number_of_blocks()) + self.with_inner(|inner| inner.number_of_blocks()) } pub unsafe fn alloc(&self, layout: Layout) -> Option> { - free(|key| self.inner.borrow(key).borrow_mut().alloc(layout)) + self.with_inner(|inner| inner.alloc(layout)) } pub unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) { - free(|key| self.inner.borrow(key).borrow_mut().dealloc(ptr, layout)); + self.with_inner(|inner| inner.dealloc(ptr, layout)); } pub unsafe fn dealloc_no_normalise(&self, ptr: *mut u8, layout: Layout) { - free(|key| { - self.inner - .borrow(key) - .borrow_mut() - .dealloc_no_normalise(ptr, layout); + self.with_inner(|inner| { + inner.dealloc_no_normalise(ptr, layout); }); } @@ -92,12 +98,7 @@ impl BlockAllocator { layout: Layout, new_layout: Layout, ) -> Option> { - free(|key| { - self.inner - .borrow(key) - .borrow_mut() - .grow(ptr, layout, new_layout) - }) + self.with_inner(|inner| inner.grow(ptr, layout, new_layout)) } } diff --git a/agb/src/interrupt.rs b/agb/src/interrupt.rs index 68aa0789..6cb8593d 100644 --- a/agb/src/interrupt.rs +++ b/agb/src/interrupt.rs @@ -1,8 +1,4 @@ -use core::{ - cell::Cell, - marker::{PhantomData, PhantomPinned}, - pin::Pin, -}; +use core::{cell::Cell, marker::PhantomPinned, pin::Pin}; use alloc::boxed::Box; use bare_metal::CriticalSection; @@ -206,9 +202,8 @@ impl Drop for InterruptInner { } } -pub struct InterruptHandler<'a> { +pub struct InterruptHandler { _inner: Pin>, - _lifetime: PhantomData<&'a ()>, } impl InterruptRoot { @@ -231,6 +226,14 @@ fn interrupt_to_root(interrupt: Interrupt) -> &'static InterruptRoot { /// Adds an interrupt handler as long as the returned value is alive. The /// closure takes a [`CriticalSection`] which can be used for mutexes. /// +/// # Safety +/// * You *must not* allocate in an interrupt. +/// - Many functions in agb allocate and it isn't always clear. +/// +/// # Staticness +/// * The closure must be static because forgetting the interrupt handler would +/// cause a use after free. +/// /// [`CriticalSection`]: bare_metal::CriticalSection /// /// # Examples @@ -238,23 +241,22 @@ fn interrupt_to_root(interrupt: Interrupt) -> &'static InterruptRoot { /// ```rust,no_run /// # #![no_std] /// # #![no_main] -/// use bare_metal::CriticalSection; -/// /// # fn foo() { -/// # use agb::interrupt::{add_interrupt_handler, Interrupt}; -/// let _a = add_interrupt_handler(Interrupt::VBlank, |_: CriticalSection| { -/// agb::println!("Woah there! There's been a vblank!"); -/// }); +/// use bare_metal::CriticalSection; +/// use agb::interrupt::{add_interrupt_handler, Interrupt}; +/// // Safety: doesn't allocate +/// let _a = unsafe { +/// add_interrupt_handler(Interrupt::VBlank, |_: CriticalSection| { +/// agb::println!("Woah there! There's been a vblank!"); +/// }) +/// }; /// # } /// ``` -pub fn add_interrupt_handler<'a>( +pub unsafe fn add_interrupt_handler( interrupt: Interrupt, - handler: impl Fn(CriticalSection) + Send + Sync + 'a, -) -> InterruptHandler<'a> { - fn do_with_inner<'a>( - interrupt: Interrupt, - inner: Pin>, - ) -> InterruptHandler<'a> { + handler: impl Fn(CriticalSection) + Send + Sync + 'static, +) -> InterruptHandler { + fn do_with_inner(interrupt: Interrupt, inner: Pin>) -> InterruptHandler { free(|_| { let root = interrupt_to_root(interrupt); root.add(); @@ -274,10 +276,7 @@ pub fn add_interrupt_handler<'a>( } }); - InterruptHandler { - _inner: inner, - _lifetime: PhantomData, - } + InterruptHandler { _inner: inner } } let root = interrupt_to_root(interrupt) as *const _; let inner = unsafe { create_interrupt_inner(handler, root) }; @@ -322,9 +321,12 @@ impl VBlank { #[must_use] pub fn get() -> Self { if !HAS_CREATED_INTERRUPT.read() { - let handler = add_interrupt_handler(Interrupt::VBlank, |_| { - NUM_VBLANKS.write(NUM_VBLANKS.read() + 1); - }); + // safety: we don't allocate in the interrupt + let handler = unsafe { + add_interrupt_handler(Interrupt::VBlank, |_| { + NUM_VBLANKS.write(NUM_VBLANKS.read() + 1); + }) + }; core::mem::forget(handler); HAS_CREATED_INTERRUPT.write(true); @@ -351,36 +353,6 @@ impl VBlank { #[cfg(test)] mod tests { use super::*; - use bare_metal::Mutex; - use core::cell::RefCell; - - #[test_case] - fn test_can_create_and_destroy_interrupt_handlers(_gba: &mut crate::Gba) { - let mut counter = Mutex::new(RefCell::new(0)); - let counter_2 = Mutex::new(RefCell::new(0)); - - let vblank = VBlank::get(); - - { - let _a = add_interrupt_handler(Interrupt::VBlank, |key: CriticalSection| { - *counter.borrow(key).borrow_mut() += 1; - }); - let _b = add_interrupt_handler(Interrupt::VBlank, |key: CriticalSection| { - *counter_2.borrow(key).borrow_mut() += 1; - }); - - while free(|key| { - *counter.borrow(key).borrow() < 100 || *counter_2.borrow(key).borrow() < 100 - }) { - vblank.wait_for_vblank(); - } - } - - vblank.wait_for_vblank(); - vblank.wait_for_vblank(); - - assert_eq!(*counter.get_mut().get_mut(), 100); - } #[test_case] fn test_interrupt_table_length(_gba: &mut crate::Gba) { @@ -406,7 +378,9 @@ pub fn profiler(timer: &mut crate::timer::Timer, period: u16) -> InterruptHandle timer.set_overflow_amount(period); timer.set_enabled(true); - add_interrupt_handler(timer.interrupt(), |_key: CriticalSection| { - crate::println!("{:#010x}", crate::program_counter_before_interrupt()); - }) + unsafe { + add_interrupt_handler(timer.interrupt(), |_key: CriticalSection| { + crate::println!("{:#010x}", crate::program_counter_before_interrupt()); + }) + } } diff --git a/agb/src/sound/mixer/sw_mixer.rs b/agb/src/sound/mixer/sw_mixer.rs index 4f32e786..8984f842 100644 --- a/agb/src/sound/mixer/sw_mixer.rs +++ b/agb/src/sound/mixer/sw_mixer.rs @@ -82,7 +82,7 @@ extern "C" { pub struct Mixer<'gba> { interrupt_timer: Timer, // SAFETY: Has to go before buffer because it holds a reference to it - _interrupt_handler: InterruptHandler<'static>, + _interrupt_handler: InterruptHandler, buffer: Pin>, channels: [Option; 8], @@ -140,9 +140,11 @@ impl Mixer<'_> { // In the case of the mixer being forgotten, both stay alive so okay let buffer_pointer_for_interrupt_handler: &MixerBuffer = unsafe { core::mem::transmute(buffer_pointer_for_interrupt_handler) }; - let interrupt_handler = add_interrupt_handler(interrupt_timer.interrupt(), |cs| { - buffer_pointer_for_interrupt_handler.swap(cs); - }); + let interrupt_handler = unsafe { + add_interrupt_handler(interrupt_timer.interrupt(), |cs| { + buffer_pointer_for_interrupt_handler.swap(cs); + }) + }; set_asm_buffer_size(frequency); diff --git a/agb/src/sync/statics.rs b/agb/src/sync/statics.rs index 69536c6c..68eb5030 100644 --- a/agb/src/sync/statics.rs +++ b/agb/src/sync/statics.rs @@ -267,65 +267,71 @@ mod test { use crate::timer::Divider; use crate::Gba; - fn write_read_concurrency_test_impl(gba: &mut Gba) { - let sentinel = [0x12345678; COUNT]; - let value: Static<[u32; COUNT]> = Static::new(sentinel); + macro_rules! generate_concurrency_test { + ($count:literal, $gba:ident) => {{ + (|gba: &mut Gba| { + const SENTINEL: [u32; $count] = [0x12345678; $count]; + static VALUE: Static<[u32; $count]> = Static::new(SENTINEL); - // set up a timer and an interrupt that uses the timer - let mut timer = gba.timers.timers().timer2; - timer.set_cascade(false); - timer.set_divider(Divider::Divider1); - timer.set_overflow_amount(1049); - timer.set_interrupt(true); - timer.set_enabled(true); + // set up a timer and an interrupt that uses the timer + let mut timer = gba.timers.timers().timer2; + timer.set_cascade(false); + timer.set_divider(Divider::Divider1); + timer.set_overflow_amount(1049); + timer.set_interrupt(true); + timer.set_enabled(true); - let _int = crate::interrupt::add_interrupt_handler(Interrupt::Timer2, |_| { - value.write(sentinel); - }); + let _int = unsafe { + crate::interrupt::add_interrupt_handler(Interrupt::Timer2, |_| { + VALUE.write(SENTINEL); + }) + }; - // the actual main test loop - let mut interrupt_seen = false; - let mut no_interrupt_seen = false; - for i in 0..250000 { - // write to the static - let new_value = [i; COUNT]; - value.write(new_value); + // the actual main test loop + let mut interrupt_seen = false; + let mut no_interrupt_seen = false; + for i in 0..250000 { + // write to the static + let new_value = [i; $count]; + VALUE.write(new_value); - // check the current value - let current = value.read(); - if current == new_value { - no_interrupt_seen = true; - } else if current == sentinel { - interrupt_seen = true; - } else { - panic!("Unexpected value found in `Static`."); - } + // check the current value + let current = VALUE.read(); + if current == new_value { + no_interrupt_seen = true; + } else if current == SENTINEL { + interrupt_seen = true; + } else { + panic!("Unexpected value found in `Static`."); + } - // we return as soon as we've seen both the value written by the main thread - // and interrupt - if interrupt_seen && no_interrupt_seen { - timer.set_enabled(false); - return; - } + // we return as soon as we've seen both the value written by the main thread + // and interrupt + if interrupt_seen && no_interrupt_seen { + timer.set_enabled(false); + return; + } - if i % 8192 == 0 && i != 0 { - timer.set_overflow_amount(1049 + (i / 64) as u16); - } - } - panic!("Concurrency test timed out: {}", COUNT) + if i % 8192 == 0 && i != 0 { + timer.set_overflow_amount(1049 + (i / 64) as u16); + } + } + panic!("Concurrency test timed out: {}", $count) + })($gba); + }}; } #[test_case] fn write_read_concurrency_test(gba: &mut Gba) { - write_read_concurrency_test_impl::<1>(gba); - write_read_concurrency_test_impl::<2>(gba); - write_read_concurrency_test_impl::<3>(gba); - write_read_concurrency_test_impl::<4>(gba); - write_read_concurrency_test_impl::<5>(gba); - write_read_concurrency_test_impl::<6>(gba); - write_read_concurrency_test_impl::<7>(gba); - write_read_concurrency_test_impl::<8>(gba); - write_read_concurrency_test_impl::<9>(gba); - write_read_concurrency_test_impl::<10>(gba); + generate_concurrency_test!(1, gba); + generate_concurrency_test!(2, gba); + generate_concurrency_test!(3, gba); + generate_concurrency_test!(4, gba); + generate_concurrency_test!(5, gba); + generate_concurrency_test!(6, gba); + generate_concurrency_test!(7, gba); + generate_concurrency_test!(8, gba); + generate_concurrency_test!(9, gba); + generate_concurrency_test!(10, gba); } }