Skip to content

Commit

Permalink
Use max-heap for query results instead of sorting huge slice
Browse files Browse the repository at this point in the history
  • Loading branch information
philippgille committed Mar 17, 2024
1 parent 54b1857 commit 86c681e
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 27 deletions.
27 changes: 8 additions & 19 deletions collection.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package chromem

import (
"cmp"
"context"
"errors"
"fmt"
Expand Down Expand Up @@ -323,30 +322,20 @@ func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float3
return nil, nil
}

// For the remaining documents, calculate cosine similarity.
docSims, err := calcDocSimilarity(ctx, queryEmbedding, filteredDocs)
// For the remaining documents, get the most similar docs.
nMaxDocs, err := getMostSimilarDocs(ctx, queryEmbedding, filteredDocs, nResults)
if err != nil {
return nil, fmt.Errorf("couldn't calculate cosine similarity: %w", err)
return nil, fmt.Errorf("couldn't get most similar docs: %w", err)
}

// Sort by similarity
slices.SortFunc(docSims, func(i, j docSim) int {
// i, j; for descending order
return cmp.Compare(j.similarity, i.similarity)
})

// Return the top nResults or len(docSim), whichever is smaller
if len(docSims) < nResults {
nResults = len(docSims)
}
res := make([]Result, 0, nResults)
for i := 0; i < nResults; i++ {
res = append(res, Result{
ID: docSims[i].docID,
Metadata: c.documents[docSims[i].docID].Metadata,
Embedding: c.documents[docSims[i].docID].Embedding,
Content: c.documents[docSims[i].docID].Content,
Similarity: docSims[i].similarity,
ID: nMaxDocs[i].docID,
Metadata: c.documents[nMaxDocs[i].docID].Metadata,
Embedding: c.documents[nMaxDocs[i].docID].Embedding,
Content: c.documents[nMaxDocs[i].docID].Content,
Similarity: nMaxDocs[i].similarity,
})
}

Expand Down
79 changes: 71 additions & 8 deletions query.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package chromem

import (
"cmp"
"container/heap"
"context"
"fmt"
"runtime"
"slices"
"strings"
"sync"
)
Expand All @@ -15,6 +18,70 @@ type docSim struct {
similarity float32
}

// docMaxHeap is a max-heap of docSims, based on similarity.
// See https://pkg.go.dev/container/[email protected]#example-package-IntHeap
type docMaxHeap []docSim

func (h docMaxHeap) Len() int { return len(h) }
func (h docMaxHeap) Less(i, j int) bool { return h[i].similarity < h[j].similarity }
func (h docMaxHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }

func (h *docMaxHeap) Push(x any) {
// Push and Pop use pointer receivers because they modify the slice's length,
// not just its contents.
*h = append(*h, x.(docSim))
}

func (h *docMaxHeap) Pop() any {
old := *h
n := len(old)
x := old[n-1]
*h = old[0 : n-1]
return x
}

// maxDocSims manages a max-heap of docSims with a fixed size, keeping the n highest
// similarities. It's safe for concurrent use, but not the result of values().
// In our benchmarks this was faster than sorting a slice of docSims at the end.
type maxDocSims struct {
h docMaxHeap
lock sync.RWMutex
size int
}

// newMaxDocSims creates a new nMaxDocs with a fixed size.
func newMaxDocSims(size int) *maxDocSims {
return &maxDocSims{
h: make(docMaxHeap, 0, size),
size: size,
}
}

// add inserts a new docSim into the heap, keeping only the top n similarities.
func (mds *maxDocSims) add(doc docSim) {
mds.lock.Lock()
defer mds.lock.Unlock()
if mds.h.Len() < mds.size {
heap.Push(&mds.h, doc)
} else if mds.h.Len() > 0 && mds.h[0].similarity < doc.similarity {
// Replace the smallest similarity if the new doc's similarity is higher
heap.Pop(&mds.h)
heap.Push(&mds.h, doc)
}
}

// values returns the docSims in the heap, sorted by similarity (descending).
// The call itself is safe for concurrent use with add(), but the result isn't.
// Only work with the result after all calls to add() have finished.
func (d *maxDocSims) values() []docSim {
d.lock.RLock()
defer d.lock.RUnlock()
slices.SortFunc(d.h, func(i, j docSim) int {
return cmp.Compare(j.similarity, i.similarity)
})
return d.h
}

// filterDocs filters a map of documents by metadata and content.
// It does this concurrently.
func filterDocs(docs map[string]*Document, where, whereDocument map[string]string) []*Document {
Expand Down Expand Up @@ -95,9 +162,8 @@ func documentMatchesFilters(document *Document, where, whereDocument map[string]
return true
}

func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Document) ([]docSim, error) {
similarities := make([]docSim, 0, len(docs))
similaritiesLock := sync.Mutex{}
func getMostSimilarDocs(ctx context.Context, queryVectors []float32, docs []*Document, n int) ([]docSim, error) {
nMaxDocs := newMaxDocSims(n)

// Determine concurrency. Use number of docs or CPUs, whichever is smaller.
numCPUs := runtime.NumCPU()
Expand Down Expand Up @@ -152,10 +218,7 @@ func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Docu
return
}

similaritiesLock.Lock()
// We don't defer the unlock because we want to unlock much earlier.
similarities = append(similarities, docSim{docID: doc.ID, similarity: sim})
similaritiesLock.Unlock()
nMaxDocs.add(docSim{docID: doc.ID, similarity: sim})
}
}(docs[start:end])
}
Expand All @@ -166,5 +229,5 @@ func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Docu
return nil, sharedErr
}

return similarities, nil
return nMaxDocs.values(), nil
}

0 comments on commit 86c681e

Please sign in to comment.