Skip to content

Commit

Permalink
portable rejection sampling for neon for now
Browse files Browse the repository at this point in the history
  • Loading branch information
franziskuskiefer committed May 23, 2024
1 parent 0c67a85 commit 76ab582
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 17 deletions.
14 changes: 13 additions & 1 deletion .github/workflows/mlkem.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ jobs:
if: ${{ matrix.os == 'macos-latest' }}
run: RUSTDOCFLAGS=-Zsanitizer=address RUSTFLAGS=-Zsanitizer=address cargo +nightly test --release --target aarch64-apple-darwin

# - name: ⬆ Upload build
# uses: ./.github/actions/upload_artifacts
# with:
# name: build_${{ matrix.os }}_${{ matrix.bits }}

# We get false positives here.
# TODO: Figure out what is going on here
# - name: 🏃🏻 Asan Linux
Expand Down Expand Up @@ -123,7 +128,6 @@ jobs:
cargo test --verbose $RUST_TARGET_FLAG
- name: 🏃🏻‍♀️ Test Release
if: ${{ matrix.os != 'macos-latest' }}
run: |
cargo clean
cargo test --verbose --release $RUST_TARGET_FLAG
Expand Down Expand Up @@ -195,6 +199,14 @@ jobs:
echo "RUST_TARGET_FLAG=--target=i686-unknown-linux-gnu" > $GITHUB_ENV
if: ${{ matrix.bits == 32 && matrix.os == 'ubuntu-latest' }}

# - name: 🔨 Build
# run: cargo build --benches

# - name: ⬆ Upload build
# uses: ./.github/actions/upload_artifacts
# with:
# name: benchmarks_${{ matrix.os }}_${{ matrix.bits }}

# Benchmarks ...

- name: 🏃🏻‍♀️ Benchmarks
Expand Down
30 changes: 28 additions & 2 deletions polynomials-aarch64/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use libcrux_traits::Operations;

mod neon;
mod rejsample;
// mod rejsample;
mod simd128ops;

pub use simd128ops::SIMD128Vector;
Expand Down Expand Up @@ -158,6 +158,32 @@ impl Operations for SIMD128Vector {
}

fn rej_sample(a: &[u8], out: &mut [i16]) -> usize {
rejsample::rej_sample(a, out)
// FIXME: The code in rejsample fails on the CI machines.
// We need to understand why and fix it before using it.
// We use the portable version in the meantime.
rej_sample(a, out)
}
}

