Skip to content

Commit

Permalink
refactor and make it extract
Browse files Browse the repository at this point in the history
  • Loading branch information
franziskuskiefer committed Mar 18, 2024
1 parent 49c50ae commit baf808e
Show file tree
Hide file tree
Showing 8 changed files with 463 additions and 364 deletions.
2 changes: 1 addition & 1 deletion kyber-crate.sh
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ rm src/kyber512.rs
rm src/kyber1024.rs

# Build & test
# cargo test
cargo test

# Extract
if [[ -z "$CHARON_HOME" ]]; then
Expand Down
435 changes: 239 additions & 196 deletions proofs/fstar/extraction-edited.patch

Large diffs are not rendered by default.

128 changes: 64 additions & 64 deletions proofs/fstar/extraction-secret-independent.patch

Large diffs are not rendered by default.

193 changes: 116 additions & 77 deletions proofs/fstar/extraction/Libcrux.Kem.Kyber.Serialize.fst
Original file line number Diff line number Diff line change
Expand Up @@ -666,83 +666,6 @@ let compress_then_serialize_ring_element_v
<:
Rust_primitives.Hax.t_Never)

let deserialize_ring_elementes_reduced (v_PUBLIC_KEY_SIZE v_K: usize) (public_key: t_Slice u8) =
let deserialized_pk:t_Array Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement v_K =
Rust_primitives.Hax.repeat Libcrux.Kem.Kyber.Arithmetic.impl__PolynomialRingElement__ZERO v_K
in
let deserialized_pk:t_Array Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement v_K =
Core.Iter.Traits.Iterator.f_fold (Core.Iter.Traits.Collect.f_into_iter (Core.Iter.Traits.Iterator.f_enumerate
(Core.Slice.impl__chunks_exact public_key
Libcrux.Kem.Kyber.Constants.v_BYTES_PER_RING_ELEMENT
<:
Core.Slice.Iter.t_ChunksExact u8)
<:
Core.Iter.Adapters.Enumerate.t_Enumerate (Core.Slice.Iter.t_ChunksExact u8))
<:
Core.Iter.Adapters.Enumerate.t_Enumerate (Core.Slice.Iter.t_ChunksExact u8))
deserialized_pk
(fun deserialized_pk temp_1_ ->
let deserialized_pk:t_Array Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement v_K =
deserialized_pk
in
let i, ring_element:(usize & t_Slice u8) = temp_1_ in
let re:Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement =
Libcrux.Kem.Kyber.Arithmetic.impl__PolynomialRingElement__ZERO
in
let re:Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement =
Core.Iter.Traits.Iterator.f_fold (Core.Iter.Traits.Collect.f_into_iter (Core.Iter.Traits.Iterator.f_enumerate
(Core.Slice.impl__chunks_exact ring_element (sz 3)
<:
Core.Slice.Iter.t_ChunksExact u8)
<:
Core.Iter.Adapters.Enumerate.t_Enumerate (Core.Slice.Iter.t_ChunksExact u8))
<:
Core.Iter.Adapters.Enumerate.t_Enumerate (Core.Slice.Iter.t_ChunksExact u8))
re
(fun re temp_1_ ->
let re:Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement = re in
let i, bytes:(usize & t_Slice u8) = temp_1_ in
let byte1:i32 = cast (bytes.[ sz 0 ] <: u8) <: i32 in
let byte2:i32 = cast (bytes.[ sz 1 ] <: u8) <: i32 in
let byte3:i32 = cast (bytes.[ sz 2 ] <: u8) <: i32 in
let re:Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement =
{
re with
Libcrux.Kem.Kyber.Arithmetic.f_coefficients
=
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize re
.Libcrux.Kem.Kyber.Arithmetic.f_coefficients
(sz 2 *! i <: usize)
((((byte2 &. 15l <: i32) <<! 8l <: i32) |. (byte1 &. 255l <: i32) <: i32) %!
Libcrux.Kem.Kyber.Constants.v_FIELD_MODULUS
<:
i32)
}
<:
Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement
in
let re:Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement =
{
re with
Libcrux.Kem.Kyber.Arithmetic.f_coefficients
=
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize re
.Libcrux.Kem.Kyber.Arithmetic.f_coefficients
((sz 2 *! i <: usize) +! sz 1 <: usize)
(((byte3 <<! 4l <: i32) |. ((byte2 >>! 4l <: i32) &. 15l <: i32) <: i32) %!
Libcrux.Kem.Kyber.Constants.v_FIELD_MODULUS
<:
i32)
}
<:
Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement
in
re)
in
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize deserialized_pk i re)
in
deserialized_pk

