-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use max-heap for query results instead of sorting huge slice
- Loading branch information
1 parent
54b1857
commit 86c681e
Showing
2 changed files
with
79 additions
and
27 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,12 @@ | ||
package chromem | ||
|
||
import ( | ||
"cmp" | ||
"container/heap" | ||
"context" | ||
"fmt" | ||
"runtime" | ||
"slices" | ||
"strings" | ||
"sync" | ||
) | ||
|
@@ -15,6 +18,70 @@ type docSim struct { | |
similarity float32 | ||
} | ||
|
||
// docMaxHeap is a max-heap of docSims, based on similarity. | ||
// See https://pkg.go.dev/container/[email protected]#example-package-IntHeap | ||
type docMaxHeap []docSim | ||
|
||
func (h docMaxHeap) Len() int { return len(h) } | ||
func (h docMaxHeap) Less(i, j int) bool { return h[i].similarity < h[j].similarity } | ||
func (h docMaxHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } | ||
|
||
func (h *docMaxHeap) Push(x any) { | ||
// Push and Pop use pointer receivers because they modify the slice's length, | ||
// not just its contents. | ||
*h = append(*h, x.(docSim)) | ||
} | ||
|
||
func (h *docMaxHeap) Pop() any { | ||
old := *h | ||
n := len(old) | ||
x := old[n-1] | ||
*h = old[0 : n-1] | ||
return x | ||
} | ||
|
||
// maxDocSims manages a max-heap of docSims with a fixed size, keeping the n highest | ||
// similarities. It's safe for concurrent use, but not the result of values(). | ||
// In our benchmarks this was faster than sorting a slice of docSims at the end. | ||
type maxDocSims struct { | ||
h docMaxHeap | ||
lock sync.RWMutex | ||
size int | ||
} | ||
|
||
// newMaxDocSims creates a new nMaxDocs with a fixed size. | ||
func newMaxDocSims(size int) *maxDocSims { | ||
return &maxDocSims{ | ||
h: make(docMaxHeap, 0, size), | ||
size: size, | ||
} | ||
} | ||
|
||
// add inserts a new docSim into the heap, keeping only the top n similarities. | ||
func (mds *maxDocSims) add(doc docSim) { | ||
mds.lock.Lock() | ||
defer mds.lock.Unlock() | ||
if mds.h.Len() < mds.size { | ||
heap.Push(&mds.h, doc) | ||
} else if mds.h.Len() > 0 && mds.h[0].similarity < doc.similarity { | ||
// Replace the smallest similarity if the new doc's similarity is higher | ||
heap.Pop(&mds.h) | ||
heap.Push(&mds.h, doc) | ||
} | ||
} | ||
|
||
// values returns the docSims in the heap, sorted by similarity (descending). | ||
// The call itself is safe for concurrent use with add(), but the result isn't. | ||
// Only work with the result after all calls to add() have finished. | ||
func (d *maxDocSims) values() []docSim { | ||
d.lock.RLock() | ||
defer d.lock.RUnlock() | ||
slices.SortFunc(d.h, func(i, j docSim) int { | ||
return cmp.Compare(j.similarity, i.similarity) | ||
}) | ||
return d.h | ||
} | ||
|
||
// filterDocs filters a map of documents by metadata and content. | ||
// It does this concurrently. | ||
func filterDocs(docs map[string]*Document, where, whereDocument map[string]string) []*Document { | ||
|
@@ -95,9 +162,8 @@ func documentMatchesFilters(document *Document, where, whereDocument map[string] | |
return true | ||
} | ||
|
||
func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Document) ([]docSim, error) { | ||
similarities := make([]docSim, 0, len(docs)) | ||
similaritiesLock := sync.Mutex{} | ||
func getMostSimilarDocs(ctx context.Context, queryVectors []float32, docs []*Document, n int) ([]docSim, error) { | ||
nMaxDocs := newMaxDocSims(n) | ||
|
||
// Determine concurrency. Use number of docs or CPUs, whichever is smaller. | ||
numCPUs := runtime.NumCPU() | ||
|
@@ -152,10 +218,7 @@ func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Docu | |
return | ||
} | ||
|
||
similaritiesLock.Lock() | ||
// We don't defer the unlock because we want to unlock much earlier. | ||
similarities = append(similarities, docSim{docID: doc.ID, similarity: sim}) | ||
similaritiesLock.Unlock() | ||
nMaxDocs.add(docSim{docID: doc.ID, similarity: sim}) | ||
} | ||
}(docs[start:end]) | ||
} | ||
|
@@ -166,5 +229,5 @@ func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Docu | |
return nil, sharedErr | ||
} | ||
|
||
return similarities, nil | ||
return nMaxDocs.values(), nil | ||
} |