diff --git a/collection.go b/collection.go index 68be66f..2a8a5a0 100644 --- a/collection.go +++ b/collection.go @@ -327,7 +327,8 @@ type Result struct { // // - queryText: The text to search for. Its embedding will be created using the // collection's embedding function. -// - nResults: The number of results to return. Must be > 0. +// - nResults: The maximum number of results to return. Must be > 0. +// There can be fewer results if a filter is applied. // - where: Conditional filtering on metadata. Optional. // - whereDocument: Conditional filtering on documents. Optional. func (c *Collection) Query(ctx context.Context, queryText string, nResults int, where, whereDocument map[string]string) ([]Result, error) { @@ -348,7 +349,8 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int, // - queryEmbedding: The embedding of the query to search for. It must be created // with the same embedding model as the document embeddings in the collection. // The embedding will be normalized if it's not the case yet. -// - nResults: The number of results to return. Must be > 0. +// - nResults: The maximum number of results to return. Must be > 0. +// There can be fewer results if a filter is applied. // - where: Conditional filtering on metadata. Optional. // - whereDocument: Conditional filtering on documents. Optional. func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float32, nResults int, where, whereDocument map[string]string) ([]Result, error) { @@ -389,14 +391,26 @@ func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float3 queryEmbedding = normalizeVector(queryEmbedding) } + // If the filtering already reduced the number of documents to fewer than nResults, + // we only need to find the most similar docs among the filtered ones. + resLen := nResults + if len(filteredDocs) < nResults { + resLen = len(filteredDocs) + } + // For the remaining documents, get the most similar docs. - nMaxDocs, err := getMostSimilarDocs(ctx, queryEmbedding, filteredDocs, nResults) + nMaxDocs, err := getMostSimilarDocs(ctx, queryEmbedding, filteredDocs, resLen) if err != nil { return nil, fmt.Errorf("couldn't get most similar docs: %w", err) } - res := make([]Result, 0, nResults) - for i := 0; i < nResults; i++ { + // As long as we don't filter by threshold, resLen should match len(nMaxDocs). + if resLen != len(nMaxDocs) { + return nil, fmt.Errorf("internal error: expected %d results, got %d", resLen, len(nMaxDocs)) + } + + res := make([]Result, 0, resLen) + for i := 0; i < resLen; i++ { res = append(res, Result{ ID: nMaxDocs[i].docID, Metadata: c.documents[nMaxDocs[i].docID].Metadata, @@ -406,7 +420,6 @@ func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float3 }) } - // Return the top nResults return res, nil }