let deserialize_then_decompress_10_ (serialized: t_Slice u8) =
let _:Prims.unit = () <: Prims.unit in
let re:Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement =
Expand Down Expand Up @@ -1281,6 +1204,122 @@ let deserialize_then_decompress_ring_element_v
<:
Rust_primitives.Hax.t_Never)

let deserialize_to_reduced_ring_element (ring_element: t_Slice u8) =
let _:Prims.unit = () <: Prims.unit in
let re:Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement =
Libcrux.Kem.Kyber.Arithmetic.impl__PolynomialRingElement__ZERO
in
let re:Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement =
Core.Iter.Traits.Iterator.f_fold (Core.Iter.Traits.Collect.f_into_iter (Core.Iter.Traits.Iterator.f_enumerate
(Core.Slice.impl__chunks_exact ring_element (sz 3) <: Core.Slice.Iter.t_ChunksExact u8
)
<:
Core.Iter.Adapters.Enumerate.t_Enumerate (Core.Slice.Iter.t_ChunksExact u8))
<:
Core.Iter.Adapters.Enumerate.t_Enumerate (Core.Slice.Iter.t_ChunksExact u8))
re
(fun re temp_1_ ->
let re:Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement = re in
let i, bytes:(usize & t_Slice u8) = temp_1_ in
let byte1:i32 = cast (bytes.[ sz 0 ] <: u8) <: i32 in
let byte2:i32 = cast (bytes.[ sz 1 ] <: u8) <: i32 in
let byte3:i32 = cast (bytes.[ sz 2 ] <: u8) <: i32 in
let re:Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement =
{
re with
Libcrux.Kem.Kyber.Arithmetic.f_coefficients
=
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize re
.Libcrux.Kem.Kyber.Arithmetic.f_coefficients
(sz 2 *! i <: usize)
(((byte2 &. 15l <: i32) <<! 8l <: i32) |. (byte1 &. 255l <: i32) <: i32)
}
<:
Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement
in
let tmp:i32 =
(re.Libcrux.Kem.Kyber.Arithmetic.f_coefficients.[ sz 2 *! i <: usize ] <: i32) %! 3329l
in
let re:Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement =
{
re with
Libcrux.Kem.Kyber.Arithmetic.f_coefficients
=
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize re
.Libcrux.Kem.Kyber.Arithmetic.f_coefficients
(sz 2 *! i <: usize)
tmp
}
<:
Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement
in
let re:Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement =
{
re with
Libcrux.Kem.Kyber.Arithmetic.f_coefficients
=
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize re
.Libcrux.Kem.Kyber.Arithmetic.f_coefficients
((sz 2 *! i <: usize) +! sz 1 <: usize)
((byte3 <<! 4l <: i32) |. ((byte2 >>! 4l <: i32) &. 15l <: i32) <: i32)
}
<:
Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement
in
let tmp:i32 =
(re.Libcrux.Kem.Kyber.Arithmetic.f_coefficients.[ (sz 2 *! i <: usize) +! sz 1 <: usize
]
<:
i32) %!
3329l
in
let re:Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement =
{
re with
Libcrux.Kem.Kyber.Arithmetic.f_coefficients
=
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize re
.Libcrux.Kem.Kyber.Arithmetic.f_coefficients
((sz 2 *! i <: usize) +! sz 1 <: usize)
tmp
}
<:
Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement
in
re)
in
re

