From 896f62e0dfd1b18984b1c520fd3c1d8ada98db87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Fri, 17 May 2024 23:11:42 +0200 Subject: [PATCH] Fix out of range panic on query And clarify in Godoc that nResults is the *maximum* number of results. --- collection.go | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/collection.go b/collection.go index 68be66f..2a8a5a0 100644 --- a/collection.go +++ b/collection.go @@ -327,7 +327,8 @@ type Result struct { // // - queryText: The text to search for. Its embedding will be created using the // collection's embedding function. -// - nResults: The number of results to return. Must be > 0. +// - nResults: The maximum number of results to return. Must be > 0. +// There can be fewer results if a filter is applied. // - where: Conditional filtering on metadata. Optional. // - whereDocument: Conditional filtering on documents. Optional. func (c *Collection) Query(ctx context.Context, queryText string, nResults int, where, whereDocument map[string]string) ([]Result, error) { @@ -348,7 +349,8 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int, // - queryEmbedding: The embedding of the query to search for. It must be created // with the same embedding model as the document embeddings in the collection. // The embedding will be normalized if it's not the case yet. -// - nResults: The number of results to return. Must be > 0. +// - nResults: The maximum number of results to return. Must be > 0. +// There can be fewer results if a filter is applied. // - where: Conditional filtering on metadata. Optional. // - whereDocument: Conditional filtering on documents. Optional. func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float32, nResults int, where, whereDocument map[string]string) ([]Result, error) { @@ -389,14 +391,26 @@ func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float3 queryEmbedding = normalizeVector(queryEmbedding) } + // If the filtering already reduced the number of documents to fewer than nResults, + // we only need to find the most similar docs among the filtered ones. + resLen := nResults + if len(filteredDocs) < nResults { + resLen = len(filteredDocs) + } + // For the remaining documents, get the most similar docs. - nMaxDocs, err := getMostSimilarDocs(ctx, queryEmbedding, filteredDocs, nResults) + nMaxDocs, err := getMostSimilarDocs(ctx, queryEmbedding, filteredDocs, resLen) if err != nil { return nil, fmt.Errorf("couldn't get most similar docs: %w", err) } - res := make([]Result, 0, nResults) - for i := 0; i < nResults; i++ { + // As long as we don't filter by threshold, resLen should match len(nMaxDocs). + if resLen != len(nMaxDocs) { + return nil, fmt.Errorf("internal error: expected %d results, got %d", resLen, len(nMaxDocs)) + } + + res := make([]Result, 0, resLen) + for i := 0; i < resLen; i++ { res = append(res, Result{ ID: nMaxDocs[i].docID, Metadata: c.documents[nMaxDocs[i].docID].Metadata, @@ -406,7 +420,6 @@ func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float3 }) } - // Return the top nResults return res, nil }