Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: re-inline always addmul #381

Merged
merged 2 commits into from
Jun 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/algorithms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub use self::{
add::{adc_n, sbb_n},
div::div,
gcd::{gcd, gcd_extended, inv_mod, LehmerMatrix},
mul::{add_nx1, addmul, addmul_n, addmul_nx1, addmul_ref, mul_nx1, submul_nx1},
mul::{add_nx1, addmul, addmul_n, addmul_nx1, mul_nx1, submul_nx1},
ops::{adc, sbb},
shift::{shift_left_small, shift_right_small},
};
Expand Down
88 changes: 43 additions & 45 deletions src/algorithms/mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,6 @@

use crate::algorithms::{ops::sbb, DoubleWord};

#[inline]
#[allow(clippy::cast_possible_truncation)] // Intentional truncation.
#[allow(dead_code)] // Used for testing
pub fn addmul_ref(result: &mut [u64], a: &[u64], b: &[u64]) -> bool {
let mut overflow = 0;
for (i, a) in a.iter().copied().enumerate() {
let mut result = result.iter_mut().skip(i);
let mut b = b.iter().copied();
let mut carry = 0_u128;
loop {
match (result.next(), b.next()) {
// Partial product.
(Some(result), Some(b)) => {
carry += u128::from(*result) + u128::from(a) * u128::from(b);
*result = carry as u64;
carry >>= 64;
}
// Carry propagation.
(Some(result), None) => {
carry += u128::from(*result);
*result = carry as u64;
carry >>= 64;
}
// Excess product.
(None, Some(b)) => {
carry += u128::from(a) * u128::from(b);
overflow |= carry as u64;
carry >>= 64;
}
// Fin.
(None, None) => {
break;
}
}
}
overflow |= carry as u64;
}
overflow != 0
}

/// ⚠️ Computes `result += a * b` and checks for overflow.
///
/// **Warning.** This function is not part of the stable API.
Expand All @@ -62,7 +22,7 @@ pub fn addmul_ref(result: &mut [u64], a: &[u64], b: &[u64]) -> bool {
/// assert_eq!(overflow, false);
/// assert_eq!(result, [12]);
/// ```
#[inline]
#[inline(always)]
pub fn addmul(mut lhs: &mut [u64], mut a: &[u64], mut b: &[u64]) -> bool {
// Trim zeros from `a`
while let [0, rest @ ..] = a {
Expand Down Expand Up @@ -116,7 +76,7 @@ pub fn addmul(mut lhs: &mut [u64], mut a: &[u64], mut b: &[u64]) -> bool {
}

/// Computes `lhs += a` and returns the carry.
#[inline]
#[inline(always)]
pub fn add_nx1(lhs: &mut [u64], mut a: u64) -> u64 {
if a == 0 {
return 0;
Expand Down Expand Up @@ -223,7 +183,7 @@ fn mac(lhs: &mut u64, a: u64, b: u64, c: u64) -> u64 {
}

/// Computes `lhs *= a` and returns the carry.
#[inline]
#[inline(always)]
pub fn mul_nx1(lhs: &mut [u64], a: u64) -> u64 {
let mut carry = 0;
for lhs in &mut *lhs {
Expand All @@ -244,7 +204,7 @@ pub fn mul_nx1(lhs: &mut [u64], a: u64) -> u64 {
/// \\\\ \mathsf{carry} &= \floor{\frac{\mathsf{lhs} + \mathsf{a} ⋅ \mathsf{b}
/// }{2^{64⋅N}}} \end{aligned}
/// $$
#[inline]
#[inline(always)]
pub fn addmul_nx1(lhs: &mut [u64], a: &[u64], b: u64) -> u64 {
debug_assert_eq!(lhs.len(), a.len());
let mut carry = 0;
Expand All @@ -267,7 +227,7 @@ pub fn addmul_nx1(lhs: &mut [u64], a: &[u64], b: u64) -> u64 {
/// \mathsf{lhs}}{2^{64⋅N}}} \end{aligned}
/// $$
// OPT: `carry` and `borrow` can probably be merged into a single var.
#[inline]
#[inline(always)]
pub fn submul_nx1(lhs: &mut [u64], a: &[u64], b: u64) -> u64 {
debug_assert_eq!(lhs.len(), a.len());
let mut carry = 0;
Expand All @@ -293,6 +253,44 @@ mod tests {
use super::*;
use proptest::{collection, num::u64, proptest};

#[allow(clippy::cast_possible_truncation)] // Intentional truncation.
fn addmul_ref(result: &mut [u64], a: &[u64], b: &[u64]) -> bool {
let mut overflow = 0;
for (i, a) in a.iter().copied().enumerate() {
let mut result = result.iter_mut().skip(i);
let mut b = b.iter().copied();
let mut carry = 0_u128;
loop {
match (result.next(), b.next()) {
// Partial product.
(Some(result), Some(b)) => {
carry += u128::from(*result) + u128::from(a) * u128::from(b);
*result = carry as u64;
carry >>= 64;
}
// Carry propagation.
(Some(result), None) => {
carry += u128::from(*result);
*result = carry as u64;
carry >>= 64;
}
// Excess product.
(None, Some(b)) => {
carry += u128::from(a) * u128::from(b);
overflow |= carry as u64;
carry >>= 64;
}
// Fin.
(None, None) => {
break;
}
}
}
overflow |= carry as u64;
}
overflow != 0
}

#[test]
fn test_addmul() {
let any_vec = collection::vec(u64::ANY, 0..10);
Expand Down
Loading