diff --git a/agb/Cargo.lock b/agb/Cargo.lock index 88f57060..a7a46153 100644 --- a/agb/Cargo.lock +++ b/agb/Cargo.lock @@ -16,6 +16,7 @@ dependencies = [ "agb_image_converter", "agb_macros", "agb_sound_converter", + "bare-metal", "bitflags", ] @@ -64,6 +65,12 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a" +[[package]] +name = "bare-metal" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8fe8f5a8a398345e52358e18ff07cc17a568fbca5c6f73873d3a62056309603" + [[package]] name = "bitflags" version = "1.3.2" diff --git a/agb/Cargo.toml b/agb/Cargo.toml index 627bfc96..dbd7d298 100644 --- a/agb/Cargo.toml +++ b/agb/Cargo.toml @@ -25,6 +25,7 @@ agb_image_converter = { version = "0.6.0", path = "../agb-image-converter" } agb_sound_converter = { version = "0.1.0", path = "../agb-sound-converter" } agb_macros = { version = "0.1.0", path = "../agb-macros" } agb_fixnum = { version = "0.1.0", path = "../agb-fixnum" } +bare-metal = "1.0" [package.metadata.docs.rs] default-target = "thumbv6m-none-eabi" diff --git a/agb/examples/output.rs b/agb/examples/output.rs index 85e9cf65..e28558f6 100644 --- a/agb/examples/output.rs +++ b/agb/examples/output.rs @@ -1,13 +1,20 @@ #![no_std] #![no_main] +use core::cell::RefCell; + +use bare_metal::{CriticalSection, Mutex}; + #[agb::entry] fn main(_gba: agb::Gba) -> ! { - let count = agb::interrupt::Mutex::new(0); - agb::add_interrupt_handler!(agb::interrupt::Interrupt::VBlank, |key| { - let mut count = count.lock_with_key(&key); - agb::println!("Hello, world, frame = {}", *count); - *count += 1; - }); + let count = Mutex::new(RefCell::new(0)); + agb::add_interrupt_handler!( + agb::interrupt::Interrupt::VBlank, + |key: &CriticalSection| { + let mut count = count.borrow(*key).borrow_mut(); + agb::println!("Hello, world, frame = {}", *count); + *count += 1; + } + ); loop {} } diff --git a/agb/examples/wave.rs b/agb/examples/wave.rs index 79b11081..9036eb9a 100644 --- a/agb/examples/wave.rs +++ b/agb/examples/wave.rs @@ -1,11 +1,14 @@ #![no_std] #![no_main] +use core::cell::RefCell; + use agb::{ display::example_logo, fixnum::FixedNum, - interrupt::{Interrupt, Mutex}, + interrupt::{free, Interrupt}, }; +use bare_metal::{CriticalSection, Mutex}; struct BackCosines { cosines: [u16; 32], @@ -21,10 +24,10 @@ fn main(mut gba: agb::Gba) -> ! { let mut time = 0; let cosines = [0_u16; 32]; - let back = Mutex::new(BackCosines { cosines, row: 0 }); + let back = Mutex::new(RefCell::new(BackCosines { cosines, row: 0 })); - agb::add_interrupt_handler!(Interrupt::HBlank, |_| { - let mut backc = back.lock(); + agb::add_interrupt_handler!(Interrupt::HBlank, |key: &CriticalSection| { + let mut backc = back.borrow(*key).borrow_mut(); let deflection = backc.cosines[backc.row % 32]; unsafe { ((0x0400_0010) as *mut u16).write_volatile(deflection) } backc.row += 1; @@ -34,14 +37,17 @@ fn main(mut gba: agb::Gba) -> ! { loop { vblank.wait_for_vblank(); - let mut backc = back.lock(); - backc.row = 0; - time += 1; - for (r, a) in backc.cosines.iter_mut().enumerate() { - let n: FixedNum<8> = (FixedNum::new(r as i32) / 32 + FixedNum::new(time) / 128).cos() - * (256 * 4 - 1) - / 256; - *a = (n.trunc() % (32 * 8)) as u16; - } + free(|key| { + let mut backc = back.borrow(*key).borrow_mut(); + backc.row = 0; + time += 1; + for (r, a) in backc.cosines.iter_mut().enumerate() { + let n: FixedNum<8> = (FixedNum::new(r as i32) / 32 + FixedNum::new(time) / 128) + .cos() + * (256 * 4 - 1) + / 256; + *a = (n.trunc() % (32 * 8)) as u16; + } + }) } } diff --git a/agb/src/agb_alloc/block_allocator.rs b/agb/src/agb_alloc/block_allocator.rs index 6ca3108c..1fc17a1b 100644 --- a/agb/src/agb_alloc/block_allocator.rs +++ b/agb/src/agb_alloc/block_allocator.rs @@ -1,13 +1,17 @@ use core::alloc::{GlobalAlloc, Layout}; + +use core::cell::RefCell; use core::ptr::NonNull; -use crate::interrupt::Mutex; +use crate::interrupt::free; +use bare_metal::Mutex; use super::bump_allocator::BumpAllocator; +use super::SendNonNull; struct Block { size: usize, - next: Option>, + next: Option>, } impl Block { @@ -25,27 +29,27 @@ impl Block { } struct BlockAllocatorState { - first_free_block: Option>, + first_free_block: Option>, } pub(crate) struct BlockAllocator { inner_allocator: BumpAllocator, - state: Mutex, + state: Mutex>, } impl BlockAllocator { pub(super) const unsafe fn new() -> Self { Self { inner_allocator: BumpAllocator::new(), - state: Mutex::new(BlockAllocatorState { + state: Mutex::new(RefCell::new(BlockAllocatorState { first_free_block: None, - }), + })), } } - unsafe fn new_block(&self, layout: Layout) -> *mut u8 { + fn new_block(&self, layout: Layout) -> *mut u8 { let overall_layout = Block::either_layout(layout); - self.inner_allocator.alloc(overall_layout) + self.inner_allocator.alloc_safe(overall_layout) } } @@ -58,46 +62,50 @@ unsafe impl GlobalAlloc for BlockAllocator { full_layout.extend(Layout::new::()).unwrap(); { - let mut state = self.state.lock(); - let mut current_block = state.first_free_block; - let mut list_ptr = &mut state.first_free_block; - while let Some(mut curr) = current_block { - let curr_block = curr.as_mut(); - if curr_block.size == full_layout.size() { - *list_ptr = curr_block.next; - return curr.as_ptr().cast(); - } else if curr_block.size >= block_after_layout.size() { - // can split block - let split_block = Block { - size: curr_block.size - block_after_layout_offset, - next: curr_block.next, - }; - let split_ptr = curr - .as_ptr() - .cast::() - .add(block_after_layout_offset) - .cast(); - *split_ptr = split_block; - *list_ptr = NonNull::new(split_ptr); + free(|key| { + let mut state = self.state.borrow(*key).borrow_mut(); + let mut current_block = state.first_free_block; + let mut list_ptr = &mut state.first_free_block; + while let Some(mut curr) = current_block { + let curr_block = curr.as_mut(); + if curr_block.size == full_layout.size() { + *list_ptr = curr_block.next; + return curr.as_ptr().cast(); + } else if curr_block.size >= block_after_layout.size() { + // can split block + let split_block = Block { + size: curr_block.size - block_after_layout_offset, + next: curr_block.next, + }; + let split_ptr = curr + .as_ptr() + .cast::() + .add(block_after_layout_offset) + .cast(); + *split_ptr = split_block; + *list_ptr = NonNull::new(split_ptr).map(SendNonNull); - return curr.as_ptr().cast(); + return curr.as_ptr().cast(); + } + current_block = curr_block.next; + list_ptr = &mut curr_block.next; } - current_block = curr_block.next; - list_ptr = &mut curr_block.next; - } - } - self.new_block(layout) + self.new_block(layout) + }) + } } unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) { let new_layout = Block::either_layout(layout); - let mut state = self.state.lock(); - let new_block_content = Block { - size: new_layout.size(), - next: state.first_free_block, - }; - *ptr.cast() = new_block_content; - state.first_free_block = NonNull::new(ptr.cast()); + free(|key| { + let mut state = self.state.borrow(*key).borrow_mut(); + let new_block_content = Block { + size: new_layout.size(), + next: state.first_free_block, + }; + *ptr.cast() = new_block_content; + state.first_free_block = NonNull::new(ptr.cast()).map(SendNonNull); + }) } } diff --git a/agb/src/agb_alloc/bump_allocator.rs b/agb/src/agb_alloc/bump_allocator.rs index 1fe1ef8a..671d4abe 100644 --- a/agb/src/agb_alloc/bump_allocator.rs +++ b/agb/src/agb_alloc/bump_allocator.rs @@ -1,45 +1,50 @@ use core::alloc::{GlobalAlloc, Layout}; +use core::cell::RefCell; use core::ptr::NonNull; -use crate::interrupt::Mutex; +use super::SendNonNull; +use crate::interrupt::free; +use bare_metal::Mutex; pub(crate) struct BumpAllocator { - current_ptr: Mutex>>, + current_ptr: Mutex>>>, } impl BumpAllocator { pub const fn new() -> Self { Self { - current_ptr: Mutex::new(None), + current_ptr: Mutex::new(RefCell::new(None)), } } } impl BumpAllocator { - fn alloc_safe(&self, layout: Layout) -> *mut u8 { - let mut current_ptr = self.current_ptr.lock(); + pub fn alloc_safe(&self, layout: Layout) -> *mut u8 { + free(|key| { + let mut current_ptr = self.current_ptr.borrow(*key).borrow_mut(); - let ptr = if let Some(c) = *current_ptr { - c.as_ptr() as usize - } else { - get_data_end() - }; + let ptr = if let Some(c) = *current_ptr { + c.as_ptr() as usize + } else { + get_data_end() + }; - let alignment_bitmask = layout.align() - 1; - let fixup = ptr & alignment_bitmask; + let alignment_bitmask = layout.align() - 1; + let fixup = ptr & alignment_bitmask; - let amount_to_add = layout.align() - fixup; + let amount_to_add = layout.align() - fixup; - let resulting_ptr = ptr + amount_to_add; - let new_current_ptr = resulting_ptr + layout.size(); + let resulting_ptr = ptr + amount_to_add; + let new_current_ptr = resulting_ptr + layout.size(); - if new_current_ptr as usize >= super::EWRAM_END { - return core::ptr::null_mut(); - } + if new_current_ptr as usize >= super::EWRAM_END { + return core::ptr::null_mut(); + } - *current_ptr = NonNull::new(new_current_ptr as *mut _); + *current_ptr = NonNull::new(new_current_ptr as *mut _).map(SendNonNull); - resulting_ptr as *mut _ + resulting_ptr as *mut _ + }) } } diff --git a/agb/src/agb_alloc/mod.rs b/agb/src/agb_alloc/mod.rs index 9e9408f4..cc31c33b 100644 --- a/agb/src/agb_alloc/mod.rs +++ b/agb/src/agb_alloc/mod.rs @@ -1,10 +1,35 @@ use core::alloc::Layout; +use core::ops::{Deref, DerefMut}; +use core::ptr::NonNull; mod block_allocator; mod bump_allocator; use block_allocator::BlockAllocator; +struct SendNonNull(NonNull); +unsafe impl Send for SendNonNull {} + +impl Clone for SendNonNull { + fn clone(&self) -> Self { + SendNonNull(self.0) + } +} +impl Copy for SendNonNull {} + +impl Deref for SendNonNull { + type Target = NonNull; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for SendNonNull { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + const EWRAM_END: usize = 0x0204_0000; #[global_allocator] diff --git a/agb/src/interrupt.rs b/agb/src/interrupt.rs index ec364e76..0c81656e 100644 --- a/agb/src/interrupt.rs +++ b/agb/src/interrupt.rs @@ -1,10 +1,11 @@ use core::{ - cell::{Cell, UnsafeCell}, + cell::Cell, marker::{PhantomData, PhantomPinned}, - ops::{Deref, DerefMut}, pin::Pin, }; +use bare_metal::CriticalSection; + use crate::{display::DISPLAY_STATUS, memory_mapped::MemoryMapped}; #[derive(Clone, Copy)] @@ -70,21 +71,22 @@ impl Interrupt { const ENABLED_INTERRUPTS: MemoryMapped = unsafe { MemoryMapped::new(0x04000200) }; const INTERRUPTS_ENABLED: MemoryMapped = unsafe { MemoryMapped::new(0x04000208) }; -struct Disable {} +struct Disable { + pre: u16, +} impl Drop for Disable { fn drop(&mut self) { - enable_interrupts(); + INTERRUPTS_ENABLED.set(self.pre); } } fn temporary_interrupt_disable() -> Disable { + let d = Disable { + pre: INTERRUPTS_ENABLED.get(), + }; disable_interrupts(); - Disable {} -} - -fn enable_interrupts() { - INTERRUPTS_ENABLED.set(1); + d } fn disable_interrupts() { @@ -158,7 +160,7 @@ pub struct InterruptClosureBounded<'a> { } struct InterruptClosure { - closure: *const (dyn Fn(Key)), + closure: *const (dyn Fn(&CriticalSection)), next: Cell<*const InterruptClosure>, root: *const InterruptRoot, } @@ -169,7 +171,7 @@ impl InterruptRoot { while !c.is_null() { let closure_ptr = unsafe { &*c }.closure; let closure_ref = unsafe { &*closure_ptr }; - closure_ref(Key()); + closure_ref(unsafe { &CriticalSection::new() }); c = unsafe { &*c }.next.get(); } } @@ -201,7 +203,7 @@ fn interrupt_to_root(interrupt: Interrupt) -> &'static InterruptRoot { } fn get_interrupt_handle_root<'a>( - f: &'a dyn Fn(Key), + f: &'a dyn Fn(&CriticalSection), root: &InterruptRoot, ) -> InterruptClosureBounded<'a> { InterruptClosureBounded { @@ -218,7 +220,7 @@ fn get_interrupt_handle_root<'a>( /// The [add_interrupt_handler!] macro should be used instead of this function. /// Creates an interrupt handler from a closure. pub fn get_interrupt_handle( - f: &(dyn Fn(Key) + Send + Sync), + f: &(dyn Fn(&CriticalSection) + Send + Sync), interrupt: Interrupt, ) -> InterruptClosureBounded { let root = interrupt_to_root(interrupt); @@ -230,22 +232,24 @@ pub fn get_interrupt_handle( /// Adds an interrupt handler to the interrupt table such that when that /// interrupt is triggered the closure is called. pub fn add_interrupt<'a>(interrupt: Pin<&'a InterruptClosureBounded<'a>>) { - let root = unsafe { &*interrupt.c.root }; - root.add(); - let mut c = root.next.get(); - if c.is_null() { - root.next.set((&interrupt.c) as *const _); - return; - } - loop { - let p = unsafe { &*c }.next.get(); - if p.is_null() { - unsafe { &*c }.next.set((&interrupt.c) as *const _); + free(|_| { + let root = unsafe { &*interrupt.c.root }; + root.add(); + let mut c = root.next.get(); + if c.is_null() { + root.next.set((&interrupt.c) as *const _); return; } + loop { + let p = unsafe { &*c }.next.get(); + if p.is_null() { + unsafe { &*c }.next.set((&interrupt.c) as *const _); + return; + } - c = p; - } + c = p; + } + }) } #[macro_export] @@ -270,90 +274,18 @@ macro_rules! add_interrupt_handler { }; } -#[derive(Clone, Copy, PartialEq, Eq, Debug)] -enum MutexState { - Unlocked, - Locked(bool), -} +pub fn free(f: F) -> R +where + F: FnOnce(&CriticalSection) -> R, +{ + let enabled = INTERRUPTS_ENABLED.get(); -pub struct Mutex { - internal: UnsafeCell, - state: UnsafeCell, -} + disable_interrupts(); -#[non_exhaustive] -pub struct Key(); + let r = f(unsafe { &CriticalSection::new() }); -unsafe impl Send for Mutex {} -unsafe impl Sync for Mutex {} - -impl Mutex { - pub fn lock(&self) -> MutexRef { - let state = INTERRUPTS_ENABLED.get(); - INTERRUPTS_ENABLED.set(0); - assert_eq!( - unsafe { *self.state.get() }, - MutexState::Unlocked, - "mutex must be unlocked to be able to lock it" - ); - unsafe { *self.state.get() = MutexState::Locked(state != 0) }; - MutexRef { - internal: &self.internal, - state: &self.state, - } - } - - pub fn lock_with_key(&self, _key: &Key) -> MutexRef { - assert_eq!( - unsafe { *self.state.get() }, - MutexState::Unlocked, - "mutex must be unlocked to be able to lock it" - ); - unsafe { *self.state.get() = MutexState::Locked(false) }; - MutexRef { - internal: &self.internal, - state: &self.state, - } - } - - pub const fn new(val: T) -> Self { - Mutex { - internal: UnsafeCell::new(val), - state: UnsafeCell::new(MutexState::Unlocked), - } - } -} - -pub struct MutexRef<'a, T> { - internal: &'a UnsafeCell, - state: &'a UnsafeCell, -} - -impl<'a, T> Drop for MutexRef<'a, T> { - fn drop(&mut self) { - let state = unsafe { &mut *self.state.get() }; - - let prev_state = *state; - *state = MutexState::Unlocked; - - match prev_state { - MutexState::Locked(b) => INTERRUPTS_ENABLED.set(b as u16), - MutexState::Unlocked => {} - } - } -} - -impl<'a, T> Deref for MutexRef<'a, T> { - type Target = T; - fn deref(&self) -> &Self::Target { - unsafe { &*self.internal.get() } - } -} - -impl<'a, T> DerefMut for MutexRef<'a, T> { - fn deref_mut(&mut self) -> &mut Self::Target { - unsafe { &mut *self.internal.get() } - } + INTERRUPTS_ENABLED.set(enabled); + r } #[non_exhaustive] @@ -382,18 +314,28 @@ impl Drop for VBlank { #[cfg(test)] mod tests { use super::*; + use bare_metal::Mutex; + use core::cell::RefCell; #[test_case] fn test_vblank_interrupt_handler(_gba: &mut crate::Gba) { { - let counter = Mutex::new(0); - let counter_2 = Mutex::new(0); - add_interrupt_handler!(Interrupt::VBlank, |key| *counter.lock_with_key(&key) += 1); - add_interrupt_handler!(Interrupt::VBlank, |_| *counter_2.lock() += 1); + let counter = Mutex::new(RefCell::new(0)); + let counter_2 = Mutex::new(RefCell::new(0)); + add_interrupt_handler!(Interrupt::VBlank, |key: &CriticalSection| *counter + .borrow(*key) + .borrow_mut() += + 1); + add_interrupt_handler!(Interrupt::VBlank, |key: &CriticalSection| *counter_2 + .borrow(*key) + .borrow_mut() += + 1); let vblank = VBlank::get(); - while *counter.lock() < 100 || *counter_2.lock() < 100 { + while free(|key| { + *counter.borrow(*key).borrow() < 100 || *counter_2.borrow(*key).borrow() < 100 + }) { vblank.wait_for_vblank(); } }