Skip to content

Commit

Permalink
Merge branch 'main' into simplify-muls
Browse files Browse the repository at this point in the history
  • Loading branch information
prestwich authored Jun 15, 2024
2 parents 4b4436e + bd1205b commit 6e25d56
Show file tree
Hide file tree
Showing 11 changed files with 81 additions and 147 deletions.
3 changes: 2 additions & 1 deletion deny.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ allow = [
"Unicode-DFS-2016",
"Unlicense",
"MPL-2.0",
"CC0-1.0"
"CC0-1.0",
"Unicode-3.0",
]

[sources]
Expand Down
2 changes: 1 addition & 1 deletion src/algorithms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub use self::{
add::{adc_n, sbb_n},
div::div,
gcd::{gcd, gcd_extended, inv_mod, LehmerMatrix},
mul::{add_nx1, addmul, addmul_n, addmul_nx1, addmul_ref, mul_nx1, submul_nx1},
mul::{add_nx1, addmul, addmul_n, addmul_nx1, mul_nx1, submul_nx1},
ops::{adc, sbb},
shift::{shift_left_small, shift_right_small},
};
Expand Down
88 changes: 43 additions & 45 deletions src/algorithms/mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,6 @@

use crate::algorithms::{ops::sbb, DoubleWord};

#[inline]
#[allow(clippy::cast_possible_truncation)] // Intentional truncation.
#[allow(dead_code)] // Used for testing
pub fn addmul_ref(result: &mut [u64], a: &[u64], b: &[u64]) -> bool {
let mut overflow = 0;
for (i, a) in a.iter().copied().enumerate() {
let mut result = result.iter_mut().skip(i);
let mut b = b.iter().copied();
let mut carry = 0_u128;
loop {
match (result.next(), b.next()) {
// Partial product.
(Some(result), Some(b)) => {
carry += u128::from(*result) + u128::from(a) * u128::from(b);
*result = carry as u64;
carry >>= 64;
}
// Carry propagation.
(Some(result), None) => {
carry += u128::from(*result);
*result = carry as u64;
carry >>= 64;
}
// Excess product.
(None, Some(b)) => {
carry += u128::from(a) * u128::from(b);
overflow |= carry as u64;
carry >>= 64;
}
// Fin.
(None, None) => {
break;
}
}
}
overflow |= carry as u64;
}
overflow != 0
}

