Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions ml-dsa/src/algebra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,53 @@ impl<K: ArraySize> AlgebraExt for Vector<K> {
)
}
}

#[cfg(test)]
mod test {
use super::*;

use crate::{MlDsa65, ParameterSet};

type Mod = <MlDsa65 as ParameterSet>::TwoGamma2;
const MOD: u32 = Mod::U32;
const MOD_ELEM: Elem = Elem::new(MOD);

#[test]
fn mod_plus_minus() {
for x in 0..MOD {
// BaseField::Q {
let x = Elem::new(x);
let x0 = x.mod_plus_minus::<Mod>();

// Outputs from mod+- should be in the half-open interval (-gamma2, gamma2]
let positive_bound = x0.0 <= MOD / 2;
let negative_bound = x0.0 > BaseField::Q - MOD / 2;
assert!(positive_bound || negative_bound);

// The output should be equivalent to the input, mod 2 * gamma2. We add 2 * gamma2
// before comparing so that both values are "positive", avoiding interactions between
// the mod-Q and mod-M operations.
let xn = x + MOD_ELEM;
let x0n = x0 + MOD_ELEM;
assert_eq!(xn.0 % MOD, x0n.0 % MOD);
}
}

#[test]
fn decompose() {
for x in 0..MOD {
let x = Elem::new(x);
let (x1, x0) = x.decompose::<Mod>();

// The low-order output from decompose() is a mod+- output, optionally minus one. So
// they should be in the closed interval [-gamma2, gamma2].
let positive_bound = x0.0 <= MOD / 2;
let negative_bound = x0.0 >= BaseField::Q - MOD / 2;
assert!(positive_bound || negative_bound);

// The low-order and high-order outputs should combine to form the input.
let xx = (MOD * x1.0 + x0.0) % BaseField::Q;
assert_eq!(xx, x.0);
}
}
}
8 changes: 4 additions & 4 deletions ml-dsa/src/hint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,18 @@ fn use_hint<TwoGamma2: Unsigned>(h: bool, r: Elem) -> Elem {
let gamma2 = TwoGamma2::U32 / 2;
if h && r0.0 <= gamma2 {
Elem::new((r1.0 + 1) % m)
} else if h && r0.0 > BaseField::Q - gamma2 {
} else if h && r0.0 >= BaseField::Q - gamma2 {
Elem::new((r1.0 + m - 1) % m)
} else if h {
// We use the Elem encoding even for signed integers. Since r0 is computed
// mod+- 2*gamma2, it is guaranteed to be in (gamma2, gamma2].
// mod+- 2*gamma2 (possibly minus 1), it is guaranteed to be in [-gamma2, gamma2].
unreachable!();
} else {
r1
}
}

#[derive(Clone, PartialEq)]
#[derive(Clone, PartialEq, Debug)]
pub struct Hint<P>(pub Array<Array<bool, U256>, P::K>)
where
P: SignatureParams;
Expand Down Expand Up @@ -116,7 +116,7 @@ where
}

fn monotonic(a: &[usize]) -> bool {
a.iter().enumerate().all(|(i, x)| i == 0 || a[i - 1] < *x)
a.iter().enumerate().all(|(i, x)| i == 0 || a[i - 1] <= *x)
}

pub fn bit_unpack(y: &EncodedHint<P>) -> Option<Self> {
Expand Down
40 changes: 39 additions & 1 deletion ml-dsa/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ pub use crate::util::B32;
pub use signature::Error;

/// An ML-DSA signature
#[derive(Clone, PartialEq)]
#[derive(Clone, PartialEq, Debug)]
pub struct Signature<P: MlDsaParams> {
c_tilde: Array<u8, P::Lambda>,
z: Vector<P::L>,
Expand Down Expand Up @@ -899,4 +899,42 @@ mod test {
sign_verify_round_trip_test::<MlDsa65>();
sign_verify_round_trip_test::<MlDsa87>();
}

fn many_round_trip_test<P>()
where
P: MlDsaParams,
{
use rand::Rng;

const ITERATIONS: usize = 1000;

let mut rng = rand::thread_rng();
let mut seed = B32::default();

for _i in 0..ITERATIONS {
let seed_data: &mut [u8] = seed.as_mut();
rng.fill(seed_data);

let kp = P::key_gen_internal(&seed);
let sk = kp.signing_key;
let vk = kp.verifying_key;

let M = b"Hello world";
let rnd = Array([0u8; 32]);
let sig = sk.sign_internal(&[M], &rnd);

let sig_enc = sig.encode();
let sig_dec = Signature::<P>::decode(&sig_enc).unwrap();

assert_eq!(sig_dec, sig);
assert!(vk.verify_internal(&[M], &sig_dec));
}
}

#[test]
fn many_round_trip() {
many_round_trip_test::<MlDsa44>();
many_round_trip_test::<MlDsa65>();
many_round_trip_test::<MlDsa87>();
}
}