Skip to content

Commit

Permalink
add aarch64 to CI matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
llogiq committed Sep 17, 2023
1 parent 8fd5e20 commit 0a50aaa
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 29 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ jobs:
arch:
- i686
- x86_64
- aarch64
features:
- default
- runtime-dispatch-simd
Expand Down
103 changes: 74 additions & 29 deletions src/simd/aarch64.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use core::arch::aarch64::{
uint8x16_t, vaddlvq_u8, vandq_u8, vceqq_u8, vcgtq_u8, vdupq_n_u8, vld1q_u8, vmvnq_u8, vsubq_u8,
uint8x16_t, uint8x16x4_t, vaddlvq_u8, vandq_u8, vceqq_u8, vdupq_n_u8, vld1q_u8, vld1q_u8_x4,
vmvnq_u8, vsubq_u8,
};

const MASK: [u8; 32] = [
Expand All @@ -9,12 +10,29 @@ const MASK: [u8; 32] = [

#[target_feature(enable = "neon")]
unsafe fn u8x16_from_offset(slice: &[u8], offset: usize) -> uint8x16_t {
debug_assert!(
offset + 16 <= slice.len(),
"{} + 16 ≥ {}",
offset,
slice.len()
);
vld1q_u8(slice.as_ptr().add(offset) as *const _) // TODO: does this need to be aligned?
}

#[target_feature(enable = "neon")]
unsafe fn sum(u8s: &uint8x16_t) -> usize {
vaddlvq_u8(*u8s) as usize
unsafe fn u8x16_x4_from_offset(slice: &[u8], offset: usize) -> uint8x16x4_t {
debug_assert!(
offset + 64 <= slice.len(),
"{} + 64 ≥ {}",
offset,
slice.len()
);
vld1q_u8_x4(slice.as_ptr().add(offset) as *const _)
}

#[target_feature(enable = "neon")]
unsafe fn sum(u8s: uint8x16_t) -> usize {
vaddlvq_u8(u8s) as usize
}

#[target_feature(enable = "neon")]
Expand All @@ -26,38 +44,67 @@ pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize {

let needles = vdupq_n_u8(needle);

// 4080
while haystack.len() >= offset + 16 * 255 {
let mut counts = vdupq_n_u8(0);
// 16320
while haystack.len() >= offset + 64 * 255 {
let (mut count1, mut count2, mut count3, mut count4) =
(vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0));
for _ in 0..255 {
counts = vsubq_u8(
counts,
vceqq_u8(u8x16_from_offset(haystack, offset), needles),
let uint8x16x4_t(h1, h2, h3, h4) = u8x16_x4_from_offset(haystack, offset);
let (eq1, eq2, eq3, eq4) = (
vceqq_u8(h1, needles),
vceqq_u8(h2, needles),
vceqq_u8(h3, needles),
vceqq_u8(h4, needles),
);
offset += 16;
count1 = vsubq_u8(count1, eq1);
count2 = vsubq_u8(count2, eq2);
count3 = vsubq_u8(count3, eq3);
count4 = vsubq_u8(count4, eq4);
offset += 64;
}
count += sum(&counts);
count += sum(count1) + sum(count2) + sum(count3) + sum(count4);
}

// 2048
if haystack.len() >= offset + 16 * 128 {
let mut counts = vdupq_n_u8(0);
for _ in 0..128 {
counts = vsubq_u8(
counts,
vceqq_u8(u8x16_from_offset(haystack, offset), needles),
// 4032
if haystack.len() >= offset + 64 * 63 {
let (mut count1, mut count2, mut count3, mut count4) =
(vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0));
for _ in 0..63 {
let uint8x16x4_t(h1, h2, h3, h4) = u8x16_x4_from_offset(haystack, offset);
let (eq1, eq2, eq3, eq4) = (
vceqq_u8(h1, needles),
vceqq_u8(h2, needles),
vceqq_u8(h3, needles),
vceqq_u8(h4, needles),
);
offset += 16;
count1 = vsubq_u8(count1, eq1);
count2 = vsubq_u8(count2, eq2);
count3 = vsubq_u8(count3, eq3);
count4 = vsubq_u8(count4, eq4);
offset += 64;
}
count += sum(&counts);
count += sum(count1) + sum(count2) + sum(count3) + sum(count4);
}

// 64
while haystack.len() >= offset + 64 {
let (mut count1, mut count2, mut count3, mut count4) =
(vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0), vdupq_n_u8(0));
let uint8x16x4_t(h1, h2, h3, h4) = u8x16_x4_from_offset(haystack, offset);
count1 = vsubq_u8(count1, vceqq_u8(h1, needles));
count2 = vsubq_u8(count2, vceqq_u8(h2, needles));
count3 = vsubq_u8(count3, vceqq_u8(h3, needles));
count4 = vsubq_u8(count4, vceqq_u8(h4, needles));
offset += 64;
count += sum(count1) + sum(count2) + sum(count3) + sum(count4);
}

// 16
let mut counts = vdupq_n_u8(0);
// 16
for i in 0..(haystack.len() - offset) / 16 {
counts = vsubq_u8(
counts,
vcgtq_u8(u8x16_from_offset(haystack, offset + i * 32), needles),
vceqq_u8(u8x16_from_offset(haystack, offset + i * 16), needles),
);
}
if haystack.len() % 16 != 0 {
Expand All @@ -69,9 +116,7 @@ pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize {
),
);
}
count += sum(&counts);

count
count + sum(counts)
}

#[target_feature(enable = "neon")]
Expand Down Expand Up @@ -100,7 +145,7 @@ pub unsafe fn chunk_num_chars(utf8_chars: &[u8]) -> usize {
);
offset += 16;
}
count += sum(&counts);
count += sum(counts);
}

// 2048
Expand All @@ -113,15 +158,15 @@ pub unsafe fn chunk_num_chars(utf8_chars: &[u8]) -> usize {
);
offset += 16;
}
count += sum(&counts);
count += sum(counts);
}

// 16
let mut counts = vdupq_n_u8(0);
for i in 0..(utf8_chars.len() - offset) / 16 {
counts = vsubq_u8(
counts,
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset + i * 32)),
is_leading_utf8_byte(u8x16_from_offset(utf8_chars, offset + i * 16)),
);
}
if utf8_chars.len() % 16 != 0 {
Expand All @@ -133,7 +178,7 @@ pub unsafe fn chunk_num_chars(utf8_chars: &[u8]) -> usize {
),
);
}
count += sum(&counts);
count += sum(counts);

count
}

0 comments on commit 0a50aaa

Please sign in to comment.