Skip to content

Commit

Permalink
Merge pull request #77 from philippgille/fix-query-embedding-is-not-n…
Browse files Browse the repository at this point in the history
…ormalized

Fix query embedding isn't normalized in Collection.QueryEmbedding() call
  • Loading branch information
philippgille authored May 17, 2024
2 parents 5496d32 + 1b10941 commit a0850b1
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
7 changes: 7 additions & 0 deletions collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int,
//
// - queryEmbedding: The embedding of the query to search for. It must be created
// with the same embedding model as the document embeddings in the collection.
// The embedding will be normalized if it's not the case yet.
// - nResults: The number of results to return. Must be > 0.
// - where: Conditional filtering on metadata. Optional.
// - whereDocument: Conditional filtering on documents. Optional.
Expand Down Expand Up @@ -382,6 +383,12 @@ func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float3
return nil, nil
}

// Normalize embedding if not the case yet. We only support cosine similarity
// for now and all documents were already normalized when added to the collection.
if !isNormalized(queryEmbedding) {
queryEmbedding = normalizeVector(queryEmbedding)
}

// For the remaining documents, get the most similar docs.
nMaxDocs, err := getMostSimilarDocs(ctx, queryEmbedding, filteredDocs, nResults)
if err != nil {
Expand Down
4 changes: 1 addition & 3 deletions collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,6 @@ func TestCollection_Delete(t *testing.T) {

// Check number of files in the persist directory
d, err := os.ReadDir(c.persistDirectory)

if err != nil {
t.Fatal("expected nil, got", err)
}
Expand Down Expand Up @@ -507,7 +506,6 @@ func TestCollection_Delete(t *testing.T) {
}

checkCount(0)

}

// Global var for assignment in the benchmark to avoid compiler optimizations.
Expand Down Expand Up @@ -566,7 +564,7 @@ func benchmarkCollection_Query(b *testing.B, n int, withContent bool) {
for j := 0; j < d; j++ {
qv[j] = r.Float32()
}
// Most embeddings are normalized, so we normalize this one too
// The document embeddings are normalized, so the query must be normalized too.
qv = normalizeVector(qv)

// Create collection
Expand Down

0 comments on commit a0850b1

Please sign in to comment.