Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

libjade avx2 bench #266

Merged
merged 2 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,6 @@ wasm-bindgen = { version = "0.2.87", optional = true }
hax-lib-macros = { version = "0.1.0-pre.1", git = "https://github.com/hacspec/hax", branch = "main" }
hax-lib = { version = "0.1.0-pre.1", git = "https://github.com/hacspec/hax/", branch = "main" }

[target.'cfg(all(not(target_os = "windows"), target_arch = "x86_64", libjade))'.dependencies]
libjade-sys = { version = "=0.0.2-pre.2", path = "sys/libjade" }

[dev-dependencies]
libcrux = { path = ".", features = ["rand", "tests"] }
pretty_env_logger = "0.5"
Expand Down
3 changes: 3 additions & 0 deletions benchmarks/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ pqcrypto-kyber = { version = "0.8.0" }
[target.'cfg(all(not(windows), not(target_arch = "wasm32"), not(target_arch = "x86")))'.dev-dependencies]
openssl = "0.10"

[target.'cfg(all(not(target_os = "windows"), target_arch = "x86_64"))'.dev-dependencies]
libjade-sys = { version = "=0.0.2-pre.2", path = "../sys/libjade" }

[[bench]]
name = "sha2"
harness = false
Expand Down
104 changes: 104 additions & 0 deletions benchmarks/benches/kyber768.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,23 @@ pub fn comparisons_key_generation(c: &mut Criterion) {
let (_public_key, _secret_key) = pqcrypto_kyber::kyber768::keypair();
})
});

#[cfg(all(not(target_os = "windows"), target_arch = "x86_64"))]
group.bench_function("libjade kyber avx2", |b| {
b.iter(|| {
let mut seed = [0; 64];
rng.fill_bytes(&mut seed);
let mut public_key = [0u8; 1184];
let mut secret_key = [0u8; 2400];
unsafe {
libjade_sys::jade_kem_kyber_kyber768_amd64_avx2_keypair_derand(
public_key.as_mut_ptr(),
secret_key.as_mut_ptr(),
seed.as_ptr(),
)
};
})
});
}

pub fn comparisons_pk_validation(c: &mut Criterion) {
Expand Down Expand Up @@ -129,6 +146,45 @@ pub fn comparisons_encapsulation(c: &mut Criterion) {
BatchSize::SmallInput,
)
});

#[cfg(all(not(target_os = "windows"), target_arch = "x86_64"))]
group.bench_function("libjade kyber avx2", |b| {
b.iter_batched(
|| {
let mut rng = OsRng;
let mut seed = [0; 64];
rng.fill_bytes(&mut seed);
let mut public_key = [0u8; 1184];
let mut secret_key = [0u8; 2400];
unsafe {
libjade_sys::jade_kem_kyber_kyber768_amd64_avx2_keypair_derand(
public_key.as_mut_ptr(),
secret_key.as_mut_ptr(),
seed.as_ptr(),
)
};

(rng, public_key)
},
|(mut rng, public_key)| {
let mut seed = [0; 32];
rng.fill_bytes(&mut seed);

let mut ciphertext = [0u8; 1088];
let mut shared_secret = [0u8; 32];

unsafe {
libjade_sys::jade_kem_kyber_kyber768_amd64_avx2_enc_derand(
ciphertext.as_mut_ptr(),
shared_secret.as_mut_ptr(),
public_key.as_ptr(),
seed.as_ptr(),
);
}
},
BatchSize::SmallInput,
)
});
}

pub fn comparisons_decapsulation(c: &mut Criterion) {
Expand Down Expand Up @@ -167,6 +223,54 @@ pub fn comparisons_decapsulation(c: &mut Criterion) {
BatchSize::SmallInput,
)
});

#[cfg(all(not(target_os = "windows"), target_arch = "x86_64"))]
group.bench_function("libjade kyber avx2", |b| {
b.iter_batched(
|| {
let mut rng = OsRng;
let mut seed = [0; 64];
rng.fill_bytes(&mut seed);
let mut public_key = [0u8; 1184];
let mut secret_key = [0u8; 2400];
unsafe {
libjade_sys::jade_kem_kyber_kyber768_amd64_avx2_keypair_derand(
public_key.as_mut_ptr(),
secret_key.as_mut_ptr(),
seed.as_ptr(),
)
};

let mut seed = [0; 32];
rng.fill_bytes(&mut seed);

let mut ciphertext = [0u8; 1088];
let mut shared_secret = [0u8; 32];

unsafe {
libjade_sys::jade_kem_kyber_kyber768_amd64_avx2_enc_derand(
ciphertext.as_mut_ptr(),
shared_secret.as_mut_ptr(),
public_key.as_ptr(),
seed.as_ptr(),
);
}

(secret_key, ciphertext)
},
|(secret_key, ciphertext)| {
let mut shared_secret = [0u8; 32];
unsafe {
libjade_sys::jade_kem_kyber_kyber768_amd64_avx2_dec(
shared_secret.as_mut_ptr(),
ciphertext.as_ptr(),
secret_key.as_ptr(),
);
}
},
BatchSize::SmallInput,
)
});
}

