Skip to content

Commit

Permalink
Add missing arithmetic ops
Browse files Browse the repository at this point in the history
survived committed Feb 27, 2024

Verified

This commit was signed with the committer’s verified signature.
survived Denis Varlakov
1 parent dd78eb1 commit c0ec6d0
Showing 2 changed files with 190 additions and 28 deletions.
178 changes: 150 additions & 28 deletions generic-ec/src/arithmetic.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub};
use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};

use crate::{Curve, Generator, NonZero, Point, Scalar, SecretScalar};

@@ -54,6 +54,38 @@ mod laws {
Point::from_raw_unchecked(result)
}

/// If $A$ is valid `Point<E>`, then $A + G$ is valid `Point<E>`
#[inline]
pub fn sum_of_point_and_generator_is_valid_point<E: Curve>(
a: &Point<E>,
g: &Generator<E>,
) -> Point<E> {
sum_of_points_is_valid_point(a, &g.to_point())
}
/// If $A$ is valid `Point<E>`, then $G + A$ is valid `Point<E>`
#[inline]
pub fn sum_of_generator_and_point_is_valid_point<E: Curve>(
g: &Generator<E>,
a: &Point<E>,
) -> Point<E> {
sum_of_points_is_valid_point(&g.to_point(), a)
}

/// If $A$ is valid `Point<E>`, then $A - G$ is valid `Point<E>`
pub fn sub_of_point_and_generator_is_valid_point<E: Curve>(
a: &Point<E>,
g: &Generator<E>,
) -> Point<E> {
sub_of_points_is_valid_point(a, &g.to_point())
}
/// If $A$ is valid `Point<E>`, then $G - A$ is valid `Point<E>`
pub fn sub_of_generator_and_point_is_valid_point<E: Curve>(
g: &Generator<E>,
a: &Point<E>,
) -> Point<E> {
sub_of_points_is_valid_point(&g.to_point(), a)
}