let deserialize_ring_elements_reduced (v_PUBLIC_KEY_SIZE v_K: usize) (public_key: t_Slice u8) =
let deserialized_pk:t_Array Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement v_K =
Rust_primitives.Hax.repeat Libcrux.Kem.Kyber.Arithmetic.impl__PolynomialRingElement__ZERO v_K
in
let deserialized_pk:t_Array Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement v_K =
Core.Iter.Traits.Iterator.f_fold (Core.Iter.Traits.Collect.f_into_iter (Core.Iter.Traits.Iterator.f_enumerate
(Core.Slice.impl__chunks_exact public_key
Libcrux.Kem.Kyber.Constants.v_BYTES_PER_RING_ELEMENT
<:
Core.Slice.Iter.t_ChunksExact u8)
<:
Core.Iter.Adapters.Enumerate.t_Enumerate (Core.Slice.Iter.t_ChunksExact u8))
<:
Core.Iter.Adapters.Enumerate.t_Enumerate (Core.Slice.Iter.t_ChunksExact u8))
deserialized_pk
(fun deserialized_pk temp_1_ ->
let deserialized_pk:t_Array Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement v_K =
deserialized_pk
in
let i, ring_element:(usize & t_Slice u8) = temp_1_ in
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize deserialized_pk
i
(deserialize_to_reduced_ring_element ring_element
<:
Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement)
<:
t_Array Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement v_K)
in
deserialized_pk

let deserialize_to_uncompressed_ring_element (serialized: t_Slice u8) =
let _:Prims.unit = () <: Prims.unit in
let re:Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement =
Expand Down
15 changes: 10 additions & 5 deletions proofs/fstar/extraction/Libcrux.Kem.Kyber.Serialize.fsti
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,6 @@ val compress_then_serialize_ring_element_v
(re: Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement)
: Prims.Pure (t_Array u8 v_OUT_LEN) Prims.l_True (fun _ -> Prims.l_True)

val deserialize_ring_elementes_reduced (v_PUBLIC_KEY_SIZE v_K: usize) (public_key: t_Slice u8)
: Prims.Pure (t_Array Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement v_K)
Prims.l_True
(fun _ -> Prims.l_True)

val deserialize_then_decompress_10_ (serialized: t_Slice u8)
: Prims.Pure Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement
Prims.l_True
Expand Down Expand Up @@ -115,6 +110,16 @@ val deserialize_then_decompress_ring_element_v
Prims.l_True
(fun _ -> Prims.l_True)

val deserialize_to_reduced_ring_element (ring_element: t_Slice u8)
: Prims.Pure Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement
Prims.l_True
(fun _ -> Prims.l_True)

val deserialize_ring_elements_reduced (v_PUBLIC_KEY_SIZE v_K: usize) (public_key: t_Slice u8)
: Prims.Pure (t_Array Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement v_K)
Prims.l_True
(fun _ -> Prims.l_True)

