Skip to content

Commit

Permalink
Add {EdwardsPoint, RistrettoPoint}::mul_small_scalar
Browse files Browse the repository at this point in the history
There are many applications where scalars may be statically known to
be less than the full range of the Ristretto group. In those cases,
multiplication can be done more cheaply while still being
constant-time. mul_small_scalar exposes that functionality for scalars
strictly less than 2^127.
  • Loading branch information
jrose-signal committed Feb 6, 2024
1 parent a12ab4e commit b5463d2
Show file tree
Hide file tree
Showing 9 changed files with 98 additions and 17 deletions.
9 changes: 9 additions & 0 deletions curve25519-dalek/benches/dalek_benchmarks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ mod edwards_benches {
});
}

fn consttime_variable_base_scalar_mul_small<M: Measurement>(c: &mut BenchmarkGroup<M>) {
let B = &constants::ED25519_BASEPOINT_POINT;
let s = Scalar::from(897987897u64).invert();
c.bench_function("Constant-time variable-base scalar mul small", move |b| {
b.iter(|| B.mul_small_scalar(&s))
});
}

fn vartime_double_base_scalar_mul<M: Measurement>(c: &mut BenchmarkGroup<M>) {
c.bench_function("Variable-time aA+bB, A variable, B fixed", |bench| {
let mut rng = thread_rng();
Expand All @@ -65,6 +73,7 @@ mod edwards_benches {
decompress(&mut g);
consttime_fixed_base_scalar_mul(&mut g);
consttime_variable_base_scalar_mul(&mut g);
consttime_variable_base_scalar_mul_small(&mut g);
vartime_double_base_scalar_mul(&mut g);
}
}
Expand Down
2 changes: 1 addition & 1 deletion curve25519-dalek/src/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ pub fn variable_base_mul(point: &EdwardsPoint, scalar: &Scalar) -> EdwardsPoint
BackendKind::Avx512 => {
self::vector::scalar_mul::variable_base::spec_avx512ifma_avx512vl::mul(point, scalar)
}
BackendKind::Serial => self::serial::scalar_mul::variable_base::mul(point, scalar),
BackendKind::Serial => self::serial::scalar_mul::variable_base::mul::<64>(point, scalar),
}
}

