Skip to content

Commit

Permalink
Add QueryWithNegative
Browse files Browse the repository at this point in the history
QueryWithNegative allows you to pass a string that will be excluded from
the search results.

There are tree ways to implement this:
1. Find the document matching the normal query. Then re-order the result
   by multiplying the similarity by the dot product with the negative
   vector and some constant.
2. Find the document matching the normal query. Exclude documents where
   the dot product with the negative vector are above a constant.
3. The simpler method I implemented which just subtracts the negative
   vector from the positive one and re-normalizes the result.

I have done some simple tests and the results look good.

I'm not sure if the extra function is a nice API. It could also be added
as extra argument to Query. Or maybe Query should get a struct for it's
argument as the number of arguments seems to keep increasing.
  • Loading branch information
erikdubbelboer committed May 19, 2024
1 parent 1de154b commit 83fecdb
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
32 changes: 32 additions & 0 deletions collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,38 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int,
return c.QueryEmbedding(ctx, queryVectors, nResults, where, whereDocument)
}

// Performs an exhaustive nearest neighbor search on the collection.
//
// - queryText: The text to search for. Its embedding will be created using the
// collection's embedding function.
// - negativeText: The text to subtract from the query embedding. Its embedding
// will be created using the collection's embedding function.
// - nResults: The number of results to return. Must be > 0.
// - where: Conditional filtering on metadata. Optional.
// - whereDocument: Conditional filtering on documents. Optional.
func (c *Collection) QueryWithNegative(ctx context.Context, queryText string, negativeText string, nResults int, where, whereDocument map[string]string) ([]Result, error) {
if queryText == "" {
return nil, errors.New("queryText is empty")
}

queryVectors, err := c.embed(ctx, queryText)
if err != nil {
return nil, fmt.Errorf("couldn't create embedding of query: %w", err)
}

if negativeText != "" {
negativeVectors, err := c.embed(ctx, negativeText)
if err != nil {
return nil, fmt.Errorf("couldn't create embedding of negative: %w", err)
}

queryVectors = subtractVector(queryVectors, negativeVectors)
queryVectors = normalizeVector(queryVectors)
}

return c.QueryEmbedding(ctx, queryVectors, nResults, where, whereDocument)
}

// Performs an exhaustive nearest neighbor search on the collection.
//
// - queryEmbedding: The embedding of the query to search for. It must be created
Expand Down
9 changes: 9 additions & 0 deletions vector.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ func normalizeVector(v []float32) []float32 {
return res
}

// subtractVector subtracts vector b from vector a in place.
func subtractVector(a, b []float32) []float32 {
for i := range a {
a[i] -= b[i]
}

return a
}

// isNormalized checks if the vector is normalized.
func isNormalized(v []float32) bool {
var sqSum float64
Expand Down

0 comments on commit 83fecdb

Please sign in to comment.