Skip to content

Commit

Permalink
Merge pull request #79 from philippgille/fix-out-of-range-panic-on-query
Browse files Browse the repository at this point in the history
Fix out of range panic on query
  • Loading branch information
philippgille authored May 20, 2024
2 parents 1de154b + 896f62e commit 3e0fa51
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,8 @@ type Result struct {
//
// - queryText: The text to search for. Its embedding will be created using the
// collection's embedding function.
// - nResults: The number of results to return. Must be > 0.
// - nResults: The maximum number of results to return. Must be > 0.
// There can be fewer results if a filter is applied.
// - where: Conditional filtering on metadata. Optional.
// - whereDocument: Conditional filtering on documents. Optional.
func (c *Collection) Query(ctx context.Context, queryText string, nResults int, where, whereDocument map[string]string) ([]Result, error) {
Expand All @@ -348,7 +349,8 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int,
// - queryEmbedding: The embedding of the query to search for. It must be created
// with the same embedding model as the document embeddings in the collection.
// The embedding will be normalized if it's not the case yet.
// - nResults: The number of results to return. Must be > 0.
// - nResults: The maximum number of results to return. Must be > 0.
// There can be fewer results if a filter is applied.
// - where: Conditional filtering on metadata. Optional.
// - whereDocument: Conditional filtering on documents. Optional.
func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float32, nResults int, where, whereDocument map[string]string) ([]Result, error) {
Expand Down Expand Up @@ -389,14 +391,26 @@ func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float3
queryEmbedding = normalizeVector(queryEmbedding)
}

// If the filtering already reduced the number of documents to fewer than nResults,
// we only need to find the most similar docs among the filtered ones.
resLen := nResults
if len(filteredDocs) < nResults {
resLen = len(filteredDocs)
}

// For the remaining documents, get the most similar docs.
nMaxDocs, err := getMostSimilarDocs(ctx, queryEmbedding, filteredDocs, nResults)
nMaxDocs, err := getMostSimilarDocs(ctx, queryEmbedding, filteredDocs, resLen)
if err != nil {
return nil, fmt.Errorf("couldn't get most similar docs: %w", err)
}

res := make([]Result, 0, nResults)
for i := 0; i < nResults; i++ {
// As long as we don't filter by threshold, resLen should match len(nMaxDocs).
if resLen != len(nMaxDocs) {
return nil, fmt.Errorf("internal error: expected %d results, got %d", resLen, len(nMaxDocs))
}

res := make([]Result, 0, resLen)
for i := 0; i < resLen; i++ {
res = append(res, Result{
ID: nMaxDocs[i].docID,
Metadata: c.documents[nMaxDocs[i].docID].Metadata,
Expand All @@ -406,7 +420,6 @@ func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float3
})
}

// Return the top nResults
return res, nil
}

Expand Down

0 comments on commit 3e0fa51

Please sign in to comment.