diff --git a/collection.go b/collection.go index 22a50c2..485660d 100644 --- a/collection.go +++ b/collection.go @@ -213,13 +213,17 @@ func (c *Collection) AddDocument(ctx context.Context, doc Document) error { m[k] = v } - // Create embedding if they don't exist + // Create embedding if they don't exist, otherwise normalize if necessary if len(doc.Embedding) == 0 { embedding, err := c.embed(ctx, doc.Content) if err != nil { return fmt.Errorf("couldn't create embedding of document: %w", err) } doc.Embedding = embedding + } else { + if !isNormalized(doc.Embedding) { + doc.Embedding = normalizeVector(doc.Embedding) + } } c.documentsLock.Lock() diff --git a/collection_test.go b/collection_test.go index 883fdab..274b271 100644 --- a/collection_test.go +++ b/collection_test.go @@ -13,7 +13,7 @@ func TestCollection_Add(t *testing.T) { ctx := context.Background() name := "test" metadata := map[string]string{"foo": "bar"} - vectors := []float32{-0.1, 0.1, 0.2} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { return vectors, nil } @@ -109,7 +109,7 @@ func TestCollection_Add_Error(t *testing.T) { ctx := context.Background() name := "test" metadata := map[string]string{"foo": "bar"} - vectors := []float32{-0.1, 0.1, 0.2} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { return vectors, nil } @@ -160,7 +160,7 @@ func TestCollection_AddConcurrently(t *testing.T) { ctx := context.Background() name := "test" metadata := map[string]string{"foo": "bar"} - vectors := []float32{-0.1, 0.1, 0.2} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { return vectors, nil } @@ -256,7 +256,7 @@ func TestCollection_AddConcurrently_Error(t *testing.T) { ctx := context.Background() name := "test" metadata := map[string]string{"foo": "bar"} - vectors := []float32{-0.1, 0.1, 0.2} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { return vectors, nil } @@ -313,8 +313,9 @@ func TestCollection_Count(t *testing.T) { db := NewDB() name := "test" metadata := map[string]string{"foo": "bar"} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { - return []float32{-0.1, 0.1, 0.2}, nil + return vectors, nil } c, err := db.CreateCollection(name, metadata, embeddingFunc) if err != nil { diff --git a/db.go b/db.go index 42ab698..7c78b39 100644 --- a/db.go +++ b/db.go @@ -13,6 +13,9 @@ import ( // EmbeddingFunc is a function that creates embeddings for a given text. // chromem-go will use OpenAI`s "text-embedding-3-small" model by default, // but you can provide your own function, using any model you like. +// The function must return a *normalized* vector, i.e. the length of the vector +// must be 1. OpenAI's and Mistral's embedding models do this by default. Some +// others like Nomic's "nomic-embed-text-v1.5" don't. type EmbeddingFunc func(ctx context.Context, text string) ([]float32, error) // DB is the chromem-go database. It holds collections, which hold documents. diff --git a/db_test.go b/db_test.go index ad04bbe..fb25373 100644 --- a/db_test.go +++ b/db_test.go @@ -10,7 +10,7 @@ func TestDB_CreateCollection(t *testing.T) { // Values in the collection name := "test" metadata := map[string]string{"foo": "bar"} - vectors := []float32{-0.1, 0.1, 0.2} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { return vectors, nil } @@ -81,7 +81,7 @@ func TestDB_ListCollections(t *testing.T) { // Values in the collection name := "test" metadata := map[string]string{"foo": "bar"} - vectors := []float32{-0.1, 0.1, 0.2} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { return vectors, nil } @@ -147,7 +147,7 @@ func TestDB_GetCollection(t *testing.T) { // Values in the collection name := "test" metadata := map[string]string{"foo": "bar"} - vectors := []float32{-0.1, 0.1, 0.2} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { return vectors, nil } @@ -196,7 +196,7 @@ func TestDB_GetOrCreateCollection(t *testing.T) { // Values in the collection name := "test" metadata := map[string]string{"foo": "bar"} - vectors := []float32{-0.1, 0.1, 0.2} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { return vectors, nil } @@ -299,7 +299,7 @@ func TestDB_DeleteCollection(t *testing.T) { // Values in the collection name := "test" metadata := map[string]string{"foo": "bar"} - vectors := []float32{-0.1, 0.1, 0.2} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { return vectors, nil } @@ -331,7 +331,7 @@ func TestDB_Reset(t *testing.T) { // Values in the collection name := "test" metadata := map[string]string{"foo": "bar"} - vectors := []float32{-0.1, 0.1, 0.2} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { return vectors, nil } diff --git a/document_test.go b/document_test.go index 668d6b0..4290c5a 100644 --- a/document_test.go +++ b/document_test.go @@ -10,7 +10,7 @@ func TestDocument_New(t *testing.T) { ctx := context.Background() id := "test" metadata := map[string]string{"foo": "bar"} - vectors := []float32{-0.1, 0.1, 0.2} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` content := "hello world" embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { return vectors, nil diff --git a/embed_compat.go b/embed_compat.go index a3d18c9..c3cf4ed 100644 --- a/embed_compat.go +++ b/embed_compat.go @@ -9,9 +9,13 @@ const ( // NewEmbeddingFuncMistral returns a function that creates embeddings for a text // using the Mistral API. func NewEmbeddingFuncMistral(apiKey string) EmbeddingFunc { + // Mistral embeddings are normalized, see section "Distance Measures" on + // https://docs.mistral.ai/guides/embeddings/. + normalized := true + // The Mistral API docs don't mention the `encoding_format` as optional, // but it seems to be, just like OpenAI. So we reuse the OpenAI function. - return NewEmbeddingFuncOpenAICompat(baseURLMistral, apiKey, embeddingModelMistral) + return NewEmbeddingFuncOpenAICompat(baseURLMistral, apiKey, embeddingModelMistral, &normalized) } const baseURLJina = "https://api.jina.ai/v1" @@ -28,7 +32,7 @@ const ( // NewEmbeddingFuncJina returns a function that creates embeddings for a text // using the Jina API. func NewEmbeddingFuncJina(apiKey string, model EmbeddingModelJina) EmbeddingFunc { - return NewEmbeddingFuncOpenAICompat(baseURLJina, apiKey, string(model)) + return NewEmbeddingFuncOpenAICompat(baseURLJina, apiKey, string(model), nil) } const baseURLMixedbread = "https://api.mixedbread.ai" @@ -49,7 +53,7 @@ const ( // NewEmbeddingFuncMixedbread returns a function that creates embeddings for a text // using the mixedbread.ai API. func NewEmbeddingFuncMixedbread(apiKey string, model EmbeddingModelMixedbread) EmbeddingFunc { - return NewEmbeddingFuncOpenAICompat(baseURLMixedbread, apiKey, string(model)) + return NewEmbeddingFuncOpenAICompat(baseURLMixedbread, apiKey, string(model), nil) } const baseURLLocalAI = "http://localhost:8080/v1" @@ -64,5 +68,5 @@ const baseURLLocalAI = "http://localhost:8080/v1" // But other embedding models are supported as well. See the LocalAI documentation // for details. func NewEmbeddingFuncLocalAI(model string) EmbeddingFunc { - return NewEmbeddingFuncOpenAICompat(baseURLLocalAI, "", model) + return NewEmbeddingFuncOpenAICompat(baseURLLocalAI, "", model, nil) } diff --git a/embed_ollama.go b/embed_ollama.go index 7ed0573..d2231f7 100644 --- a/embed_ollama.go +++ b/embed_ollama.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "net/http" + "sync" ) // TODO: Turn into const and use as default, but allow user to pass custom URL @@ -28,6 +29,9 @@ func NewEmbeddingFuncOllama(model string) EmbeddingFunc { // and it might have to be a long timeout, depending on the text length. client := &http.Client{} + var checkedNormalized bool + checkNormalized := sync.Once{} + return func(ctx context.Context, text string) ([]float32, error) { // Prepare the request body. reqBody, err := json.Marshal(map[string]string{ @@ -74,6 +78,18 @@ func NewEmbeddingFuncOllama(model string) EmbeddingFunc { return nil, errors.New("no embeddings found in the response") } - return embeddingResponse.Embedding, nil + v := embeddingResponse.Embedding + checkNormalized.Do(func() { + if isNormalized(v) { + checkedNormalized = true + } else { + checkedNormalized = false + } + }) + if !checkedNormalized { + v = normalizeVector(v) + } + + return v, nil } } diff --git a/embed_ollama_test.go b/embed_ollama_test.go index 5ce6121..a3af70a 100644 --- a/embed_ollama_test.go +++ b/embed_ollama_test.go @@ -25,7 +25,7 @@ func TestNewEmbeddingFuncOllama(t *testing.T) { if err != nil { t.Fatal("unexpected error:", err) } - wantRes := []float32{-0.1, 0.1, 0.2} + wantRes := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` // Mock server ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/embed_openai.go b/embed_openai.go index 2d301c7..1d09e01 100644 --- a/embed_openai.go +++ b/embed_openai.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "os" + "sync" ) const BaseURLOpenAI = "https://api.openai.com/v1" @@ -39,7 +40,9 @@ func NewEmbeddingFuncDefault() EmbeddingFunc { // NewEmbeddingFuncOpenAI returns a function that creates embeddings for a text // using the OpenAI API. func NewEmbeddingFuncOpenAI(apiKey string, model EmbeddingModelOpenAI) EmbeddingFunc { - return NewEmbeddingFuncOpenAICompat(BaseURLOpenAI, apiKey, string(model)) + // OpenAI embeddings are normalized + normalized := true + return NewEmbeddingFuncOpenAICompat(BaseURLOpenAI, apiKey, string(model), &normalized) } // NewEmbeddingFuncOpenAICompat returns a function that creates embeddings for a text @@ -48,12 +51,20 @@ func NewEmbeddingFuncOpenAI(apiKey string, model EmbeddingModelOpenAI) Embedding // - LitLLM: https://github.com/BerriAI/litellm // - Ollama: https://github.com/ollama/ollama/blob/main/docs/openai.md // - etc. -func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string) EmbeddingFunc { +// +// The `normalized` parameter indicates whether the vectors returned by the embedding +// model are already normalized, as is the case for OpenAI's and Mistral's models. +// The flag is optional. If it's nil, it will be autodetected on the first request +// (which bears a small risk that the vector just happens to have a length of 1). +func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *bool) EmbeddingFunc { // We don't set a default timeout here, although it's usually a good idea. // In our case though, the library user can set the timeout on the context, // and it might have to be a long timeout, depending on the text length. client := &http.Client{} + var checkedNormalized bool + checkNormalized := sync.Once{} + return func(ctx context.Context, text string) ([]float32, error) { // Prepare the request body. reqBody, err := json.Marshal(map[string]string{ @@ -101,6 +112,24 @@ func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string) EmbeddingFunc { return nil, errors.New("no embeddings found in the response") } - return embeddingResponse.Data[0].Embedding, nil + v := embeddingResponse.Data[0].Embedding + if normalized != nil { + if *normalized { + return v, nil + } + return normalizeVector(v), nil + } + checkNormalized.Do(func() { + if isNormalized(v) { + checkedNormalized = true + } else { + checkedNormalized = false + } + }) + if !checkedNormalized { + v = normalizeVector(v) + } + + return v, nil } } diff --git a/embed_openai_test.go b/embed_openai_test.go index 03049b0..5243b81 100644 --- a/embed_openai_test.go +++ b/embed_openai_test.go @@ -33,7 +33,7 @@ func TestNewEmbeddingFuncOpenAICompat(t *testing.T) { if err != nil { t.Fatal("unexpected error:", err) } - wantRes := []float32{-0.1, 0.1, 0.2} + wantRes := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` // Mock server ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -75,7 +75,7 @@ func TestNewEmbeddingFuncOpenAICompat(t *testing.T) { defer ts.Close() baseURL := ts.URL + baseURLSuffix - f := chromem.NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model) + f := chromem.NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model, nil) res, err := f(context.Background(), input) if err != nil { t.Fatal("expected nil, got", err) diff --git a/persistence_test.go b/persistence_test.go index 469126e..34dc944 100644 --- a/persistence_test.go +++ b/persistence_test.go @@ -23,7 +23,7 @@ func TestPersistence(t *testing.T) { } obj := s{ Foo: "test", - Bar: []float32{-0.1, 0.1, 0.2}, + Bar: []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`, } persist(tempDir, obj) diff --git a/query.go b/query.go index f0544de..9da1be2 100644 --- a/query.go +++ b/query.go @@ -145,7 +145,8 @@ func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Docu return } - sim, err := cosineSimilarity(queryVectors, doc.Embedding) + // As the vectors are normalized, the dot product is the cosine similarity. + sim, err := dotProduct(queryVectors, doc.Embedding) if err != nil { setSharedErr(fmt.Errorf("couldn't calculate similarity for document '%s': %w", doc.ID, err)) return diff --git a/vector.go b/vector.go index 5caaa71..972b6b2 100644 --- a/vector.go +++ b/vector.go @@ -2,6 +2,7 @@ package chromem import ( "errors" + "fmt" "math" ) @@ -20,11 +21,30 @@ func cosineSimilarity(a, b []float32) (float32, error) { if !isNormalized(a) || !isNormalized(b) { a, b = normalizeVector(a), normalizeVector(b) } + dotProduct, err := dotProduct(a, b) + if err != nil { + return 0, fmt.Errorf("couldn't calculate dot product: %w", err) + } + + // Vectors are already normalized, so no need to divide by magnitudes + + return dotProduct, nil +} + +// dotProduct calculates the dot product between two vectors. +// It's the same as cosine similarity for normalized vectors. +// The resulting value represents the similarity, so a higher value means the +// vectors are more similar. +func dotProduct(a, b []float32) (float32, error) { + // The vectors must have the same length + if len(a) != len(b) { + return 0, errors.New("vectors must have the same length") + } + var dotProduct float32 for i := range a { dotProduct += a[i] * b[i] } - // Vectors are already normalized, so no need to divide by magnitudes return dotProduct, nil }