From 3598ff015b5bae79be857fee3e1c91ffcd5b3de6 Mon Sep 17 00:00:00 2001 From: Erik Dubbelboer Date: Sun, 19 May 2024 15:03:55 +0200 Subject: [PATCH 1/2] Make normalizeVector(v) in place All usage of normalizeVector was in the form of v = normalizeVector(v) In that case it's better to just do the replacement in place and not having to allocate a new slice. --- collection.go | 4 ++-- collection_test.go | 4 ++-- embed_cohere.go | 2 +- embed_ollama.go | 2 +- embed_openai.go | 5 +++-- vector.go | 10 ++++------ 6 files changed, 13 insertions(+), 14 deletions(-) diff --git a/collection.go b/collection.go index 2a8a5a0..115428c 100644 --- a/collection.go +++ b/collection.go @@ -225,7 +225,7 @@ func (c *Collection) AddDocument(ctx context.Context, doc Document) error { doc.Embedding = embedding } else { if !isNormalized(doc.Embedding) { - doc.Embedding = normalizeVector(doc.Embedding) + normalizeVector(doc.Embedding) } } @@ -388,7 +388,7 @@ func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float3 // Normalize embedding if not the case yet. We only support cosine similarity // for now and all documents were already normalized when added to the collection. if !isNormalized(queryEmbedding) { - queryEmbedding = normalizeVector(queryEmbedding) + normalizeVector(queryEmbedding) } // If the filtering already reduced the number of documents to fewer than nResults, diff --git a/collection_test.go b/collection_test.go index c652655..86f949a 100644 --- a/collection_test.go +++ b/collection_test.go @@ -567,7 +567,7 @@ func benchmarkCollection_Query(b *testing.B, n int, withContent bool) { qv[j] = r.Float32() } // The document embeddings are normalized, so the query must be normalized too. - qv = normalizeVector(qv) + normalizeVector(qv) // Create collection db := NewDB() @@ -590,7 +590,7 @@ func benchmarkCollection_Query(b *testing.B, n int, withContent bool) { for j := 0; j < d; j++ { v[j] = r.Float32() } - v = normalizeVector(v) + normalizeVector(v) // Add document with some metadata and content depending on parameter. // When providing embeddings, the embedding func is not called. diff --git a/embed_cohere.go b/embed_cohere.go index 0891924..ab4a341 100644 --- a/embed_cohere.go +++ b/embed_cohere.go @@ -159,7 +159,7 @@ func NewEmbeddingFuncCohere(apiKey string, model EmbeddingModelCohere) Embedding } }) if !checkedNormalized { - v = normalizeVector(v) + normalizeVector(v) } return v, nil diff --git a/embed_ollama.go b/embed_ollama.go index 019c669..561fc72 100644 --- a/embed_ollama.go +++ b/embed_ollama.go @@ -91,7 +91,7 @@ func NewEmbeddingFuncOllama(model string, baseURLOllama string) EmbeddingFunc { } }) if !checkedNormalized { - v = normalizeVector(v) + normalizeVector(v) } return v, nil diff --git a/embed_openai.go b/embed_openai.go index de3de82..4133061 100644 --- a/embed_openai.go +++ b/embed_openai.go @@ -142,7 +142,8 @@ func newEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *boo if *normalized { return v, nil } - return normalizeVector(v), nil + normalizeVector(v) + return v, nil } checkNormalized.Do(func() { if isNormalized(v) { @@ -152,7 +153,7 @@ func newEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *boo } }) if !checkedNormalized { - v = normalizeVector(v) + normalizeVector(v) } return v, nil diff --git a/vector.go b/vector.go index 972b6b2..506842b 100644 --- a/vector.go +++ b/vector.go @@ -19,7 +19,8 @@ func cosineSimilarity(a, b []float32) (float32, error) { } if !isNormalized(a) || !isNormalized(b) { - a, b = normalizeVector(a), normalizeVector(b) + normalizeVector(a) + normalizeVector(b) } dotProduct, err := dotProduct(a, b) if err != nil { @@ -49,19 +50,16 @@ func dotProduct(a, b []float32) (float32, error) { return dotProduct, nil } -func normalizeVector(v []float32) []float32 { +func normalizeVector(v []float32) { var norm float32 for _, val := range v { norm += val * val } norm = float32(math.Sqrt(float64(norm))) - res := make([]float32, len(v)) for i, val := range v { - res[i] = val / norm + v[i] = val / norm } - - return res } // isNormalized checks if the vector is normalized. From 43e1a07713135b6743470319b99d0aedba1ba363 Mon Sep 17 00:00:00 2001 From: Erik Dubbelboer Date: Sun, 26 May 2024 08:31:03 +0200 Subject: [PATCH 2/2] Don't normalize user supplied vectors in place Except for in AddDocument as the caller hands of control of the Document to us there. --- collection.go | 4 ++-- collection_test.go | 4 ++-- embed_cohere.go | 2 +- embed_ollama.go | 2 +- embed_openai.go | 4 ++-- vector.go | 13 ++++++++++--- 6 files changed, 18 insertions(+), 11 deletions(-) diff --git a/collection.go b/collection.go index 115428c..3e2efc4 100644 --- a/collection.go +++ b/collection.go @@ -225,7 +225,7 @@ func (c *Collection) AddDocument(ctx context.Context, doc Document) error { doc.Embedding = embedding } else { if !isNormalized(doc.Embedding) { - normalizeVector(doc.Embedding) + normalizeVectorInPlace(doc.Embedding) } } @@ -388,7 +388,7 @@ func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float3 // Normalize embedding if not the case yet. We only support cosine similarity // for now and all documents were already normalized when added to the collection. if !isNormalized(queryEmbedding) { - normalizeVector(queryEmbedding) + queryEmbedding = normalizeVector(queryEmbedding) } // If the filtering already reduced the number of documents to fewer than nResults, diff --git a/collection_test.go b/collection_test.go index 86f949a..3db67d3 100644 --- a/collection_test.go +++ b/collection_test.go @@ -567,7 +567,7 @@ func benchmarkCollection_Query(b *testing.B, n int, withContent bool) { qv[j] = r.Float32() } // The document embeddings are normalized, so the query must be normalized too. - normalizeVector(qv) + normalizeVectorInPlace(qv) // Create collection db := NewDB() @@ -590,7 +590,7 @@ func benchmarkCollection_Query(b *testing.B, n int, withContent bool) { for j := 0; j < d; j++ { v[j] = r.Float32() } - normalizeVector(v) + normalizeVectorInPlace(v) // Add document with some metadata and content depending on parameter. // When providing embeddings, the embedding func is not called. diff --git a/embed_cohere.go b/embed_cohere.go index ab4a341..5bee627 100644 --- a/embed_cohere.go +++ b/embed_cohere.go @@ -159,7 +159,7 @@ func NewEmbeddingFuncCohere(apiKey string, model EmbeddingModelCohere) Embedding } }) if !checkedNormalized { - normalizeVector(v) + normalizeVectorInPlace(v) } return v, nil diff --git a/embed_ollama.go b/embed_ollama.go index 561fc72..f2bfb2f 100644 --- a/embed_ollama.go +++ b/embed_ollama.go @@ -91,7 +91,7 @@ func NewEmbeddingFuncOllama(model string, baseURLOllama string) EmbeddingFunc { } }) if !checkedNormalized { - normalizeVector(v) + normalizeVectorInPlace(v) } return v, nil diff --git a/embed_openai.go b/embed_openai.go index 4133061..d302fc0 100644 --- a/embed_openai.go +++ b/embed_openai.go @@ -142,7 +142,7 @@ func newEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *boo if *normalized { return v, nil } - normalizeVector(v) + normalizeVectorInPlace(v) return v, nil } checkNormalized.Do(func() { @@ -153,7 +153,7 @@ func newEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *boo } }) if !checkedNormalized { - normalizeVector(v) + normalizeVectorInPlace(v) } return v, nil diff --git a/vector.go b/vector.go index 506842b..99caff9 100644 --- a/vector.go +++ b/vector.go @@ -19,8 +19,8 @@ func cosineSimilarity(a, b []float32) (float32, error) { } if !isNormalized(a) || !isNormalized(b) { - normalizeVector(a) - normalizeVector(b) + a = normalizeVector(a) + b = normalizeVector(b) } dotProduct, err := dotProduct(a, b) if err != nil { @@ -50,7 +50,7 @@ func dotProduct(a, b []float32) (float32, error) { return dotProduct, nil } -func normalizeVector(v []float32) { +func normalizeVectorInPlace(v []float32) { var norm float32 for _, val := range v { norm += val * val @@ -62,6 +62,13 @@ func normalizeVector(v []float32) { } } +func normalizeVector(v []float32) []float32 { + r := make([]float32, len(v)) + copy(r, v) + normalizeVectorInPlace(r) + return r +} + // isNormalized checks if the vector is normalized. func isNormalized(v []float32) bool { var sqSum float64