Skip to content

Commit

Permalink
Fix out of range panic on query
Browse files Browse the repository at this point in the history
And clarify in Godoc that nResults is the *maximum* number of results.
  • Loading branch information
philippgille committed May 17, 2024
1 parent 1de154b commit 896f62e
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,8 @@ type Result struct {
//
// - queryText: The text to search for. Its embedding will be created using the
// collection's embedding function.
// - nResults: The number of results to return. Must be > 0.
// - nResults: The maximum number of results to return. Must be > 0.
// There can be fewer results if a filter is applied.
// - where: Conditional filtering on metadata. Optional.
// - whereDocument: Conditional filtering on documents. Optional.
func (c *Collection) Query(ctx context.Context, queryText string, nResults int, where, whereDocument map[string]string) ([]Result, error) {
Expand All @@ -348,7 +349,8 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int,
// - queryEmbedding: The embedding of the query to search for. It must be created
// with the same embedding model as the document embeddings in the collection.
// The embedding will be normalized if it's not the case yet.
// - nResults: The number of results to return. Must be > 0.
// - nResults: The maximum number of results to return. Must be > 0.
// There can be fewer results if a filter is applied.
// - where: Conditional filtering on metadata. Optional.
// - whereDocument: Conditional filtering on documents. Optional.
func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float32, nResults int, where, whereDocument map[string]string) ([]Result, error) {
Expand Down Expand Up @@ -389,14 +391,26 @@ func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float3
queryEmbedding = normalizeVector(queryEmbedding)
}

// If the filtering already reduced the number of documents to fewer than nResults,
// we only need to find the most similar docs among the filtered ones.
resLen := nResults
if len(filteredDocs) < nResults {
resLen = len(filteredDocs)
}

// For the remaining documents, get the most similar docs.
nMaxDocs, err := getMostSimilarDocs(ctx, queryEmbedding, filteredDocs, nResults)
nMaxDocs, err := getMostSimilarDocs(ctx, queryEmbedding, filteredDocs, resLen)
if err != nil {
return nil, fmt.Errorf("couldn't get most similar docs: %w", err)
}

res := make([]Result, 0, nResults)
for i := 0; i < nResults; i++ {
// As long as we don't filter by threshold, resLen should match len(nMaxDocs).
if resLen != len(nMaxDocs) {
return nil, fmt.Errorf("internal error: expected %d results, got %d", resLen, len(nMaxDocs))
}

res := make([]Result, 0, resLen)
for i := 0; i < resLen; i++ {
res = append(res, Result{
ID: nMaxDocs[i].docID,
Metadata: c.documents[nMaxDocs[i].docID].Metadata,
Expand All @@ -406,7 +420,6 @@ func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float3
})
}

// Return the top nResults
return res, nil
}

Expand Down

0 comments on commit 896f62e

Please sign in to comment.