diff --git a/agb-fixnum/src/lib.rs b/agb-fixnum/src/lib.rs index be5a97be..f058b1b5 100644 --- a/agb-fixnum/src/lib.rs +++ b/agb-fixnum/src/lib.rs @@ -91,7 +91,7 @@ pub trait FixedWidthSignedInteger: FixedWidthUnsignedInteger + Neg { + ($T: ty, $Upcast: ident) => { impl FixedWidthUnsignedInteger for $T { #[inline(always)] fn zero() -> Self { @@ -109,10 +109,37 @@ macro_rules! fixed_width_unsigned_integer_impl { fn from_as_i32(v: i32) -> Self { v as $T } - #[inline(always)] - fn upcast_multiply(a: Self, b: Self, n: usize) -> Self { - (((a as $Upcast) * (b as $Upcast)) >> n) as $T - } + + upcast_multiply_impl!($T, $Upcast); + } + }; +} + +macro_rules! upcast_multiply_impl { + ($T: ty, optimised_64_bit) => { + #[inline(always)] + fn upcast_multiply(a: Self, b: Self, n: usize) -> Self { + let mask = (Self::one() << n).wrapping_sub(1); + + let a_floor = a >> n; + let a_frac = a & mask; + + let b_floor = b >> n; + let b_frac = b & mask; + + (a_floor.wrapping_mul(b_floor) << n) + .wrapping_add( + a_floor + .wrapping_mul(b_frac) + .wrapping_add(b_floor.wrapping_mul(a_frac)), + ) + .wrapping_add(a_frac.wrapping_mul(b_frac) >> n) + } + }; + ($T: ty, $Upcast: ty) => { + #[inline(always)] + fn upcast_multiply(a: Self, b: Self, n: usize) -> Self { + (((a as $Upcast) * (b as $Upcast)) >> n) as $T } }; } @@ -131,8 +158,9 @@ macro_rules! fixed_width_signed_integer_impl { fixed_width_unsigned_integer_impl!(u8, u32); fixed_width_unsigned_integer_impl!(i16, i32); fixed_width_unsigned_integer_impl!(u16, u32); -fixed_width_unsigned_integer_impl!(i32, i64); -fixed_width_unsigned_integer_impl!(u32, u64); + +fixed_width_unsigned_integer_impl!(i32, optimised_64_bit); +fixed_width_unsigned_integer_impl!(u32, optimised_64_bit); fixed_width_signed_integer_impl!(i16); fixed_width_signed_integer_impl!(i32); @@ -1370,4 +1398,17 @@ mod tests { ] ); } + + #[cfg(not(debug_assertions))] + #[test] + fn test_all_multiplies() { + use super::*; + + for i in 0..u32::MAX { + let fix_num: Num<_, 7> = Num::from_raw(i); + let upcasted = ((i as u64 * i as u64) >> 7) as u32; + + assert_eq!((fix_num * fix_num).to_raw(), upcasted); + } + } } diff --git a/justfile b/justfile index d9a17a3d..05a12dce 100644 --- a/justfile +++ b/justfile @@ -20,6 +20,7 @@ test: test-release: just _test-release agb + just _test-release agb-fixnum just _test-release-arm agb doctest-agb: