Skip to content

Commit

Permalink
ann: refactor HNSW (#904)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghaoz authored Dec 23, 2024
1 parent dcfd453 commit 9eabf0a
Show file tree
Hide file tree
Showing 8 changed files with 537 additions and 2 deletions.
2 changes: 2 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,14 @@ jobs:
wget https://cdn.gorse.io/datasets/frappe.zip -P ~/.gorse/download
wget https://cdn.gorse.io/datasets/ml-tag.zip -P ~/.gorse/download
wget https://cdn.gorse.io/datasets/criteo.zip -P ~/.gorse/download
wget https://pub-64226d9f34c64d6f829f5b63a5540d27.r2.dev/datasets/mnist.zip -P ~/.gorse/download
unzip ~/.gorse/download/ml-100k.zip -d ~/.gorse/dataset
unzip ~/.gorse/download/ml-1m.zip -d ~/.gorse/dataset
unzip ~/.gorse/download/pinterest-20.zip -d ~/.gorse/dataset
unzip ~/.gorse/download/frappe.zip -d ~/.gorse/dataset
unzip ~/.gorse/download/ml-tag.zip -d ~/.gorse/dataset
unzip ~/.gorse/download/criteo.zip -d ~/.gorse/dataset
unzip ~/.gorse/download/mnist.zip -d ~/.gorse/dataset
- run:
name: Upgrade Go
command: |
Expand Down
6 changes: 6 additions & 0 deletions .github/workflows/build_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,14 @@ jobs:
wget https://cdn.gorse.io/datasets/frappe.zip -P ~/.gorse/download
wget https://cdn.gorse.io/datasets/ml-tag.zip -P ~/.gorse/download
wget https://cdn.gorse.io/datasets/criteo.zip -P ~/.gorse/download
wget https://pub-64226d9f34c64d6f829f5b63a5540d27.r2.dev/datasets/mnist.zip -P ~/.gorse/download
unzip ~/.gorse/download/ml-100k.zip -d ~/.gorse/dataset
unzip ~/.gorse/download/ml-1m.zip -d ~/.gorse/dataset
unzip ~/.gorse/download/pinterest-20.zip -d ~/.gorse/dataset
unzip ~/.gorse/download/frappe.zip -d ~/.gorse/dataset
unzip ~/.gorse/download/ml-tag.zip -d ~/.gorse/dataset
unzip ~/.gorse/download/criteo.zip -d ~/.gorse/dataset
unzip ~/.gorse/download/mnist.zip -d ~/.gorse/dataset
- name: Set up Go 1.23.x
uses: actions/setup-go@v4
Expand Down Expand Up @@ -123,12 +125,14 @@ jobs:
wget https://cdn.gorse.io/datasets/frappe.zip -P ~/.gorse/download
wget https://cdn.gorse.io/datasets/ml-tag.zip -P ~/.gorse/download
wget https://cdn.gorse.io/datasets/criteo.zip -P ~/.gorse/download
wget https://pub-64226d9f34c64d6f829f5b63a5540d27.r2.dev/datasets/mnist.zip -P ~/.gorse/download
unzip ~/.gorse/download/ml-100k.zip -d ~/.gorse/dataset
unzip ~/.gorse/download/ml-1m.zip -d ~/.gorse/dataset
unzip ~/.gorse/download/pinterest-20.zip -d ~/.gorse/dataset
unzip ~/.gorse/download/frappe.zip -d ~/.gorse/dataset
unzip ~/.gorse/download/ml-tag.zip -d ~/.gorse/dataset
unzip ~/.gorse/download/criteo.zip -d ~/.gorse/dataset
unzip ~/.gorse/download/mnist.zip -d ~/.gorse/dataset
- name: Set up Go 1.23.x
uses: actions/setup-go@v4
Expand Down Expand Up @@ -156,12 +160,14 @@ jobs:
Invoke-WebRequest https://cdn.gorse.io/datasets/frappe.zip -OutFile ~/.gorse/download/frappe.zip
Invoke-WebRequest https://cdn.gorse.io/datasets/ml-tag.zip -OutFile ~/.gorse/download/ml-tag.zip
Invoke-WebRequest https://cdn.gorse.io/datasets/criteo.zip -OutFile ~/.gorse/download/criteo.zip
Invoke-WebRequest https://pub-64226d9f34c64d6f829f5b63a5540d27.r2.dev/datasets/mnist.zip -OutFile ~/.gorse/download/mnist.zip
Expand-Archive ~/.gorse/download/ml-100k.zip -DestinationPath ~/.gorse/dataset
Expand-Archive ~/.gorse/download/ml-1m.zip -DestinationPath ~/.gorse/dataset
Expand-Archive ~/.gorse/download/pinterest-20.zip -DestinationPath ~/.gorse/dataset
Expand-Archive ~/.gorse/download/frappe.zip -DestinationPath ~/.gorse/dataset
Expand-Archive ~/.gorse/download/ml-tag.zip -DestinationPath ~/.gorse/dataset
Expand-Archive ~/.gorse/download/criteo.zip -DestinationPath ~/.gorse/dataset
Expand-Archive ~/.gorse/download/mnist.zip -DestinationPath ~/.gorse/dataset
- name: Set up Go 1.23.x
uses: actions/setup-go@v4
Expand Down
25 changes: 25 additions & 0 deletions common/ann/ann.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Copyright 2024 gorse Project Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package search

import (
"github.com/samber/lo"
)

type Index interface {
Add(v []float32) (int, error)
SearchIndex(q, k int, prune0 bool) ([]lo.Tuple2[int, float32], error)
SearchVector(q []float32, k int, prune0 bool) ([]lo.Tuple2[int, float32], error)
}
148 changes: 148 additions & 0 deletions common/ann/ann_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
// Copyright 2024 gorse Project Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package search

import (
"bufio"
mapset "github.com/deckarep/golang-set/v2"
"github.com/samber/lo"
"github.com/stretchr/testify/assert"
"github.com/zhenghaoz/gorse/base/floats"
"github.com/zhenghaoz/gorse/common/dataset"
"github.com/zhenghaoz/gorse/common/util"
"os"
"path/filepath"
"strconv"
"strings"
"testing"
)

const (
trainSize = 6000
testSize = 1000
)

func recall(gt, pred []lo.Tuple2[int, float32]) float64 {
s := mapset.NewSet[int]()
for _, pair := range gt {
s.Add(pair.A)
}
hit := 0
for _, pair := range pred {
if s.Contains(pair.A) {
hit++
}
}
return float64(hit) / float64(len(gt))
}

type MNIST struct {
TrainImages [][]float32
TrainLabels []uint8
TestImages [][]float32
TestLabels []uint8
}

func mnist() (*MNIST, error) {
// Download and unzip dataset
path, err := dataset.DownloadAndUnzip("mnist")
if err != nil {
return nil, err
}
// Open dataset
m := new(MNIST)
m.TrainImages, m.TrainLabels, err = m.openFile(filepath.Join(path, "train.libfm"))
if err != nil {
return nil, err
}
m.TestImages, m.TestLabels, err = m.openFile(filepath.Join(path, "test.libfm"))
if err != nil {
return nil, err
}
return m, nil
}

func (m *MNIST) openFile(path string) ([][]float32, []uint8, error) {
// Open file
f, err := os.Open(path)
if err != nil {
return nil, nil, err
}
defer f.Close()
// Read data line by line
var (
images [][]float32
labels []uint8
)
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
splits := strings.Split(line, " ")
// Parse label
label, err := util.ParseUInt8(splits[0])
if err != nil {
return nil, nil, err
}
labels = append(labels, label)
// Parse image
image := make([]float32, 784)
for _, split := range splits[1:] {
kv := strings.Split(split, ":")
index, err := strconv.Atoi(kv[0])
if err != nil {
return nil, nil, err
}
value, err := util.ParseFloat32(kv[1])
if err != nil {
return nil, nil, err
}
image[index] = value
}
images = append(images, image)
}
return images, labels, nil
}