Expand Down
2 changes: 1 addition & 1 deletion curve25519-dalek/src/backend/serial/scalar_mul/straus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ impl MultiscalarMul for Straus {
#[cfg_attr(not(feature = "zeroize"), allow(unused_mut))]
let mut scalar_digits: Vec<_> = scalars
.into_iter()
.map(|s| s.borrow().as_radix_16())
.map(|s| s.borrow().as_radix_16::<64>())
.collect();

let mut Q = EdwardsPoint::identity();
Expand Down
11 changes: 7 additions & 4 deletions curve25519-dalek/src/backend/serial/scalar_mul/variable_base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@ use crate::traits::Identity;
use crate::window::LookupTable;

/// Perform constant-time, variable-base scalar multiplication.
///
/// MODIFIED BY SIGNAL: The generic parameter `N` is the maximum number of **nibbles** in `scalar`,
/// with the top bit of the top nibble clear. See [`Scalar::as_radix_16`].
#[rustfmt::skip] // keep alignment of explanatory comments
pub(crate) fn mul(point: &EdwardsPoint, scalar: &Scalar) -> EdwardsPoint {
pub(crate) fn mul<const N: usize>(point: &EdwardsPoint, scalar: &Scalar) -> EdwardsPoint {
// Construct a lookup table of [P,2P,3P,4P,5P,6P,7P,8P]
let lookup_table = LookupTable::<ProjectiveNielsPoint>::from(point);
// Setting s = scalar, compute
Expand All @@ -17,7 +20,7 @@ pub(crate) fn mul(point: &EdwardsPoint, scalar: &Scalar) -> EdwardsPoint {
//
// with `-8 ≤ s_i < 8` for `0 ≤ i < 63` and `-8 ≤ s_63 ≤ 8`.
// This decomposition requires s < 2^255, which is guaranteed by Scalar invariant #1.
let scalar_digits = scalar.as_radix_16();
let scalar_digits = scalar.as_radix_16::<N>();
// Compute s*P as
//
// s*P = P*(s_0 + s_1*16^1 + s_2*16^2 + ... + s_63*16^63)
Expand All @@ -29,9 +32,9 @@ pub(crate) fn mul(point: &EdwardsPoint, scalar: &Scalar) -> EdwardsPoint {
// Unwrap first loop iteration to save computing 16*identity
let mut tmp2;
let mut tmp3 = EdwardsPoint::identity();
let mut tmp1 = &tmp3 + &lookup_table.select(scalar_digits[63]);
let mut tmp1 = &tmp3 + &lookup_table.select(scalar_digits[N-1]);
// Now tmp1 = s_63*P in P1xP1 coords
for i in (0..63).rev() {
for i in (0..(N-1)).rev() {
tmp2 = tmp1.as_projective(); // tmp2 = (prev) in P2 coords
tmp1 = tmp2.double(); // tmp1 = 2*(prev) in P1xP1 coords
tmp2 = tmp1.as_projective(); // tmp2 = 2*(prev) in P2 coords
Expand Down
2 changes: 1 addition & 1 deletion curve25519-dalek/src/backend/vector/scalar_mul/straus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ pub mod spec {

let scalar_digits_vec: Vec<_> = scalars
.into_iter()
.map(|s| s.borrow().as_radix_16())
.map(|s| s.borrow().as_radix_16::<64>())
.collect();
// Pass ownership to a `Zeroizing` wrapper
#[cfg(feature = "zeroize")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub mod spec {
// s = s_0 + s_1*16^1 + ... + s_63*16^63,
//
// with `-8 ≤ s_i < 8` for `0 ≤ i < 63` and `-8 ≤ s_63 ≤ 8`.
let scalar_digits = scalar.as_radix_16();
let scalar_digits = scalar.as_radix_16::<64>();
// Compute s*P as
//
// s*P = P*(s_0 + s_1*16^1 + s_2*16^2 + ... + s_63*16^63)
Expand Down
56 changes: 55 additions & 1 deletion curve25519-dalek/src/edwards.rs
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,19 @@ impl EdwardsPoint {
};
Self::mul_base(&s)
}

/// Limited scalar multiplication: compute `scalar * self` when `scalar` is known to be less
/// than 2^127.
///
/// This is still constant-time.
pub fn mul_small_scalar(&self, scalar: &Scalar) -> Self {
assert!(
scalar.as_bytes()[16..].iter().all(|&b| b == 0)
&& scalar.as_bytes()[15] & 0b1000_0000 == 0,
"scalar is out of range for mul_small_scalar"
);
crate::backend::serial::scalar_mul::variable_base::mul::<32>(self, scalar)
}
}

// ------------------------------------------------------------------------
Expand Down Expand Up @@ -1591,7 +1604,7 @@ impl CofactorGroup for EdwardsPoint {
#[cfg(test)]
mod test {
use super::*;
use crate::{field::FieldElement, scalar::Scalar};
use crate::{constants::ED25519_BASEPOINT_POINT, field::FieldElement, scalar::Scalar};
use subtle::ConditionallySelectable;

#[cfg(feature = "alloc")]
Expand Down Expand Up @@ -1929,6 +1942,47 @@ mod test {
}
}

/// Check that Mul and mul_small_scalar agree
#[test]
fn mul_small_scalar() {
let mut csprng = rand_core::OsRng;

// Make a random curve point in the curve. Give it torsion to make things interesting.
let random_point = {
let mut b = [0u8; 32];
csprng.fill_bytes(&mut b);
EdwardsPoint::mul_base_clamped(b) + constants::EIGHT_TORSION[1]
};

// Test agreement on random integers less than 2^127
for _ in 0..100 {
let mut bytes = [0u8; 32];
csprng.fill_bytes(&mut bytes[..16]);
bytes[15] &= 0b0111_1111;
let a = Scalar { bytes };

assert_eq!(random_point * a, random_point.mul_small_scalar(&a));
}
}

#[test]
#[should_panic]
fn mul_small_scalar_2_127() {
let mut bytes = [0u8; 32];
bytes[15] = 0b1000_0000;
let a = Scalar { bytes };
_ = ED25519_BASEPOINT_POINT.mul_small_scalar(&a);
}

#[test]
#[should_panic]
fn mul_small_scalar_2_128() {
let mut bytes = [0u8; 32];
bytes[16] = 1;
let a = Scalar { bytes };
_ = ED25519_BASEPOINT_POINT.mul_small_scalar(&a);
}

#[test]
#[cfg(feature = "alloc")]
fn impl_sum() {
Expand Down
8 changes: 8 additions & 0 deletions curve25519-dalek/src/ristretto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,14 @@ impl RistrettoPoint {
scalar * constants::RISTRETTO_BASEPOINT_TABLE
}
}

/// Limited scalar multiplication: compute `scalar * self` when `scalar` is known to be less
/// than 2^127.
///
/// This is still constant-time.
pub fn mul_small_scalar(&self, scalar: &Scalar) -> Self {
RistrettoPoint(self.0.mul_small_scalar(scalar))
}
}

define_mul_assign_variants!(LHS = RistrettoPoint, RHS = Scalar);
Expand Down
23 changes: 15 additions & 8 deletions curve25519-dalek/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -984,9 +984,16 @@ impl Scalar {
/// The largest value that can be decomposed like this is just over \\(2^{255}\\). Thus, in
/// order to not error, the top bit MUST NOT be set, i.e., `Self` MUST be less than
/// \\(2^{255}\\).
pub(crate) fn as_radix_16(&self) -> [i8; 64] {
debug_assert!(self[31] <= 127);
let mut output = [0i8; 64];
///
/// MODIFIED BY SIGNAL: The output can be truncated to any even number of nibbles `N` as long as
/// the top bit is not set in the last nibble, and no bits are set in the truncated nibbles.
pub(crate) fn as_radix_16<const N: usize>(&self) -> [i8; N] {
debug_assert!(N % 2 == 0);
debug_assert!(self[N/2 - 1] <= 127);
for i in (N/2)..32 {
debug_assert!(self[i] == 0);
}
let mut output = [0i8; N];

// Step 1: change radix.
// Convert from radix 256 (bytes) to radix 16 (nibbles)
Expand All @@ -1000,20 +1007,20 @@ impl Scalar {
(x >> 4) & 15
}

for i in 0..32 {
for i in 0..(N/2) {
output[2 * i] = bot_half(self[i]) as i8;
output[2 * i + 1] = top_half(self[i]) as i8;
}
// Precondition note: since self[31] <= 127, output[63] <= 7
// Precondition note: since self.last() <= 127, output.last() <= 7

// Step 2: recenter coefficients from [0,16) to [-8,8)
for i in 0..63 {
for i in 0..(N-1) {
let carry = (output[i] + 8) >> 4;
output[i] -= carry << 4;
output[i + 1] += carry;
}
// Precondition note: output[63] is not recentered. It
// increases by carry <= 1. Thus output[63] <= 8.
// Precondition note: output[N-1] is not recentered. It
// increases by carry <= 1. Thus output[N-1] <= 8.

output
}
Expand Down

0 comments on commit b5463d2

Please sign in to comment.