#[inline(always)]
pub(crate) fn rej_sample(a: &[u8], result: &mut [i16]) -> usize {
let mut sampled = 0;
for bytes in a.chunks(3) {
let b1 = bytes[0] as i16;
let b2 = bytes[1] as i16;
let b3 = bytes[2] as i16;

let d1 = ((b2 & 0xF) << 8) | b1;
let d2 = (b3 << 4) | (b2 >> 4);

if d1 < libcrux_traits::FIELD_MODULUS && sampled < 16 {
result[sampled] = d1;
sampled += 1
}
if d2 < libcrux_traits::FIELD_MODULUS && sampled < 16 {
result[sampled] = d2;
sampled += 1
}
}
sampled
}
10 changes: 5 additions & 5 deletions polynomials-aarch64/src/neon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ pub(crate) fn _vdupq_n_s16(i: i16) -> int16x8_t {

#[inline(always)]
pub(crate) fn _vst1q_s16(out: &mut [i16], v: int16x8_t) {
unsafe { vst1q_s16(out.as_mut_ptr() as *mut i16, v) }
unsafe { vst1q_s16(out.as_mut_ptr(), v) }
}

#[inline(always)]
pub(crate) fn _vld1q_s16(array: &[i16]) -> int16x8_t {
unsafe { vld1q_s16(array.as_ptr() as *const i16) }
unsafe { vld1q_s16(array.as_ptr()) }
}

#[inline(always)]
Expand Down Expand Up @@ -193,7 +193,7 @@ pub(crate) fn _vmlal_high_s16(a: int32x4_t, b: int16x8_t, c: int16x8_t) -> int32
}
#[inline(always)]
pub(crate) fn _vld1q_u8(ptr: &[u8]) -> uint8x16_t {
unsafe { vld1q_u8(ptr.as_ptr() as *const u8) }
unsafe { vld1q_u8(ptr.as_ptr()) }
}
#[inline(always)]
pub(crate) fn _vreinterpretq_u8_s16(a: int16x8_t) -> uint8x16_t {
Expand Down Expand Up @@ -254,7 +254,7 @@ pub(super) fn _vreinterpretq_u8_s64(a: int64x2_t) -> uint8x16_t {

#[inline(always)]
pub(super) fn _vst1q_u8(out: &mut [u8], v: uint8x16_t) {
unsafe { vst1q_u8(out.as_mut_ptr() as *mut u8, v) }
unsafe { vst1q_u8(out.as_mut_ptr(), v) }
}
#[inline(always)]
pub(crate) fn _vdupq_n_u16(value: u16) -> uint16x8_t {
Expand All @@ -270,7 +270,7 @@ pub(crate) fn _vreinterpretq_u16_u8(a: uint8x16_t) -> uint16x8_t {
}
#[inline(always)]
pub(crate) fn _vld1q_u16(ptr: &[u16]) -> uint16x8_t {
unsafe { vld1q_u16(ptr.as_ptr() as *const u16) }
unsafe { vld1q_u16(ptr.as_ptr()) }
}
#[inline(always)]
pub(crate) fn _vcleq_s16(a: int16x8_t, b: int16x8_t) -> uint16x8_t {
Expand Down
13 changes: 8 additions & 5 deletions polynomials-aarch64/src/rejsample.rs
Original file line number Diff line number Diff line change
Expand Up @@ -776,18 +776,21 @@ pub(crate) fn rej_sample(a: &[u8], out: &mut [i16]) -> usize {
let input = super::simd128ops::deserialize_12(a);
let mask0 = _vcleq_s16(input.low, fm);
let mask1 = _vcleq_s16(input.high, fm);
let used0 = _vaddvq_u16(_vandq_u16(mask0, bits));
let used1 = _vaddvq_u16(_vandq_u16(mask1, bits));
let masked = _vandq_u16(mask0, bits);
let used0 = _vaddvq_u16(masked);
let masked = _vandq_u16(mask1, bits);
let used1 = _vaddvq_u16(masked);
let pick0 = used0.count_ones();
let pick1 = used1.count_ones();

let index_vec0 = _vld1q_u8(&IDX_TABLE[used0 as usize]);
// XXX: the indices used0 and used1 must be < 256.
let index_vec0 = _vld1q_u8(&IDX_TABLE[(used0 as u8) as usize]);
let shifted0 = _vreinterpretq_s16_u8(_vqtbl1q_u8(_vreinterpretq_u8_s16(input.low), index_vec0));
let index_vec1 = _vld1q_u8(&IDX_TABLE[used1 as usize]);
let index_vec1 = _vld1q_u8(&IDX_TABLE[(used1 as u8) as usize]);
let shifted1 =
_vreinterpretq_s16_u8(_vqtbl1q_u8(_vreinterpretq_u8_s16(input.high), index_vec1));

let idx0 = pick0 as usize;
let idx0 = usize::try_from(pick0).unwrap();
_vst1q_s16(&mut out[0..8], shifted0);
_vst1q_s16(&mut out[idx0..idx0 + 8], shifted1);
(pick0 + pick1) as usize
Expand Down
5 changes: 1 addition & 4 deletions polynomials-aarch64/src/simd128ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -848,8 +848,5 @@ pub(crate) fn deserialize_12(v: &[u8]) -> SIMD128Vector {
let shifted1 = _vshlq_u16(moved1, shift_vec);
let high = _vreinterpretq_s16_u16(_vandq_u16(shifted1, mask12));

SIMD128Vector {
low: low,
high: high,
}
SIMD128Vector { low, high }
}

0 comments on commit 76ab582

Please sign in to comment.