Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow hax extraction for ML-DSA #558

Merged
merged 13 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/hax.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,7 @@ jobs:
HAX_HOME=${{ github.workspace }}/hax \
PATH="${PATH}:${{ github.workspace }}/fstar/bin" \
./hax.py prove --admit

- name: 🏃 Extract ML-DSA crate
working-directory: libcrux-ml-dsa
run: cargo hax into fstar
132 changes: 127 additions & 5 deletions libcrux-intrinsics/src/avx2_extract.rs
jschneider-bensch marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,27 @@

pub type Vec256 = u8;
pub type Vec128 = u8;
pub type Vec256Float = u8;

pub fn mm256_storeu_si256_u8(output: &mut [u8], vector: Vec256) {
debug_assert_eq!(output.len(), 32);
unimplemented!()
}
pub fn mm256_storeu_si256_i16(output: &mut [i16], vector: Vec256) {
debug_assert_eq!(output.len(), 16);
unimplemented!()
}

pub fn mm256_storeu_si256_u8(output: &mut [u8], vector: Vec256) {
debug_assert_eq!(output.len(), 32);
pub fn mm256_storeu_si256_i32(output: &mut [i32], vector: Vec256) {
debug_assert_eq!(output.len(), 8);
unimplemented!()
}

pub fn mm_storeu_si128(output: &mut [i16], vector: Vec128) {
// debug_assert_eq!(output.len(), 8);
debug_assert!(output.len() >= 8);
unimplemented!()
}
pub fn mm_storeu_si128_i32(output: &mut [i32], vector: Vec128) {
debug_assert_eq!(output.len(), 4);
unimplemented!()
}

Expand All @@ -34,15 +43,21 @@ pub fn mm256_loadu_si256_u8(input: &[u8]) -> Vec256 {
debug_assert_eq!(input.len(), 32);
unimplemented!()
}

pub fn mm256_loadu_si256_i16(input: &[i16]) -> Vec256 {
debug_assert_eq!(input.len(), 16);
unimplemented!()
}
pub fn mm256_loadu_si256_i32(input: &[i32]) -> Vec256 {
debug_assert_eq!(input.len(), 8);
unimplemented!()
}

pub fn mm256_setzero_si256() -> Vec256 {
unimplemented!()
}
pub fn mm256_set_m128i(hi: Vec128, lo: Vec128) -> Vec256 {
unimplemented!()
}

pub fn mm_set_epi8(
byte15: u8,
Expand Down Expand Up @@ -126,13 +141,21 @@ pub fn mm256_set_epi16(
unimplemented!()
}

#[inline(always)]
pub fn mm_set1_epi16(constant: i16) -> Vec128 {
unimplemented!()
}

#[inline(always)]
pub fn mm256_set1_epi32(constant: i32) -> Vec256 {
unimplemented!()
}

#[inline(always)]
pub fn mm_set_epi32(input3: i32, input2: i32, input1: i32, input0: i32) -> Vec128 {
unimplemented!()
}
#[inline(always)]
pub fn mm256_set_epi32(
input7: i32,
input6: i32,
Expand All @@ -146,22 +169,40 @@ pub fn mm256_set_epi32(
unimplemented!()
}

#[inline(always)]
pub fn mm_add_epi16(lhs: Vec128, rhs: Vec128) -> Vec128 {
unimplemented!()
}
#[inline(always)]
pub fn mm256_add_epi16(lhs: Vec256, rhs: Vec256) -> Vec256 {
unimplemented!()
}
#[inline(always)]
pub fn mm256_madd_epi16(lhs: Vec256, rhs: Vec256) -> Vec256 {
unimplemented!()
}
#[inline(always)]
pub fn mm256_add_epi32(lhs: Vec256, rhs: Vec256) -> Vec256 {
unimplemented!()
}

#[inline(always)]
pub fn mm256_add_epi64(lhs: Vec256, rhs: Vec256) -> Vec256 {
unimplemented!()
}

#[inline(always)]
pub fn mm256_abs_epi32(a: Vec256) -> Vec256 {
unimplemented!()
}

pub fn mm256_sub_epi16(lhs: Vec256, rhs: Vec256) -> Vec256 {
unimplemented!()
}
pub fn mm256_sub_epi32(lhs: Vec256, rhs: Vec256) -> Vec256 {
unimplemented!()
}

pub fn mm_sub_epi16(lhs: Vec128, rhs: Vec128) -> Vec128 {
unimplemented!()
}
Expand All @@ -174,9 +215,33 @@ pub fn mm_mullo_epi16(lhs: Vec128, rhs: Vec128) -> Vec128 {
unimplemented!()
}

#[inline(always)]
pub fn mm256_cmpgt_epi16(lhs: Vec256, rhs: Vec256) -> Vec256 {
unimplemented!()
}
#[inline(always)]
pub fn mm256_cmpgt_epi32(lhs: Vec256, rhs: Vec256) -> Vec256 {
unimplemented!()
}
#[inline(always)]
pub fn mm256_cmpeq_epi32(a: Vec256, b: Vec256) -> Vec256 {
unimplemented!()
}

#[inline(always)]
pub fn mm256_sign_epi32(a: Vec256, b: Vec256) -> Vec256 {
unimplemented!()
}

#[inline(always)]
pub fn mm256_castsi256_ps(a: Vec256) -> Vec256Float {
unimplemented!()
}

#[inline(always)]
pub fn mm256_movemask_ps(a: Vec256Float) -> i32 {
unimplemented!()
}

pub fn mm_mulhi_epi16(lhs: Vec128, rhs: Vec128) -> Vec128 {
unimplemented!()
Expand All @@ -194,10 +259,25 @@ pub fn mm256_mul_epu32(lhs: Vec256, rhs: Vec256) -> Vec256 {
unimplemented!()
}

#[inline(always)]
pub fn mm256_mul_epi32(lhs: Vec256, rhs: Vec256) -> Vec256 {
unimplemented!()
}

#[inline(always)]
pub fn mm256_and_si256(lhs: Vec256, rhs: Vec256) -> Vec256 {
unimplemented!()
}

#[inline(always)]
pub fn mm256_or_si256(a: Vec256, b: Vec256) -> Vec256 {
unimplemented!()
}

pub fn mm256_testz_si256(lhs: Vec256, rhs: Vec256) -> i32 {
unimplemented!()
}

pub fn mm256_xor_si256(lhs: Vec256, rhs: Vec256) -> Vec256 {
unimplemented!()
}
Expand All @@ -220,6 +300,10 @@ pub fn mm256_srli_epi32<const SHIFT_BY: i32>(vector: Vec256) -> Vec256 {
unimplemented!()
}

pub fn mm_srli_epi64<const SHIFT_BY: i32>(vector: Vec128) -> Vec128 {
debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 64);
unimplemented!()
}
pub fn mm256_srli_epi64<const SHIFT_BY: i32>(vector: Vec256) -> Vec256 {
debug_assert!(SHIFT_BY >= 0 && SHIFT_BY < 64);
unimplemented!()
Expand Down Expand Up @@ -291,19 +375,47 @@ pub fn mm256_inserti128_si256<const CONTROL: i32>(vector: Vec256, vector_i128: V
unimplemented!()
}

#[inline(always)]
pub fn mm256_blend_epi16<const CONTROL: i32>(lhs: Vec256, rhs: Vec256) -> Vec256 {
debug_assert!(CONTROL >= 0 && CONTROL < 256);
unimplemented!()
}

#[inline(always)]
pub fn mm256_blend_epi32<const CONTROL: i32>(lhs: Vec256, rhs: Vec256) -> Vec256 {
debug_assert!(CONTROL >= 0 && CONTROL < 256);
unimplemented!()
}

// This is essentially _mm256_blendv_ps adapted for use with the Vec256 type.
// It is not offered by the AVX2 instruction set.
#[inline(always)]
pub fn vec256_blendv_epi32(a: Vec256, b: Vec256, mask: Vec256) -> Vec256 {
unimplemented!()
}

#[inline(always)]
pub fn mm_movemask_epi8(vector: Vec128) -> i32 {
unimplemented!()
}

#[inline(always)]
pub fn mm256_permutevar8x32_epi32(vector: Vec256, control: Vec256) -> Vec256 {
unimplemented!()
}

#[inline(always)]
pub fn mm256_srlv_epi32(vector: Vec256, counts: Vec256) -> Vec256 {
unimplemented!()
}
#[inline(always)]
pub fn mm256_srlv_epi64(vector: Vec256, counts: Vec256) -> Vec256 {
unimplemented!()
}

pub fn mm_sllv_epi32(vector: Vec128, counts: Vec128) -> Vec128 {
unimplemented!()
}
pub fn mm256_sllv_epi32(vector: Vec256, counts: Vec256) -> Vec256 {
unimplemented!()
}
Expand All @@ -313,6 +425,12 @@ pub fn mm256_slli_epi64<const LEFT: i32>(x: Vec256) -> Vec256 {
unimplemented!()
}

#[inline(always)]
pub fn mm256_bsrli_epi128<const SHIFT_BY: i32>(x: Vec256) -> Vec256 {
debug_assert!(SHIFT_BY > 0 && SHIFT_BY < 16);
unimplemented!()
}

#[inline(always)]
pub fn mm256_andnot_si256(a: Vec256, b: Vec256) -> Vec256 {
unimplemented!()
Expand All @@ -322,6 +440,10 @@ pub fn mm256_andnot_si256(a: Vec256, b: Vec256) -> Vec256 {
pub fn mm256_set1_epi64x(a: i64) -> Vec256 {
unimplemented!()
}
#[inline(always)]
pub fn mm256_set_epi64x(input3: i64, input2: i64, input1: i64, input0: i64) -> Vec256 {
unimplemented!()
}

#[inline(always)]
pub fn mm256_unpacklo_epi64(a: Vec256, b: Vec256) -> Vec256 {
Expand Down
5 changes: 3 additions & 2 deletions libcrux-ml-dsa/examples/verify_65.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ fn main() {
let message = random_array::<1023>();

let keypair = ml_dsa_65::generate_key_pair(key_generation_seed);
let signature = ml_dsa_65::sign(&keypair.signing_key, &message, signing_randomness);
let signature = ml_dsa_65::sign(&keypair.signing_key, &message, signing_randomness)
.expect("Rejection sampling failure probability is < 2⁻¹²⁸");

for _i in 0..100_000 {
ml_dsa_65::verify(&keypair.verification_key, &message, &signature).unwrap();
let _ = ml_dsa_65::verify(&keypair.verification_key, &message, &signature);
}
}
1 change: 1 addition & 0 deletions libcrux-ml-dsa/src/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ pub(crate) const MESSAGE_REPRESENTATIVE_SIZE: usize = 64;
pub(crate) const MASK_SEED_SIZE: usize = 64;

pub(crate) const VERIFIER_CHALLENGE_SEED_SIZE: usize = 32;
pub(crate) const REJECTION_SAMPLE_BOUND: usize = 576;
2 changes: 1 addition & 1 deletion libcrux-ml-dsa/src/encoding/commitment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ fn serialize<SIMDUnit: Operations, const OUTPUT_SIZE: usize>(
) -> [u8; OUTPUT_SIZE] {
let mut serialized = [0u8; OUTPUT_SIZE];

match OUTPUT_SIZE {
match OUTPUT_SIZE as u8 {
128 => {
// The commitment has coefficients in [0,15] => each coefficient occupies
// 4 bits. Each SIMD unit contains 8 elements, which means each
Expand Down
4 changes: 2 additions & 2 deletions libcrux-ml-dsa/src/encoding/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pub(crate) fn serialize<SIMDUnit: Operations, const ETA: usize, const OUTPUT_SIZ
) -> [u8; OUTPUT_SIZE] {
let mut serialized = [0u8; OUTPUT_SIZE];

match ETA {
match ETA as u8 {
2 => {
const OUTPUT_BYTES_PER_SIMD_UNIT: usize = 3;

Expand Down Expand Up @@ -41,7 +41,7 @@ pub(crate) fn serialize<SIMDUnit: Operations, const ETA: usize, const OUTPUT_SIZ
fn deserialize<SIMDUnit: Operations, const ETA: usize>(
serialized: &[u8],
) -> PolynomialRingElement<SIMDUnit> {
let mut serialized_chunks = match ETA {
let mut serialized_chunks = match ETA as u8 {
2 => serialized.chunks(3),
4 => serialized.chunks(4),
_ => unreachable!(),
Expand Down
4 changes: 2 additions & 2 deletions libcrux-ml-dsa/src/encoding/gamma1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub(crate) fn serialize<
) -> [u8; OUTPUT_BYTES] {
let mut serialized = [0u8; OUTPUT_BYTES];

match GAMMA1_EXPONENT {
match GAMMA1_EXPONENT as u8 {
17 => {
const OUTPUT_BYTES_PER_SIMD_UNIT: usize = 18;

Expand Down Expand Up @@ -43,7 +43,7 @@ pub(crate) fn serialize<
pub(crate) fn deserialize<SIMDUnit: Operations, const GAMMA1_EXPONENT: usize>(
serialized: &[u8],
) -> PolynomialRingElement<SIMDUnit> {
let mut serialized_chunks = match GAMMA1_EXPONENT {
let mut serialized_chunks = match GAMMA1_EXPONENT as u8 {
17 => serialized.chunks(18),
19 => serialized.chunks(20),
_ => unreachable!(),
Expand Down
Loading
Loading