Skip to content

Commit

Permalink
set float16 in clang
Browse files Browse the repository at this point in the history
  • Loading branch information
eddyxu committed Sep 15, 2024
1 parent dff7c10 commit ac76fa3
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 62 deletions.
27 changes: 0 additions & 27 deletions rust/lance-core/src/utils/cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,33 +48,6 @@ lazy_static! {
};

Check warning on line 48 in rust/lance-core/src/utils/cpu.rs

View workflow job for this annotation

GitHub Actions / linux-arm

Diff in /runner/_work/lance/lance/rust/lance-core/src/utils/cpu.rs
}

// Inspired by https://github.com/RustCrypto/utils/blob/master/cpufeatures/src/aarch64.rs
// aarch64 doesn't have userspace feature detection built in, so we have to call
// into OS-specific functions to check for features.
#[cfg(all(target_arch = "aarch64", target_os = "macos"))]
mod aarch64 {
pub fn has_neon_f16_support() -> bool {
// Maybe we can assume it's there?
true
}
}

#[cfg(all(target_arch = "aarch64", target_os = "linux"))]
mod aarch64 {
pub fn has_neon_f16_support() -> bool {
// See: https://github.com/rust-lang/libc/blob/7ce81ca7aeb56aae7ca0237ef9353d58f3d7d2f1/src/unix/linux_like/linux/gnu/b64/aarch64/mod.rs#L533
let flags = unsafe { libc::getauxval(libc::AT_HWCAP) };
flags & libc::HWCAP_FPHP != 0
}
}

#[cfg(all(target_arch = "aarch64", target_os = "windows"))]
mod aarch64 {
pub fn has_neon_f16_support() -> bool {
// https://github.com/lancedb/lance/issues/2411
false
}
}

#[cfg(target_arch = "loongarch64")]
mod loongarch64 {
Expand Down
17 changes: 7 additions & 10 deletions rust/lance-linalg/src/distance/dot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ use arrow_array::{cast::AsArray, types::Float32Type, Array, FixedSizeListArray,
use arrow_schema::DataType;
use half::{bf16, f16};
use lance_arrow::{ArrowFloatType, FloatArray};
use lance_core::utils::cpu::SimdSupport;
use lance_core::utils::cpu::FP16_SIMD_SUPPORT;
use num_traits::{real::Real, AsPrimitive, Num};

use crate::simd::{
Expand Down Expand Up @@ -126,17 +124,16 @@ impl Dot for f16 {
}
}

match *FP16_SIMD_SUPPORT {
#[cfg(all(feature = "fp16kernels", target_arch = "loongarch64"))]
SimdSupport::Lasx => unsafe {
#[cfg(target_arch = "loongarch64")]
{
if loongarch64::has_lasx_support() {
kernel::dot_f16_lasx(x.as_ptr(), y.as_ptr(), x.len() as u32)
},
#[cfg(all(feature = "fp16kernels", target_arch = "loongarch64"))]
SimdSupport::Lsx => unsafe {
} else if loongarch64::has_lsx_support() {
kernel::dot_f16_lsx(x.as_ptr(), y.as_ptr(), x.len() as u32)
},
_ => dot_scalar::<Self, f32, 16>(x, y),
}
}

dot_scalar::<Self, f32, 16>(x, y)
}
}

Expand Down
39 changes: 14 additions & 25 deletions rust/lance-linalg/src/simd/f16.c
Original file line number Diff line number Diff line change
@@ -1,16 +1,5 @@
// Copyright 2023 Lance Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright The Lance Authors

#include <stddef.h>
#include <stdint.h>
Expand All @@ -24,11 +13,12 @@

// TODO: I wonder if we could re-purpose this macro to compile bf16 kernels?
#if defined(__clang__)
// Note: we use __fp16 instead of _Float16 because Clang < 15.0.0 does not
// support it well for most targets. __fp16 works for our purposes here since
// we are always casting it to float anyways. This doesn't make a difference
// in the compiled assembly code for these functions.
#if __FLT16_MANT_DIG__
// Clang supports _Float16
#define FP16 _Float16
#else
#define FP16 __fp16
#endif
#elif defined(__GNUC__) || defined(__GNUG__)
#define FP16 _Float16
#endif
Expand All @@ -54,23 +44,22 @@ float FUNC(dot_f16)(const FP16 *x, const FP16 *y, uint32_t dimension) {

#pragma clang loop unroll(enable) interleave(enable) vectorize(enable)
for (uint32_t i = 0; i < dimension; i++) {
sum += (float) x[i] * (float) y[i];
// Use float32 as the accumulator to avoid overflow.
sum += x[i] * y[i];
}
return sum;
}

float FUNC(l2_f16)(const FP16 *x, const FP16 *y, uint32_t dimension) {
float x2 = 0.0;
float y2 = 0.0;
float xy = 0.0;
float sum = 0;

#pragma clang loop unroll(enable) interleave(enable) vectorize(enable)
for (uint32_t i = 0; i < dimension; i++) {
x2 += x[i] * x[i];
y2 += y[i] * y[i];
xy += x[i] * y[i];
FP16 sub = x[i] - y[i];
// Use float32 as the accumulator to avoid overflow.
sum += sub * sub;
}
return x2 + y2 - 2 * xy;
return sum;
}

float FUNC(cosine_f16)(const FP16 *x, float x_norm, const FP16 *y, uint32_t dimension) {
Expand Down

0 comments on commit ac76fa3

Please sign in to comment.