diff --git a/collection.go b/collection.go index 2a8a5a0..3e2efc4 100644 --- a/collection.go +++ b/collection.go @@ -225,7 +225,7 @@ func (c *Collection) AddDocument(ctx context.Context, doc Document) error { doc.Embedding = embedding } else { if !isNormalized(doc.Embedding) { - doc.Embedding = normalizeVector(doc.Embedding) + normalizeVectorInPlace(doc.Embedding) } } diff --git a/collection_test.go b/collection_test.go index c652655..3db67d3 100644 --- a/collection_test.go +++ b/collection_test.go @@ -567,7 +567,7 @@ func benchmarkCollection_Query(b *testing.B, n int, withContent bool) { qv[j] = r.Float32() } // The document embeddings are normalized, so the query must be normalized too. - qv = normalizeVector(qv) + normalizeVectorInPlace(qv) // Create collection db := NewDB() @@ -590,7 +590,7 @@ func benchmarkCollection_Query(b *testing.B, n int, withContent bool) { for j := 0; j < d; j++ { v[j] = r.Float32() } - v = normalizeVector(v) + normalizeVectorInPlace(v) // Add document with some metadata and content depending on parameter. // When providing embeddings, the embedding func is not called. diff --git a/embed_cohere.go b/embed_cohere.go index 0891924..5bee627 100644 --- a/embed_cohere.go +++ b/embed_cohere.go @@ -159,7 +159,7 @@ func NewEmbeddingFuncCohere(apiKey string, model EmbeddingModelCohere) Embedding } }) if !checkedNormalized { - v = normalizeVector(v) + normalizeVectorInPlace(v) } return v, nil diff --git a/embed_ollama.go b/embed_ollama.go index 019c669..f2bfb2f 100644 --- a/embed_ollama.go +++ b/embed_ollama.go @@ -91,7 +91,7 @@ func NewEmbeddingFuncOllama(model string, baseURLOllama string) EmbeddingFunc { } }) if !checkedNormalized { - v = normalizeVector(v) + normalizeVectorInPlace(v) } return v, nil diff --git a/embed_openai.go b/embed_openai.go index de3de82..d302fc0 100644 --- a/embed_openai.go +++ b/embed_openai.go @@ -142,7 +142,8 @@ func newEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *boo if *normalized { return v, nil } - return normalizeVector(v), nil + normalizeVectorInPlace(v) + return v, nil } checkNormalized.Do(func() { if isNormalized(v) { @@ -152,7 +153,7 @@ func newEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *boo } }) if !checkedNormalized { - v = normalizeVector(v) + normalizeVectorInPlace(v) } return v, nil diff --git a/vector.go b/vector.go index 972b6b2..99caff9 100644 --- a/vector.go +++ b/vector.go @@ -19,7 +19,8 @@ func cosineSimilarity(a, b []float32) (float32, error) { } if !isNormalized(a) || !isNormalized(b) { - a, b = normalizeVector(a), normalizeVector(b) + a = normalizeVector(a) + b = normalizeVector(b) } dotProduct, err := dotProduct(a, b) if err != nil { @@ -49,19 +50,23 @@ func dotProduct(a, b []float32) (float32, error) { return dotProduct, nil } -func normalizeVector(v []float32) []float32 { +func normalizeVectorInPlace(v []float32) { var norm float32 for _, val := range v { norm += val * val } norm = float32(math.Sqrt(float64(norm))) - res := make([]float32, len(v)) for i, val := range v { - res[i] = val / norm + v[i] = val / norm } +} - return res +func normalizeVector(v []float32) []float32 { + r := make([]float32, len(v)) + copy(r, v) + normalizeVectorInPlace(r) + return r } // isNormalized checks if the vector is normalized.