From 1b1094197bc6cd096b695407a818e5f9e2f8fb62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Sun, 12 May 2024 15:57:04 +0200 Subject: [PATCH] Fix query embedding isn't normalized Only when calling Collection.QueryEmbedding() and passing existing embeddings that aren't normalized yet. --- collection.go | 7 +++++++ collection_test.go | 4 +--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/collection.go b/collection.go index 2a2c859..68be66f 100644 --- a/collection.go +++ b/collection.go @@ -347,6 +347,7 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int, // // - queryEmbedding: The embedding of the query to search for. It must be created // with the same embedding model as the document embeddings in the collection. +// The embedding will be normalized if it's not the case yet. // - nResults: The number of results to return. Must be > 0. // - where: Conditional filtering on metadata. Optional. // - whereDocument: Conditional filtering on documents. Optional. @@ -382,6 +383,12 @@ func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float3 return nil, nil } + // Normalize embedding if not the case yet. We only support cosine similarity + // for now and all documents were already normalized when added to the collection. + if !isNormalized(queryEmbedding) { + queryEmbedding = normalizeVector(queryEmbedding) + } + // For the remaining documents, get the most similar docs. nMaxDocs, err := getMostSimilarDocs(ctx, queryEmbedding, filteredDocs, nResults) if err != nil { diff --git a/collection_test.go b/collection_test.go index e37e511..0823748 100644 --- a/collection_test.go +++ b/collection_test.go @@ -461,7 +461,6 @@ func TestCollection_Delete(t *testing.T) { // Check number of files in the persist directory d, err := os.ReadDir(c.persistDirectory) - if err != nil { t.Fatal("expected nil, got", err) } @@ -507,7 +506,6 @@ func TestCollection_Delete(t *testing.T) { } checkCount(0) - } // Global var for assignment in the benchmark to avoid compiler optimizations. @@ -566,7 +564,7 @@ func benchmarkCollection_Query(b *testing.B, n int, withContent bool) { for j := 0; j < d; j++ { qv[j] = r.Float32() } - // Most embeddings are normalized, so we normalize this one too + // The document embeddings are normalized, so the query must be normalized too. qv = normalizeVector(qv) // Create collection