-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
393 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,5 +4,6 @@ go 1.23.0 | |
|
||
require ( | ||
github.com/google/go-cmp v0.6.0 | ||
golang.org/x/sys v0.26.0 | ||
pgregory.net/rapid v1.1.0 | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,6 @@ | ||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= | ||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= | ||
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= | ||
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= | ||
pgregory.net/rapid v1.1.0 h1:CMa0sjHSru3puNx+J0MIAuiiEV4N0qj8/cMWGBBCsjw= | ||
pgregory.net/rapid v1.1.0/go.mod h1:PY5XlDGj0+V1FCq0o192FdRhpKHGTRIWBgqjDBTrq04= |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
//go:build !arm64 | ||
|
||
package neon | ||
|
||
func DistancesWide(a []uint64, bs [][]uint64, out []uint32) { | ||
distancesWideGeneric(a, bs, out) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
package neon | ||
|
||
import ( | ||
"golang.org/x/sys/cpu" | ||
) | ||
|
||
func init() { | ||
if cpu.ARM64.HasASIMD { | ||
DistancesWide = DistancesWideNEON | ||
} | ||
} | ||
|
||
var DistancesWide = distancesWideGeneric | ||
|
||
func DistancesWideNEON(a []uint64, bs [][]uint64, out []uint32) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
#include "go_asm.h" | ||
#include "textflag.h" | ||
|
||
// func DistancesWideNEON(a []uint64, b [][]uint64, out []uint32) | ||
// | ||
// Computes the Hamming distance between 'a' and each slice in 'b', | ||
// storing the results in 'out'. | ||
// | ||
// Inputs: | ||
// a_base+0(FP) : base address of slice a | ||
// a_len+8(FP) : length of slice a | ||
// (a_cap+16(FP) : capacity of slice a) | ||
// bs_base+24(FP) : base address of slice b (slice of slices) | ||
// bs_len+32(FP) : length of slice b (number of slices) | ||
// (bs_cap+40(FP) : capacity of slice b) | ||
// out_base+48(FP): base address of output slice | ||
// (out_len+56(FP)): length of output slice | ||
// (out_cap+64(FP)): capacity of output slice | ||
// | ||
// Assumes that all slices in 'b' have the same length as 'a', | ||
// and that 'out' has at least 'bs_len' elements. | ||
|
||
//go:linkname DistancesWideNEON DistancesWideNEON | ||
//go:noescape | ||
TEXT ·DistancesWideNEON(SB), NOSPLIT, $0-72 | ||
// Load input parameters | ||
MOVD a_base+0(FP), R0 | ||
MOVD a_len+8(FP), R1 | ||
MOVD bs_base+24(FP), R2 | ||
MOVD bs_len+32(FP), R3 | ||
MOVD out_base+48(FP), R4 | ||
|
||
// Outer loop counter | ||
MOVD R3, R5 | ||
CBZ R5, done | ||
|
||
outer_loop: | ||
MOVD a_base+0(FP), R0 | ||
|
||
// Load the base address of the current slice in 'b' | ||
MOVD (R2), R6 | ||
ADD $24, R2 // Move to the next slice in 'b' | ||
|
||
// Initialize the result for this slice to 0 | ||
MOVD $0, R7 | ||
|
||
// Inner loop counter (number of uint64 in 'a') | ||
MOVD R1, R8 | ||
|
||
VEOR V1.B16,V1.B16,V1.B16 | ||
VEOR V2.B16,V2.B16,V2.B16 | ||
VEOR V3.B16,V3.B16,V3.B16 | ||
// Check if the length is at least 2 (16 bytes) | ||
CMP $2, R8 | ||
BLT inner_remainder | ||
|
||
inner_loop: | ||
// Load 16 bytes (2 uint64s) from each slice | ||
VLD1.P 16(R0), [V0.D2] | ||
VLD1.P 16(R6), [V1.D2] | ||
|
||
// XOR the loaded vectors | ||
VEOR V0.B16, V1.B16, V2.B16 | ||
|
||
// Count the set bits | ||
VCNT V2.B16, V2.B16 | ||
|
||
// Sum up the counts | ||
VUADDLV V2.B16, V3 | ||
|
||
// Add the result to the total | ||
FMOVD F3, R9 | ||
ADD R9, R7 | ||
|
||
// Decrement the counter by 2 and continue if there are more elements | ||
SUB $2, R8 | ||
CMP $2, R8 | ||
BGE inner_loop | ||
|
||
inner_remainder: | ||
// Handle the remaining element if the length is odd | ||
CBZ R8, inner_done | ||
MOVD (R0), R9 | ||
MOVD (R6), R10 | ||
EOR R9, R10, R9 | ||
FMOVD R9, F0 | ||
VCNT V0.B8, V0.B8 | ||
VUADDLV V0.B8, V0 | ||
FMOVD F0, R9 | ||
ADD R9, R7 | ||
|
||
inner_done: | ||
// Store the distance in the output slice | ||
MOVW R7, (R4) | ||
ADD $4, R4 // Move to the next element in 'out' | ||
|
||
// Decrement the outer loop counter and continue if there are more slices | ||
SUB $1, R5 | ||
CBNZ R5, outer_loop | ||
|
||
done: | ||
RET |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
package neon | ||
|
||
import ( | ||
"math/bits" | ||
"testing" | ||
|
||
"pgregory.net/rapid" | ||
) | ||
|
||
func TestDistancesWideNEON(t *testing.T) { | ||
t.Run("DistancesWideGenericEquivBits", func(t *testing.T) { | ||
rapid.Check(t, func(t *rapid.T) { | ||
dims := rapid.IntRange(0, 10_000).Draw(t, "dims") | ||
data := rapid.SliceOfN(rapid.SliceOfN(rapid.Uint64(), dims, dims), 16, 10_000).Draw(t, "data") | ||
q := rapid.SliceOfN(rapid.Uint64(), dims, dims).Draw(t, "q") | ||
for batchSize := range []int{0, 1, 2, len(data), len(data) - 1, len(data) * 2} { | ||
out := make([]uint32, batchSize) | ||
DistancesWideNEON(q, data[:batchSize], out) | ||
for i, d := range out { | ||
expected := 0 | ||
for j, q := range q { | ||
expected += bits.OnesCount64(q ^ data[i][j]) | ||
} | ||
if int(d) != expected { | ||
t.Fatal(d, expected) | ||
} | ||
} | ||
} | ||
}) | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
package neon | ||
|
||
import "math/bits" | ||
|
||
func distancesWideGeneric(a []uint64, bs [][]uint64, out []uint32) { | ||
for i, b := range bs { | ||
dist := 0 | ||
for j, aj := range a { | ||
dist += bits.OnesCount64(aj ^ b[j]) | ||
} | ||
out[i] = uint32(dist) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
package neon | ||
|
||
import ( | ||
"math/bits" | ||
"testing" | ||
|
||
"pgregory.net/rapid" | ||
) | ||
|
||
func TestDistancesWideGeneric(t *testing.T) { | ||
t.Run("DistancesWideGenericEquivBits", func(t *testing.T) { | ||
rapid.Check(t, func(t *rapid.T) { | ||
dims := rapid.IntRange(0, 10_000).Draw(t, "dims") | ||
data := rapid.SliceOfN(rapid.SliceOfN(rapid.Uint64(), dims, dims), 16, 10_000).Draw(t, "data") | ||
q := rapid.SliceOfN(rapid.Uint64(), dims, dims).Draw(t, "q") | ||
for batchSize := range []int{0, 1, 2, len(data), len(data) - 1, len(data) * 2} { | ||
out := make([]uint32, batchSize) | ||
distancesWideGeneric(q, data[:batchSize], out) | ||
for i, d := range out { | ||
expected := 0 | ||
for j, q := range q { | ||
expected += bits.OnesCount64(q ^ data[i][j]) | ||
} | ||
if int(d) != expected { | ||
t.Fatal(d, expected) | ||
} | ||
} | ||
} | ||
}) | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.