Skip to content
This repository has been archived by the owner on Aug 16, 2021. It is now read-only.

Commit

Permalink
Merge pull request #29 from mathetake/develop
Browse files Browse the repository at this point in the history
add utility functions to save and read GannIndex to/from disk
  • Loading branch information
mathetake authored Jul 2, 2018
2 parents 1352add + e29844c commit 6a4c754
Show file tree
Hide file tree
Showing 14 changed files with 268 additions and 176 deletions.
43 changes: 1 addition & 42 deletions Gopkg.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

34 changes: 21 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ The implemented algorithm is truly inspired by Annoy (https://github.com/spotify
```golang
import (
"fmt"
"github.com/mathetake/gann"
"github.com/mathetake/gann/index"
"math/rand"
"time"
)
Expand All @@ -39,26 +39,34 @@ func main() {
}

// create index
gIDx, err := gann.GetIndex(rawItems, dim, nTrees, k, true)
if err != nil {
panic(err)
}
// build index
gIDx := index.GetIndex(rawItems, dim, nTrees, k, true)
gIDx.Build()

// do search
q := []float32{0.1, 0.02, 0.001}
ann, err := gIDx.GetANNbyVector(q, 5, 10)
ann, _ := gIDx.GetANNbyVector(q, 5, 10)
fmt.Println("result:", ann)
}
```
# interfaces


You can also save and load your index to/from disk:

```golang
type GannIndex interface {
Build() error
GetANNbyItemID(id int64, num int, bucketScale float32) (ann []int64, err error)
GetANNbyVector(v []float32, num int, bucketScale float32) (ann []int64, err error)
gIDx := index.GetIndex(rawItems, dim, nTrees, k, true)
gIDx.Build()

var path = "foo.gann"

err := gIDx.Save(path)
if err != nil {
panic(err)
}

var idx = &index.Index{}
err := idx.Load(path)
if err != nil {
panic(err)
}
```

Expand All @@ -75,4 +83,4 @@ https://mathetake.github.io/blogs/gann.html

# License

MIT
MIT
22 changes: 16 additions & 6 deletions gann.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,29 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
*/

package gann

import (
"github.com/mathetake/gann/index"
)

type GannIndex interface {
Build() error // build search trees.
// Index ... an interface for gann's index in `index` package (only used for interface declaration on its methods)
type Index interface {
// Build ... build gann's index
Build() error

// GetANNbyItemID ... search ANNs by a given itemID
GetANNbyItemID(id int64, num int, bucketScale float32) (ann []int64, err error)

// GetANNbyVector ... search ANNs by a given query vector
GetANNbyVector(v []float32, num int, bucketScale float32) (ann []int64, err error)
}

// GetIndex ... get index (composed of trees, nodes, etc.)
func GetIndex(items [][]float32, d int, nT int, k int, normalize bool) (GannIndex, error) {
return index.Initialize(items, d, nT, k, normalize)
// Load ... load index from disk
Load(path string) error

// Save ... save index to disk
Save(path string) error
}

var _ Index = &index.Index{}
28 changes: 11 additions & 17 deletions gann_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"math/rand"
"testing"
"time"

"github.com/mathetake/gann/index"
)

type benchTemplate struct {
Expand Down Expand Up @@ -79,14 +81,12 @@ func BenchmarkGetANNByVector3(b *testing.B) {
}
}

func _getTestIndex(tmpl *benchTemplate) GannIndex {
func _getTestIndex(tmpl *benchTemplate) *index.Index {
its := _getItems(tmpl.dim, tmpl.nItem)

// create index
gIDx, err := GetIndex(its, tmpl.dim, tmpl.nTree, tmpl.k, true)
if err != nil {
panic(err)
}
gIDx := index.GetIndex(its, tmpl.dim, tmpl.nTree, tmpl.k, true)

// build index
gIDx.Build()
return gIDx
Expand Down Expand Up @@ -125,10 +125,8 @@ func BenchmarkBuildIndex1(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
// create index
gIDx, err := GetIndex(its, tmpl.dim, tmpl.nTree, tmpl.k, true)
if err != nil {
panic(err)
}
gIDx := index.GetIndex(its, tmpl.dim, tmpl.nTree, tmpl.k, true)

// build index
gIDx.Build()
}
Expand All @@ -148,10 +146,8 @@ func BenchmarkBuildIndex2(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
// create index
gIDx, err := GetIndex(its, tmpl.dim, tmpl.nTree, tmpl.k, true)
if err != nil {
panic(err)
}
gIDx := index.GetIndex(its, tmpl.dim, tmpl.nTree, tmpl.k, true)

// build index
gIDx.Build()
}
Expand All @@ -171,10 +167,8 @@ func BenchmarkBuildIndex3(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
// create index
gIDx, err := GetIndex(its, tmpl.dim, tmpl.nTree, tmpl.k, true)
if err != nil {
panic(err)
}
gIDx := index.GetIndex(its, tmpl.dim, tmpl.nTree, tmpl.k, true)

// build index
gIDx.Build()
}
Expand Down
38 changes: 19 additions & 19 deletions index/impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ import (
"github.com/pkg/errors"
)

// GetANNbyItem ... get ANNs by a item.Item
// GetANNbyItemID ... get ANNs by a item.Item
func (idx *Index) GetANNbyItemID(id int64, num int, searchBucket float32) (ann []int64, err error) {
it, ok := idx.itemIDToItem[id]
it, ok := idx.ItemIDToItem[id]
if !ok {
return ann, errors.Errorf("Item not found for %v", id)
}
Expand All @@ -36,7 +36,7 @@ func (idx *Index) getANNbyVector(v []float32, num int, bucketScale float32) ([]i
5. Return the top `num` ones.
*/

if len(idx.roots) == 0 {
if len(idx.Roots) == 0 {
return []int64{}, errors.Errorf("Please build Index before searching.")
}

Expand All @@ -46,7 +46,7 @@ func (idx *Index) getANNbyVector(v []float32, num int, bucketScale float32) ([]i
pq := node.PriorityQueue{}

// 1.
for i, r := range idx.roots {
for i, r := range idx.Roots {
n := &node.QueueItem{
Value: r.ID,
Index: i,
Expand All @@ -61,7 +61,7 @@ func (idx *Index) getANNbyVector(v []float32, num int, bucketScale float32) ([]i
for {
q := heap.Pop(&pq).(*node.QueueItem)
d := q.Priority
n, ok := idx.nodeIDToNode[q.Value]
n, ok := idx.NodeIDToNode[q.Value]
if !ok {
panic("wrong item set in priority queue")
}
Expand Down Expand Up @@ -92,7 +92,7 @@ func (idx *Index) getANNbyVector(v []float32, num int, bucketScale float32) ([]i
ann := make([]int64, 0, len(annMap))
for id := range annMap {
ann = append(ann, id)
idToDist[id] = item.DotProduct(idx.itemIDToItem[id].Vec, v)
idToDist[id] = item.DotProduct(idx.ItemIDToItem[id].Vec, v)
}

// 4.
Expand All @@ -113,47 +113,47 @@ func (idx *Index) Build() error {

var wg sync.WaitGroup
var m sync.Map
for i := range idx.roots {
for i := range idx.Roots {
wg.Add(1)
ii := i
go func() {
idx.roots[ii].Build(idx.items, idx.k, idx.dim, &m)
idx.Roots[ii].Build(idx.Items, idx.K, idx.Dim, &m)
wg.Done()
}()
}
wg.Wait()

m.Range(func(key, _ interface{}) bool {
n := key.(*node.Node)
idx.nodes = append(idx.nodes, n)
idx.Nodes = append(idx.Nodes, n)
return true
})

if len(idx.nodes) == 0 {
if len(idx.Nodes) == 0 {
panic("# of nodes is zero.")
}

// build nodeIDToNode map
for _, n := range idx.nodes {
idx.nodeIDToNode[n.ID] = n
for _, n := range idx.Nodes {
idx.NodeIDToNode[n.ID] = n
}
return nil
}

func (idx *Index) initRootNodes() {
vecs := make([]item.Vector, len(idx.itemIDToItem))
for i, it := range idx.items {
vecs := make([]item.Vector, len(idx.ItemIDToItem))
for i, it := range idx.Items {
vecs[i] = it.Vec
}
for i := 0; i < idx.nTree; i++ {
nv := item.GetNormalVectorOfSplittingHyperPlane(vecs, idx.dim)
for i := 0; i < idx.NTree; i++ {
nv := item.GetNormalVectorOfSplittingHyperPlane(vecs, idx.Dim)
r := &node.Node{
ID: uuid.New().String(),
Vec: nv,
NDescendants: len(idx.items),
NDescendants: len(idx.Items),
}
idx.roots = append(idx.roots, r)
idx.nodes = append(idx.nodes, r)
idx.Roots = append(idx.Roots, r)
idx.Nodes = append(idx.Nodes, r)
}
}

Expand Down
Loading

0 comments on commit 6a4c754

Please sign in to comment.