/// ⚠️ Computes `result += a * b` and checks for overflow.
///
/// **Warning.** This function is not part of the stable API.
Expand All @@ -62,7 +22,7 @@ pub fn addmul_ref(result: &mut [u64], a: &[u64], b: &[u64]) -> bool {
/// assert_eq!(overflow, false);
/// assert_eq!(result, [12]);
/// ```
#[inline]
#[inline(always)]
pub fn addmul(mut lhs: &mut [u64], mut a: &[u64], mut b: &[u64]) -> bool {
// Trim zeros from `a`
while let [0, rest @ ..] = a {
Expand Down Expand Up @@ -116,7 +76,7 @@ pub fn addmul(mut lhs: &mut [u64], mut a: &[u64], mut b: &[u64]) -> bool {
}

/// Computes `lhs += a` and returns the carry.
#[inline]
#[inline(always)]
pub fn add_nx1(lhs: &mut [u64], mut a: u64) -> u64 {
if a == 0 {
return 0;
Expand Down Expand Up @@ -219,7 +179,7 @@ fn mac(lhs: &mut u64, a: u64, b: u64, c: u64) -> u64 {
}

/// Computes `lhs *= a` and returns the carry.
#[inline]
#[inline(always)]
pub fn mul_nx1(lhs: &mut [u64], a: u64) -> u64 {
let mut carry = 0;
for lhs in lhs {
Expand All @@ -238,7 +198,7 @@ pub fn mul_nx1(lhs: &mut [u64], a: u64) -> u64 {
/// \\\\ \mathsf{carry} &= \floor{\frac{\mathsf{lhs} + \mathsf{a} ⋅ \mathsf{b}
/// }{2^{64⋅N}}} \end{aligned}
/// $$
#[inline]
#[inline(always)]
pub fn addmul_nx1(lhs: &mut [u64], a: &[u64], b: u64) -> u64 {
assume!(lhs.len() == a.len());
let mut carry = 0;
Expand All @@ -259,7 +219,7 @@ pub fn addmul_nx1(lhs: &mut [u64], a: &[u64], b: u64) -> u64 {
/// \mathsf{lhs}}{2^{64⋅N}}} \end{aligned}
/// $$
// OPT: `carry` and `borrow` can probably be merged into a single var.
#[inline]
#[inline(always)]
pub fn submul_nx1(lhs: &mut [u64], a: &[u64], b: u64) -> u64 {
assume!(lhs.len() == a.len());
let mut carry = 0;
Expand All @@ -280,6 +240,44 @@ mod tests {
use super::*;
use proptest::{collection, num::u64, proptest};

#[allow(clippy::cast_possible_truncation)] // Intentional truncation.
fn addmul_ref(result: &mut [u64], a: &[u64], b: &[u64]) -> bool {
let mut overflow = 0;
for (i, a) in a.iter().copied().enumerate() {
let mut result = result.iter_mut().skip(i);
let mut b = b.iter().copied();
let mut carry = 0_u128;
loop {
match (result.next(), b.next()) {
// Partial product.
(Some(result), Some(b)) => {
carry += u128::from(*result) + u128::from(a) * u128::from(b);
*result = carry as u64;
carry >>= 64;
}
// Carry propagation.
(Some(result), None) => {
carry += u128::from(*result);
*result = carry as u64;
carry >>= 64;
}
// Excess product.
(None, Some(b)) => {
carry += u128::from(a) * u128::from(b);
overflow |= carry as u64;
carry >>= 64;
}
// Fin.
(None, None) => {
break;
}
}
}
overflow |= carry as u64;
}
overflow != 0
}

#[test]
fn test_addmul() {
let any_vec = collection::vec(u64::ANY, 0..10);
Expand Down
86 changes: 19 additions & 67 deletions src/bits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -254,53 +254,22 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
/// the shift is larger than `BITS` (which is IMHO not very useful).
#[inline]
#[must_use]
pub fn overflowing_shl(mut self, rhs: usize) -> (Self, bool) {
pub fn overflowing_shl(self, rhs: usize) -> (Self, bool) {
let (limbs, bits) = (rhs / 64, rhs % 64);
if limbs >= LIMBS {
return (Self::ZERO, self != Self::ZERO);
}
if bits == 0 {
// Check for overflow
let mut overflow = false;
for i in (LIMBS - limbs)..LIMBS {
overflow |= self.limbs[i] != 0;
}
if self.limbs[LIMBS - limbs - 1] > Self::MASK {
overflow = true;
}

// Shift
for i in (limbs..LIMBS).rev() {
assume!(i >= limbs && i - limbs < LIMBS);
self.limbs[i] = self.limbs[i - limbs];
}
self.limbs[..limbs].fill(0);
self.limbs[LIMBS - 1] &= Self::MASK;
return (self, overflow);
}

// Check for overflow
let mut overflow = false;
for i in (LIMBS - limbs)..LIMBS {
overflow |= self.limbs[i] != 0;
}
if self.limbs[LIMBS - limbs - 1] >> (64 - bits) != 0 {
overflow = true;
let word_bits = 64;
let mut r = Self::ZERO;
let mut carry = 0;
for i in 0..Self::LIMBS - limbs {
let x = self.limbs[i];
r.limbs[i + limbs] = (x << bits) | carry;
carry = (x >> (word_bits - bits - 1)) >> 1;
}
if self.limbs[LIMBS - limbs - 1] << bits > Self::MASK {
overflow = true;
}

// Shift
for i in (limbs + 1..LIMBS).rev() {
assume!(i - limbs < LIMBS && i - limbs - 1 < LIMBS);
self.limbs[i] = self.limbs[i - limbs] << bits;
self.limbs[i] |= self.limbs[i - limbs - 1] >> (64 - bits);
}
self.limbs[limbs] = self.limbs[0] << bits;
self.limbs[..limbs].fill(0);
self.limbs[LIMBS - 1] &= Self::MASK;
(self, overflow)
r.limbs[LIMBS - 1] &= Self::MASK;
(r, carry != 0)
}

/// Left shift by `rhs` bits.
Expand Down Expand Up @@ -349,38 +318,21 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
/// the shift is larger than `BITS` (which is IMHO not very useful).
#[inline]
#[must_use]
pub fn overflowing_shr(mut self, rhs: usize) -> (Self, bool) {
pub fn overflowing_shr(self, rhs: usize) -> (Self, bool) {
let (limbs, bits) = (rhs / 64, rhs % 64);
if limbs >= LIMBS {
return (Self::ZERO, self != Self::ZERO);
}
if bits == 0 {
// Check for overflow
let mut overflow = false;
for i in 0..limbs {
overflow |= self.limbs[i] != 0;
}

// Shift
for i in 0..(LIMBS - limbs) {
self.limbs[i] = self.limbs[i + limbs];
}
self.limbs[LIMBS - limbs..].fill(0);
return (self, overflow);
}

// Check for overflow
let overflow = self.limbs[LIMBS - limbs - 1] >> (bits - 1) & 1 != 0;

// Shift
for i in 0..(LIMBS - limbs - 1) {
assume!(i + limbs < LIMBS && i + limbs + 1 < LIMBS);
self.limbs[i] = self.limbs[i + limbs] >> bits;
self.limbs[i] |= self.limbs[i + limbs + 1] << (64 - bits);
let word_bits = 64;
let mut r = Self::ZERO;
let mut carry = 0;
for i in 0..LIMBS - limbs {
let x = self.limbs[LIMBS - 1 - i];
r.limbs[LIMBS - 1 - i - limbs] = (x >> bits) | carry;
carry = (x << (word_bits - bits - 1)) << 1;
}
self.limbs[LIMBS - limbs - 1] = self.limbs[LIMBS - 1] >> bits;
self.limbs[LIMBS - limbs..].fill(0);
(self, overflow)
(r, carry != 0)
}

/// Right shift by `rhs` bits.
Expand Down
8 changes: 3 additions & 5 deletions src/div.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
#[must_use]
#[allow(clippy::missing_const_for_fn)] // False positive
pub fn checked_div(self, rhs: Self) -> Option<Self> {
if rhs == Self::ZERO {
if rhs.is_zero() {
return None;
}
Some(self.div(rhs))
Expand All @@ -18,7 +18,7 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
#[must_use]
#[allow(clippy::missing_const_for_fn)] // False positive
pub fn checked_rem(self, rhs: Self) -> Option<Self> {
if rhs == Self::ZERO {
if rhs.is_zero() {
return None;
}
Some(self.rem(rhs))
Expand All @@ -33,9 +33,8 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
#[must_use]
#[track_caller]
pub fn div_ceil(self, rhs: Self) -> Self {
assert!(rhs != Self::ZERO, "Division by zero");
let (q, r) = self.div_rem(rhs);
if r == Self::ZERO {
if r.is_zero() {
q
} else {
q + Self::from(1)
Expand All @@ -51,7 +50,6 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
#[must_use]
#[track_caller]
pub fn div_rem(mut self, mut rhs: Self) -> (Self, Self) {
assert!(rhs != Self::ZERO, "Division by zero");
algorithms::div(&mut self.limbs, &mut rhs.limbs);
(self, rhs)
}
Expand Down
2 changes: 1 addition & 1 deletion src/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
#[inline]
#[must_use]
pub fn checked_log(self, base: Self) -> Option<usize> {
if base < Self::from(2) || self == Self::ZERO {
if base < Self::from(2) || self.is_zero() {
return None;
}
Some(self.log(base))
Expand Down
2 changes: 2 additions & 0 deletions src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ macro_rules! impl_bin_op {
};
}

#[allow(unused)]
macro_rules! assume {
($e:expr $(,)?) => {
if !$e {
Expand All @@ -89,6 +90,7 @@ macro_rules! assume {
};
}

#[allow(unused)]
macro_rules! debug_unreachable {
($($t:tt)*) => {
if cfg!(debug_assertions) {
Expand Down
6 changes: 3 additions & 3 deletions src/modular.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
#[inline]
#[must_use]
pub fn reduce_mod(mut self, modulus: Self) -> Self {
if modulus == Self::ZERO {
if modulus.is_zero() {
return Self::ZERO;
}
if self >= modulus {
Expand Down Expand Up @@ -53,7 +53,7 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
#[inline]
#[must_use]
pub fn mul_mod(self, rhs: Self, mut modulus: Self) -> Self {
if modulus == Self::ZERO {
if modulus.is_zero() {
return Self::ZERO;
}

Expand Down Expand Up @@ -84,7 +84,7 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
#[inline]
#[must_use]
pub fn pow_mod(mut self, mut exp: Self, modulus: Self) -> Self {
if modulus == Self::ZERO || modulus <= Self::from(1) {
if modulus.is_zero() || modulus <= Self::from(1) {
// Also covers Self::BITS == 0
return Self::ZERO;
}
Expand Down
2 changes: 1 addition & 1 deletion src/root.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
assert!(degree > 0, "degree must be greater than zero");

// Handle zero case (including BITS == 0).
if self == Self::ZERO {
if self.is_zero() {
return Self::ZERO;
}

Expand Down
4 changes: 2 additions & 2 deletions src/special.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,11 @@ impl<const BITS: usize, const LIMBS: usize> Uint<BITS, LIMBS> {
#[inline]
#[must_use]
pub fn checked_next_multiple_of(self, rhs: Self) -> Option<Self> {
if rhs == Self::ZERO {
if rhs.is_zero() {
return None;
}
let (q, r) = self.div_rem(rhs);
if r == Self::ZERO {
if r.is_zero() {
return Some(self);
}
let q = q.checked_add(Self::from(1))?;
Expand Down
Loading

0 comments on commit 6e25d56

Please sign in to comment.