diff --git a/agb/src/interrupt.rs b/agb/src/interrupt.rs index c88fb9fd..6aba025c 100644 --- a/agb/src/interrupt.rs +++ b/agb/src/interrupt.rs @@ -288,7 +288,7 @@ pub fn add_interrupt_handler<'a>( /// [`CriticalSection`] /// /// [`CriticalSection`]: bare_metal::CriticalSection -pub fn free(f: F) -> R +pub fn free(mut f: F) -> R where F: FnOnce(CriticalSection) -> R, { @@ -296,7 +296,13 @@ where disable_interrupts(); - let r = f(unsafe { CriticalSection::new() }); + // prevents the contents of the function from being reordered before IME is disabled. + crate::sync::memory_write_hint(&mut f); + + let mut r = f(unsafe { CriticalSection::new() }); + + // prevents the contents of the function from being reordered after IME is re-enabled. + crate::sync::memory_write_hint(&mut r); INTERRUPTS_ENABLED.set(enabled); r diff --git a/agb/src/lib.rs b/agb/src/lib.rs index 957e5950..fa3dd6f9 100644 --- a/agb/src/lib.rs +++ b/agb/src/lib.rs @@ -13,6 +13,7 @@ #![feature(alloc_error_handler)] #![feature(allocator_api)] #![feature(asm_const)] +#![feature(isa_attribute)] #![warn(clippy::all)] #![deny(clippy::must_use_candidate)] #![deny(clippy::trivially_copy_pass_by_ref)] @@ -170,6 +171,8 @@ pub mod rng; mod single; /// Implements sound output. pub mod sound; +/// A module containing functions and utilities useful for synchronizing state. +pub mod sync; /// System BIOS calls / syscalls. pub mod syscall; /// Interactions with the internal timers diff --git a/agb/src/sync/locks.rs b/agb/src/sync/locks.rs new file mode 100644 index 00000000..ce93465d --- /dev/null +++ b/agb/src/sync/locks.rs @@ -0,0 +1,203 @@ +use crate::sync::Static; +use core::cell::UnsafeCell; +use core::mem::MaybeUninit; +use core::ops::{Deref, DerefMut}; +use core::ptr; +use core::sync::atomic::{compiler_fence, Ordering}; + +#[inline(never)] +fn already_locked() -> ! { + panic!("IRQ and main thread are attempting to access the same Mutex!") +} + +/// A mutex that prevents code from running in both an IRQ and normal code at +/// the same time. +/// +/// Note that this does not support blocking like a typical mutex, and instead +/// mainly exists for memory safety reasons. +pub struct RawMutex(Static); +impl RawMutex { + /// Creates a new lock. + #[must_use] + pub const fn new() -> Self { + RawMutex(Static::new(false)) + } + + /// Locks the mutex and returns whether a lock was successfully acquired. + fn raw_lock(&self) -> bool { + if self.0.replace(true) { + // value was already true, opps. + false + } else { + // prevent any weird reordering, and continue + compiler_fence(Ordering::Acquire); + true + } + } + + /// Unlocks the mutex. + fn raw_unlock(&self) { + compiler_fence(Ordering::Release); + if !self.0.replace(false) { + panic!("Internal error: Attempt to unlock a `RawMutex` which is not locked.") + } + } + + /// Returns a guard for this lock, or panics if there is another lock active. + pub fn lock(&self) -> RawMutexGuard<'_> { + self.try_lock().unwrap_or_else(|| already_locked()) + } + + /// Returns a guard for this lock, or `None` if there is another lock active. + pub fn try_lock(&self) -> Option> { + if self.raw_lock() { + Some(RawMutexGuard(self)) + } else { + None + } + } +} +unsafe impl Send for RawMutex {} +unsafe impl Sync for RawMutex {} + +/// A guard representing an active lock on an [`RawMutex`]. +pub struct RawMutexGuard<'a>(&'a RawMutex); +impl<'a> Drop for RawMutexGuard<'a> { + fn drop(&mut self) { + self.0.raw_unlock(); + } +} + +/// A mutex that protects an object from being accessed in both an IRQ and +/// normal code at once. +/// +/// Note that this does not support blocking like a typical mutex, and instead +/// mainly exists for memory safety reasons. +pub struct Mutex { + raw: RawMutex, + data: UnsafeCell, +} +impl Mutex { + /// Creates a new lock containing a given value. + #[must_use] + pub const fn new(t: T) -> Self { + Mutex { raw: RawMutex::new(), data: UnsafeCell::new(t) } + } + + /// Returns a guard for this lock, or panics if there is another lock active. + pub fn lock(&self) -> MutexGuard<'_, T> { + self.try_lock().unwrap_or_else(|| already_locked()) + } + + /// Returns a guard for this lock or `None` if there is another lock active. + pub fn try_lock(&self) -> Option> { + if self.raw.raw_lock() { + Some(MutexGuard { underlying: self, ptr: self.data.get() }) + } else { + None + } + } +} +unsafe impl Send for Mutex {} +unsafe impl Sync for Mutex {} + +/// A guard representing an active lock on an [`Mutex`]. +pub struct MutexGuard<'a, T> { + underlying: &'a Mutex, + ptr: *mut T, +} +impl<'a, T> Drop for MutexGuard<'a, T> { + fn drop(&mut self) { + self.underlying.raw.raw_unlock(); + } +} +impl<'a, T> Deref for MutexGuard<'a, T> { + type Target = T; + fn deref(&self) -> &Self::Target { + unsafe { &*self.ptr } + } +} +impl<'a, T> DerefMut for MutexGuard<'a, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { &mut *self.ptr } + } +} + +enum Void {} + +/// A helper type that ensures a particular value is only initialized once. +pub struct InitOnce { + is_initialized: Static, + value: UnsafeCell>, +} +impl InitOnce { + /// Creates a new uninitialized object. + #[must_use] + pub const fn new() -> Self { + InitOnce { + is_initialized: Static::new(false), + value: UnsafeCell::new(MaybeUninit::uninit()), + } + } + + /// Gets the contents of this state, or initializes it if it has not already + /// been initialized. + /// + /// The initializer function is guaranteed to only be called once. + /// + /// This function disables IRQs while it is initializing the inner value. + /// While this can cause audio skipping and other similar issues, it is + /// not normally a problem as interrupts will only be disabled once per + /// `InitOnce` during the life cycle of the program. + pub fn get(&self, initializer: impl FnOnce() -> T) -> &T { + match self.try_get(|| -> Result { Ok(initializer()) }) { + Ok(v) => v, + _ => unimplemented!(), + } + } + + /// Gets the contents of this state, or initializes it if it has not already + /// been initialized. + /// + /// The initializer function is guaranteed to only be called once if it + /// returns `Ok`. If it returns `Err`, it will be called again in the + /// future until an attempt at initialization succeeds. + /// + /// This function disables IRQs while it is initializing the inner value. + /// While this can cause audio skipping and other similar issues, it is + /// not normally a problem as interrupts will only be disabled once per + /// `InitOnce` during the life cycle of the program. + pub fn try_get(&self, initializer: impl FnOnce() -> Result) -> Result<&T, E> { + unsafe { + if !self.is_initialized.read() { + // We disable interrupts to make this simpler, since this is likely to + // only occur once in a program anyway. + crate::interrupt::free(|_| -> Result<(), E> { + // We check again to make sure this function wasn't called in an + // interrupt between the first check and when interrupts were + // actually disabled. + if !self.is_initialized.read() { + // Do the actual initialization. + ptr::write_volatile((*self.value.get()).as_mut_ptr(), initializer()?); + self.is_initialized.write(true); + } + Ok(()) + })?; + } + compiler_fence(Ordering::Acquire); + Ok(&*(*self.value.get()).as_mut_ptr()) + } + } +} +impl Drop for InitOnce { + fn drop(&mut self) { + if self.is_initialized.read() { + // drop the value inside the `MaybeUninit` + unsafe { + ptr::read((*self.value.get()).as_ptr()); + } + } + } +} +unsafe impl Send for InitOnce {} +unsafe impl Sync for InitOnce {} diff --git a/agb/src/sync/mod.rs b/agb/src/sync/mod.rs new file mode 100644 index 00000000..17dad495 --- /dev/null +++ b/agb/src/sync/mod.rs @@ -0,0 +1,48 @@ +mod locks; +mod statics; + +pub use locks::*; +pub use statics::*; + +use core::arch::asm; + +/// Marks that a pointer is read without actually reading from this. +/// +/// This uses an [`asm!`] instruction that marks the parameter as being read, +/// requiring the compiler to treat this function as if anything could be +/// done to it. +#[inline(always)] +pub fn memory_read_hint(val: *const T) { + unsafe { asm!("/* {0} */", in(reg) val, options(readonly, nostack)) } +} + +/// Marks that a pointer is read or written to without actually writing to it. +/// +/// This uses an [`asm!`] instruction that marks the parameter as being read +/// and written, requiring the compiler to treat this function as if anything +/// could be done to it. +#[inline(always)] +pub fn memory_write_hint(val: *mut T) { + unsafe { asm!("/* {0} */", in(reg) val, options(nostack)) } +} + +/// An internal function used as a temporary hack to get `compiler_fence` +/// working. While this call is not properly inlined, working is better than not +/// working at all. +/// +/// This seems to be a problem caused by Rust issue #62256: +/// +/// +/// **WARNING FOR ANYONE WHO FINDS THIS**: This implementation will *only* be +/// correct on the GBA, and should not be used on any other platform. The GBA +/// is very old, and has no atomics to begin with - only a main thread and +/// interrupts. On any more recent CPU, this implementation is extremely +/// unlikely to be sound. +/// +/// Not public API, obviously. +#[doc(hidden)] +#[deprecated] +#[allow(dead_code)] +#[no_mangle] +#[inline(always)] +pub unsafe extern "C" fn __sync_synchronize() {} diff --git a/agb/src/sync/statics.rs b/agb/src/sync/statics.rs new file mode 100644 index 00000000..3970589f --- /dev/null +++ b/agb/src/sync/statics.rs @@ -0,0 +1,329 @@ +use core::arch::asm; +use core::cell::UnsafeCell; +use core::mem; +use core::ptr; + +/// The internal function for replacing a `Copy` (really `!Drop`) value in a +/// [`Static`]. This uses assembly to use an `stmia` instruction to ensure +/// an IRQ cannot occur during the write operation. +unsafe fn transfer(dst: *mut T, src: *const T) { + let align = mem::align_of::(); + let size = mem::size_of::(); + + if size == 0 { + // Do nothing with ZSTs. + } else if size <= 16 && align % 4 == 0 { + // We can do an 4-byte aligned transfer up to 16 bytes. + transfer_align4_thumb(dst, src); + } else if size <= 40 && align % 4 == 0 { + // We can do the same up to 40 bytes, but we need to switch to ARM. + transfer_align4_arm(dst, src); + } else if size <= 2 && align % 2 == 0 { + // We can do a 2-byte aligned transfer up to 2 bytes. + asm!( + "ldrh {2},[{0}]", + "strh {2},[{1}]", + in(reg) src, in(reg) dst, out(reg) _, + ); + } else if size == 1 { + // We can do a simple byte copy. + asm!( + "ldrb {2},[{0}]", + "strb {2},[{1}]", + in(reg) src, in(reg) dst, out(reg) _, + ); + } else { + // When we don't have an optimized path, we just disable IRQs. + crate::interrupt::free(|_| ptr::write_volatile(dst, ptr::read_volatile(src))); + } +} + +#[allow(unused_assignments)] +unsafe fn transfer_align4_thumb(mut dst: *mut T, mut src: *const T) { + let size = mem::size_of::(); + + if size <= 4 { + // We use assembly here regardless to just do the word aligned copy. This + // ensures it's done with a single ldr/str instruction. + asm!( + "ldr {2},[{0}]", + "str {2},[{1}]", + inout(reg) src, in(reg) dst, out(reg) _, + ); + } else if size <= 8 { + // Starting at size == 8, we begin using ldmia/stmia to load/save multiple + // words in one instruction, avoiding IRQs from interrupting our operation. + asm!( + "ldmia {0}!, {{r2-r3}}", + "stmia {1}!, {{r2-r3}}", + inout(reg) src, inout(reg) dst, + out("r2") _, out("r3") _, + ); + } else if size <= 12 { + asm!( + "ldmia {0}!, {{r2-r4}}", + "stmia {1}!, {{r2-r4}}", + inout(reg) src, inout(reg) dst, + out("r2") _, out("r3") _, out("r4") _, + ); + } else if size <= 16 { + asm!( + "ldmia {0}!, {{r2-r5}}", + "stmia {1}!, {{r2-r5}}", + inout(reg) src, inout(reg) dst, + out("r2") _, out("r3") _, out("r4") _, out("r5") _, + ); + } else { + unimplemented!("This should be done via transfer_arm."); + } +} + +#[instruction_set(arm::a32)] +#[allow(unused_assignments)] +unsafe fn transfer_align4_arm(mut dst: *mut T, mut src: *const T) { + let size = mem::size_of::(); + + if size <= 16 { + unimplemented!("This should be done via transfer_thumb."); + } else if size <= 20 { + // Starting at size == 16, we have to switch to ARM due to lack of + // accessible registers in THUMB mode. + asm!( + "ldmia {0}!, {{r2-r5,r7}}", + "stmia {1}!, {{r2-r5,r7}}", + inout(reg) src, inout(reg) dst, + out("r2") _, out("r3") _, out("r4") _, out("r5") _, out("r7") _, + ); + } else if size <= 24 { + asm!( + "ldmia {0}!, {{r2-r5,r7-r8}}", + "stmia {1}!, {{r2-r5,r7-r8}}", + inout(reg) src, inout(reg) dst, + out("r2") _, out("r3") _, out("r4") _, out("r5") _, out("r7") _, + out("r8") _, + ); + } else if size <= 28 { + asm!( + "ldmia {0}!, {{r2-r5,r7-r9}}", + "stmia {1}!, {{r2-r5,r7-r9}}", + inout(reg) src, inout(reg) dst, + out("r2") _, out("r3") _, out("r4") _, out("r5") _, out("r7") _, + out("r8") _, out("r9") _, + ); + } else if size <= 32 { + asm!( + "ldmia {0}!, {{r2-r5,r7-r10}}", + "stmia {1}!, {{r2-r5,r7-r10}}", + inout(reg) src, inout(reg) dst, + out("r2") _, out("r3") _, out("r4") _, out("r5") _, out("r7") _, + out("r8") _, out("r9") _, out("r10") _, + ); + } else if size <= 36 { + asm!( + "ldmia {0}!, {{r2-r5,r7-r10,r12}}", + "stmia {1}!, {{r2-r5,r7-r10,r12}}", + inout(reg) src, inout(reg) dst, + out("r2") _, out("r3") _, out("r4") _, out("r5") _, out("r7") _, + out("r8") _, out("r9") _, out("r10") _, out("r12") _, + ); + } else if size <= 40 { + asm!( + "ldmia {0}!, {{r2-r5,r7-r10,r12,r14}}", + "stmia {1}!, {{r2-r5,r7-r10,r12,r14}}", + inout(reg) src, inout(reg) dst, + out("r2") _, out("r3") _, out("r4") _, out("r5") _, out("r7") _, + out("r8") _, out("r9") _, out("r10") _, out("r12") _, out("r14") _, + ); + } else { + // r13 is sp, and r15 is pc. Neither are usable + unimplemented!("Copy too large for use of ldmia/stmia."); + } +} + +/// The internal function for swapping the current value of a [`Static`] with +/// another value. +unsafe fn exchange(dst: *mut T, src: *const T) -> T { + let align = mem::align_of::(); + let size = mem::size_of::(); + if size == 0 { + // Do nothing with ZSTs. + ptr::read(dst) + } else if size <= 4 && align % 4 == 0 { + // Swap a single word with the SWP instruction. + let val = ptr::read(src as *const u32); + let new_val = exchange_align4_arm(dst, val); + ptr::read(&new_val as *const _ as *const T) + } else if size == 1 { + // Swap a byte with the SWPB instruction. + let val = ptr::read(src as *const u8); + let new_val = exchange_align1_arm(dst, val); + ptr::read(&new_val as *const _ as *const T) + } else { + // fallback + crate::interrupt::free(|_| { + let cur = ptr::read_volatile(dst); + ptr::write_volatile(dst, ptr::read_volatile(src)); + cur + }) + } +} + +#[instruction_set(arm::a32)] +unsafe fn exchange_align4_arm(dst: *mut T, i: u32) -> u32 { + let out; + asm!("swp {2}, {1}, [{0}]", in(reg) dst, in(reg) i, lateout(reg) out); + out +} + +#[instruction_set(arm::a32)] +unsafe fn exchange_align1_arm(dst: *mut T, i: u8) -> u8 { + let out; + asm!("swpb {2}, {1}, [{0}]", in(reg) dst, in(reg) i, lateout(reg) out); + out +} + +/// A helper that implements static variables. +/// +/// It ensures that even if you use the same static variable in both an IRQ +/// and normal code, the IRQ will never observe an invalid value of the +/// variable. +/// +/// This type only works with owned values. If you need to work with borrows, +/// consider using [`sync::Mutex`](`crate::sync::Mutex`) instead. +/// +/// ## Performance +/// +/// Writing or reading from a static variable is efficient under the following +/// conditions: +/// +/// * The type is aligned to 4 bytes and can be stored in 40 bytes or less. +/// * The type is aligned to 2 bytes and can be stored in 2 bytes. +/// * The type is can be stored in a single byte. +/// +/// Replacing the current value of the static variable is efficient under the +/// following conditions: +/// +/// * The type is aligned to 4 bytes and can be stored in 4 bytes or less. +/// * The type is can be stored in a single byte. +/// +/// When these conditions are not met, static variables are handled using a +/// fallback routine that disables IRQs and does a normal copy. This can be +/// dangerous as disabling IRQs can cause your program to miss out on important +/// interrupts such as V-Blank. +/// +/// Consider using [`sync::Mutex`](`crate::sync::Mutex`) instead if you need to +/// use a large amount of operations that would cause IRQs to be disabled. Also +/// consider using `#[repr(align(4))]` to force proper alignment for your type. +pub struct Static { + data: UnsafeCell, +} +impl Static { + /// Creates a new static variable. + pub const fn new(val: T) -> Self { + Static { data: UnsafeCell::new(val) } + } + + /// Replaces the current value of the static variable with another, and + /// returns the old value. + #[allow(clippy::needless_pass_by_value)] // critical for safety + pub fn replace(&self, val: T) -> T { + unsafe { exchange(self.data.get(), &val) } + } + + /// Extracts the interior value of the static variable. + pub fn into_inner(self) -> T { + self.data.into_inner() + } +} +impl Static { + /// Writes a new value into this static variable. + pub fn write(&self, val: T) { + unsafe { transfer(self.data.get(), &val) } + } + + /// Reads a value from this static variable. + pub fn read(&self) -> T { + unsafe { + let mut out: mem::MaybeUninit = mem::MaybeUninit::uninit(); + transfer(out.as_mut_ptr(), self.data.get()); + out.assume_init() + } + } +} +impl Default for Static { + fn default() -> Self { + Static::new(T::default()) + } +} +unsafe impl Send for Static {} +unsafe impl Sync for Static {} + +#[cfg(test)] +mod test { + use crate::Gba; + use crate::interrupt::Interrupt; + use crate::sync::Static; + use crate::timer::Divider; + + fn write_read_concurrency_test_impl(gba: &mut Gba) { + let sentinel = [0x12345678; COUNT]; + let value: Static<[u32; COUNT]> = Static::new(sentinel); + + // set up a timer and an interrupt that uses the timer + let mut timer = gba.timers.timers().timer2; + timer.set_cascade(false); + timer.set_divider(Divider::Divider1); + timer.set_overflow_amount(1049); + timer.set_interrupt(true); + timer.set_enabled(true); + + let _int = crate::interrupt::add_interrupt_handler(Interrupt::Timer2, |_| { + value.write(sentinel); + }); + + // the actual main test loop + let mut interrupt_seen = false; + let mut no_interrupt_seen = false; + for i in 0..100000 { + // write to the static + let new_value = [i; COUNT]; + value.write(new_value); + + // check the current value + let current = value.read(); + if current == new_value { + no_interrupt_seen = true; + } else if current == sentinel { + interrupt_seen = true; + } else { + panic!("Unexpected value found in `Static`."); + } + + // we return as soon as we've seen both the value written by the main thread + // and interrupt + if interrupt_seen && no_interrupt_seen { + timer.set_enabled(false); + return + } + + if i % 8192 == 0 && i != 0 { + timer.set_overflow_amount(1049 + (i / 64) as u16); + } + } + panic!("Concurrency test timed out: {}", COUNT) + } + + #[test_case] + fn write_read_concurrency_test(gba: &mut Gba) { + write_read_concurrency_test_impl::<1>(gba); + write_read_concurrency_test_impl::<2>(gba); + write_read_concurrency_test_impl::<3>(gba); + write_read_concurrency_test_impl::<4>(gba); + write_read_concurrency_test_impl::<5>(gba); + write_read_concurrency_test_impl::<6>(gba); + write_read_concurrency_test_impl::<7>(gba); + write_read_concurrency_test_impl::<8>(gba); + write_read_concurrency_test_impl::<9>(gba); + write_read_concurrency_test_impl::<10>(gba); + } +} \ No newline at end of file