Skip to content

Commit

Permalink
enable all
Browse files Browse the repository at this point in the history
  • Loading branch information
eddyxu committed Sep 15, 2024
1 parent 651135f commit dff7c10
Showing 1 changed file with 21 additions and 19 deletions.
40 changes: 21 additions & 19 deletions rust/lance-linalg/src/distance/dot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ use arrow_array::{cast::AsArray, types::Float32Type, Array, FixedSizeListArray,
use arrow_schema::DataType;
use half::{bf16, f16};
use lance_arrow::{ArrowFloatType, FloatArray};
#[cfg(feature = "fp16kernels")]
use lance_core::utils::cpu::SimdSupport;

Check warning on line 16 in rust/lance-linalg/src/distance/dot.rs

View workflow job for this annotation

GitHub Actions / linux-arm

unused import: `lance_core::utils::cpu::SimdSupport`
use lance_core::utils::cpu::FP16_SIMD_SUPPORT;
use num_traits::{real::Real, AsPrimitive, Num};
Expand Down Expand Up @@ -85,7 +84,6 @@ impl Dot for bf16 {
}
}

#[cfg(feature = "fp16kernels")]
mod kernel {
use super::*;

Expand All @@ -94,7 +92,7 @@ mod kernel {
extern "C" {
#[cfg(target_arch = "aarch64")]
pub fn dot_f16_neon(ptr1: *const f16, ptr2: *const f16, len: u32) -> f32;
#[cfg(all(kernel_support = "avx512", target_arch = "x86_64"))]
#[cfg(target_arch = "x86_64")]
pub fn dot_f16_avx512(ptr1: *const f16, ptr2: *const f16, len: u32) -> f32;
#[cfg(target_arch = "x86_64")]
pub fn dot_f16_avx2(ptr1: *const f16, ptr2: *const f16, len: u32) -> f32;
Expand All @@ -108,23 +106,27 @@ mod kernel {
impl Dot for f16 {
#[inline]
fn dot(x: &[Self], y: &[Self]) -> f32 {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512fp16") {
return unsafe { kernel::dot_f16_avx512(x.as_ptr(), y.as_ptr(), x.len() as u32) };
}
if is_x86_feature_detected!("avx2") {
return unsafe { kernel::dot_f16_avx2(x.as_ptr(), y.as_ptr(), x.len() as u32) };
}
}

#[cfg(target_arch = "aarch64")]
{
// TODO: add SVE
if std::arch::is_aarch64_feature_detected!("neon")
&& std::arch::is_aarch64_feature_detected!("fp16")
{
return unsafe { kernel::dot_f16_neon(x.as_ptr(), y.as_ptr(), x.len() as u32) };
}
}

match *FP16_SIMD_SUPPORT {
// #[cfg(all(feature = "fp16kernels", target_arch = "aarch64"))]
// SimdSupport::Neon => unsafe {
// kernel::dot_f16_neon(x.as_ptr(), y.as_ptr(), x.len() as u32)
// },
#[cfg(all(
feature = "fp16kernels",
kernel_support = "avx512",
target_arch = "x86_64"
))]
SimdSupport::Avx512 => unsafe {
kernel::dot_f16_avx512(x.as_ptr(), y.as_ptr(), x.len() as u32)
},
#[cfg(all(feature = "fp16kernels", target_arch = "x86_64"))]
SimdSupport::Avx2 => unsafe {
kernel::dot_f16_avx2(x.as_ptr(), y.as_ptr(), x.len() as u32)
},
#[cfg(all(feature = "fp16kernels", target_arch = "loongarch64"))]
SimdSupport::Lasx => unsafe {
kernel::dot_f16_lasx(x.as_ptr(), y.as_ptr(), x.len() as u32)
Expand Down

0 comments on commit dff7c10

Please sign in to comment.