func TestMNIST(t *testing.T) {
dat, err := mnist()
assert.NoError(t, err)

// Create brute-force index
bf := NewBruteforce(floats.Euclidean)
for _, image := range dat.TrainImages[:trainSize] {
_, err := bf.Add(image)
assert.NoError(t, err)
}

// Create HNSW index
hnsw := NewHNSW(floats.Euclidean)
for _, image := range dat.TrainImages[:trainSize] {
_, err := hnsw.Add(image)
assert.NoError(t, err)
}

// Test search
r := 0.0
for _, image := range dat.TestImages[:testSize] {
gt, err := bf.SearchVector(image, 100, false)
assert.NoError(t, err)
assert.Len(t, gt, 100)
scores, err := hnsw.SearchVector(image, 100, false)
assert.NoError(t, err)
assert.Len(t, scores, 100)
r += recall(gt, scores)
}
assert.Greater(t, r, 0.99)
}
90 changes: 90 additions & 0 deletions common/ann/bruteforce.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Copyright 2024 gorse Project Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package search

import (
"github.com/juju/errors"
"github.com/samber/lo"
"github.com/zhenghaoz/gorse/base/heap"
)

// Bruteforce is a naive implementation of vector index.
type Bruteforce[T any] struct {
distanceFunc func(a, b []T) float32
dimension int
vectors [][]T
}

func NewBruteforce[T any](distanceFunc func(a, b []T) float32) *Bruteforce[T] {
return &Bruteforce[T]{distanceFunc: distanceFunc}
}

func (b *Bruteforce[T]) Add(v []T) (int, error) {
// Check dimension
if b.dimension == 0 {
b.dimension = len(v)
} else if b.dimension != len(v) {
return 0, errors.Errorf("dimension mismatch: %v != %v", b.dimension, len(v))
}
// Add vector
b.vectors = append(b.vectors, v)
return len(b.vectors) - 1, nil
}

func (b *Bruteforce[T]) SearchIndex(q, k int, prune0 bool) ([]lo.Tuple2[int, float32], error) {
// Check index
if q < 0 || q >= len(b.vectors) {
return nil, errors.Errorf("index out of range: %v", q)
}
// Search
pq := heap.NewPriorityQueue(true)
for i, vec := range b.vectors {
if i != q {
pq.Push(int32(i), b.distanceFunc(b.vectors[q], vec))
if pq.Len() > k {
pq.Pop()
}
}
}
pq = pq.Reverse()
scores := make([]lo.Tuple2[int, float32], 0)
for pq.Len() > 0 {
value, score := pq.Pop()
if !prune0 || score < 0 {
scores = append(scores, lo.Tuple2[int, float32]{A: int(value), B: score})
}
}
return scores, nil
}

func (b *Bruteforce[T]) SearchVector(q []T, k int, prune0 bool) ([]lo.Tuple2[int, float32], error) {
// Search
pq := heap.NewPriorityQueue(true)
for i, vec := range b.vectors {
pq.Push(int32(i), b.distanceFunc(q, vec))
if pq.Len() > k {
pq.Pop()
}
}
pq = pq.Reverse()
scores := make([]lo.Tuple2[int, float32], 0)
for pq.Len() > 0 {
value, score := pq.Pop()
if !prune0 || score < 0 {
scores = append(scores, lo.Tuple2[int, float32]{A: int(value), B: score})
}
}
return scores, nil
}
Loading

0 comments on commit 9eabf0a

Please sign in to comment.