Skip to content

Commit

Permalink
neon
Browse files Browse the repository at this point in the history
  • Loading branch information
sgreben committed Oct 12, 2024
1 parent aa22ede commit 2eafb9a
Show file tree
Hide file tree
Showing 13 changed files with 393 additions and 7 deletions.
30 changes: 28 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ If you need to classify **binary feature vectors that fit into `uint64`s**, this

If your vectors are **longer than 64 bits**, you can [pack](#packing-wide-data) them into `[]uint64` and classify them using the ["wide" model variants](#packing-wide-data).

On ARM64 with NEON vector instruction support, `bitknn` can be [a bit faster than otherwise](#arm64-neon-support) on wide data.

You can optionally weigh class votes by distance, or specify different vote values per data point.

The sub-package [`lsh`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh) implements several [Locality-Sensitive Hashing (LSH)](https://en.m.wikipedia.org/wiki/Locality-sensitive_hashing) schemes for `uint64` feature vectors.
Expand All @@ -24,6 +26,7 @@ The sub-package [`lsh`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh)
- [Basic usage](#basic-usage)
- [LSH](#lsh)
- [Packing wide data](#packing-wide-data)
- [ARM64 NEON Support](#arm64-neon-support)
- [Options](#options)
- [Benchmarks](#benchmarks)
- [License](#license)
Expand All @@ -37,11 +40,11 @@ There are just three methods you'll typically need:
Variants: [`bitknn.Fit`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#Fit), [`bitknn.FitWide`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#FitWide), [`lsh.Fit`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#Fit), [`lsh.FitWide`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#FitWide)
- **Find** *(k, point)*: Given a point, return the *k* nearest neighbor's indices and distances.

Variants: [`bitknn.Model.Find`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#Model.Find), [`bitknn.WideModel.Find`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#WideModel.Find), [`lsh.Model.Find`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#Model.Find), [`lsh.WideModel.Find`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#WideModel.Find)
Variants: [`bitknn.Model.Find`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#Model.Find), [`bitknn.WideModel.Find`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#WideModel.Find), [`lsh.Model.Find`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#Model.Find), [`lsh.WideModel.Find`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#WideModel.Find), [`bitknn.WideModel.FindV`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#WideModel.FindV) (vectorized on ARM64 with NEON instructions)

- **Predict** *(k, point, votes)*: Predict the label for a given point based on its nearest neighbors, write the label votes into the provided vote counter.

Variants: [`bitknn.Model.Predict`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#Model.Predict), [`bitknn.WideModel.Predict`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#WideModel.Predict), [`lsh.Model.Predict`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#Model.Predict), [`lsh.WideModel.Predict`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#WideModel.Predict)
Variants: [`bitknn.Model.Predict`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#Model.Predict), [`bitknn.WideModel.Predict`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#WideModel.Predict), [`lsh.Model.Predict`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#Model.Predict), [`lsh.WideModel.Predict`](https://pkg.go.dev/github.com/keilerkonzept/bitknn/lsh#WideModel.Predict), [`bitknn.WideModel.PredictV`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#WideModel.PredictV) (vectorized on ARM64 with NEON instructions).

Each of the above methods is available on each model type. There are four model types in total:

Expand Down Expand Up @@ -193,6 +196,29 @@ func main() {

The wide model fitting function [`bitknn.FitWide`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#FitWide) accepts the same [Options](#options) as the "narrow" one.

### ARM64 NEON Support

For ARM64 CPUs with NEON instructions, `bitknn` has a [vectorized distance function for `[]uint64s`s](internal/neon/distance_arm64.s) that is about twice as fast as what the compiler generates.

When run on such a CPU, the ***V** methods [`WideModel.FindV`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#WideModel.FindV) and [`WideModel.PredictV`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#WideModel.predictV) are noticeably faster than the regular `Find`/`Predict`:

| Bits | N | k | `Find` s/op | `FindV` s/op | diff |
|-------|---------|-----|--------------|--------------|------------------------|
| 128 | 1000 | 3 | 2.374µ ± 0% | 1.792µ ± 0% | -24.54% (p=0.000 n=10) |
| 128 | 1000 | 10 | 2.901µ ± 1% | 2.028µ ± 1% | -30.08% (p=0.000 n=10) |
| 128 | 1000 | 100 | 5.472µ ± 3% | 4.359µ ± 1% | -20.34% (p=0.000 n=10) |
| 128 | 1000000 | 3 | 2.273m ± 3% | 1.380m ± 2% | -39.27% (p=0.000 n=10) |
| 128 | 1000000 | 10 | 2.261m ± 1% | 1.406m ± 1% | -37.84% (p=0.000 n=10) |
| 128 | 1000000 | 100 | 2.289m ± 0% | 1.425m ± 2% | -37.76% (p=0.000 n=10) |
| 640 | 1000 | 3 | 6.201µ ± 1% | 3.716µ ± 0% | -40.07% (p=0.000 n=10) |
| 640 | 1000 | 10 | 6.728µ ± 1% | 3.973µ ± 1% | -40.96% (p=0.000 n=10) |
| 640 | 1000 | 100 | 10.855µ ± 2% | 6.917µ ± 1% | -36.28% (p=0.000 n=10) |
| 640 | 1000000 | 3 | 5.832m ± 2% | 3.337m ± 1% | -42.78% (p=0.000 n=10) |
| 640 | 1000000 | 10 | 5.830m ± 5% | 3.339m ± 1% | -42.73% (p=0.000 n=10) |
| 640 | 1000000 | 100 | 5.872m ± 1% | 3.361m ± 1% | -42.77% (p=0.000 n=10) |
| 8192 | 1000000 | 10 | 72.66m ± 1% | 30.96m ± 3% | -57.39% (p=0.000 n=10) |


## Options

- [`bitknn.WithLinearDistanceWeighting()`](https://pkg.go.dev/github.com/keilerkonzept/bitknn#WithLinearDistanceWeighting): Apply linear distance weighting (`1 / (1 + dist)`).
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@ go 1.23.0

require (
github.com/google/go-cmp v0.6.0
golang.org/x/sys v0.26.0
pgregory.net/rapid v1.1.0
)
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
pgregory.net/rapid v1.1.0 h1:CMa0sjHSru3puNx+J0MIAuiiEV4N0qj8/cMWGBBCsjw=
pgregory.net/rapid v1.1.0/go.mod h1:PY5XlDGj0+V1FCq0o192FdRhpKHGTRIWBgqjDBTrq04=
7 changes: 7 additions & 0 deletions internal/neon/distance.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
//go:build !arm64

package neon

func DistancesWide(a []uint64, bs [][]uint64, out []uint32) {
distancesWideGeneric(a, bs, out)
}
15 changes: 15 additions & 0 deletions internal/neon/distance_arm64.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package neon

import (
"golang.org/x/sys/cpu"
)

func init() {
if cpu.ARM64.HasASIMD {
DistancesWide = DistancesWideNEON
}
}

var DistancesWide = distancesWideGeneric

func DistancesWideNEON(a []uint64, bs [][]uint64, out []uint32)
102 changes: 102 additions & 0 deletions internal/neon/distance_arm64.s
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#include "go_asm.h"
#include "textflag.h"

// func DistancesWideNEON(a []uint64, b [][]uint64, out []uint32)
//
// Computes the Hamming distance between 'a' and each slice in 'b',
// storing the results in 'out'.
//
// Inputs:
// a_base+0(FP) : base address of slice a
// a_len+8(FP) : length of slice a
// (a_cap+16(FP) : capacity of slice a)
// bs_base+24(FP) : base address of slice b (slice of slices)
// bs_len+32(FP) : length of slice b (number of slices)
// (bs_cap+40(FP) : capacity of slice b)
// out_base+48(FP): base address of output slice
// (out_len+56(FP)): length of output slice
// (out_cap+64(FP)): capacity of output slice
//
// Assumes that all slices in 'b' have the same length as 'a',
// and that 'out' has at least 'bs_len' elements.

//go:linkname DistancesWideNEON DistancesWideNEON
//go:noescape
TEXT ·DistancesWideNEON(SB), NOSPLIT, $0-72
// Load input parameters
MOVD a_base+0(FP), R0
MOVD a_len+8(FP), R1
MOVD bs_base+24(FP), R2
MOVD bs_len+32(FP), R3
MOVD out_base+48(FP), R4

// Outer loop counter
MOVD R3, R5
CBZ R5, done

outer_loop:
MOVD a_base+0(FP), R0

// Load the base address of the current slice in 'b'
MOVD (R2), R6
ADD $24, R2 // Move to the next slice in 'b'

// Initialize the result for this slice to 0
MOVD $0, R7

// Inner loop counter (number of uint64 in 'a')
MOVD R1, R8

VEOR V1.B16,V1.B16,V1.B16
VEOR V2.B16,V2.B16,V2.B16
VEOR V3.B16,V3.B16,V3.B16
// Check if the length is at least 2 (16 bytes)
CMP $2, R8
BLT inner_remainder

inner_loop:
// Load 16 bytes (2 uint64s) from each slice
VLD1.P 16(R0), [V0.D2]
VLD1.P 16(R6), [V1.D2]

// XOR the loaded vectors
VEOR V0.B16, V1.B16, V2.B16

// Count the set bits
VCNT V2.B16, V2.B16

// Sum up the counts
VUADDLV V2.B16, V3

// Add the result to the total
FMOVD F3, R9
ADD R9, R7

// Decrement the counter by 2 and continue if there are more elements
SUB $2, R8
CMP $2, R8
BGE inner_loop

inner_remainder:
// Handle the remaining element if the length is odd
CBZ R8, inner_done
MOVD (R0), R9
MOVD (R6), R10
EOR R9, R10, R9
FMOVD R9, F0
VCNT V0.B8, V0.B8
VUADDLV V0.B8, V0
FMOVD F0, R9
ADD R9, R7

inner_done:
// Store the distance in the output slice
MOVW R7, (R4)
ADD $4, R4 // Move to the next element in 'out'

// Decrement the outer loop counter and continue if there are more slices
SUB $1, R5
CBNZ R5, outer_loop

done:
RET
31 changes: 31 additions & 0 deletions internal/neon/distance_arm64_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package neon

import (
"math/bits"
"testing"

"pgregory.net/rapid"
)

func TestDistancesWideNEON(t *testing.T) {
t.Run("DistancesWideGenericEquivBits", func(t *testing.T) {
rapid.Check(t, func(t *rapid.T) {
dims := rapid.IntRange(0, 10_000).Draw(t, "dims")
data := rapid.SliceOfN(rapid.SliceOfN(rapid.Uint64(), dims, dims), 16, 10_000).Draw(t, "data")
q := rapid.SliceOfN(rapid.Uint64(), dims, dims).Draw(t, "q")
for batchSize := range []int{0, 1, 2, len(data), len(data) - 1, len(data) * 2} {
out := make([]uint32, batchSize)
DistancesWideNEON(q, data[:batchSize], out)
for i, d := range out {
expected := 0
for j, q := range q {
expected += bits.OnesCount64(q ^ data[i][j])
}
if int(d) != expected {
t.Fatal(d, expected)
}
}
}
})
})
}
13 changes: 13 additions & 0 deletions internal/neon/distance_generic.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package neon

import "math/bits"

func distancesWideGeneric(a []uint64, bs [][]uint64, out []uint32) {
for i, b := range bs {
dist := 0
for j, aj := range a {
dist += bits.OnesCount64(aj ^ b[j])
}
out[i] = uint32(dist)
}
}
31 changes: 31 additions & 0 deletions internal/neon/distance_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package neon

import (
"math/bits"
"testing"

"pgregory.net/rapid"
)

func TestDistancesWideGeneric(t *testing.T) {
t.Run("DistancesWideGenericEquivBits", func(t *testing.T) {
rapid.Check(t, func(t *rapid.T) {
dims := rapid.IntRange(0, 10_000).Draw(t, "dims")
data := rapid.SliceOfN(rapid.SliceOfN(rapid.Uint64(), dims, dims), 16, 10_000).Draw(t, "data")
q := rapid.SliceOfN(rapid.Uint64(), dims, dims).Draw(t, "q")
for batchSize := range []int{0, 1, 2, len(data), len(data) - 1, len(data) * 2} {
out := make([]uint32, batchSize)
distancesWideGeneric(q, data[:batchSize], out)
for i, d := range out {
expected := 0
for j, q := range q {
expected += bits.OnesCount64(q ^ data[i][j])
}
if int(d) != expected {
t.Fatal(d, expected)
}
}
}
})
})
}
33 changes: 31 additions & 2 deletions model_wide.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,17 @@ func (me *WideModel) PreallocateHeap(k int) {
// Writes their distances and indices in the dataset into the pre-allocated slices.
// Returns the distance and index slices, truncated to the actual number of neighbors found.
func (me *WideModel) Find(k int, x []uint64) ([]int, []int) {
me.Narrow.PreallocateHeap(k)
me.PreallocateHeap(k)
return me.FindInto(k, x, me.Narrow.HeapDistances, me.Narrow.HeapIndices)
}

// FindV is [Find], but vectorizable (currently only on ARM64 with NEON instructions).
// The provided [batch] slice must have length >=k and is used to pre-compute batches of distances.
func (me *WideModel) FindV(k int, x []uint64, batch []uint32) ([]int, []int) {
me.PreallocateHeap(k)
return me.FindIntoV(k, x, batch, me.Narrow.HeapDistances, me.Narrow.HeapIndices)
}

// Finds the nearest neighbors of the given point.
// Writes their distances and indices in the dataset into the provided slices.
// The slices should be pre-allocated to length k+1.
Expand All @@ -38,10 +45,17 @@ func (me *WideModel) FindInto(k int, x []uint64, distances []int, indices []int)
return distances[:k], indices[:k]
}

// FindIntoV is [FindInto], but vectorizable (currently only on ARM64 with NEON instructions).
// The provided [batch] slice must have length >=k and is used to pre-compute batches of distances.
func (me *WideModel) FindIntoV(k int, x []uint64, batch []uint32, distances []int, indices []int) ([]int, []int) {
k = NearestWideV(me.WideData, k, x, batch, distances, indices)
return distances[:k], indices[:k]
}

// Predicts the label of a single input point. Reuses two slices of length K+1 for the neighbor heap.
// Returns the number of neighbors found.
func (me *WideModel) Predict(k int, x []uint64, votes VoteCounter) int {
me.Narrow.PreallocateHeap(k)
me.PreallocateHeap(k)
return me.PredictInto(k, x, me.Narrow.HeapDistances, me.Narrow.HeapIndices, votes)
}

Expand All @@ -52,3 +66,18 @@ func (me *WideModel) PredictInto(k int, x []uint64, distances []int, indices []i
me.Narrow.Vote(k, distances, indices, votes)
return k
}

// PredictV is [Predict], but vectorizable (currently only on ARM64 with NEON instructions).
// The provided [batch] slice must have length >=k and is used to pre-compute batches of distances.
func (me *WideModel) PredictV(k int, x []uint64, batch []uint32, votes VoteCounter) int {
me.PreallocateHeap(k)
return me.PredictIntoV(k, x, batch, me.Narrow.HeapDistances, me.Narrow.HeapIndices, votes)
}

// PredictIntoV is [PredictInto], but vectorizable (currently only on ARM64 with NEON instructions).
// The provided [batch] slice must have length >=k and is used to pre-compute batches of distances.
func (me *WideModel) PredictIntoV(k int, x []uint64, batch []uint32, distances []int, indices []int, votes VoteCounter) int {
k = NearestWideV(me.WideData, k, x, batch, distances, indices)
me.Narrow.Vote(k, distances, indices, votes)
return k
}
19 changes: 17 additions & 2 deletions model_wide_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@ func BenchmarkWideModel(b *testing.B) {
dim []int
dataSize []int
k []int
batch []int
}
benches := []bench{
{dim: []int{1, 2, 10}, dataSize: []int{100}, k: []int{3, 10}},
{dim: []int{1, 2, 10}, dataSize: []int{1000, 1_000_000}, k: []int{3, 10, 100}},
{dim: []int{1, 2, 10}, dataSize: []int{100}, k: []int{3, 10}, batch: nil},
{dim: []int{1}, dataSize: []int{1000, 1_000_000}, k: []int{3, 10, 100}, batch: nil},
{dim: []int{2, 10}, dataSize: []int{1000, 1_000_000}, k: []int{3, 10, 100}, batch: []int{1000}},
{dim: []int{128}, dataSize: []int{1_000_000}, k: []int{10}, batch: []int{1000}},
}
for _, bench := range benches {
for _, dim := range bench.dim {
Expand All @@ -42,6 +45,18 @@ func BenchmarkWideModel(b *testing.B) {
model.Find(k, query)
}
})
for _, batchSize := range bench.batch {
batchSize = min(batchSize, dataSize)
batchSize = max(batchSize, k)
batch := make([]uint32, batchSize)
b.Run(fmt.Sprintf("Op=FindV_batch=%d_bits=%d_N=%d_k=%d", batchSize, dim*64, dataSize, k), func(b *testing.B) {
model.PreallocateHeap(k)
b.ResetTimer()
for n := 0; n < b.N; n++ {
model.FindV(k, query, batch)
}
})
}
}
}
}
Expand Down
Loading

0 comments on commit 2eafb9a

Please sign in to comment.