pub fn comparisons(c: &mut Criterion) {
Expand Down
6 changes: 5 additions & 1 deletion sys/libjade/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@ fn create_bindings(platform: Platform, home_dir: &Path) {
.allowlist_var("JADE_STREAM_CHACHA_CHACHA20_.*")
.allowlist_function("jade_kem_kyber_kyber768_.*")
.allowlist_var("JADE_KEM_KYBER_KYBER768_.*")
.allowlist_function("randombytes")
.allowlist_function("__jasmin_syscall_randombytes__")
// Block everything we don't need or define ourselves.
.blocklist_type("__.*")
// Disable tests to avoid warnings and keep it portable
.layout_tests(false)
// Generate bindings
.parse_callbacks(Box::new(bindgen::CargoCallbacks))
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
.use_core()
.generate()
.expect("Unable to generate bindings");
Expand Down Expand Up @@ -91,6 +93,7 @@ fn build(platform: Platform, home_path: &Path) {
"chacha20_ref.s",
"poly1305_ref.s",
"kyber_kyber768_ref.s",
"randombytes.c",
];
compile_files("jade", &files, home_path, &args);

Expand All @@ -102,6 +105,7 @@ fn build(platform: Platform, home_path: &Path) {
"sha3_512_avx2.s",
"chacha20_avx2.s",
"poly1305_avx2.s",
"kyber_kyber768_avx2.s",
];

let mut simd256_flags = args.clone();
Expand Down
56 changes: 56 additions & 0 deletions sys/libjade/jazz/include/kyber_kyber768_avx2.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#ifndef JADE_KEM_kyber_kyber768_amd64_avx2_API_H
#define JADE_KEM_kyber_kyber768_amd64_avx2_API_H

#if defined(__cplusplus)
extern "C"
{
#endif

#include <stdint.h>

#define JADE_KEM_kyber_kyber768_amd64_avx2_SECRETKEYBYTES 2400
#define JADE_KEM_kyber_kyber768_amd64_avx2_PUBLICKEYBYTES 1184
#define JADE_KEM_kyber_kyber768_amd64_avx2_CIPHERTEXTBYTES 1088
#define JADE_KEM_kyber_kyber768_amd64_avx2_KEYPAIRCOINBYTES 64
#define JADE_KEM_kyber_kyber768_amd64_avx2_ENCCOINBYTES 32
#define JADE_KEM_kyber_kyber768_amd64_avx2_BYTES 32

#define JADE_KEM_kyber_kyber768_amd64_avx2_ALGNAME "Kyber768"
#define JADE_KEM_kyber_kyber768_amd64_avx2_ARCH "amd64"
#define JADE_KEM_kyber_kyber768_amd64_avx2_IMPL "avx2"

int jade_kem_kyber_kyber768_amd64_avx2_keypair_derand(
uint8_t *public_key,
uint8_t *secret_key,
const uint8_t *coins
);

int jade_kem_kyber_kyber768_amd64_avx2_keypair(
uint8_t *public_key,
uint8_t *secret_key
);

int jade_kem_kyber_kyber768_amd64_avx2_enc_derand(
uint8_t *ciphertext,
uint8_t *shared_secret,
const uint8_t *public_key,
const uint8_t *coins
);

int jade_kem_kyber_kyber768_amd64_avx2_enc(
uint8_t *ciphertext,
uint8_t *shared_secret,
const uint8_t *public_key
);

int jade_kem_kyber_kyber768_amd64_avx2_dec(
uint8_t *shared_secret,
const uint8_t *ciphertext,
const uint8_t *secret_key
);

#if defined(__cplusplus)
}
#endif

#endif
2 changes: 2 additions & 0 deletions sys/libjade/jazz/include/libjade.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "randombytes.h"
#include "sha256.h"
#include "x25519_ref.h"
#include "x25519_mulx.h"
Expand All @@ -18,6 +19,7 @@
#include "sha3_512_avx2.h"
#include "poly1305_avx2.h"
#include "chacha20_avx2.h"
#include "kyber_kyber768_avx2.h"
#endif

#ifdef SIMD128
Expand Down
9 changes: 9 additions & 0 deletions sys/libjade/jazz/include/randombytes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#ifndef LIBJADE_RANDOMBYTES_H
#define LIBJADE_RANDOMBYTES_H

#include <stdint.h>

uint8_t* __jasmin_syscall_randombytes__(uint8_t* _x, uint64_t xlen) __asm__("__jasmin_syscall_randombytes__");
void randombytes(uint8_t* _x, uint64_t xlen);

#endif
Loading
Loading