add optimisation for 64 bit upcast multiply

This commit is contained in:
Corwin 2023-06-06 23:00:17 +01:00
parent 28e3a7faf4
commit 2a21c5fdab
No known key found for this signature in database
2 changed files with 49 additions and 7 deletions

View file

@ -91,7 +91,7 @@ pub trait FixedWidthSignedInteger: FixedWidthUnsignedInteger + Neg<Output = Self
}
macro_rules! fixed_width_unsigned_integer_impl {
($T: ty, $Upcast: ty) => {
($T: ty, $Upcast: ident) => {
impl FixedWidthUnsignedInteger for $T {
#[inline(always)]
fn zero() -> Self {
@ -109,10 +109,37 @@ macro_rules! fixed_width_unsigned_integer_impl {
fn from_as_i32(v: i32) -> Self {
v as $T
}
#[inline(always)]
fn upcast_multiply(a: Self, b: Self, n: usize) -> Self {
(((a as $Upcast) * (b as $Upcast)) >> n) as $T
}
upcast_multiply_impl!($T, $Upcast);
}
};
}
macro_rules! upcast_multiply_impl {
($T: ty, optimised_64_bit) => {
#[inline(always)]
fn upcast_multiply(a: Self, b: Self, n: usize) -> Self {
let mask = (Self::one() << n).wrapping_sub(1);
let a_floor = a >> n;
let a_frac = a & mask;
let b_floor = b >> n;
let b_frac = b & mask;
(a_floor.wrapping_mul(b_floor) << n)
.wrapping_add(
a_floor
.wrapping_mul(b_frac)
.wrapping_add(b_floor.wrapping_mul(a_frac)),
)
.wrapping_add(a_frac.wrapping_mul(b_frac) >> n)
}
};
($T: ty, $Upcast: ty) => {
#[inline(always)]
fn upcast_multiply(a: Self, b: Self, n: usize) -> Self {
(((a as $Upcast) * (b as $Upcast)) >> n) as $T
}
};
}
@ -131,8 +158,9 @@ macro_rules! fixed_width_signed_integer_impl {
fixed_width_unsigned_integer_impl!(u8, u32);
fixed_width_unsigned_integer_impl!(i16, i32);
fixed_width_unsigned_integer_impl!(u16, u32);
fixed_width_unsigned_integer_impl!(i32, i64);
fixed_width_unsigned_integer_impl!(u32, u64);
fixed_width_unsigned_integer_impl!(i32, optimised_64_bit);
fixed_width_unsigned_integer_impl!(u32, optimised_64_bit);
fixed_width_signed_integer_impl!(i16);
fixed_width_signed_integer_impl!(i32);
@ -1370,4 +1398,17 @@ mod tests {
]
);
}
#[cfg(not(debug_assertions))]
#[test]
fn test_all_multiplies() {
use super::*;
for i in 0..u32::MAX {
let fix_num: Num<_, 7> = Num::from_raw(i);
let upcasted = ((i as u64 * i as u64) >> 7) as u32;
assert_eq!((fix_num * fix_num).to_raw(), upcasted);
}
}
}

View file

@ -20,6 +20,7 @@ test:
test-release:
just _test-release agb
just _test-release agb-fixnum
just _test-release-arm agb
doctest-agb: