Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix out of range panic on query #79

Merged
merged 1 commit into from
May 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,8 @@ type Result struct {
//
// - queryText: The text to search for. Its embedding will be created using the
// collection's embedding function.
// - nResults: The number of results to return. Must be > 0.
// - nResults: The maximum number of results to return. Must be > 0.
// There can be fewer results if a filter is applied.
// - where: Conditional filtering on metadata. Optional.
// - whereDocument: Conditional filtering on documents. Optional.
func (c *Collection) Query(ctx context.Context, queryText string, nResults int, where, whereDocument map[string]string) ([]Result, error) {
Expand All @@ -348,7 +349,8 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int,
// - queryEmbedding: The embedding of the query to search for. It must be created
// with the same embedding model as the document embeddings in the collection.
// The embedding will be normalized if it's not the case yet.
// - nResults: The number of results to return. Must be > 0.
// - nResults: The maximum number of results to return. Must be > 0.
// There can be fewer results if a filter is applied.
// - where: Conditional filtering on metadata. Optional.
// - whereDocument: Conditional filtering on documents. Optional.
func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float32, nResults int, where, whereDocument map[string]string) ([]Result, error) {
Expand Down Expand Up @@ -389,14 +391,26 @@ func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float3
queryEmbedding = normalizeVector(queryEmbedding)
}

// If the filtering already reduced the number of documents to fewer than nResults,
// we only need to find the most similar docs among the filtered ones.
resLen := nResults
if len(filteredDocs) < nResults {
resLen = len(filteredDocs)
}

// For the remaining documents, get the most similar docs.
nMaxDocs, err := getMostSimilarDocs(ctx, queryEmbedding, filteredDocs, nResults)
nMaxDocs, err := getMostSimilarDocs(ctx, queryEmbedding, filteredDocs, resLen)
if err != nil {
return nil, fmt.Errorf("couldn't get most similar docs: %w", err)
}

res := make([]Result, 0, nResults)
for i := 0; i < nResults; i++ {
// As long as we don't filter by threshold, resLen should match len(nMaxDocs).
if resLen != len(nMaxDocs) {
return nil, fmt.Errorf("internal error: expected %d results, got %d", resLen, len(nMaxDocs))
}

res := make([]Result, 0, resLen)
for i := 0; i < resLen; i++ {
res = append(res, Result{
ID: nMaxDocs[i].docID,
Metadata: c.documents[nMaxDocs[i].docID].Metadata,
Expand All @@ -406,7 +420,6 @@ func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float3
})
}

// Return the top nResults
return res, nil
}

Expand Down