From 5e32a0cb658781ebe1ac09786b30c35db1ea5185 Mon Sep 17 00:00:00 2001 From: DaniPopes <57450786+DaniPopes@users.noreply.github.com> Date: Sat, 15 Jun 2024 10:55:19 +0200 Subject: [PATCH 1/3] chore: use is_zero more --- src/div.rs | 6 +++--- src/log.rs | 2 +- src/modular.rs | 6 +++--- src/root.rs | 2 +- src/special.rs | 4 ++-- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/div.rs b/src/div.rs index ad6b757e..36b3e57a 100644 --- a/src/div.rs +++ b/src/div.rs @@ -7,7 +7,7 @@ impl Uint { #[must_use] #[allow(clippy::missing_const_for_fn)] // False positive pub fn checked_div(self, rhs: Self) -> Option { - if rhs == Self::ZERO { + if rhs.is_zero() { return None; } Some(self.div(rhs)) @@ -18,7 +18,7 @@ impl Uint { #[must_use] #[allow(clippy::missing_const_for_fn)] // False positive pub fn checked_rem(self, rhs: Self) -> Option { - if rhs == Self::ZERO { + if rhs.is_zero() { return None; } Some(self.rem(rhs)) @@ -35,7 +35,7 @@ impl Uint { pub fn div_ceil(self, rhs: Self) -> Self { assert!(rhs != Self::ZERO, "Division by zero"); let (q, r) = self.div_rem(rhs); - if r == Self::ZERO { + if r.is_zero() { q } else { q + Self::from(1) diff --git a/src/log.rs b/src/log.rs index 84e0fa5e..3c0bb685 100644 --- a/src/log.rs +++ b/src/log.rs @@ -9,7 +9,7 @@ impl Uint { #[inline] #[must_use] pub fn checked_log(self, base: Self) -> Option { - if base < Self::from(2) || self == Self::ZERO { + if base < Self::from(2) || self.is_zero() { return None; } Some(self.log(base)) diff --git a/src/modular.rs b/src/modular.rs index 3970b520..4969ccf6 100644 --- a/src/modular.rs +++ b/src/modular.rs @@ -17,7 +17,7 @@ impl Uint { #[inline] #[must_use] pub fn reduce_mod(mut self, modulus: Self) -> Self { - if modulus == Self::ZERO { + if modulus.is_zero() { return Self::ZERO; } if self >= modulus { @@ -53,7 +53,7 @@ impl Uint { #[inline] #[must_use] pub fn mul_mod(self, rhs: Self, mut modulus: Self) -> Self { - if modulus == Self::ZERO { + if modulus.is_zero() { return Self::ZERO; } @@ -84,7 +84,7 @@ impl Uint { #[inline] #[must_use] pub fn pow_mod(mut self, mut exp: Self, modulus: Self) -> Self { - if modulus == Self::ZERO || modulus <= Self::from(1) { + if modulus.is_zero() || modulus <= Self::from(1) { // Also covers Self::BITS == 0 return Self::ZERO; } diff --git a/src/root.rs b/src/root.rs index 695218a3..d99b7274 100644 --- a/src/root.rs +++ b/src/root.rs @@ -31,7 +31,7 @@ impl Uint { assert!(degree > 0, "degree must be greater than zero"); // Handle zero case (including BITS == 0). - if self == Self::ZERO { + if self.is_zero() { return Self::ZERO; } diff --git a/src/special.rs b/src/special.rs index cda96890..11fa8b93 100644 --- a/src/special.rs +++ b/src/special.rs @@ -109,11 +109,11 @@ impl Uint { #[inline] #[must_use] pub fn checked_next_multiple_of(self, rhs: Self) -> Option { - if rhs == Self::ZERO { + if rhs.is_zero() { return None; } let (q, r) = self.div_rem(rhs); - if r == Self::ZERO { + if r.is_zero() { return Some(self); } let q = q.checked_add(Self::from(1))?; From 2c51325f27c0b264579a8fe9e9ca4820544b6df2 Mon Sep 17 00:00:00 2001 From: DaniPopes <57450786+DaniPopes@users.noreply.github.com> Date: Sat, 15 Jun 2024 10:56:14 +0200 Subject: [PATCH 2/3] perf: remove unnecessary assertions in div_* --- src/div.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/div.rs b/src/div.rs index ad6b757e..a0a09717 100644 --- a/src/div.rs +++ b/src/div.rs @@ -33,7 +33,6 @@ impl Uint { #[must_use] #[track_caller] pub fn div_ceil(self, rhs: Self) -> Self { - assert!(rhs != Self::ZERO, "Division by zero"); let (q, r) = self.div_rem(rhs); if r == Self::ZERO { q @@ -51,7 +50,6 @@ impl Uint { #[must_use] #[track_caller] pub fn div_rem(mut self, mut rhs: Self) -> (Self, Self) { - assert!(rhs != Self::ZERO, "Division by zero"); algorithms::div(&mut self.limbs, &mut rhs.limbs); (self, rhs) } From 898a95571018321f10c0c9731ab1f2b1a00cf601 Mon Sep 17 00:00:00 2001 From: DaniPopes <57450786+DaniPopes@users.noreply.github.com> Date: Sat, 15 Jun 2024 11:06:11 +0200 Subject: [PATCH 3/3] perf: re-inline always addmul --- src/algorithms/mod.rs | 2 +- src/algorithms/mul.rs | 88 +++++++++++++++++++++---------------------- 2 files changed, 44 insertions(+), 46 deletions(-) diff --git a/src/algorithms/mod.rs b/src/algorithms/mod.rs index 69e09e2a..f056e405 100644 --- a/src/algorithms/mod.rs +++ b/src/algorithms/mod.rs @@ -20,7 +20,7 @@ pub use self::{ add::{adc_n, sbb_n}, div::div, gcd::{gcd, gcd_extended, inv_mod, LehmerMatrix}, - mul::{add_nx1, addmul, addmul_n, addmul_nx1, addmul_ref, mul_nx1, submul_nx1}, + mul::{add_nx1, addmul, addmul_n, addmul_nx1, mul_nx1, submul_nx1}, ops::{adc, sbb}, shift::{shift_left_small, shift_right_small}, }; diff --git a/src/algorithms/mul.rs b/src/algorithms/mul.rs index 2c5436ed..04f278a6 100644 --- a/src/algorithms/mul.rs +++ b/src/algorithms/mul.rs @@ -2,46 +2,6 @@ use crate::algorithms::{ops::sbb, DoubleWord}; -#[inline] -#[allow(clippy::cast_possible_truncation)] // Intentional truncation. -#[allow(dead_code)] // Used for testing -pub fn addmul_ref(result: &mut [u64], a: &[u64], b: &[u64]) -> bool { - let mut overflow = 0; - for (i, a) in a.iter().copied().enumerate() { - let mut result = result.iter_mut().skip(i); - let mut b = b.iter().copied(); - let mut carry = 0_u128; - loop { - match (result.next(), b.next()) { - // Partial product. - (Some(result), Some(b)) => { - carry += u128::from(*result) + u128::from(a) * u128::from(b); - *result = carry as u64; - carry >>= 64; - } - // Carry propagation. - (Some(result), None) => { - carry += u128::from(*result); - *result = carry as u64; - carry >>= 64; - } - // Excess product. - (None, Some(b)) => { - carry += u128::from(a) * u128::from(b); - overflow |= carry as u64; - carry >>= 64; - } - // Fin. - (None, None) => { - break; - } - } - } - overflow |= carry as u64; - } - overflow != 0 -} - /// ⚠️ Computes `result += a * b` and checks for overflow. /// /// **Warning.** This function is not part of the stable API. @@ -62,7 +22,7 @@ pub fn addmul_ref(result: &mut [u64], a: &[u64], b: &[u64]) -> bool { /// assert_eq!(overflow, false); /// assert_eq!(result, [12]); /// ``` -#[inline] +#[inline(always)] pub fn addmul(mut lhs: &mut [u64], mut a: &[u64], mut b: &[u64]) -> bool { // Trim zeros from `a` while let [0, rest @ ..] = a { @@ -116,7 +76,7 @@ pub fn addmul(mut lhs: &mut [u64], mut a: &[u64], mut b: &[u64]) -> bool { } /// Computes `lhs += a` and returns the carry. -#[inline] +#[inline(always)] pub fn add_nx1(lhs: &mut [u64], mut a: u64) -> u64 { if a == 0 { return 0; @@ -223,7 +183,7 @@ fn mac(lhs: &mut u64, a: u64, b: u64, c: u64) -> u64 { } /// Computes `lhs *= a` and returns the carry. -#[inline] +#[inline(always)] pub fn mul_nx1(lhs: &mut [u64], a: u64) -> u64 { let mut carry = 0; for lhs in &mut *lhs { @@ -244,7 +204,7 @@ pub fn mul_nx1(lhs: &mut [u64], a: u64) -> u64 { /// \\\\ \mathsf{carry} &= \floor{\frac{\mathsf{lhs} + \mathsf{a} ⋅ \mathsf{b} /// }{2^{64⋅N}}} \end{aligned} /// $$ -#[inline] +#[inline(always)] pub fn addmul_nx1(lhs: &mut [u64], a: &[u64], b: u64) -> u64 { debug_assert_eq!(lhs.len(), a.len()); let mut carry = 0; @@ -267,7 +227,7 @@ pub fn addmul_nx1(lhs: &mut [u64], a: &[u64], b: u64) -> u64 { /// \mathsf{lhs}}{2^{64⋅N}}} \end{aligned} /// $$ // OPT: `carry` and `borrow` can probably be merged into a single var. -#[inline] +#[inline(always)] pub fn submul_nx1(lhs: &mut [u64], a: &[u64], b: u64) -> u64 { debug_assert_eq!(lhs.len(), a.len()); let mut carry = 0; @@ -293,6 +253,44 @@ mod tests { use super::*; use proptest::{collection, num::u64, proptest}; + #[allow(clippy::cast_possible_truncation)] // Intentional truncation. + fn addmul_ref(result: &mut [u64], a: &[u64], b: &[u64]) -> bool { + let mut overflow = 0; + for (i, a) in a.iter().copied().enumerate() { + let mut result = result.iter_mut().skip(i); + let mut b = b.iter().copied(); + let mut carry = 0_u128; + loop { + match (result.next(), b.next()) { + // Partial product. + (Some(result), Some(b)) => { + carry += u128::from(*result) + u128::from(a) * u128::from(b); + *result = carry as u64; + carry >>= 64; + } + // Carry propagation. + (Some(result), None) => { + carry += u128::from(*result); + *result = carry as u64; + carry >>= 64; + } + // Excess product. + (None, Some(b)) => { + carry += u128::from(a) * u128::from(b); + overflow |= carry as u64; + carry >>= 64; + } + // Fin. + (None, None) => { + break; + } + } + } + overflow |= carry as u64; + } + overflow != 0 + } + #[test] fn test_addmul() { let any_vec = collection::vec(u64::ANY, 0..10);