Skip to content

Commit

Permalink
Add some NTT functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
xvzcf committed May 28, 2024
1 parent f22a356 commit 1bab36e
Show file tree
Hide file tree
Showing 4 changed files with 341 additions and 8 deletions.
51 changes: 48 additions & 3 deletions libcrux-ml-dsa/src/arithmetic.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
use crate::constants::{BITS_IN_LOWER_PART_OF_T, COEFFICIENTS_IN_RING_ELEMENT, FIELD_MODULUS};

/// Values having this type hold a representative 'x' of the ML-DSA field.
pub(crate) type FieldElement = i32;

#[derive(Clone, Copy)]
pub struct PolynomialRingElement {
pub(crate) coefficients: [FieldElement; COEFFICIENTS_IN_RING_ELEMENT],
Expand All @@ -15,6 +12,46 @@ impl PolynomialRingElement {
};
}

pub(crate) fn get_n_least_significant_bits(n: u8, value: u64) -> u64 {
value & ((1 << n) - 1)
}

/// Values having this type hold a representative 'x' of the ML-DSA field.
pub(crate) type FieldElement = i32;

/// If 'x' denotes a value of type `fe`, values having this type hold a
/// representative y ≡ x·MONTGOMERY_R^(-1) (mod FIELD_MODULUS).
/// We use 'mfe' as a shorthand for this type
pub(crate) type MontgomeryFieldElement = i32;

/// If 'x' denotes a value of type `fe`, values having this type hold a
/// representative y ≡ x·MONTGOMERY_R (mod FIELD_MODULUS).
/// We use 'fer' as a shorthand for this type.
pub(crate) type FieldElementTimesMontgomeryR = i32;

const MONTGOMERY_SHIFT: u8 = 32;
const INVERSE_OF_MODULUS_MOD_MONTGOMERY_R: u64 = 58_728_449; // FIELD_MODULUS^{-1} mod 2^32
pub(crate) fn montgomery_reduce(value: i64) -> MontgomeryFieldElement {
let t = get_n_least_significant_bits(MONTGOMERY_SHIFT, value as u64)
* INVERSE_OF_MODULUS_MOD_MONTGOMERY_R;
let k = get_n_least_significant_bits(MONTGOMERY_SHIFT, t) as i32;

let k_times_modulus = (k as i64) * (FIELD_MODULUS as i64);

let c = (k_times_modulus >> MONTGOMERY_SHIFT) as i32;
let value_high = (value >> MONTGOMERY_SHIFT) as i32;

value_high - c
}

#[inline(always)]
pub(crate) fn montgomery_multiply_fe_by_fer(
fe: FieldElement,
fer: FieldElementTimesMontgomeryR,
) -> FieldElement {
montgomery_reduce((fe as i64) * (fer as i64))
}

