diff --git a/rust/lance-linalg/src/distance/dot.rs b/rust/lance-linalg/src/distance/dot.rs index abf1ce2e62..7977d24478 100644 --- a/rust/lance-linalg/src/distance/dot.rs +++ b/rust/lance-linalg/src/distance/dot.rs @@ -13,7 +13,6 @@ use arrow_array::{cast::AsArray, types::Float32Type, Array, FixedSizeListArray, use arrow_schema::DataType; use half::{bf16, f16}; use lance_arrow::{ArrowFloatType, FloatArray}; -#[cfg(feature = "fp16kernels")] use lance_core::utils::cpu::SimdSupport; use lance_core::utils::cpu::FP16_SIMD_SUPPORT; use num_traits::{real::Real, AsPrimitive, Num}; @@ -85,7 +84,6 @@ impl Dot for bf16 { } } -#[cfg(feature = "fp16kernels")] mod kernel { use super::*; @@ -94,7 +92,7 @@ mod kernel { extern "C" { #[cfg(target_arch = "aarch64")] pub fn dot_f16_neon(ptr1: *const f16, ptr2: *const f16, len: u32) -> f32; - #[cfg(all(kernel_support = "avx512", target_arch = "x86_64"))] + #[cfg(target_arch = "x86_64")] pub fn dot_f16_avx512(ptr1: *const f16, ptr2: *const f16, len: u32) -> f32; #[cfg(target_arch = "x86_64")] pub fn dot_f16_avx2(ptr1: *const f16, ptr2: *const f16, len: u32) -> f32; @@ -108,23 +106,27 @@ mod kernel { impl Dot for f16 { #[inline] fn dot(x: &[Self], y: &[Self]) -> f32 { + #[cfg(target_arch = "x86_64")] + { + if is_x86_feature_detected!("avx512fp16") { + return unsafe { kernel::dot_f16_avx512(x.as_ptr(), y.as_ptr(), x.len() as u32) }; + } + if is_x86_feature_detected!("avx2") { + return unsafe { kernel::dot_f16_avx2(x.as_ptr(), y.as_ptr(), x.len() as u32) }; + } + } + + #[cfg(target_arch = "aarch64")] + { + // TODO: add SVE + if std::arch::is_aarch64_feature_detected!("neon") + && std::arch::is_aarch64_feature_detected!("fp16") + { + return unsafe { kernel::dot_f16_neon(x.as_ptr(), y.as_ptr(), x.len() as u32) }; + } + } + match *FP16_SIMD_SUPPORT { - // #[cfg(all(feature = "fp16kernels", target_arch = "aarch64"))] - // SimdSupport::Neon => unsafe { - // kernel::dot_f16_neon(x.as_ptr(), y.as_ptr(), x.len() as u32) - // }, - #[cfg(all( - feature = "fp16kernels", - kernel_support = "avx512", - target_arch = "x86_64" - ))] - SimdSupport::Avx512 => unsafe { - kernel::dot_f16_avx512(x.as_ptr(), y.as_ptr(), x.len() as u32) - }, - #[cfg(all(feature = "fp16kernels", target_arch = "x86_64"))] - SimdSupport::Avx2 => unsafe { - kernel::dot_f16_avx2(x.as_ptr(), y.as_ptr(), x.len() as u32) - }, #[cfg(all(feature = "fp16kernels", target_arch = "loongarch64"))] SimdSupport::Lasx => unsafe { kernel::dot_f16_lasx(x.as_ptr(), y.as_ptr(), x.len() as u32)