Skip to content

Commit

Permalink
implement SIMD-accelerated euclidean (#903)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghaoz authored Dec 22, 2024
1 parent 0892aee commit dcfd453
Show file tree
Hide file tree
Showing 21 changed files with 1,586 additions and 1,594 deletions.
20 changes: 19 additions & 1 deletion base/floats/floats.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@

package floats

import "math"
import (
"math"

"github.com/chewxy/math32"
)

func dot(a, b []float32) (ret float32) {
for i := range a {
Expand All @@ -23,6 +27,13 @@ func dot(a, b []float32) (ret float32) {
return
}

func euclidean(a, b []float32) (ret float32) {
for i := range a {
ret += (a[i] - b[i]) * (a[i] - b[i])
}
return math32.Sqrt(ret)
}

func mulTo(a, b, c []float32) {
for i := range a {
c[i] = a[i] * b[i]
Expand Down Expand Up @@ -170,3 +181,10 @@ func Dot(a, b []float32) (ret float32) {
}
return impl.dot(a, b)
}

func Euclidean(a, b []float32) float32 {
if len(a) != len(b) {
panic("floats: slice lengths do not match")
}
return impl.euclidean(a, b)
}
22 changes: 19 additions & 3 deletions base/floats/floats_amd64.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
package floats

import (
"github.com/klauspost/cpuid/v2"
"unsafe"

"github.com/klauspost/cpuid/v2"
)

//go:generate go run ../../cmd/goat src/floats_avx.c -O3 -mavx
//go:generate go run ../../cmd/goat src/floats_avx512.c -O3 -mavx -mfma -mavx512f -mavx512dq
//go:generate goat src/floats_avx.c -O3 -mavx
//go:generate goat src/floats_avx512.c -O3 -mavx -mfma -mavx512f -mavx512dq

var impl = Default

Expand Down Expand Up @@ -111,3 +112,18 @@ func (i implementation) dot(a, b []float32) float32 {
return dot(a, b)
}
}

func (i implementation) euclidean(a, b []float32) float32 {
switch i {
case AVX:
var ret float32
_mm256_euclidean(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), unsafe.Pointer(uintptr(len(a))), unsafe.Pointer(&ret))
return ret
case AVX512:
var ret float32
_mm512_euclidean(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), unsafe.Pointer(uintptr(len(a))), unsafe.Pointer(&ret))
return ret
default:
return euclidean(a, b)
}
}
39 changes: 39 additions & 0 deletions base/floats/floats_amd64_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,17 @@ func TestAVX_Dot(t *testing.T) {
assert.Equal(t, expected, actual)
}

func TestAVX_Euclidean(t *testing.T) {
if !cpuid.CPU.Supports(cpuid.AVX) || !cpuid.CPU.Supports(cpuid.FMA3) {
t.Skip("AVX and FMA3 are not supported in the current CPU")
}
a := []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
b := []float32{0, 2, 4, 6, 8, 10, 12, 14, 16, 18}
actual := AVX.euclidean(a, b)
expected := Default.euclidean(a, b)
assert.InDelta(t, expected, actual, 1e-6)
}

func TestAVX512_MulConstAddTo(t *testing.T) {
if !cpuid.CPU.Supports(cpuid.AVX512F) || !cpuid.CPU.Supports(cpuid.AVX512DQ) {
t.Skip("AVX512F and AVX512DQ are not supported in the current CPU")
Expand Down Expand Up @@ -141,6 +152,17 @@ func TestAVX512_Dot(t *testing.T) {
assert.Equal(t, expected, actual)
}

func TestAVX512_Euclidean(t *testing.T) {
if !cpuid.CPU.Supports(cpuid.AVX512F) || !cpuid.CPU.Supports(cpuid.AVX512DQ) {
t.Skip("AVX512F and AVX512DQ are not supported in the current CPU")
}
a := []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
b := []float32{0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20}
actual := AVX512.euclidean(a, b)
expected := Default.euclidean(a, b)
assert.InDelta(t, expected, actual, 1e-6)
}

func initializeFloat32Array(n int) []float32 {
x := make([]float32, n)
for i := 0; i < n; i++ {
Expand All @@ -166,6 +188,23 @@ func BenchmarkDot(b *testing.B) {
}
}

func BenchmarkEuclidean(b *testing.B) {
for _, impl := range []implementation{Default, AVX, AVX512} {
b.Run(impl.String(), func(b *testing.B) {
for i := 16; i <= 128; i *= 2 {
b.Run(strconv.Itoa(i), func(b *testing.B) {
v1 := initializeFloat32Array(i)
v2 := initializeFloat32Array(i)
b.ResetTimer()
for i := 0; i < b.N; i++ {
impl.euclidean(v1, v2)
}
})
}
})
}
}

func BenchmarkMulConstAddTo(b *testing.B) {
for _, impl := range []implementation{Default, AVX, AVX512} {
b.Run(impl.String(), func(b *testing.B) {
Expand Down
12 changes: 11 additions & 1 deletion base/floats/floats_arm64.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package floats

import "unsafe"

//go:generate go run ../../cmd/goat src/floats_neon.c -O3
//go:generate goat src/floats_neon.c -O3

var impl = Neon

Expand Down Expand Up @@ -79,3 +79,13 @@ func (i implementation) dot(a, b []float32) float32 {
return dot(a, b)
}
}

func (i implementation) euclidean(a, b []float32) float32 {
if i == Neon {
var ret float32
veuclidean(unsafe.Pointer(&a[0]), unsafe.Pointer(&b[0]), unsafe.Pointer(uintptr(len(a))), unsafe.Pointer(&ret))
return ret
} else {
return euclidean(a, b)
}
}
25 changes: 25 additions & 0 deletions base/floats/floats_arm64_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ func TestNEON_Dot(t *testing.T) {
assert.Equal(t, expected, actual)
}

func TestNEON_Euclidean(t *testing.T) {
a := []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
b := []float32{10, 20, 30, 40, 50, 60, 70, 80, 90, 100}
actual := Neon.euclidean(a, b)
expected := Default.euclidean(a, b)
assert.Equal(t, expected, actual)
}

func initializeFloat32Array(n int) []float32 {
x := make([]float32, n)
for i := 0; i < n; i++ {
Expand All @@ -92,6 +100,23 @@ func BenchmarkDot(b *testing.B) {
}
}

func BenchmarkEuclidean(b *testing.B) {
for _, impl := range []implementation{Default, Neon} {
b.Run(impl.String(), func(b *testing.B) {
for i := 16; i <= 128; i *= 2 {
b.Run(strconv.Itoa(i), func(b *testing.B) {
v1 := initializeFloat32Array(i)
v2 := initializeFloat32Array(i)
b.ResetTimer()
for i := 0; i < b.N; i++ {
impl.euclidean(v1, v2)
}
})
}
})
}
}

func BenchmarkMulConstAddTo(b *testing.B) {
for _, impl := range []implementation{Default, Neon} {
b.Run(impl.String(), func(b *testing.B) {
Expand Down
6 changes: 4 additions & 2 deletions base/floats/floats_avx.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit dcfd453

Please sign in to comment.