// Splits 0 ≤ t < Q into t0 and t1 with a = t1*2ᴰ + t0
// and -2ᴰ⁻¹ < t0 < 2ᴰ⁻¹. Returns t0 and t1 computed as.
//
Expand Down Expand Up @@ -54,6 +91,14 @@ pub(crate) fn t0_to_unsigned_representative(t0: i32) -> i32 {
mod tests {
use super::*;

#[test]
fn test_montgomery_reduce() {
assert_eq!(montgomery_reduce(10933346042510), -1553279);
assert_eq!(montgomery_reduce(-20392060523118), 1331779);
assert_eq!(montgomery_reduce(13704140696092), -1231016);
assert_eq!(montgomery_reduce(-631922212176), -2580954);
}

#[test]
fn test_power2round() {
assert_eq!(power2round(2898283), (-1685, 354));
Expand Down
4 changes: 2 additions & 2 deletions libcrux-ml-dsa/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ mod arithmetic;
mod constants;
mod hash_functions;
mod matrix;
mod ml_dsa_generic;
mod ntt;
mod sample;
mod serialize;
mod utils;

mod ml_dsa_generic;

pub mod ml_dsa_65;
285 changes: 285 additions & 0 deletions libcrux-ml-dsa/src/ntt.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
use super::{
arithmetic::{
montgomery_multiply_fe_by_fer, montgomery_reduce, FieldElementTimesMontgomeryR,
PolynomialRingElement,
},
constants::COEFFICIENTS_IN_RING_ELEMENT,
};

const ZETAS_TIMES_MONTGOMERY_R: [FieldElementTimesMontgomeryR; 256] = [
0, 25847, -2608894, -518909, 237124, -777960, -876248, 466468, 1826347, 2353451, -359251,
-2091905, 3119733, -2884855, 3111497, 2680103, 2725464, 1024112, -1079900, 3585928, -549488,
-1119584, 2619752, -2108549, -2118186, -3859737, -1399561, -3277672, 1757237, -19422, 4010497,
280005, 2706023, 95776, 3077325, 3530437, -1661693, -3592148, -2537516, 3915439, -3861115,
-3043716, 3574422, -2867647, 3539968, -300467, 2348700, -539299, -1699267, -1643818, 3505694,
-3821735, 3507263, -2140649, -1600420, 3699596, 811944, 531354, 954230, 3881043, 3900724,
-2556880, 2071892, -2797779, -3930395, -1528703, -3677745, -3041255, -1452451, 3475950,
2176455, -1585221, -1257611, 1939314, -4083598, -1000202, -3190144, -3157330, -3632928, 126922,
3412210, -983419, 2147896, 2715295, -2967645, -3693493, -411027, -2477047, -671102, -1228525,
-22981, -1308169, -381987, 1349076, 1852771, -1430430, -3343383, 264944, 508951, 3097992,
44288, -1100098, 904516, 3958618, -3724342, -8578, 1653064, -3249728, 2389356, -210977, 759969,
-1316856, 189548, -3553272, 3159746, -1851402, -2409325, -177440, 1315589, 1341330, 1285669,
-1584928, -812732, -1439742, -3019102, -3881060, -3628969, 3839961, 2091667, 3407706, 2316500,
3817976, -3342478, 2244091, -2446433, -3562462, 266997, 2434439, -1235728, 3513181, -3520352,
-3759364, -1197226, -3193378, 900702, 1859098, 909542, 819034, 495491, -1613174, -43260,
-522500, -655327, -3122442, 2031748, 3207046, -3556995, -525098, -768622, -3595838, 342297,
286988, -2437823, 4108315, 3437287, -3342277, 1735879, 203044, 2842341, 2691481, -2590150,
1265009, 4055324, 1247620, 2486353, 1595974, -3767016, 1250494, 2635921, -3548272, -2994039,
1869119, 1903435, -1050970, -1333058, 1237275, -3318210, -1430225, -451100, 1312455, 3306115,
-1962642, -1279661, 1917081, -2546312, -1374803, 1500165, 777191, 2235880, 3406031, -542412,
-2831860, -1671176, -1846953, -2584293, -3724270, 594136, -3776993, -2013608, 2432395, 2454455,
-164721, 1957272, 3369112, 185531, -1207385, -3183426, 162844, 1616392, 3014001, 810149,
1652634, -3694233, -1799107, -3038916, 3523897, 3866901, 269760, 2213111, -975884, 1717735,
472078, -426683, 1723600, -1803090, 1910376, -1667432, -1104333, -260646, -3833893, -2939036,
-2235985, -420899, -2286327, 183443, -976891, 1612842, -3545687, -554416, 3919660, -48306,
-1362209, 3937738, 1400424, -846154, 1976782,
];

#[inline(always)]
fn ntt_at_layer(
zeta_i: &mut usize,
mut re: PolynomialRingElement,
layer: usize,
) -> PolynomialRingElement {
let step = 1 << layer;

for round in 0..(128 >> layer) {
*zeta_i += 1;

let offset = round * step * 2;

for j in offset..offset + step {
let t = montgomery_multiply_fe_by_fer(
re.coefficients[j + step],
ZETAS_TIMES_MONTGOMERY_R[*zeta_i],
);
re.coefficients[j + step] = re.coefficients[j] - t;
re.coefficients[j] = re.coefficients[j] + t;
}
}

re
}

#[inline(always)]
pub(crate) fn ntt(mut re: PolynomialRingElement) -> PolynomialRingElement {
let mut zeta_i = 0;

re = ntt_at_layer(&mut zeta_i, re, 7);
re = ntt_at_layer(&mut zeta_i, re, 6);
re = ntt_at_layer(&mut zeta_i, re, 5);
re = ntt_at_layer(&mut zeta_i, re, 4);
re = ntt_at_layer(&mut zeta_i, re, 3);
re = ntt_at_layer(&mut zeta_i, re, 2);
re = ntt_at_layer(&mut zeta_i, re, 1);
re = ntt_at_layer(&mut zeta_i, re, 0);

re
}

#[inline(always)]
fn invert_ntt_at_layer(
zeta_i: &mut usize,
mut re: PolynomialRingElement,
layer: usize,
) -> PolynomialRingElement {
let step = 1 << layer;

for round in 0..(128 >> layer) {
*zeta_i -= 1;

let offset = round * step * 2;

for j in offset..offset + step {
let a_minus_b = re.coefficients[j + step] - re.coefficients[j];

re.coefficients[j] = re.coefficients[j] + re.coefficients[j + step];
re.coefficients[j + step] =
montgomery_multiply_fe_by_fer(a_minus_b, ZETAS_TIMES_MONTGOMERY_R[*zeta_i]);
}
}

re
}
#[inline(always)]
pub(crate) fn invert_ntt_montgomery(mut re: PolynomialRingElement) -> PolynomialRingElement {
let mut zeta_i = COEFFICIENTS_IN_RING_ELEMENT;

re = invert_ntt_at_layer(&mut zeta_i, re, 0);
re = invert_ntt_at_layer(&mut zeta_i, re, 1);
re = invert_ntt_at_layer(&mut zeta_i, re, 2);
re = invert_ntt_at_layer(&mut zeta_i, re, 3);
re = invert_ntt_at_layer(&mut zeta_i, re, 4);
re = invert_ntt_at_layer(&mut zeta_i, re, 5);
re = invert_ntt_at_layer(&mut zeta_i, re, 6);
re = invert_ntt_at_layer(&mut zeta_i, re, 7);

for i in 0..COEFFICIENTS_IN_RING_ELEMENT {
// TODO: We could probably skip this multiplication, revisit this
// after key-generation is working.
re.coefficients[i] = montgomery_reduce(41978 * (re.coefficients[i] as i64));
}
re
}

pub(crate) fn ntt_multiply(
lhs: &PolynomialRingElement,
rhs: &PolynomialRingElement,
) -> PolynomialRingElement {
let mut out = PolynomialRingElement::ZERO;

for i in 0..COEFFICIENTS_IN_RING_ELEMENT {
out.coefficients[i] =
montgomery_reduce((lhs.coefficients[i] as i64) * (rhs.coefficients[i] as i64));
}

out
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_ntt() {
let re = PolynomialRingElement {
coefficients: [
245230, -429681, -35753, 256940, 138755, -82158, -453212, -296769, 106884, -496329,
-275542, 350156, 295061, 462432, 162727, 219494, 43263, -84315, -100731, 5560,
-38846, 343612, 76881, 427547, 165700, -361163, -18964, 270770, -289948, -326181,
-17540, -376674, -101359, 324588, 265493, -376942, -270029, -201717, -350446,
222164, -314686, -60609, 172509, -199265, 391809, 375196, 333441, -433240, -28862,
274251, -218805, 400627, -408915, 131269, -305167, 78967, -487687, 98675, -430105,
293491, 317484, -180888, -333359, -263010, 258853, -84618, -350795, 334736,
-438451, 479262, -265874, -115692, -521929, -220715, -456043, -24131, 94695,
473893, -503297, 75679, -129421, 83315, -248504, -64226, -24884, 316438, 264565,
-248440, 222228, -386736, 89534, 196079, -196063, 434306, -388976, -29596, 424028,
290804, -348654, 208245, 394447, -105640, -522040, 250479, -443666, -503110,
299944, 497539, -28052, 30579, -332034, 492009, -327080, -173581, -94157, -126088,
388734, 468785, -120589, -146970, -291234, 337402, 311007, -289990, 506654,
-431388, 410292, -376624, 422627, 246536, -273872, 443039, -265954, -250947,
451185, -386654, -19185, -171927, -128698, -277965, 142565, -229030, -470985,
511916, -68612, 272580, -293969, 151888, -53429, -115171, 234680, -482360, -399860,
-268942, 146734, -414798, 502035, -157203, -328592, 266628, 95760, 107840, 354606,
-367167, 396086, -287062, 57888, 140152, 442747, -217984, -69604, -136006, -56581,
202803, -440282, -290558, -192319, -49121, -76454, 426678, 433484, -93094, 244295,
-195275, -262446, -169118, 187824, -60480, -206921, -204671, -407794, -139194,
-182819, 133480, 520760, -17757, -444106, -214471, 457449, 29697, -149734, -497293,
-177518, -266611, -133962, -40139, 9030, 37706, 300290, -370302, 257446, -290991,
353260, 393727, -269498, 249049, -166327, -354566, -309239, 481747, 82459, -425894,
107583, 10935, -498533, 437188, -121594, -90890, -261475, -44165, 394580, -392499,
206781, -222053, -334528, -194081, -373973, -356982, -27220, -444980, 449174,
-391807, 392057, -132521, -441664, -349459, -373059, -296519, 274235, 42417, 47385,
-104540, 142532, 246380, -515363, -422665,
],
};

let expected_coefficients = [
-17129289, -17188287, -11027856, -7293060, -14589541, -12369669, -1420304, -9409026,
-2745174, -2813844, -1829426, 2574100, -5026817, -9781421, -9951567, -7272515, 4818335,
-3195023, -6970219, -7364953, 1800133, -219955, 5457527, -2421101, -2719347, 4851863,
-5375188, -6373272, -6881235, 1470681, 2364683, 4847471, 2424421, -2276079, 2780402,
3720484, 6345079, -150847, 4499295, 3841925, -4612874, 227272, -1650880, -4068714,
1238348, -6241908, 674916, 8597432, 1045161, 2838309, -4022618, -8710072, -3036374,
-3401044, -6864890, -4717312, -3844346, 3755766, 4699242, -1232858, -1007843, -2372141,
-5151898, 2215126, 5056427, 5704699, 11731990, 12381420, 2784890, -2861996, 1452131,
5933279, 4031780, 5298922, 3626052, 4969414, 3453854, -4627414, -1023658, -5769310,
1437156, 1156658, -2817787, -8761943, -2668956, -9522412, -12938019, -10322153,
-9811386, -8779334, 2078963, -4674611, -4110129, 2451543, -4834924, -2503578, -5536189,
-1677443, -6867926, -4019342, -10584384, -7739886, -6447026, -13889812, -6819207,
587959, -7563216, -14153360, -5061746, -11893138, -2225507, -1089121, -1869464,
3296810, 6674836, -1150818, 324295, -509763, -1197550, -5578514, -5136666, -4382368,
3113889, -3428119, 235128, 4223510, 70873, -1793487, 1662772, 7347100, 15227445,
9348419, 9598008, 9940972, 7506539, 9092233, 4526452, 9976840, 6619274, 8638534,
8098748, 4080374, 9497479, 9356635, -239442, 6155758, -2930736, -4891836, 2066938,
7359172, 597336, 7980226, -1781310, -5283606, 596800, 3537228, -8539373, -4044371,
-1411916, 4051564, 2598458, 9958426, 1194732, 9002276, -926584, 5985194, 980962,
856944, 6456619, 8929175, 9047642, 12797200, 11248612, 4324864, 18190009, 10462927,
4906049, 2341517, 3945796, 8377830, 5195877, 10702083, -247762, 4149842, -4852089,
-1576975, 516061, 1908067, 2840273, -4492477, 9446409, 3700267, 346209, 2692483,
-7029253, -5625659, 4093774, -3922644, 2578212, 6694254, -1244120, -1475796, -9388817,
-5401831, -6934520, -8620440, -5385728, -6961628, -8648379, -2747757, -10439151,
-5664161, -1208977, -8828047, -1715189, 5918789, 2038973, -5412689, 4197315, -3211379,
12103869, 4104929, 3182052, 6094506, 1986313, -481257, -3678130, -673934, 2320744,
1656034, -5630954, -3497176, 6334075, 11828589, 6053995, -1775095, 6687195, 7765831,
7946592, 7821130, -2626065, 4613455, 10127838, 3728296, 9154301, 11337805, 8531104,
15979738, 1459696, 8351548, 3335586, 1150210, -2462074, -4642922, 4538634, 1858098,
];

assert_eq!(ntt(re).coefficients, expected_coefficients);
}

#[test]
fn test_invert_ntt_montgomery() {
let re = PolynomialRingElement {
coefficients: [
-1799977, -2102152, -2642101, -635466, -1853482, 642462, 1199623, -2231752,
-3968977, 1443304, 1461464, 2556315, 4140492, -1725885, 4153465, -556916, -2133612,
1372025, 3676714, 3519610, 706947, 788194, 3622849, -2734117, 3727454, -2190265,
208958, 2555531, -1748893, -256927, 1863384, -135807, -2321243, 2766307, -707368,
1548297, 1449797, 3892466, 3417597, -2439676, 1503122, -1273655, 4077536, 6648,
-1763675, 2151278, -2147862, -4095105, 1452564, -194892, 150869, -168533, 3172154,
-4157552, 950081, 1263047, 1073782, -2392089, 3913931, 2548808, -3641576, -133231,
-345656, -3993400, -772374, 650580, 11714, 2745379, 4102554, -1493814, -1216073,
-3687790, -520554, 674428, -2363771, -2062557, 1353060, 421679, -1736183, 3309070,
3705199, -1614110, -1885448, 1502936, -3904361, 2506554, 101679, -2500045, 2538220,
3019542, 1264486, 1681771, 889126, 951808, -3807112, -1917333, 2530518, 2276961,
3921082, 1553244, -2044159, -2836376, 498383, 2971233, 3286160, 1149491, -3659209,
-1963092, -2566288, 2114154, 34024, -2989138, 424058, 3042007, -135014, -2866292,
-1138173, -3010844, -2893275, -2118818, 3839605, 2956371, 2356470, 2895933, 390703,
-1316703, -2882388, 3833928, 2118987, -2371764, 7210, 2032760, 770491, -2466615,
-3672908, -2397815, -3106703, 3523515, 1794988, 2551854, -134246, -4189103,
3840541, -3703204, -2229747, 1599893, 1611447, -1126296, -1497526, 1422269,
-1183163, 861126, -155866, 1642344, 3459388, 2621579, 1200190, -1791368, -2396064,
3313131, -1704442, -1632644, 3659167, 3290628, 1933900, 475446, 1952630, -847369,
2639611, 2205667, -767651, 2248190, -1679262, 2250674, 3194928, 3674776, 2014792,
-1384769, -2579573, 1424682, 1150591, 1027245, -2676627, 3620918, -2364392, 971022,
170291, -16161, 3517252, -4070880, -2207879, -577017, 2484069, 927714, 2453609,
-2953744, 3140280, 3160147, 1667259, -3082713, -4047424, 124404, 3473451, 1419723,
161430, 483773, 2459342, 1207398, 3486346, -2400797, 3217001, 2022150, -694480,
-919315, -3442035, -1734814, -3231832, 2955471, 2104900, 1922217, 1829070, 1605538,
3862195, 1423572, 3831618, 2188925, -967302, 677729, 3187197, 1048944, 1276467,
-3329616, 3735664, 3795986, 4038386, -3516780, -1902194, -880027, -1787327,
-869158, 3693240, 494399, -3852589, -3881813, 2536840, -2924666, 2425664, 2635292,
2752536, -136653, 4057087, -633680, 3039079, -2733512, 1734173, -2109687,
],
};

let expected_coefficients = [
3966085, -2067161, 579114, -3597478, 2232818, -17588, 1194752, -1205114, -4058138,
1212005, -523747, -3757135, -2096288, 1564176, -2621702, 3098337, 1686358, 3045166,
-190650, -3650792, -4016863, -3509278, 53081, -2698465, -3058034, 1934801, -489614,
2562002, 135070, -561684, 1429883, 2143581, -2641675, -2638118, -2881420, 951375,
-1178399, 3905449, -909202, 2293747, -977585, 2405262, -2582841, -3503339, 372978,
-2217708, -2992060, -1261148, 1429205, -1436912, 1169879, 2688127, -1902970, -818037,
1527388, 515446, -1660913, -1628614, 1155517, -2384683, 2424576, -207150, 3423525,
196083, -1457572, 3843617, 670886, -3116174, -630147, 3833721, 162664, 1173694,
3200069, 994675, -354381, 2157831, 3701560, -3878865, 3783818, -3698782, 2695001,
3599085, 2818801, 1802598, 277871, 1672290, -928625, -1037863, 787843, -361648, 182577,
3733189, -2641972, -3072669, 3466030, -2878519, -1137138, -2234722, 892883, -209264,
-3945665, 1153968, -1994007, 1819301, -647462, 831906, 1571924, -1135087, 1990613,
2944454, -2464655, 522799, -3957487, -3013253, -1137760, 1106259, 3564711, -2315418,
-548862, -119514, -3611453, -3293829, 1519241, 1021839, -1511635, 1732685, 702257,
-3656778, 3962669, -944275, 3309609, -1039174, 2265306, -3153610, -410668, -393039,
-2356731, -4083957, 1859494, -2076440, -3697967, -2186461, 931673, -995414, -3309480,
-1686811, -3134252, -680168, 378149, 3825792, -2700073, -1365989, 3367427, -2958570,
2528663, 3774899, -1901765, -2885701, 286776, 2131014, 298955, -4037068, -3037990,
2455918, 704405, -1034441, -1506766, 2257217, -3924289, 1910829, 2993230, -1731617,
-4161994, 3826182, 3755775, 1410753, 302082, 4181290, 369048, -3123017, 1037659,
2483263, -3207, 3709718, -1929249, 1661215, -1332343, 3103632, -1390082, 3718199,
-2596263, 43403, -2068945, 1769551, 1148998, 3519758, 3484982, 1229675, -3917179,
-1790200, 2942297, -367881, -2727323, 1780713, -972875, -4100902, 2103216, 1089969,
3802059, -3600967, 3714015, 2262528, -836384, -4049058, -881894, -3019639, 1871325,
-1127081, 2468781, 396133, -2210254, 1879546, 1761434, -2556875, -2260147, 2063043,
593247, -1102091, 4017202, 1982539, 4103863, -3242, -1313258, -1128572, -1496667,
3051626, 2072070, 1085473, 455772, 311363, -4073347, -1058544, 1001208, -3106675,
1281322, 1054592, 2483921, -893262, 392334, 3052309, -3717274, 1212358, -4009407,
-3909173, 1453538, -4079655,
];

assert_eq!(
invert_ntt_montgomery(re).coefficients,
expected_coefficients
);
}
}
Loading

0 comments on commit 1bab36e

Please sign in to comment.