From 86c681e8e66e8c7d8e007c628131f356fc92df73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Sun, 17 Mar 2024 15:28:07 +0100 Subject: [PATCH] Use max-heap for query results instead of sorting huge slice --- collection.go | 27 ++++++------------ query.go | 79 +++++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 79 insertions(+), 27 deletions(-) diff --git a/collection.go b/collection.go index b4186b0..52d8844 100644 --- a/collection.go +++ b/collection.go @@ -1,7 +1,6 @@ package chromem import ( - "cmp" "context" "errors" "fmt" @@ -323,30 +322,20 @@ func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float3 return nil, nil } - // For the remaining documents, calculate cosine similarity. - docSims, err := calcDocSimilarity(ctx, queryEmbedding, filteredDocs) + // For the remaining documents, get the most similar docs. + nMaxDocs, err := getMostSimilarDocs(ctx, queryEmbedding, filteredDocs, nResults) if err != nil { - return nil, fmt.Errorf("couldn't calculate cosine similarity: %w", err) + return nil, fmt.Errorf("couldn't get most similar docs: %w", err) } - // Sort by similarity - slices.SortFunc(docSims, func(i, j docSim) int { - // i, j; for descending order - return cmp.Compare(j.similarity, i.similarity) - }) - - // Return the top nResults or len(docSim), whichever is smaller - if len(docSims) < nResults { - nResults = len(docSims) - } res := make([]Result, 0, nResults) for i := 0; i < nResults; i++ { res = append(res, Result{ - ID: docSims[i].docID, - Metadata: c.documents[docSims[i].docID].Metadata, - Embedding: c.documents[docSims[i].docID].Embedding, - Content: c.documents[docSims[i].docID].Content, - Similarity: docSims[i].similarity, + ID: nMaxDocs[i].docID, + Metadata: c.documents[nMaxDocs[i].docID].Metadata, + Embedding: c.documents[nMaxDocs[i].docID].Embedding, + Content: c.documents[nMaxDocs[i].docID].Content, + Similarity: nMaxDocs[i].similarity, }) } diff --git a/query.go b/query.go index 9da1be2..240060c 100644 --- a/query.go +++ b/query.go @@ -1,9 +1,12 @@ package chromem import ( + "cmp" + "container/heap" "context" "fmt" "runtime" + "slices" "strings" "sync" ) @@ -15,6 +18,70 @@ type docSim struct { similarity float32 } +// docMaxHeap is a max-heap of docSims, based on similarity. +// See https://pkg.go.dev/container/heap@go1.22#example-package-IntHeap +type docMaxHeap []docSim + +func (h docMaxHeap) Len() int { return len(h) } +func (h docMaxHeap) Less(i, j int) bool { return h[i].similarity < h[j].similarity } +func (h docMaxHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } + +func (h *docMaxHeap) Push(x any) { + // Push and Pop use pointer receivers because they modify the slice's length, + // not just its contents. + *h = append(*h, x.(docSim)) +} + +func (h *docMaxHeap) Pop() any { + old := *h + n := len(old) + x := old[n-1] + *h = old[0 : n-1] + return x +} + +// maxDocSims manages a max-heap of docSims with a fixed size, keeping the n highest +// similarities. It's safe for concurrent use, but not the result of values(). +// In our benchmarks this was faster than sorting a slice of docSims at the end. +type maxDocSims struct { + h docMaxHeap + lock sync.RWMutex + size int +} + +// newMaxDocSims creates a new nMaxDocs with a fixed size. +func newMaxDocSims(size int) *maxDocSims { + return &maxDocSims{ + h: make(docMaxHeap, 0, size), + size: size, + } +} + +// add inserts a new docSim into the heap, keeping only the top n similarities. +func (mds *maxDocSims) add(doc docSim) { + mds.lock.Lock() + defer mds.lock.Unlock() + if mds.h.Len() < mds.size { + heap.Push(&mds.h, doc) + } else if mds.h.Len() > 0 && mds.h[0].similarity < doc.similarity { + // Replace the smallest similarity if the new doc's similarity is higher + heap.Pop(&mds.h) + heap.Push(&mds.h, doc) + } +} + +// values returns the docSims in the heap, sorted by similarity (descending). +// The call itself is safe for concurrent use with add(), but the result isn't. +// Only work with the result after all calls to add() have finished. +func (d *maxDocSims) values() []docSim { + d.lock.RLock() + defer d.lock.RUnlock() + slices.SortFunc(d.h, func(i, j docSim) int { + return cmp.Compare(j.similarity, i.similarity) + }) + return d.h +} + // filterDocs filters a map of documents by metadata and content. // It does this concurrently. func filterDocs(docs map[string]*Document, where, whereDocument map[string]string) []*Document { @@ -95,9 +162,8 @@ func documentMatchesFilters(document *Document, where, whereDocument map[string] return true } -func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Document) ([]docSim, error) { - similarities := make([]docSim, 0, len(docs)) - similaritiesLock := sync.Mutex{} +func getMostSimilarDocs(ctx context.Context, queryVectors []float32, docs []*Document, n int) ([]docSim, error) { + nMaxDocs := newMaxDocSims(n) // Determine concurrency. Use number of docs or CPUs, whichever is smaller. numCPUs := runtime.NumCPU() @@ -152,10 +218,7 @@ func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Docu return } - similaritiesLock.Lock() - // We don't defer the unlock because we want to unlock much earlier. - similarities = append(similarities, docSim{docID: doc.ID, similarity: sim}) - similaritiesLock.Unlock() + nMaxDocs.add(docSim{docID: doc.ID, similarity: sim}) } }(docs[start:end]) } @@ -166,5 +229,5 @@ func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Docu return nil, sharedErr } - return similarities, nil + return nMaxDocs.values(), nil }