From 4b4436e551ce06446fff154547736925b4212624 Mon Sep 17 00:00:00 2001 From: DaniPopes <57450786+DaniPopes@users.noreply.github.com> Date: Sat, 15 Jun 2024 11:44:18 +0200 Subject: [PATCH] chore: simplify algorithms::mul* --- src/algorithms/mul.rs | 61 +++++++++++++++++-------------------------- 1 file changed, 24 insertions(+), 37 deletions(-) diff --git a/src/algorithms/mul.rs b/src/algorithms/mul.rs index 2c5436e..16e47a8 100644 --- a/src/algorithms/mul.rs +++ b/src/algorithms/mul.rs @@ -122,9 +122,7 @@ pub fn add_nx1(lhs: &mut [u64], mut a: u64) -> u64 { return 0; } for lhs in lhs { - let sum = u128::add(*lhs, a); - *lhs = sum.low(); - a = sum.high(); + (*lhs, a) = u128::add(*lhs, a).split(); if a == 0 { return 0; } @@ -147,18 +145,16 @@ pub fn addmul_n(lhs: &mut [u64], a: &[u64], b: &[u64]) { 2 => addmul_2(lhs, a, b), 3 => addmul_3(lhs, a, b), 4 => addmul_4(lhs, a, b), - _ => { - let _ = addmul(lhs, a, b); - } + _ => _ = addmul(lhs, a, b), } } /// Computes `lhs += a * b` for 1 limb. #[inline(always)] fn addmul_1(lhs: &mut [u64], a: &[u64], b: &[u64]) { - assert_eq!(lhs.len(), 1); - assert_eq!(a.len(), 1); - assert_eq!(b.len(), 1); + assume!(lhs.len() == 1); + assume!(a.len() == 1); + assume!(b.len() == 1); mac(&mut lhs[0], a[0], b[0], 0); } @@ -166,9 +162,9 @@ fn addmul_1(lhs: &mut [u64], a: &[u64], b: &[u64]) { /// Computes `lhs += a * b` for 2 limbs. #[inline(always)] fn addmul_2(lhs: &mut [u64], a: &[u64], b: &[u64]) { - assert_eq!(lhs.len(), 2); - assert_eq!(a.len(), 2); - assert_eq!(b.len(), 2); + assume!(lhs.len() == 2); + assume!(a.len() == 2); + assume!(b.len() == 2); let carry = mac(&mut lhs[0], a[0], b[0], 0); mac(&mut lhs[1], a[0], b[1], carry); @@ -179,9 +175,9 @@ fn addmul_2(lhs: &mut [u64], a: &[u64], b: &[u64]) { /// Computes `lhs += a * b` for 3 limbs. #[inline(always)] fn addmul_3(lhs: &mut [u64], a: &[u64], b: &[u64]) { - assert_eq!(lhs.len(), 3); - assert_eq!(a.len(), 3); - assert_eq!(b.len(), 3); + assume!(lhs.len() == 3); + assume!(a.len() == 3); + assume!(b.len() == 3); let carry = mac(&mut lhs[0], a[0], b[0], 0); let carry = mac(&mut lhs[1], a[0], b[1], carry); @@ -196,9 +192,9 @@ fn addmul_3(lhs: &mut [u64], a: &[u64], b: &[u64]) { /// Computes `lhs += a * b` for 4 limbs. #[inline(always)] fn addmul_4(lhs: &mut [u64], a: &[u64], b: &[u64]) { - assert_eq!(lhs.len(), 4); - assert_eq!(a.len(), 4); - assert_eq!(b.len(), 4); + assume!(lhs.len() == 4); + assume!(a.len() == 4); + assume!(b.len() == 4); let carry = mac(&mut lhs[0], a[0], b[0], 0); let carry = mac(&mut lhs[1], a[0], b[1], carry); @@ -226,10 +222,8 @@ fn mac(lhs: &mut u64, a: u64, b: u64, c: u64) -> u64 { #[inline] pub fn mul_nx1(lhs: &mut [u64], a: u64) -> u64 { let mut carry = 0; - for lhs in &mut *lhs { - let product = u128::muladd(*lhs, a, carry); - *lhs = product.low(); - carry = product.high(); + for lhs in lhs { + (*lhs, carry) = u128::muladd(*lhs, a, carry).split(); } carry } @@ -246,12 +240,10 @@ pub fn mul_nx1(lhs: &mut [u64], a: u64) -> u64 { /// $$ #[inline] pub fn addmul_nx1(lhs: &mut [u64], a: &[u64], b: u64) -> u64 { - debug_assert_eq!(lhs.len(), a.len()); + assume!(lhs.len() == a.len()); let mut carry = 0; - for (lhs, a) in lhs.iter_mut().zip(a.iter().copied()) { - let product = u128::muladd2(a, b, carry, *lhs); - *lhs = product.low(); - carry = product.high(); + for i in 0..a.len() { + (lhs[i], carry) = u128::muladd2(a[i], b, carry, lhs[i]).split(); } carry } @@ -269,21 +261,16 @@ pub fn addmul_nx1(lhs: &mut [u64], a: &[u64], b: u64) -> u64 { // OPT: `carry` and `borrow` can probably be merged into a single var. #[inline] pub fn submul_nx1(lhs: &mut [u64], a: &[u64], b: u64) -> u64 { - debug_assert_eq!(lhs.len(), a.len()); + assume!(lhs.len() == a.len()); let mut carry = 0; let mut borrow = 0; - for (lhs, a) in lhs.iter_mut().zip(a.iter().copied()) { + for i in 0..a.len() { // Compute product limbs - let limb = { - let product = u128::muladd(a, b, carry); - carry = product.high(); - product.low() - }; + let limb; + (limb, carry) = u128::muladd(a[i], b, carry).split(); // Subtract - let (new, b) = sbb(*lhs, limb, borrow); - *lhs = new; - borrow = b; + (lhs[i], borrow) = sbb(lhs[i], limb, borrow); } borrow + carry }