Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement SIMD-accelerated euclidean #903

Merged
merged 6 commits into from
Dec 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion base/floats/floats.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@

package floats

import "math"
import (
"math"

"github.com/chewxy/math32"
)

func dot(a, b []float32) (ret float32) {
for i := range a {
Expand All @@ -23,6 +27,13 @@ func dot(a, b []float32) (ret float32) {
return
}

func euclidean(a, b []float32) (ret float32) {
for i := range a {
ret += (a[i] - b[i]) * (a[i] - b[i])
}
return math32.Sqrt(ret)
}

func mulTo(a, b, c []float32) {
for i := range a {
c[i] = a[i] * b[i]
Expand Down Expand Up @@ -170,3 +181,10 @@ func Dot(a, b []float32) (ret float32) {
}
return impl.dot(a, b)
}

func Euclidean(a, b []float32) float32 {
if len(a) != len(b) {
panic("floats: slice lengths do not match")
}
return impl.euclidean(a, b)
}
22 changes: 19 additions & 3 deletions base/floats/floats_amd64.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
package floats

import (
"github.com/klauspost/cpuid/v2"
"unsafe"

"github.com/klauspost/cpuid/v2"
)

//go:generate go run ../../cmd/goat src/floats_avx.c -O3 -mavx
//go:generate go run ../../cmd/goat src/floats_avx512.c -O3 -mavx -mfma -mavx512f -mavx512dq
//go:generate goat src/floats_avx.c -O3 -mavx
//go:generate goat src/floats_avx512.c -O3 -mavx -mfma -mavx512f -mavx512dq

var impl = Default

Expand Down Expand Up @@ -111,3 +112,18 @@ func (i implementation) dot(a, b []float32) float32 {
return dot(a, b)
}
}

func (i implementation) euclidean(a, b []float32) float32 {
switch i {
case AVX:
var ret float32
_mm256_euclidean(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), unsafe.Pointer(uintptr(len(a))), unsafe.Pointer(&ret))
return ret
case AVX512:
var ret float32
_mm512_euclidean(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), unsafe.Pointer(uintptr(len(a))), unsafe.Pointer(&ret))
return ret
default:
return euclidean(a, b)
}
}
39 changes: 39 additions & 0 deletions base/floats/floats_amd64_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,17 @@ func TestAVX_Dot(t *testing.T) {
assert.Equal(t, expected, actual)
}

func TestAVX_Euclidean(t *testing.T) {
if !cpuid.CPU.Supports(cpuid.AVX) || !cpuid.CPU.Supports(cpuid.FMA3) {
t.Skip("AVX and FMA3 are not supported in the current CPU")
}
a := []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
b := []float32{0, 2, 4, 6, 8, 10, 12, 14, 16, 18}
actual := AVX.euclidean(a, b)
expected := Default.euclidean(a, b)
assert.InDelta(t, expected, actual, 1e-6)
}

func TestAVX512_MulConstAddTo(t *testing.T) {
if !cpuid.CPU.Supports(cpuid.AVX512F) || !cpuid.CPU.Supports(cpuid.AVX512DQ) {
t.Skip("AVX512F and AVX512DQ are not supported in the current CPU")
Expand Down Expand Up @@ -141,6 +152,17 @@ func TestAVX512_Dot(t *testing.T) {
assert.Equal(t, expected, actual)
}

func TestAVX512_Euclidean(t *testing.T) {
if !cpuid.CPU.Supports(cpuid.AVX512F) || !cpuid.CPU.Supports(cpuid.AVX512DQ) {
t.Skip("AVX512F and AVX512DQ are not supported in the current CPU")
}
a := []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
b := []float32{0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20}
actual := AVX512.euclidean(a, b)
expected := Default.euclidean(a, b)
assert.InDelta(t, expected, actual, 1e-6)
}

func initializeFloat32Array(n int) []float32 {
x := make([]float32, n)
for i := 0; i < n; i++ {
Expand All @@ -166,6 +188,23 @@ func BenchmarkDot(b *testing.B) {
}
}

func BenchmarkEuclidean(b *testing.B) {
for _, impl := range []implementation{Default, AVX, AVX512} {
b.Run(impl.String(), func(b *testing.B) {
for i := 16; i <= 128; i *= 2 {
b.Run(strconv.Itoa(i), func(b *testing.B) {
v1 := initializeFloat32Array(i)
v2 := initializeFloat32Array(i)
b.ResetTimer()
for i := 0; i < b.N; i++ {
impl.euclidean(v1, v2)
}
})
}
})
}
}

func BenchmarkMulConstAddTo(b *testing.B) {
for _, impl := range []implementation{Default, AVX, AVX512} {
b.Run(impl.String(), func(b *testing.B) {
Expand Down
12 changes: 11 additions & 1 deletion base/floats/floats_arm64.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package floats

import "unsafe"

//go:generate go run ../../cmd/goat src/floats_neon.c -O3
//go:generate goat src/floats_neon.c -O3

var impl = Neon

Expand Down Expand Up @@ -79,3 +79,13 @@ func (i implementation) dot(a, b []float32) float32 {
return dot(a, b)
}
}

func (i implementation) euclidean(a, b []float32) float32 {
if i == Neon {
var ret float32
veuclidean(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), unsafe.Pointer(uintptr(len(a))), unsafe.Pointer(&ret))
return ret
} else {
return euclidean(a, b)
}
}
25 changes: 25 additions & 0 deletions base/floats/floats_arm64_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ func TestNEON_Dot(t *testing.T) {
assert.Equal(t, expected, actual)
}

func TestNEON_Euclidean(t *testing.T) {
a := []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
b := []float32{10, 20, 30, 40, 50, 60, 70, 80, 90, 100}
actual := Neon.euclidean(a, b)
expected := Default.euclidean(a, b)
assert.Equal(t, expected, actual)
}

func initializeFloat32Array(n int) []float32 {
x := make([]float32, n)
for i := 0; i < n; i++ {
Expand All @@ -92,6 +100,23 @@ func BenchmarkDot(b *testing.B) {
}
}

func BenchmarkEuclidean(b *testing.B) {
for _, impl := range []implementation{Default, Neon} {
b.Run(impl.String(), func(b *testing.B) {
for i := 16; i <= 128; i *= 2 {
b.Run(strconv.Itoa(i), func(b *testing.B) {
v1 := initializeFloat32Array(i)
v2 := initializeFloat32Array(i)
b.ResetTimer()
for i := 0; i < b.N; i++ {
impl.euclidean(v1, v2)
}
})
}
})
}
}

func BenchmarkMulConstAddTo(b *testing.B) {
for _, impl := range []implementation{Default, Neon} {
b.Run(impl.String(), func(b *testing.B) {
Expand Down
6 changes: 4 additions & 2 deletions base/floats/floats_avx.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading