diff --git a/collection.go b/collection.go index 68be66f..d30f26f 100644 --- a/collection.go +++ b/collection.go @@ -327,7 +327,8 @@ type Result struct { // // - queryText: The text to search for. Its embedding will be created using the // collection's embedding function. -// - nResults: The number of results to return. Must be > 0. +// - nResults: The maximum number of results to return. Must be > 0. +// There can be fewer results if a filter is applied. // - where: Conditional filtering on metadata. Optional. // - whereDocument: Conditional filtering on documents. Optional. func (c *Collection) Query(ctx context.Context, queryText string, nResults int, where, whereDocument map[string]string) ([]Result, error) { @@ -348,7 +349,8 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int, // - queryEmbedding: The embedding of the query to search for. It must be created // with the same embedding model as the document embeddings in the collection. // The embedding will be normalized if it's not the case yet. -// - nResults: The number of results to return. Must be > 0. +// - nResults: The maximum number of results to return. Must be > 0. +// There can be fewer results if a filter is applied. // - where: Conditional filtering on metadata. Optional. // - whereDocument: Conditional filtering on documents. Optional. func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float32, nResults int, where, whereDocument map[string]string) ([]Result, error) { @@ -395,8 +397,16 @@ func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float3 return nil, fmt.Errorf("couldn't get most similar docs: %w", err) } - res := make([]Result, 0, nResults) - for i := 0; i < nResults; i++ { + // We return nResults or len(nMaxDocs), whichever is smaller. + // The latter can be the case if the collection has fewer documents than nResults + // or if the filtering reduced the number of documents. + resNum := nResults + if len(nMaxDocs) < nResults { + resNum = len(nMaxDocs) + } + + res := make([]Result, 0, resNum) + for i := 0; i < resNum; i++ { res = append(res, Result{ ID: nMaxDocs[i].docID, Metadata: c.documents[nMaxDocs[i].docID].Metadata, @@ -406,7 +416,6 @@ func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float3 }) } - // Return the top nResults return res, nil }