diff --git a/collection.go b/collection.go index 68be66f..e0e41fb 100644 --- a/collection.go +++ b/collection.go @@ -343,6 +343,38 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int, return c.QueryEmbedding(ctx, queryVectors, nResults, where, whereDocument) } +// Performs an exhaustive nearest neighbor search on the collection. +// +// - queryText: The text to search for. Its embedding will be created using the +// collection's embedding function. +// - negativeText: The text to subtract from the query embedding. Its embedding +// will be created using the collection's embedding function. +// - nResults: The number of results to return. Must be > 0. +// - where: Conditional filtering on metadata. Optional. +// - whereDocument: Conditional filtering on documents. Optional. +func (c *Collection) QueryWithNegative(ctx context.Context, queryText string, negativeText string, nResults int, where, whereDocument map[string]string) ([]Result, error) { + if queryText == "" { + return nil, errors.New("queryText is empty") + } + + queryVectors, err := c.embed(ctx, queryText) + if err != nil { + return nil, fmt.Errorf("couldn't create embedding of query: %w", err) + } + + if negativeText != "" { + negativeVectors, err := c.embed(ctx, negativeText) + if err != nil { + return nil, fmt.Errorf("couldn't create embedding of negative: %w", err) + } + + queryVectors = subtractVector(queryVectors, negativeVectors) + queryVectors = normalizeVector(queryVectors) + } + + return c.QueryEmbedding(ctx, queryVectors, nResults, where, whereDocument) +} + // Performs an exhaustive nearest neighbor search on the collection. // // - queryEmbedding: The embedding of the query to search for. It must be created diff --git a/vector.go b/vector.go index 972b6b2..558126c 100644 --- a/vector.go +++ b/vector.go @@ -64,6 +64,15 @@ func normalizeVector(v []float32) []float32 { return res } +// subtractVector subtracts vector b from vector a in place. +func subtractVector(a, b []float32) []float32 { + for i := range a { + a[i] -= b[i] + } + + return a +} + // isNormalized checks if the vector is normalized. func isNormalized(v []float32) bool { var sqSum float64