From 83fecdbdfce424c910feeb99b5f3cb3b81e9aedf Mon Sep 17 00:00:00 2001 From: Erik Dubbelboer Date: Sun, 19 May 2024 14:54:48 +0200 Subject: [PATCH] Add QueryWithNegative QueryWithNegative allows you to pass a string that will be excluded from the search results. There are tree ways to implement this: 1. Find the document matching the normal query. Then re-order the result by multiplying the similarity by the dot product with the negative vector and some constant. 2. Find the document matching the normal query. Exclude documents where the dot product with the negative vector are above a constant. 3. The simpler method I implemented which just subtracts the negative vector from the positive one and re-normalizes the result. I have done some simple tests and the results look good. I'm not sure if the extra function is a nice API. It could also be added as extra argument to Query. Or maybe Query should get a struct for it's argument as the number of arguments seems to keep increasing. --- collection.go | 32 ++++++++++++++++++++++++++++++++ vector.go | 9 +++++++++ 2 files changed, 41 insertions(+) diff --git a/collection.go b/collection.go index 68be66f..e0e41fb 100644 --- a/collection.go +++ b/collection.go @@ -343,6 +343,38 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int, return c.QueryEmbedding(ctx, queryVectors, nResults, where, whereDocument) } +// Performs an exhaustive nearest neighbor search on the collection. +// +// - queryText: The text to search for. Its embedding will be created using the +// collection's embedding function. +// - negativeText: The text to subtract from the query embedding. Its embedding +// will be created using the collection's embedding function. +// - nResults: The number of results to return. Must be > 0. +// - where: Conditional filtering on metadata. Optional. +// - whereDocument: Conditional filtering on documents. Optional. +func (c *Collection) QueryWithNegative(ctx context.Context, queryText string, negativeText string, nResults int, where, whereDocument map[string]string) ([]Result, error) { + if queryText == "" { + return nil, errors.New("queryText is empty") + } + + queryVectors, err := c.embed(ctx, queryText) + if err != nil { + return nil, fmt.Errorf("couldn't create embedding of query: %w", err) + } + + if negativeText != "" { + negativeVectors, err := c.embed(ctx, negativeText) + if err != nil { + return nil, fmt.Errorf("couldn't create embedding of negative: %w", err) + } + + queryVectors = subtractVector(queryVectors, negativeVectors) + queryVectors = normalizeVector(queryVectors) + } + + return c.QueryEmbedding(ctx, queryVectors, nResults, where, whereDocument) +} + // Performs an exhaustive nearest neighbor search on the collection. // // - queryEmbedding: The embedding of the query to search for. It must be created diff --git a/vector.go b/vector.go index 972b6b2..558126c 100644 --- a/vector.go +++ b/vector.go @@ -64,6 +64,15 @@ func normalizeVector(v []float32) []float32 { return res } +// subtractVector subtracts vector b from vector a in place. +func subtractVector(a, b []float32) []float32 { + for i := range a { + a[i] -= b[i] + } + + return a +} + // isNormalized checks if the vector is normalized. func isNormalized(v []float32) bool { var sqSum float64