From 2eafb9a8996a14088802ba5299e36946ad5fe7c4 Mon Sep 17 00:00:00 2001 From: Sergey Grebenshchikov Date: Sat, 12 Oct 2024 02:26:25 +0200 Subject: [PATCH] neon --- README.md | 30 +++++++- go.mod | 1 + go.sum | 2 + internal/neon/distance.go | 7 ++ internal/neon/distance_arm64.go | 15 ++++ internal/neon/distance_arm64.s | 102 +++++++++++++++++++++++++++ internal/neon/distance_arm64_test.go | 31 ++++++++ internal/neon/distance_generic.go | 13 ++++ internal/neon/distance_test.go | 31 ++++++++ model_wide.go | 33 ++++++++- model_wide_bench_test.go | 19 ++++- model_wide_test.go | 56 ++++++++++++++- nearest_wide.go | 60 ++++++++++++++++ 13 files changed, 393 insertions(+), 7 deletions(-) create mode 100644 internal/neon/distance.go create mode 100644 internal/neon/distance_arm64.go create mode 100644 internal/neon/distance_arm64.s create mode 100644 internal/neon/distance_arm64_test.go create mode 100644 internal/neon/distance_generic.go create mode 100644 internal/neon/distance_test.go diff --git a/README.md b/README.md index eb667a1..1ada7d6 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,8 @@ If you need to classify **binary feature vectors that fit into `uint64`s**, this If your vectors are **longer than 64 bits**, you can [pack](#packing-wide-data) them into `[]uint64` and classify them using the ["wide" model variants](#packing-wide-data). +On ARM64 with NEON vector instruction support, `bitknn` can be [a bit faster than otherwise](#arm64-neon-support) on wide data. + You can optionally weigh class votes by distance, or specify different vote values per data point. The sub-package [`lsh`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh) implements several [Locality-Sensitive Hashing (LSH)](https://en.m.wikipedia.org/wiki/Locality-sensitive_hashing) schemes for `uint64` feature vectors. @@ -24,6 +26,7 @@ The sub-package [`lsh`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh) - [Basic usage](#basic-usage) - [LSH](#lsh) - [Packing wide data](#packing-wide-data) + - [ARM64 NEON Support](#arm64-neon-support) - [Options](#options) - [Benchmarks](#benchmarks) - [License](#license) @@ -37,11 +40,11 @@ There are just three methods you'll typically need: Variants: [`bitknn.Fit`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#Fit), [`bitknn.FitWide`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#FitWide), [`lsh.Fit`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#Fit), [`lsh.FitWide`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#FitWide) - **Find** *(k, point)*: Given a point, return the *k* nearest neighbor's indices and distances. - Variants: [`bitknn.Model.Find`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#Model.Find), [`bitknn.WideModel.Find`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#WideModel.Find), [`lsh.Model.Find`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#Model.Find), [`lsh.WideModel.Find`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#WideModel.Find) + Variants: [`bitknn.Model.Find`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#Model.Find), [`bitknn.WideModel.Find`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#WideModel.Find), [`lsh.Model.Find`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#Model.Find), [`lsh.WideModel.Find`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#WideModel.Find), [`bitknn.WideModel.FindV`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#WideModel.FindV) (vectorized on ARM64 with NEON instructions) - **Predict** *(k, point, votes)*: Predict the label for a given point based on its nearest neighbors, write the label votes into the provided vote counter. - Variants: [`bitknn.Model.Predict`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#Model.Predict), [`bitknn.WideModel.Predict`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#WideModel.Predict), [`lsh.Model.Predict`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#Model.Predict), [`lsh.WideModel.Predict`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#WideModel.Predict) + Variants: [`bitknn.Model.Predict`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#Model.Predict), [`bitknn.WideModel.Predict`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#WideModel.Predict), [`lsh.Model.Predict`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#Model.Predict), [`lsh.WideModel.Predict`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#WideModel.Predict), [`bitknn.WideModel.PredictV`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#WideModel.PredictV) (vectorized on ARM64 with NEON instructions). Each of the above methods is available on each model type. There are four model types in total: @@ -193,6 +196,29 @@ func main() { The wide model fitting function [`bitknn.FitWide`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#FitWide) accepts the same [Options](#options) as the "narrow" one. +### ARM64 NEON Support + +For ARM64 CPUs with NEON instructions, `bitknn` has a [vectorized distance function for `[]uint64s`s](internal/neon/distance_arm64.s) that is about twice as fast as what the compiler generates. + +When run on such a CPU, the ***V** methods [`WideModel.FindV`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#WideModel.FindV) and [`WideModel.PredictV`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#WideModel.predictV) are noticeably faster than the regular `Find`/`Predict`: + +| Bits | N | k | `Find` s/op | `FindV` s/op | diff | +|-------|---------|-----|--------------|--------------|------------------------| +| 128 | 1000 | 3 | 2.374µ ± 0% | 1.792µ ± 0% | -24.54% (p=0.000 n=10) | +| 128 | 1000 | 10 | 2.901µ ± 1% | 2.028µ ± 1% | -30.08% (p=0.000 n=10) | +| 128 | 1000 | 100 | 5.472µ ± 3% | 4.359µ ± 1% | -20.34% (p=0.000 n=10) | +| 128 | 1000000 | 3 | 2.273m ± 3% | 1.380m ± 2% | -39.27% (p=0.000 n=10) | +| 128 | 1000000 | 10 | 2.261m ± 1% | 1.406m ± 1% | -37.84% (p=0.000 n=10) | +| 128 | 1000000 | 100 | 2.289m ± 0% | 1.425m ± 2% | -37.76% (p=0.000 n=10) | +| 640 | 1000 | 3 | 6.201µ ± 1% | 3.716µ ± 0% | -40.07% (p=0.000 n=10) | +| 640 | 1000 | 10 | 6.728µ ± 1% | 3.973µ ± 1% | -40.96% (p=0.000 n=10) | +| 640 | 1000 | 100 | 10.855µ ± 2% | 6.917µ ± 1% | -36.28% (p=0.000 n=10) | +| 640 | 1000000 | 3 | 5.832m ± 2% | 3.337m ± 1% | -42.78% (p=0.000 n=10) | +| 640 | 1000000 | 10 | 5.830m ± 5% | 3.339m ± 1% | -42.73% (p=0.000 n=10) | +| 640 | 1000000 | 100 | 5.872m ± 1% | 3.361m ± 1% | -42.77% (p=0.000 n=10) | +| 8192 | 1000000 | 10 | 72.66m ± 1% | 30.96m ± 3% | -57.39% (p=0.000 n=10) | + + ## Options - [`bitknn.WithLinearDistanceWeighting()`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#WithLinearDistanceWeighting): Apply linear distance weighting (`1 / (1 + dist)`). diff --git a/go.mod b/go.mod index 254da43..66bdffa 100644 --- a/go.mod +++ b/go.mod @@ -4,5 +4,6 @@ go 1.23.0 require ( github.com/google/go-cmp v0.6.0 + golang.org/x/sys v0.26.0 pgregory.net/rapid v1.1.0 ) diff --git a/go.sum b/go.sum index 93587e3..5655815 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,6 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= +golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= pgregory.net/rapid v1.1.0 h1:CMa0sjHSru3puNx+J0MIAuiiEV4N0qj8/cMWGBBCsjw= pgregory.net/rapid v1.1.0/go.mod h1:PY5XlDGj0+V1FCq0o192FdRhpKHGTRIWBgqjDBTrq04= diff --git a/internal/neon/distance.go b/internal/neon/distance.go new file mode 100644 index 0000000..7935bbb --- /dev/null +++ b/internal/neon/distance.go @@ -0,0 +1,7 @@ +//go:build !arm64 + +package neon + +func DistancesWide(a []uint64, bs [][]uint64, out []uint32) { + distancesWideGeneric(a, bs, out) +} diff --git a/internal/neon/distance_arm64.go b/internal/neon/distance_arm64.go new file mode 100644 index 0000000..94ba64c --- /dev/null +++ b/internal/neon/distance_arm64.go @@ -0,0 +1,15 @@ +package neon + +import ( + "golang.org/x/sys/cpu" +) + +func init() { + if cpu.ARM64.HasASIMD { + DistancesWide = DistancesWideNEON + } +} + +var DistancesWide = distancesWideGeneric + +func DistancesWideNEON(a []uint64, bs [][]uint64, out []uint32) diff --git a/internal/neon/distance_arm64.s b/internal/neon/distance_arm64.s new file mode 100644 index 0000000..b95b6e0 --- /dev/null +++ b/internal/neon/distance_arm64.s @@ -0,0 +1,102 @@ +#include "go_asm.h" +#include "textflag.h" + +// func DistancesWideNEON(a []uint64, b [][]uint64, out []uint32) +// +// Computes the Hamming distance between 'a' and each slice in 'b', +// storing the results in 'out'. +// +// Inputs: +// a_base+0(FP) : base address of slice a +// a_len+8(FP) : length of slice a +// (a_cap+16(FP) : capacity of slice a) +// bs_base+24(FP) : base address of slice b (slice of slices) +// bs_len+32(FP) : length of slice b (number of slices) +// (bs_cap+40(FP) : capacity of slice b) +// out_base+48(FP): base address of output slice +// (out_len+56(FP)): length of output slice +// (out_cap+64(FP)): capacity of output slice +// +// Assumes that all slices in 'b' have the same length as 'a', +// and that 'out' has at least 'bs_len' elements. + +//go:linkname DistancesWideNEON DistancesWideNEON +//go:noescape +TEXT ·DistancesWideNEON(SB), NOSPLIT, $0-72 + // Load input parameters + MOVD a_base+0(FP), R0 + MOVD a_len+8(FP), R1 + MOVD bs_base+24(FP), R2 + MOVD bs_len+32(FP), R3 + MOVD out_base+48(FP), R4 + + // Outer loop counter + MOVD R3, R5 + CBZ R5, done + +outer_loop: + MOVD a_base+0(FP), R0 + + // Load the base address of the current slice in 'b' + MOVD (R2), R6 + ADD $24, R2 // Move to the next slice in 'b' + + // Initialize the result for this slice to 0 + MOVD $0, R7 + + // Inner loop counter (number of uint64 in 'a') + MOVD R1, R8 + + VEOR V1.B16,V1.B16,V1.B16 + VEOR V2.B16,V2.B16,V2.B16 + VEOR V3.B16,V3.B16,V3.B16 + // Check if the length is at least 2 (16 bytes) + CMP $2, R8 + BLT inner_remainder + +inner_loop: + // Load 16 bytes (2 uint64s) from each slice + VLD1.P 16(R0), [V0.D2] + VLD1.P 16(R6), [V1.D2] + + // XOR the loaded vectors + VEOR V0.B16, V1.B16, V2.B16 + + // Count the set bits + VCNT V2.B16, V2.B16 + + // Sum up the counts + VUADDLV V2.B16, V3 + + // Add the result to the total + FMOVD F3, R9 + ADD R9, R7 + + // Decrement the counter by 2 and continue if there are more elements + SUB $2, R8 + CMP $2, R8 + BGE inner_loop + +inner_remainder: + // Handle the remaining element if the length is odd + CBZ R8, inner_done + MOVD (R0), R9 + MOVD (R6), R10 + EOR R9, R10, R9 + FMOVD R9, F0 + VCNT V0.B8, V0.B8 + VUADDLV V0.B8, V0 + FMOVD F0, R9 + ADD R9, R7 + +inner_done: + // Store the distance in the output slice + MOVW R7, (R4) + ADD $4, R4 // Move to the next element in 'out' + + // Decrement the outer loop counter and continue if there are more slices + SUB $1, R5 + CBNZ R5, outer_loop + +done: + RET diff --git a/internal/neon/distance_arm64_test.go b/internal/neon/distance_arm64_test.go new file mode 100644 index 0000000..5802742 --- /dev/null +++ b/internal/neon/distance_arm64_test.go @@ -0,0 +1,31 @@ +package neon + +import ( + "math/bits" + "testing" + + "pgregory.net/rapid" +) + +func TestDistancesWideNEON(t *testing.T) { + t.Run("DistancesWideGenericEquivBits", func(t *testing.T) { + rapid.Check(t, func(t *rapid.T) { + dims := rapid.IntRange(0, 10_000).Draw(t, "dims") + data := rapid.SliceOfN(rapid.SliceOfN(rapid.Uint64(), dims, dims), 16, 10_000).Draw(t, "data") + q := rapid.SliceOfN(rapid.Uint64(), dims, dims).Draw(t, "q") + for batchSize := range []int{0, 1, 2, len(data), len(data) - 1, len(data) * 2} { + out := make([]uint32, batchSize) + DistancesWideNEON(q, data[:batchSize], out) + for i, d := range out { + expected := 0 + for j, q := range q { + expected += bits.OnesCount64(q ^ data[i][j]) + } + if int(d) != expected { + t.Fatal(d, expected) + } + } + } + }) + }) +} diff --git a/internal/neon/distance_generic.go b/internal/neon/distance_generic.go new file mode 100644 index 0000000..2be70ea --- /dev/null +++ b/internal/neon/distance_generic.go @@ -0,0 +1,13 @@ +package neon + +import "math/bits" + +func distancesWideGeneric(a []uint64, bs [][]uint64, out []uint32) { + for i, b := range bs { + dist := 0 + for j, aj := range a { + dist += bits.OnesCount64(aj ^ b[j]) + } + out[i] = uint32(dist) + } +} diff --git a/internal/neon/distance_test.go b/internal/neon/distance_test.go new file mode 100644 index 0000000..578d1b7 --- /dev/null +++ b/internal/neon/distance_test.go @@ -0,0 +1,31 @@ +package neon + +import ( + "math/bits" + "testing" + + "pgregory.net/rapid" +) + +func TestDistancesWideGeneric(t *testing.T) { + t.Run("DistancesWideGenericEquivBits", func(t *testing.T) { + rapid.Check(t, func(t *rapid.T) { + dims := rapid.IntRange(0, 10_000).Draw(t, "dims") + data := rapid.SliceOfN(rapid.SliceOfN(rapid.Uint64(), dims, dims), 16, 10_000).Draw(t, "data") + q := rapid.SliceOfN(rapid.Uint64(), dims, dims).Draw(t, "q") + for batchSize := range []int{0, 1, 2, len(data), len(data) - 1, len(data) * 2} { + out := make([]uint32, batchSize) + distancesWideGeneric(q, data[:batchSize], out) + for i, d := range out { + expected := 0 + for j, q := range q { + expected += bits.OnesCount64(q ^ data[i][j]) + } + if int(d) != expected { + t.Fatal(d, expected) + } + } + } + }) + }) +} diff --git a/model_wide.go b/model_wide.go index 1e93520..f11a1ac 100644 --- a/model_wide.go +++ b/model_wide.go @@ -25,10 +25,17 @@ func (me *WideModel) PreallocateHeap(k int) { // Writes their distances and indices in the dataset into the pre-allocated slices. // Returns the distance and index slices, truncated to the actual number of neighbors found. func (me *WideModel) Find(k int, x []uint64) ([]int, []int) { - me.Narrow.PreallocateHeap(k) + me.PreallocateHeap(k) return me.FindInto(k, x, me.Narrow.HeapDistances, me.Narrow.HeapIndices) } +// FindV is [Find], but vectorizable (currently only on ARM64 with NEON instructions). +// The provided [batch] slice must have length >=k and is used to pre-compute batches of distances. +func (me *WideModel) FindV(k int, x []uint64, batch []uint32) ([]int, []int) { + me.PreallocateHeap(k) + return me.FindIntoV(k, x, batch, me.Narrow.HeapDistances, me.Narrow.HeapIndices) +} + // Finds the nearest neighbors of the given point. // Writes their distances and indices in the dataset into the provided slices. // The slices should be pre-allocated to length k+1. @@ -38,10 +45,17 @@ func (me *WideModel) FindInto(k int, x []uint64, distances []int, indices []int) return distances[:k], indices[:k] } +// FindIntoV is [FindInto], but vectorizable (currently only on ARM64 with NEON instructions). +// The provided [batch] slice must have length >=k and is used to pre-compute batches of distances. +func (me *WideModel) FindIntoV(k int, x []uint64, batch []uint32, distances []int, indices []int) ([]int, []int) { + k = NearestWideV(me.WideData, k, x, batch, distances, indices) + return distances[:k], indices[:k] +} + // Predicts the label of a single input point. Reuses two slices of length K+1 for the neighbor heap. // Returns the number of neighbors found. func (me *WideModel) Predict(k int, x []uint64, votes VoteCounter) int { - me.Narrow.PreallocateHeap(k) + me.PreallocateHeap(k) return me.PredictInto(k, x, me.Narrow.HeapDistances, me.Narrow.HeapIndices, votes) } @@ -52,3 +66,18 @@ func (me *WideModel) PredictInto(k int, x []uint64, distances []int, indices []i me.Narrow.Vote(k, distances, indices, votes) return k } + +// PredictV is [Predict], but vectorizable (currently only on ARM64 with NEON instructions). +// The provided [batch] slice must have length >=k and is used to pre-compute batches of distances. +func (me *WideModel) PredictV(k int, x []uint64, batch []uint32, votes VoteCounter) int { + me.PreallocateHeap(k) + return me.PredictIntoV(k, x, batch, me.Narrow.HeapDistances, me.Narrow.HeapIndices, votes) +} + +// PredictIntoV is [PredictInto], but vectorizable (currently only on ARM64 with NEON instructions). +// The provided [batch] slice must have length >=k and is used to pre-compute batches of distances. +func (me *WideModel) PredictIntoV(k int, x []uint64, batch []uint32, distances []int, indices []int, votes VoteCounter) int { + k = NearestWideV(me.WideData, k, x, batch, distances, indices) + me.Narrow.Vote(k, distances, indices, votes) + return k +} diff --git a/model_wide_bench_test.go b/model_wide_bench_test.go index e729878..39760e6 100644 --- a/model_wide_bench_test.go +++ b/model_wide_bench_test.go @@ -14,10 +14,13 @@ func BenchmarkWideModel(b *testing.B) { dim []int dataSize []int k []int + batch []int } benches := []bench{ - {dim: []int{1, 2, 10}, dataSize: []int{100}, k: []int{3, 10}}, - {dim: []int{1, 2, 10}, dataSize: []int{1000, 1_000_000}, k: []int{3, 10, 100}}, + {dim: []int{1, 2, 10}, dataSize: []int{100}, k: []int{3, 10}, batch: nil}, + {dim: []int{1}, dataSize: []int{1000, 1_000_000}, k: []int{3, 10, 100}, batch: nil}, + {dim: []int{2, 10}, dataSize: []int{1000, 1_000_000}, k: []int{3, 10, 100}, batch: []int{1000}}, + {dim: []int{128}, dataSize: []int{1_000_000}, k: []int{10}, batch: []int{1000}}, } for _, bench := range benches { for _, dim := range bench.dim { @@ -42,6 +45,18 @@ func BenchmarkWideModel(b *testing.B) { model.Find(k, query) } }) + for _, batchSize := range bench.batch { + batchSize = min(batchSize, dataSize) + batchSize = max(batchSize, k) + batch := make([]uint32, batchSize) + b.Run(fmt.Sprintf("Op=FindV_batch=%d_bits=%d_N=%d_k=%d", batchSize, dim*64, dataSize, k), func(b *testing.B) { + model.PreallocateHeap(k) + b.ResetTimer() + for n := 0; n < b.N; n++ { + model.FindV(k, query, batch) + } + }) + } } } } diff --git a/model_wide_test.go b/model_wide_test.go index 4568e1b..7530da8 100644 --- a/model_wide_test.go +++ b/model_wide_test.go @@ -10,7 +10,7 @@ import ( "pgregory.net/rapid" ) -func Test_Model_64bit_Equal_To_Narrow(t *testing.T) { +func TestModel_64bitWideEquivNarrow(t *testing.T) { id := func(a uint64) uint64 { return a } rapid.Check(t, func(t *rapid.T) { k := rapid.IntRange(3, 1001).Draw(t, "k") @@ -71,3 +71,57 @@ func Test_Model_64bit_Equal_To_Narrow(t *testing.T) { }) } + +func TestModel_FindV_Equiv_Find_0Remainder(t *testing.T) { + rapid.Check(t, func(t *rapid.T) { + k := rapid.IntRange(1, 100).Draw(t, "k") + n := rapid.IntRange(2, 100).Draw(t, "n") + dims := rapid.IntRange(1, 10).Draw(t, "dims") + data := rapid.SliceOfN(rapid.SliceOfN(rapid.Uint64(), dims, dims), n*k, n*k).Draw(t, "data") + q := rapid.SliceOfN(rapid.Uint64(), dims, dims).Draw(t, "q") + batchSize := k + m1 := bitknn.FitWide(data, nil) + m2 := bitknn.FitWide(data, nil) + batch := make([]uint32, batchSize) + vds, vis := m1.FindV(k, q, batch) + ds, is := m2.Find(k, q) + if !reflect.DeepEqual(vds, ds) { + t.Fatal(vds, ds) + } + if !reflect.DeepEqual(vis, is) { + t.Fatal(vis, is) + } + }) +} + +func TestModel_FindV_Equiv_Find(t *testing.T) { + rapid.Check(t, func(t *rapid.T) { + k := rapid.IntRange(0, 1000).Draw(t, "k") + dims := rapid.IntRange(1, 10_000).Draw(t, "dims") + data := rapid.SliceOf(rapid.SliceOfN(rapid.Uint64(), dims, dims)).Draw(t, "data") + batchSizes := []int{0, len(data), len(data) - 1, len(data) - 2, 2048, 100_000} + q := rapid.SliceOfN(rapid.Uint64(), dims, dims).Draw(t, "q") + for _, batchSize := range batchSizes { + batchSize = max(k, batchSize) + m1 := bitknn.FitWide(data, nil) + m2 := bitknn.FitWide(data, nil) + batch := make([]uint32, batchSize) + vds, vis := m1.FindV(k, q, batch) + ds, is := m2.Find(k, q) + if !reflect.DeepEqual(vds, ds) { + t.Fatal(vds, ds) + } + if !reflect.DeepEqual(vis, is) { + t.Fatal(vis, is) + } + batchAll := make([]uint32, batchSize) + vds, vis = m1.FindV(k, q, batchAll) + if !reflect.DeepEqual(vds, ds) { + t.Fatal(vds, ds) + } + if !reflect.DeepEqual(vis, is) { + t.Fatal(vis, is) + } + } + }) +} diff --git a/nearest_wide.go b/nearest_wide.go index 6b97904..840cb62 100644 --- a/nearest_wide.go +++ b/nearest_wide.go @@ -4,6 +4,7 @@ import ( "math/bits" "github.com/keilerkonzept/bitknn/internal/heap" + "github.com/keilerkonzept/bitknn/internal/neon" ) // [bitknn.Nearest], but for wide data. @@ -40,3 +41,62 @@ func NearestWide(data [][]uint64, k int, x []uint64, distances, indices []int) i } return k } + +// [NearestWide], but vectorizable (currently only on ARM64 with NEON instructions). +// The `batch` array must have at least length `k`, and is used to pre-compute batches of distances. +func NearestWideV(data [][]uint64, k int, x []uint64, batch []uint32, distances, indices []int) int { + if k == 0 || len(data) == 0 { + return 0 + } + _ = batch[k-1] + heap := heap.MakeMax(distances, indices) + distance0 := &distances[0] + + k0 := min(k, len(data)) + datak0 := data[:k0:k0] + + batchk0 := batch[:k0:k0] + neon.DistancesWide(x, datak0, batchk0) + + for i, dist := range batchk0 { + heap.Push(int(dist), i) + } + + if len(data) <= k { + return k0 + } + + maxDist := *distance0 + + b := len(batch) + _ = data[k] + i := k + for ; i <= len(data)-b; i += b { + neon.DistancesWide(x, data[i:i+b], batch) + for j := range batch { + dist := int(batch[j]) + if dist >= maxDist { + continue + } + heap.PushPop(dist, i+j) + maxDist = *distance0 + } + } + + remainder := len(data) - i + if remainder <= 0 { + return k + } + _ = batch[remainder-1] + + neon.DistancesWide(x, data[i:], batch) + for j := range remainder { + dist := int(batch[j]) + if dist >= maxDist { + continue + } + heap.PushPop(dist, i+j) + maxDist = *distance0 + } + return k +}