Skip to content

Commit

Permalink
Make normalizeVector(v) in place
Browse files Browse the repository at this point in the history
All usage of normalizeVector was in the form of

    v = normalizeVector(v)

In that case it's better to just do the replacement in place and not
having to allocate a new slice.
  • Loading branch information
erikdubbelboer committed May 19, 2024
1 parent 1de154b commit fb1beb6
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 14 deletions.
4 changes: 2 additions & 2 deletions collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ func (c *Collection) AddDocument(ctx context.Context, doc Document) error {
doc.Embedding = embedding
} else {
if !isNormalized(doc.Embedding) {
doc.Embedding = normalizeVector(doc.Embedding)
normalizeVector(doc.Embedding)
}
}

Expand Down Expand Up @@ -386,7 +386,7 @@ func (c *Collection) QueryEmbedding(ctx context.Context, queryEmbedding []float3
// Normalize embedding if not the case yet. We only support cosine similarity
// for now and all documents were already normalized when added to the collection.
if !isNormalized(queryEmbedding) {
queryEmbedding = normalizeVector(queryEmbedding)
normalizeVector(queryEmbedding)
}

// For the remaining documents, get the most similar docs.
Expand Down
4 changes: 2 additions & 2 deletions collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ func benchmarkCollection_Query(b *testing.B, n int, withContent bool) {
qv[j] = r.Float32()
}
// The document embeddings are normalized, so the query must be normalized too.
qv = normalizeVector(qv)
normalizeVector(qv)

// Create collection
db := NewDB()
Expand All @@ -590,7 +590,7 @@ func benchmarkCollection_Query(b *testing.B, n int, withContent bool) {
for j := 0; j < d; j++ {
v[j] = r.Float32()
}
v = normalizeVector(v)
normalizeVector(v)

// Add document with some metadata and content depending on parameter.
// When providing embeddings, the embedding func is not called.
Expand Down
2 changes: 1 addition & 1 deletion embed_cohere.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func NewEmbeddingFuncCohere(apiKey string, model EmbeddingModelCohere) Embedding
}
})
if !checkedNormalized {
v = normalizeVector(v)
normalizeVector(v)
}

return v, nil
Expand Down
2 changes: 1 addition & 1 deletion embed_ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func NewEmbeddingFuncOllama(model string, baseURLOllama string) EmbeddingFunc {
}
})
if !checkedNormalized {
v = normalizeVector(v)
normalizeVector(v)
}

return v, nil
Expand Down
5 changes: 3 additions & 2 deletions embed_openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ func newEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *boo
if *normalized {
return v, nil
}
return normalizeVector(v), nil
normalizeVector(v)
return v, nil
}
checkNormalized.Do(func() {
if isNormalized(v) {
Expand All @@ -152,7 +153,7 @@ func newEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *boo
}
})
if !checkedNormalized {
v = normalizeVector(v)
normalizeVector(v)
}

return v, nil
Expand Down
10 changes: 4 additions & 6 deletions vector.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ func cosineSimilarity(a, b []float32) (float32, error) {
}

if !isNormalized(a) || !isNormalized(b) {
a, b = normalizeVector(a), normalizeVector(b)
normalizeVector(a)
normalizeVector(b)
}
dotProduct, err := dotProduct(a, b)
if err != nil {
Expand Down Expand Up @@ -49,19 +50,16 @@ func dotProduct(a, b []float32) (float32, error) {
return dotProduct, nil
}

func normalizeVector(v []float32) []float32 {
func normalizeVector(v []float32) {
var norm float32
for _, val := range v {
norm += val * val
}
norm = float32(math.Sqrt(float64(norm)))

res := make([]float32, len(v))
for i, val := range v {
res[i] = val / norm
v[i] = val / norm
}

return res
}

// isNormalized checks if the vector is normalized.
Expand Down

0 comments on commit fb1beb6

Please sign in to comment.