/// If $A$ is a valid `Point<E>`, then $-A$ is a valid `Point<E>`
///
/// [`sub_of_points_is_valid_point`] proves that subtraction of two valid `Point<E>` is a
@@ -212,7 +244,7 @@ mod laws {

mod scalar {
use crate::as_raw::{AsRaw, FromRaw};
use crate::core::*;
use crate::{core::*, SecretScalar};
use crate::{NonZero, Scalar};

#[inline]
@@ -245,6 +277,14 @@ mod scalar {
// Correctness: since `a` is not zero, `-a` is not zero by definition
NonZero::new_unchecked(neg)
}

#[inline]
pub fn neg_nonzero_secret<E: Curve>(a: &NonZero<SecretScalar<E>>) -> NonZero<SecretScalar<E>> {
let mut a: Scalar<E> = *a.as_ref();
a *= -Scalar::one();
// Correctness: since `a` is not zero, `-a` is not zero by definition
NonZero::new_unchecked(SecretScalar::new(&mut a))
}
}

macro_rules! impl_binary_ops {
@@ -329,6 +369,11 @@ impl_binary_ops! {
Add (Point<E>, add, Point<E> = Point<E>) laws::sum_of_points_is_valid_point,
Sub (Point<E>, sub, Point<E> = Point<E>) laws::sub_of_points_is_valid_point,

Add (Point<E>, add, Generator<E> = Point<E>) laws::sum_of_point_and_generator_is_valid_point,
Add (Generator<E>, add, Point<E> = Point<E>) laws::sum_of_generator_and_point_is_valid_point,
Sub (Point<E>, sub, Generator<E> = Point<E>) laws::sub_of_point_and_generator_is_valid_point,
Sub (Generator<E>, sub, Point<E> = Point<E>) laws::sub_of_generator_and_point_is_valid_point,

Add (Scalar<E>, add, Scalar<E> = Scalar<E>) scalar::add,
Sub (Scalar<E>, sub, Scalar<E> = Scalar<E>) scalar::sub,
Mul (Scalar<E>, mul, Scalar<E> = Scalar<E>) scalar::mul,
@@ -432,13 +477,42 @@ impl_unary_ops! {
Neg (neg Scalar<E>) scalar::neg,
Neg (neg NonZero<Point<E>>) laws::neg_nonzero_point_is_nonzero_point,
Neg (neg NonZero<Scalar<E>>) scalar::neg_nonzero,
Neg (neg NonZero<SecretScalar<E>>) scalar::neg_nonzero_secret,
}

impl_op_assign! {
Point<E>, AddAssign, Point<E>, add_assign, +,
Point<E>, AddAssign, NonZero<Point<E>>, add_assign, +,
Point<E>, AddAssign, Generator<E>, add_assign, +,

Point<E>, SubAssign, Point<E>, sub_assign, +,
Point<E>, SubAssign, NonZero<Point<E>>, sub_assign, +,
Point<E>, SubAssign, Generator<E>, sub_assign, +,

Point<E>, MulAssign, Scalar<E>, mul_assign, *,
Point<E>, MulAssign, NonZero<Scalar<E>>, mul_assign, *,
Point<E>, MulAssign, SecretScalar<E>, mul_assign, *,
Point<E>, MulAssign, NonZero<SecretScalar<E>>, mul_assign, *,

Scalar<E>, AddAssign, Scalar<E>, add_assign, +,
Scalar<E>, AddAssign, NonZero<Scalar<E>>, add_assign, +,
Scalar<E>, AddAssign, SecretScalar<E>, add_assign, +,
Scalar<E>, AddAssign, NonZero<SecretScalar<E>>, add_assign, +,

Scalar<E>, SubAssign, Scalar<E>, sub_assign, +,
Scalar<E>, SubAssign, NonZero<Scalar<E>>, sub_assign, +,
Scalar<E>, SubAssign, SecretScalar<E>, sub_assign, +,
Scalar<E>, SubAssign, NonZero<SecretScalar<E>>, sub_assign, +,

Scalar<E>, MulAssign, Scalar<E>, mul_assign, *,
Scalar<E>, MulAssign, NonZero<Scalar<E>>, mul_assign, *,
Scalar<E>, MulAssign, SecretScalar<E>, mul_assign, *,
Scalar<E>, MulAssign, NonZero<SecretScalar<E>>, mul_assign, *,

NonZero<Point<E>>, MulAssign, NonZero<Scalar<E>>, mul_assign, *,
NonZero<Point<E>>, MulAssign, NonZero<SecretScalar<E>>, mul_assign, *,
NonZero<Scalar<E>>, MulAssign, NonZero<Scalar<E>>, mul_assign, *,
NonZero<Scalar<E>>, MulAssign, NonZero<SecretScalar<E>>, mul_assign, *,
}

#[cfg(test)]
@@ -453,38 +527,46 @@ fn ensure_ops_implemented<E: Curve>(
non_zero_secret_scalar: NonZero<SecretScalar<E>>,
) {
macro_rules! assert_binary_ops {
($($a:ident $op:tt $b:expr => $out:ty),+,) => {$(
let _: $out = $a $op $b;
let _: $out = &$a $op $b;
let _: $out = $a $op &$b;
($($a:ident $op:tt $b:ident => $out:ty),+,) => {$(
let _: $out = $a.clone() $op $b.clone();
let _: $out = &$a $op $b.clone();
let _: $out = $a.clone() $op &$b;
let _: $out = &$a $op &$b;

let _: $out = $b $op $a;
let _: $out = &$b $op $a;
let _: $out = $b $op &$a;
let _: $out = $b.clone() $op $a.clone();
let _: $out = &$b $op $a.clone();
let _: $out = $b.clone() $op &$a;
let _: $out = &$b $op &$a;
)+};
}
macro_rules! assert_unary_ops {
($($op:tt $a:ident => $out:ty),+,) => {$(
let _: $out = $op $a;
let _: $out = $op $a.clone();
let _: $out = $op &$a;
)+};
}

macro_rules! assert_op_assign {
($($a:ident $op:tt $b:ident);+;) => {{$(
let mut a = $a.clone();
a $op $b.clone();
a $op &$b;
)+}};
}

assert_binary_ops!(
g * scalar => Point<E>,
point * scalar => Point<E>,
g * non_zero_scalar => NonZero<Point<E>>,
non_zero_point * non_zero_scalar => NonZero<Point<E>>,

g * secret_scalar.clone() => Point<E>,
point * secret_scalar.clone() => Point<E>,
non_zero_point * secret_scalar.clone() => Point<E>,
g * secret_scalar => Point<E>,
point * secret_scalar => Point<E>,
non_zero_point * secret_scalar => Point<E>,

g * non_zero_secret_scalar.clone() => NonZero<Point<E>>,
point * non_zero_secret_scalar.clone() => Point<E>,
non_zero_point * non_zero_secret_scalar.clone() => NonZero<Point<E>>,
g * non_zero_secret_scalar => NonZero<Point<E>>,
point * non_zero_secret_scalar => Point<E>,
non_zero_point * non_zero_secret_scalar => NonZero<Point<E>>,

point + point => Point<E>,
point + non_zero_point => Point<E>,
@@ -498,37 +580,77 @@ fn ensure_ops_implemented<E: Curve>(
scalar + non_zero_scalar => Scalar<E>,
non_zero_scalar + non_zero_scalar => Scalar<E>,

scalar + secret_scalar.clone() => Scalar<E>,
non_zero_scalar + secret_scalar.clone() => Scalar<E>,
scalar + secret_scalar => Scalar<E>,
non_zero_scalar + secret_scalar => Scalar<E>,

scalar + non_zero_secret_scalar.clone() => Scalar<E>,
non_zero_scalar + non_zero_secret_scalar.clone() => Scalar<E>,
scalar + non_zero_secret_scalar => Scalar<E>,
non_zero_scalar + non_zero_secret_scalar => Scalar<E>,

scalar - scalar => Scalar<E>,
scalar - non_zero_scalar => Scalar<E>,
non_zero_scalar - non_zero_scalar => Scalar<E>,

scalar - secret_scalar.clone() => Scalar<E>,
non_zero_scalar - secret_scalar.clone() => Scalar<E>,
scalar - secret_scalar => Scalar<E>,
non_zero_scalar - secret_scalar => Scalar<E>,

scalar - non_zero_secret_scalar.clone() => Scalar<E>,
non_zero_scalar - non_zero_secret_scalar.clone() => Scalar<E>,
scalar - non_zero_secret_scalar => Scalar<E>,
non_zero_scalar - non_zero_secret_scalar => Scalar<E>,

scalar * scalar => Scalar<E>,
scalar * non_zero_scalar => Scalar<E>,
non_zero_scalar * non_zero_scalar => NonZero<Scalar<E>>,

scalar * secret_scalar.clone() => Scalar<E>,
non_zero_scalar * secret_scalar.clone() => Scalar<E>,
scalar * secret_scalar => Scalar<E>,
non_zero_scalar * secret_scalar => Scalar<E>,

scalar * non_zero_secret_scalar => Scalar<E>,
non_zero_scalar * non_zero_secret_scalar => NonZero<Scalar<E>>,

scalar * non_zero_secret_scalar.clone() => Scalar<E>,
non_zero_scalar * non_zero_secret_scalar.clone() => NonZero<Scalar<E>>,
non_zero_secret_scalar + non_zero_secret_scalar => Scalar<E>,
non_zero_secret_scalar - non_zero_secret_scalar => Scalar<E>,
non_zero_secret_scalar * non_zero_secret_scalar => NonZero<Scalar<E>>,
);

assert_unary_ops!(
-point => Point<E>,
-non_zero_point => NonZero<Point<E>>,
-scalar => Scalar<E>,
-non_zero_scalar => NonZero<Scalar<E>>,
-non_zero_secret_scalar => NonZero<SecretScalar<E>>,
);

assert_op_assign!(
point += point;
point += non_zero_point;
point += g;

point -= point;
point -= non_zero_point;
point -= g;

point *= scalar;
point *= non_zero_scalar;
point *= secret_scalar;
point *= non_zero_scalar;

non_zero_point *= non_zero_scalar;

scalar += scalar;
scalar -= scalar;
scalar *= scalar;

scalar += non_zero_scalar;
scalar -= non_zero_scalar;
scalar *= non_zero_scalar;

scalar += secret_scalar;
scalar -= secret_scalar;
scalar *= secret_scalar;

scalar += non_zero_secret_scalar;
scalar -= non_zero_secret_scalar;
scalar *= non_zero_secret_scalar;

non_zero_scalar *= non_zero_scalar;
);
}
40 changes: 40 additions & 0 deletions generic-ec/src/non_zero/mod.rs
Original file line number Diff line number Diff line change
@@ -96,6 +96,14 @@ impl<E: Curve> NonZero<Scalar<E>> {
// Correctness: `inv` is nonzero by definition
Self::new_unchecked(inv)
}

/// Upgrades the non-zero scalar into non-zero [`SecretScalar`]
pub fn into_secret(self) -> NonZero<SecretScalar<E>> {
let mut scalar = self.into_inner();
let secret_scalar = SecretScalar::new(&mut scalar);
// Correctness: `scalar` was checked to be nonzero
NonZero::new_unchecked(secret_scalar)
}
}

impl<E: Curve> NonZero<SecretScalar<E>> {
@@ -213,6 +221,22 @@ impl<'s, E: Curve> Sum<&'s NonZero<Scalar<E>>> for Scalar<E> {
}
}

impl<'s, E: Curve> Sum<&'s NonZero<SecretScalar<E>>> for SecretScalar<E> {
fn sum<I: Iterator<Item = &'s NonZero<SecretScalar<E>>>>(iter: I) -> Self {
let mut out = Scalar::zero();
iter.for_each(|x| out += x);
SecretScalar::new(&mut out)
}
}

impl<E: Curve> Sum<NonZero<SecretScalar<E>>> for SecretScalar<E> {
fn sum<I: Iterator<Item = NonZero<SecretScalar<E>>>>(iter: I) -> Self {
let mut out = Scalar::zero();
iter.for_each(|x| out += x);
SecretScalar::new(&mut out)
}
}

impl<E: Curve> Product<NonZero<Scalar<E>>> for NonZero<Scalar<E>> {
fn product<I: Iterator<Item = NonZero<Scalar<E>>>>(iter: I) -> Self {
iter.fold(Self::one(), |acc, x| acc * x)
@@ -225,6 +249,22 @@ impl<'s, E: Curve> Product<&'s NonZero<Scalar<E>>> for NonZero<Scalar<E>> {
}
}

impl<'s, E: Curve> Product<&'s NonZero<SecretScalar<E>>> for NonZero<SecretScalar<E>> {
fn product<I: Iterator<Item = &'s NonZero<SecretScalar<E>>>>(iter: I) -> Self {
let mut out = NonZero::<Scalar<E>>::one();
iter.for_each(|x| out *= x);
out.into_secret()
}
}

impl<E: Curve> Product<NonZero<SecretScalar<E>>> for NonZero<SecretScalar<E>> {
fn product<I: Iterator<Item = NonZero<SecretScalar<E>>>>(iter: I) -> Self {
let mut out = NonZero::<Scalar<E>>::one();
iter.for_each(|x| out *= x);
out.into_secret()
}
}

impl<E: Curve> crate::traits::Samplable for NonZero<Scalar<E>> {
fn random<R: RngCore>(rng: &mut R) -> Self {
Self::random(rng)

0 comments on commit c0ec6d0

Please sign in to comment.