Skip to content

Commit

Permalink
Don't normalize user supplied vectors in place
Browse files Browse the repository at this point in the history
Except for in AddDocument as the caller hands of control of the Document
to us there.
  • Loading branch information
erikdubbelboer committed May 26, 2024
1 parent 3598ff0 commit 43e1a07
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 11 deletions.
4 changes: 2 additions & 2 deletions collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ func (c *Collection) AddDocument(ctx context.Context, doc Document) error {
doc.Embedding = embedding
} else {
if !isNormalized(doc.Embedding) {
normalizeVector(doc.Embedding)
normalizeVectorInPlace(doc.Embedding)
}
}

Expand Down Expand Up @@ -388,7 +388,7 @@ func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float3
// Normalize embedding if not the case yet. We only support cosine similarity
// for now and all documents were already normalized when added to the collection.
if !isNormalized(queryEmbedding) {
normalizeVector(queryEmbedding)
queryEmbedding = normalizeVector(queryEmbedding)
}

// If the filtering already reduced the number of documents to fewer than nResults,
Expand Down
4 changes: 2 additions & 2 deletions collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ func benchmarkCollection_Query(b *testing.B, n int, withContent bool) {
qv[j] = r.Float32()
}
// The document embeddings are normalized, so the query must be normalized too.
normalizeVector(qv)
normalizeVectorInPlace(qv)

// Create collection
db := NewDB()
Expand All @@ -590,7 +590,7 @@ func benchmarkCollection_Query(b *testing.B, n int, withContent bool) {
for j := 0; j < d; j++ {
v[j] = r.Float32()
}
normalizeVector(v)
normalizeVectorInPlace(v)

// Add document with some metadata and content depending on parameter.
// When providing embeddings, the embedding func is not called.
Expand Down
2 changes: 1 addition & 1 deletion embed_cohere.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func NewEmbeddingFuncCohere(apiKey string, model EmbeddingModelCohere) Embedding
}
})
if !checkedNormalized {
normalizeVector(v)
normalizeVectorInPlace(v)
}

return v, nil
Expand Down
2 changes: 1 addition & 1 deletion embed_ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func NewEmbeddingFuncOllama(model string, baseURLOllama string) EmbeddingFunc {
}
})
if !checkedNormalized {
normalizeVector(v)
normalizeVectorInPlace(v)
}

return v, nil
Expand Down
4 changes: 2 additions & 2 deletions embed_openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func newEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *boo
if *normalized {
return v, nil
}
normalizeVector(v)
normalizeVectorInPlace(v)
return v, nil
}
checkNormalized.Do(func() {
Expand All @@ -153,7 +153,7 @@ func newEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *boo
}
})
if !checkedNormalized {
normalizeVector(v)
normalizeVectorInPlace(v)
}

return v, nil
Expand Down
13 changes: 10 additions & 3 deletions vector.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ func cosineSimilarity(a, b []float32) (float32, error) {
}

if !isNormalized(a) || !isNormalized(b) {
normalizeVector(a)
normalizeVector(b)
a = normalizeVector(a)
b = normalizeVector(b)
}
dotProduct, err := dotProduct(a, b)
if err != nil {
Expand Down Expand Up @@ -50,7 +50,7 @@ func dotProduct(a, b []float32) (float32, error) {
return dotProduct, nil
}

func normalizeVector(v []float32) {
func normalizeVectorInPlace(v []float32) {
var norm float32
for _, val := range v {
norm += val * val
Expand All @@ -62,6 +62,13 @@ func normalizeVector(v []float32) {
}
}

func normalizeVector(v []float32) []float32 {
r := make([]float32, len(v))
copy(r, v)
normalizeVectorInPlace(r)
return r
}

// isNormalized checks if the vector is normalized.
func isNormalized(v []float32) bool {
var sqSum float64
Expand Down

0 comments on commit 43e1a07

Please sign in to comment.