diff --git a/collection.go b/collection.go index 68be66f..08db442 100644 --- a/collection.go +++ b/collection.go @@ -225,7 +225,7 @@ func (c *Collection) AddDocument(ctx context.Context, doc Document) error { doc.Embedding = embedding } else { if !isNormalized(doc.Embedding) { - doc.Embedding = normalizeVector(doc.Embedding) + normalizeVector(doc.Embedding) } } @@ -386,7 +386,7 @@ func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float3 // Normalize embedding if not the case yet. We only support cosine similarity // for now and all documents were already normalized when added to the collection. if !isNormalized(queryEmbedding) { - queryEmbedding = normalizeVector(queryEmbedding) + normalizeVector(queryEmbedding) } // For the remaining documents, get the most similar docs. diff --git a/collection_test.go b/collection_test.go index c652655..86f949a 100644 --- a/collection_test.go +++ b/collection_test.go @@ -567,7 +567,7 @@ func benchmarkCollection_Query(b *testing.B, n int, withContent bool) { qv[j] = r.Float32() } // The document embeddings are normalized, so the query must be normalized too. - qv = normalizeVector(qv) + normalizeVector(qv) // Create collection db := NewDB() @@ -590,7 +590,7 @@ func benchmarkCollection_Query(b *testing.B, n int, withContent bool) { for j := 0; j < d; j++ { v[j] = r.Float32() } - v = normalizeVector(v) + normalizeVector(v) // Add document with some metadata and content depending on parameter. // When providing embeddings, the embedding func is not called. diff --git a/embed_cohere.go b/embed_cohere.go index 0891924..ab4a341 100644 --- a/embed_cohere.go +++ b/embed_cohere.go @@ -159,7 +159,7 @@ func NewEmbeddingFuncCohere(apiKey string, model EmbeddingModelCohere) Embedding } }) if !checkedNormalized { - v = normalizeVector(v) + normalizeVector(v) } return v, nil diff --git a/embed_ollama.go b/embed_ollama.go index 019c669..561fc72 100644 --- a/embed_ollama.go +++ b/embed_ollama.go @@ -91,7 +91,7 @@ func NewEmbeddingFuncOllama(model string, baseURLOllama string) EmbeddingFunc { } }) if !checkedNormalized { - v = normalizeVector(v) + normalizeVector(v) } return v, nil diff --git a/embed_openai.go b/embed_openai.go index de3de82..4133061 100644 --- a/embed_openai.go +++ b/embed_openai.go @@ -142,7 +142,8 @@ func newEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *boo if *normalized { return v, nil } - return normalizeVector(v), nil + normalizeVector(v) + return v, nil } checkNormalized.Do(func() { if isNormalized(v) { @@ -152,7 +153,7 @@ func newEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *boo } }) if !checkedNormalized { - v = normalizeVector(v) + normalizeVector(v) } return v, nil diff --git a/vector.go b/vector.go index 972b6b2..1fa083c 100644 --- a/vector.go +++ b/vector.go @@ -49,19 +49,16 @@ func dotProduct(a, b []float32) (float32, error) { return dotProduct, nil } -func normalizeVector(v []float32) []float32 { +func normalizeVector(v []float32) { var norm float32 for _, val := range v { norm += val * val } norm = float32(math.Sqrt(float64(norm))) - res := make([]float32, len(v)) for i, val := range v { - res[i] = val / norm + v[i] = val / norm } - - return res } // isNormalized checks if the vector is normalized.