val deserialize_to_uncompressed_ring_element (serialized: t_Slice u8)
: Prims.Pure Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement
Prims.l_True
Expand Down
2 changes: 1 addition & 1 deletion proofs/fstar/extraction/Libcrux.Kem.Kyber.fst
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ let validate_public_key
(public_key: t_Array u8 v_PUBLIC_KEY_SIZE)
=
let deserialized_pk:t_Array Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement v_K =
Libcrux.Kem.Kyber.Serialize.deserialize_ring_elementes_reduced v_PUBLIC_KEY_SIZE
Libcrux.Kem.Kyber.Serialize.deserialize_ring_elements_reduced v_PUBLIC_KEY_SIZE
v_K
(public_key.[ { Core.Ops.Range.f_end = v_RANKED_BYTES_PER_RING_ELEMENT }
<:
Expand Down
4 changes: 2 additions & 2 deletions src/kem/kyber.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use self::{
constants::{CPA_PKE_KEY_GENERATION_SEED_SIZE, H_DIGEST_SIZE, SHARED_SECRET_SIZE},
hash_functions::{G, H, PRF},
ind_cpa::{into_padded_array, serialize_public_key},
serialize::deserialize_ring_elementes_reduced,
serialize::deserialize_ring_elements_reduced,
};

/// Seed size for key generation
Expand Down Expand Up @@ -74,7 +74,7 @@ pub(super) fn validate_public_key<
>(
public_key: &[u8; PUBLIC_KEY_SIZE],
) -> bool {
let deserialized_pk = deserialize_ring_elementes_reduced::<PUBLIC_KEY_SIZE, K>(
let deserialized_pk = deserialize_ring_elements_reduced::<PUBLIC_KEY_SIZE, K>(
&public_key[..RANKED_BYTES_PER_RING_ELEMENT],
);

Expand Down
48 changes: 30 additions & 18 deletions src/kem/kyber/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use super::{
},
constants::{BYTES_PER_RING_ELEMENT, SHARED_SECRET_SIZE},
};
use crate::cloop;
use crate::hax_utils::hax_debug_assert;
use crate::{cloop, kem::kyber::constants::FIELD_MODULUS};

#[cfg(not(hax))]
use super::constants::COEFFICIENTS_IN_RING_ELEMENT;
Expand Down Expand Up @@ -101,12 +101,39 @@ pub(super) fn deserialize_to_uncompressed_ring_element(serialized: &[u8]) -> Pol
re
}

#[inline(always)]
fn deserialize_to_reduced_ring_element(ring_element: &[u8]) -> PolynomialRingElement {
hax_debug_assert!(ring_element.len() == BYTES_PER_RING_ELEMENT);

let mut re = PolynomialRingElement::ZERO;

cloop! {
for (i, bytes) in ring_element.chunks_exact(3).enumerate() {
let byte1 = bytes[0] as FieldElement;
let byte2 = bytes[1] as FieldElement;
let byte3 = bytes[2] as FieldElement;

// The modulus here is ok because the input must be public.
// XXX: The awkward code here is necessary to work around Charon shortcomings.
re.coefficients[2 * i] = (byte2 & 0x0F) << 8 | (byte1 & 0xFF);
let tmp = re.coefficients[2 * i] % 3329; // FIELD_MODULUS
re.coefficients[2 * i] = tmp;

re.coefficients[2 * i + 1] = (byte3 << 4) | ((byte2 >> 4) & 0x0F);
let tmp = re.coefficients[2 * i + 1] % 3329; // FIELD_MODULUS
re.coefficients[2 * i + 1] = tmp;
}
}

re
}

/// This function deserializes ring elements and reduces the result by the field
/// modulus.
///
/// This function MUST NOT be used on secret inputs.
#[inline(always)]
pub(super) fn deserialize_ring_elementes_reduced<const PUBLIC_KEY_SIZE: usize, const K: usize>(
pub(super) fn deserialize_ring_elements_reduced<const PUBLIC_KEY_SIZE: usize, const K: usize>(
public_key: &[u8],
) -> [PolynomialRingElement; K] {
let mut deserialized_pk = [PolynomialRingElement::ZERO; K];
Expand All @@ -115,22 +142,7 @@ pub(super) fn deserialize_ring_elementes_reduced<const PUBLIC_KEY_SIZE: usize, c
.chunks_exact(BYTES_PER_RING_ELEMENT)
.enumerate()
{
deserialized_pk[i] = {
let mut re = PolynomialRingElement::ZERO;
cloop! {
for (i, bytes) in ring_element.chunks_exact(3).enumerate() {
let byte1 = bytes[0] as FieldElement;
let byte2 = bytes[1] as FieldElement;
let byte3 = bytes[2] as FieldElement;

// The modulus here is ok because the input must be public.
re.coefficients[2 * i] = ((byte2 & 0x0F) << 8 | (byte1 & 0xFF)) % FIELD_MODULUS;
re.coefficients[2 * i + 1] = ((byte3 << 4) | ((byte2 >> 4) & 0x0F)) % FIELD_MODULUS;
}
}

re
}
deserialized_pk[i] =deserialize_to_reduced_ring_element(ring_element);
}
}
deserialized_pk
Expand Down

0 comments on commit baf808e

Please sign in to comment.