From 914a1b9f9702019c6c8de4d2aa1a08a710055ad7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Tue, 12 Mar 2024 22:54:26 +0100 Subject: [PATCH 01/12] Unlock documents lock earlier Not relevant for single query, but for concurrent ones --- collection.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/collection.go b/collection.go index fc1ae0d..21d6428 100644 --- a/collection.go +++ b/collection.go @@ -262,7 +262,7 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int, } c.documentsLock.RLock() - defer c.documentsLock.RUnlock() + // We don't defer the unlock because we want to do it earlier. if len(c.documents) == 0 { return nil, nil } @@ -276,6 +276,7 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int, // Filter docs by metadata and content filteredDocs := filterDocs(c.documents, where, whereDocument) + c.documentsLock.RUnlock() // No need to continue if the filters got rid of all documents if len(filteredDocs) == 0 { From 4db073268989f6f30853d5d4b0c1b850a502edc3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Wed, 13 Mar 2024 00:13:41 +0100 Subject: [PATCH 02/12] Stop copying all doc structs when returning doc similarities --- collection.go | 39 +++++++++++++++++++++++++++++++++------ query.go | 35 ++++++++++------------------------- 2 files changed, 43 insertions(+), 31 deletions(-) diff --git a/collection.go b/collection.go index 21d6428..22a50c2 100644 --- a/collection.go +++ b/collection.go @@ -247,6 +247,19 @@ func (c *Collection) Count() int { return len(c.documents) } +// Result represents a single result from a query. +type Result struct { + ID string + Metadata map[string]string + Embedding []float32 + Content string + + // The cosine similarity between the query and the document. + // The higher the value, the more similar the document is to the query. + // The value is in the range [-1, 1]. + Similarity float32 +} + // Performs an exhaustive nearest neighbor search on the collection. // // - queryText: The text to search for. @@ -262,7 +275,7 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int, } c.documentsLock.RLock() - // We don't defer the unlock because we want to do it earlier. + defer c.documentsLock.RUnlock() if len(c.documents) == 0 { return nil, nil } @@ -276,7 +289,6 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int, // Filter docs by metadata and content filteredDocs := filterDocs(c.documents, where, whereDocument) - c.documentsLock.RUnlock() // No need to continue if the filters got rid of all documents if len(filteredDocs) == 0 { @@ -289,17 +301,32 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int, } // For the remaining documents, calculate cosine similarity. - res, err := calcDocSimilarity(ctx, queryVectors, filteredDocs) + docSim, err := calcDocSimilarity(ctx, queryVectors, filteredDocs) if err != nil { return nil, fmt.Errorf("couldn't calculate cosine similarity: %w", err) } // Sort by similarity - sort.Slice(res, func(i, j int) bool { + sort.Slice(docSim, func(i, j int) bool { // The `less` function would usually use `<`, but we want to sort descending. - return res[i].Similarity > res[j].Similarity + return docSim[i].similarity > docSim[j].similarity }) + // Return the top nResults or len(docSim), whichever is smaller + if len(docSim) < nResults { + nResults = len(docSim) + } + res := make([]Result, 0, nResults) + for i := 0; i < nResults; i++ { + res = append(res, Result{ + ID: docSim[i].docID, + Metadata: c.documents[docSim[i].docID].Metadata, + Embedding: c.documents[docSim[i].docID].Embedding, + Content: c.documents[docSim[i].docID].Content, + Similarity: docSim[i].similarity, + }) + } + // Return the top nResults - return res[:nResults], nil + return res, nil } diff --git a/query.go b/query.go index e7e5cd3..8399618 100644 --- a/query.go +++ b/query.go @@ -10,17 +10,9 @@ import ( var supportedFilters = []string{"$contains", "$not_contains"} -// Result represents a single result from a query. -type Result struct { - ID string - Metadata map[string]string - Embedding []float32 - Content string - - // The cosine similarity between the query and the document. - // The higher the value, the more similar the document is to the query. - // The value is in the range [-1, 1]. - Similarity float32 +type docSim struct { + docID string + similarity float32 } // filterDocs filters a map of documents by metadata and content. @@ -103,9 +95,9 @@ func documentMatchesFilters(document *Document, where, whereDocument map[string] return true } -func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Document) ([]Result, error) { - res := make([]Result, 0, len(docs)) - resLock := sync.Mutex{} +func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Document) ([]*docSim, error) { + similarities := make([]*docSim, 0, len(docs)) + similaritiesLock := sync.Mutex{} // Determine concurrency. Use number of docs or CPUs, whichever is smaller. numCPUs := runtime.NumCPU() @@ -148,17 +140,10 @@ func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Docu return } - resLock.Lock() + similaritiesLock.Lock() // We don't defer the unlock because we want to unlock much earlier. - res = append(res, Result{ - ID: doc.ID, - Metadata: doc.Metadata, - Embedding: doc.Embedding, - Content: doc.Content, - - Similarity: sim, - }) - resLock.Unlock() + similarities = append(similarities, &docSim{docID: doc.ID, similarity: sim}) + similaritiesLock.Unlock() } }() } @@ -184,5 +169,5 @@ OuterLoop: return nil, sharedErr } - return res, nil + return similarities, nil } From ff7e80fb0f476f0985d0efad35c7ab9ade136b6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Wed, 13 Mar 2024 00:53:45 +0100 Subject: [PATCH 03/12] Use sub slices instead of channel to pass documents into goroutines --- query.go | 34 +++++++++++++++------------------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/query.go b/query.go index 8399618..336a086 100644 --- a/query.go +++ b/query.go @@ -123,12 +123,23 @@ func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Docu } wg := sync.WaitGroup{} - docChan := make(chan *Document, concurrency*2) + // Instead of using a channel to pass documents into the goroutines, we just + // split the slice into sub-slices and pass those to the goroutines. + // This turned out to be faster in the query benchmarks. + subSliceSize := len(docs) / concurrency // Can leave remainder, e.g. 10/3 = 3; leaves 1 + rem := len(docs) % concurrency for i := 0; i < concurrency; i++ { + start := i * subSliceSize + end := start + subSliceSize + // Add remainder to last goroutine + if i == concurrency-1 { + end += rem + } + wg.Add(1) - go func() { + go func(subSlice []*Document) { defer wg.Done() - for doc := range docChan { + for _, doc := range subSlice { // Stop work if another goroutine encountered an error. if ctx.Err() != nil { return @@ -145,24 +156,9 @@ func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Docu similarities = append(similarities, &docSim{docID: doc.ID, similarity: sim}) similaritiesLock.Unlock() } - }() + }(docs[start:end]) } -OuterLoop: - for _, doc := range docs { - // The doc channel has limited capacity, so writing to the channel blocks - // when a goroutine runs into an error and then all goroutines stop processing - // the channel and it gets full. - // To avoid a deadlock we check for ctx.Done() here, which is closed by - // the goroutine that encountered the error. - select { - case docChan <- doc: - case <-ctx.Done(): - break OuterLoop - } - } - close(docChan) - wg.Wait() if sharedErr != nil { From cc86f2c34a3d345029760b308f54fe55e3a632ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Wed, 13 Mar 2024 01:05:48 +0100 Subject: [PATCH 04/12] Use channel for goroutine results instead of shared slice + lock Also gets ride of sync.WaitGroup --- query.go | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/query.go b/query.go index 336a086..846e79f 100644 --- a/query.go +++ b/query.go @@ -96,9 +96,6 @@ func documentMatchesFilters(document *Document, where, whereDocument map[string] } func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Document) ([]*docSim, error) { - similarities := make([]*docSim, 0, len(docs)) - similaritiesLock := sync.Mutex{} - // Determine concurrency. Use number of docs or CPUs, whichever is smaller. numCPUs := runtime.NumCPU() numDocs := len(docs) @@ -122,7 +119,7 @@ func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Docu } } - wg := sync.WaitGroup{} + resChan := make(chan *docSim, concurrency*2) // Instead of using a channel to pass documents into the goroutines, we just // split the slice into sub-slices and pass those to the goroutines. // This turned out to be faster in the query benchmarks. @@ -136,9 +133,7 @@ func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Docu end += rem } - wg.Add(1) go func(subSlice []*Document) { - defer wg.Done() for _, doc := range subSlice { // Stop work if another goroutine encountered an error. if ctx.Err() != nil { @@ -151,18 +146,19 @@ func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Docu return } - similaritiesLock.Lock() - // We don't defer the unlock because we want to unlock much earlier. - similarities = append(similarities, &docSim{docID: doc.ID, similarity: sim}) - similaritiesLock.Unlock() + resChan <- &docSim{docID: doc.ID, similarity: sim} } }(docs[start:end]) } - wg.Wait() - - if sharedErr != nil { - return nil, sharedErr + similarities := make([]*docSim, 0, len(docs)) + for i := 0; i < len(docs); i++ { + select { + case res := <-resChan: + similarities = append(similarities, res) + case <-ctx.Done(): + return nil, sharedErr + } } return similarities, nil From 089007b0114f198bc421a86620ec24477fecb001 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Wed, 13 Mar 2024 01:08:53 +0100 Subject: [PATCH 05/12] Revert "Use channel for goroutine results instead of shared slice + lock" This reverts commit cc86f2c34a3d345029760b308f54fe55e3a632ce. --- query.go | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/query.go b/query.go index 846e79f..336a086 100644 --- a/query.go +++ b/query.go @@ -96,6 +96,9 @@ func documentMatchesFilters(document *Document, where, whereDocument map[string] } func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Document) ([]*docSim, error) { + similarities := make([]*docSim, 0, len(docs)) + similaritiesLock := sync.Mutex{} + // Determine concurrency. Use number of docs or CPUs, whichever is smaller. numCPUs := runtime.NumCPU() numDocs := len(docs) @@ -119,7 +122,7 @@ func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Docu } } - resChan := make(chan *docSim, concurrency*2) + wg := sync.WaitGroup{} // Instead of using a channel to pass documents into the goroutines, we just // split the slice into sub-slices and pass those to the goroutines. // This turned out to be faster in the query benchmarks. @@ -133,7 +136,9 @@ func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Docu end += rem } + wg.Add(1) go func(subSlice []*Document) { + defer wg.Done() for _, doc := range subSlice { // Stop work if another goroutine encountered an error. if ctx.Err() != nil { @@ -146,19 +151,18 @@ func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Docu return } - resChan <- &docSim{docID: doc.ID, similarity: sim} + similaritiesLock.Lock() + // We don't defer the unlock because we want to unlock much earlier. + similarities = append(similarities, &docSim{docID: doc.ID, similarity: sim}) + similaritiesLock.Unlock() } }(docs[start:end]) } - similarities := make([]*docSim, 0, len(docs)) - for i := 0; i < len(docs); i++ { - select { - case res := <-resChan: - similarities = append(similarities, res) - case <-ctx.Done(): - return nil, sharedErr - } + wg.Wait() + + if sharedErr != nil { + return nil, sharedErr } return similarities, nil From 15d1858899df54606d2dfa0cb11993dad287c944 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Sat, 16 Mar 2024 11:08:16 +0100 Subject: [PATCH 06/12] Only normalize vector if it's not normalized yet For now we check this by computing the length. In the future we could pass a flag if it's already known whether a vector is normalized, which is the case for many embedding models. --- vector.go | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/vector.go b/vector.go index 674b584..5caaa71 100644 --- a/vector.go +++ b/vector.go @@ -5,6 +5,8 @@ import ( "math" ) +const isNormalizedPrecisionTolerance = 1e-6 + // cosineSimilarity calculates the cosine similarity between two vectors. // Vectors are normalized first. // The resulting value represents the similarity, so a higher value means the @@ -15,10 +17,12 @@ func cosineSimilarity(a, b []float32) (float32, error) { return 0, errors.New("vectors must have the same length") } - x, y := normalizeVector(a), normalizeVector(b) + if !isNormalized(a) || !isNormalized(b) { + a, b = normalizeVector(a), normalizeVector(b) + } var dotProduct float32 - for i := range x { - dotProduct += x[i] * y[i] + for i := range a { + dotProduct += a[i] * b[i] } // Vectors are already normalized, so no need to divide by magnitudes @@ -39,3 +43,13 @@ func normalizeVector(v []float32) []float32 { return res } + +// isNormalized checks if the vector is normalized. +func isNormalized(v []float32) bool { + var sqSum float64 + for _, val := range v { + sqSum += float64(val) * float64(val) + } + magnitude := math.Sqrt(sqSum) + return math.Abs(magnitude-1) < isNormalizedPrecisionTolerance +} From 503c3cecc849be3e463c4420852ca1b60e4b3f53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Sat, 16 Mar 2024 11:21:40 +0100 Subject: [PATCH 07/12] Turn slice of pointers to slice of structs Greatly reduces number of allocations. For a query of 5,000 documents from ~5000 allocations to ~50. Number of allocations are also now constant, i.e. 50 for querying 100,000 documents. --- query.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/query.go b/query.go index 336a086..f0544de 100644 --- a/query.go +++ b/query.go @@ -95,8 +95,8 @@ func documentMatchesFilters(document *Document, where, whereDocument map[string] return true } -func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Document) ([]*docSim, error) { - similarities := make([]*docSim, 0, len(docs)) +func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Document) ([]docSim, error) { + similarities := make([]docSim, 0, len(docs)) similaritiesLock := sync.Mutex{} // Determine concurrency. Use number of docs or CPUs, whichever is smaller. @@ -153,7 +153,7 @@ func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Docu similaritiesLock.Lock() // We don't defer the unlock because we want to unlock much earlier. - similarities = append(similarities, &docSim{docID: doc.ID, similarity: sim}) + similarities = append(similarities, docSim{docID: doc.ID, similarity: sim}) similaritiesLock.Unlock() } }(docs[start:end]) From ff28a3807c27ebf65b511d06614cc87b4e50b78b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Sat, 16 Mar 2024 12:25:57 +0100 Subject: [PATCH 08/12] Add "normalized" parameter to skip check if normalization is known --- README.md | 2 +- collection.go | 12 ++++++---- collection_test.go | 12 +++++----- db.go | 14 +++++++---- db_test.go | 18 +++++++------- examples/rag-wikipedia-ollama/main.go | 5 +++- examples/semantic-search-arxiv-openai/main.go | 5 +++- query.go | 4 ++-- vector.go | 24 +++++++++++++++---- 9 files changed, 62 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index cb6881a..ded8c5f 100644 --- a/README.md +++ b/README.md @@ -98,7 +98,7 @@ func main() { db := chromem.NewDB() // Create collection. GetCollection, GetOrCreateCollection, DeleteCollection also available! - collection, _ := db.CreateCollection("all-my-documents", nil, nil) + collection, _ := db.CreateCollection("all-my-documents", nil, nil, nil) // Add docs to the collection. Update and delete will be added in the future. // Can be multi-threaded with AddConcurrently()! diff --git a/collection.go b/collection.go index 22a50c2..6e1bf4a 100644 --- a/collection.go +++ b/collection.go @@ -22,11 +22,12 @@ type Collection struct { documents map[string]*Document documentsLock sync.RWMutex embed EmbeddingFunc + normalized *bool } // We don't export this yet to keep the API surface to the bare minimum. // Users create collections via [Client.CreateCollection]. -func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, dir string) (*Collection, error) { +func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, normalized *bool, dir string) (*Collection, error) { // We copy the metadata to avoid data races in case the caller modifies the // map after creating the collection while we range over it. m := make(map[string]string, len(metadata)) @@ -37,9 +38,10 @@ func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, c := &Collection{ Name: name, - metadata: m, - documents: make(map[string]*Document), - embed: embed, + metadata: m, + documents: make(map[string]*Document), + embed: embed, + normalized: normalized, } // Persistence @@ -301,7 +303,7 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int, } // For the remaining documents, calculate cosine similarity. - docSim, err := calcDocSimilarity(ctx, queryVectors, filteredDocs) + docSim, err := calcDocSimilarity(ctx, queryVectors, filteredDocs, c.normalized) if err != nil { return nil, fmt.Errorf("couldn't calculate cosine similarity: %w", err) } diff --git a/collection_test.go b/collection_test.go index 883fdab..2202c38 100644 --- a/collection_test.go +++ b/collection_test.go @@ -20,7 +20,7 @@ func TestCollection_Add(t *testing.T) { // Create collection db := NewDB() - c, err := db.CreateCollection(name, metadata, embeddingFunc) + c, err := db.CreateCollection(name, metadata, embeddingFunc, nil) if err != nil { t.Fatal("expected no error, got", err) } @@ -116,7 +116,7 @@ func TestCollection_Add_Error(t *testing.T) { // Create collection db := NewDB() - c, err := db.CreateCollection(name, metadata, embeddingFunc) + c, err := db.CreateCollection(name, metadata, embeddingFunc, nil) if err != nil { t.Fatal("expected no error, got", err) } @@ -167,7 +167,7 @@ func TestCollection_AddConcurrently(t *testing.T) { // Create collection db := NewDB() - c, err := db.CreateCollection(name, metadata, embeddingFunc) + c, err := db.CreateCollection(name, metadata, embeddingFunc, nil) if err != nil { t.Fatal("expected no error, got", err) } @@ -263,7 +263,7 @@ func TestCollection_AddConcurrently_Error(t *testing.T) { // Create collection db := NewDB() - c, err := db.CreateCollection(name, metadata, embeddingFunc) + c, err := db.CreateCollection(name, metadata, embeddingFunc, nil) if err != nil { t.Fatal("expected no error, got", err) } @@ -316,7 +316,7 @@ func TestCollection_Count(t *testing.T) { embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { return []float32{-0.1, 0.1, 0.2}, nil } - c, err := db.CreateCollection(name, metadata, embeddingFunc) + c, err := db.CreateCollection(name, metadata, embeddingFunc, nil) if err != nil { t.Fatal("expected no error, got", err) } @@ -407,7 +407,7 @@ func benchmarkCollection_Query(b *testing.B, n int, withContent bool) { // Create collection db := NewDB() name := "test" - c, err := db.CreateCollection(name, nil, embeddingFunc) + c, err := db.CreateCollection(name, nil, embeddingFunc, &trueVal) if err != nil { b.Fatal("expected no error, got", err) } diff --git a/db.go b/db.go index 42ab698..ba39b0b 100644 --- a/db.go +++ b/db.go @@ -142,14 +142,17 @@ func NewPersistentDB(path string) (*DB, error) { // - metadata: Optional metadata to associate with the collection. // - embeddingFunc: Optional function to use to embed documents. // Uses the default embedding function if not provided. -func (db *DB) CreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc) (*Collection, error) { +// - normalized: Optional flag to indicate if the embeddings of the collection +// are normalized (when you add embeddings yourself, or the embeddings created +// by the embeddingFunc). If nil it will be autodetected. +func (db *DB) CreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc, normalized *bool) (*Collection, error) { if name == "" { return nil, errors.New("collection name is empty") } if embeddingFunc == nil { embeddingFunc = NewEmbeddingFuncDefault() } - collection, err := newCollection(name, metadata, embeddingFunc, db.persistDirectory) + collection, err := newCollection(name, metadata, embeddingFunc, normalized, db.persistDirectory) if err != nil { return nil, fmt.Errorf("couldn't create collection: %w", err) } @@ -213,12 +216,15 @@ func (db *DB) GetCollection(name string, embeddingFunc EmbeddingFunc) *Collectio // - metadata: Optional metadata to associate with the collection. // - embeddingFunc: Optional function to use to embed documents. // Uses the default embedding function if not provided. -func (db *DB) GetOrCreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc) (*Collection, error) { +// - normalized: Optional flag to indicate if the embeddings of the collection +// are normalized (when you add embeddings yourself, or the embeddings created +// by the embeddingFunc). If nil it will be autodetected. +func (db *DB) GetOrCreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc, normalized *bool) (*Collection, error) { // No need to lock here, because the methods we call do that. collection := db.GetCollection(name, embeddingFunc) if collection == nil { var err error - collection, err = db.CreateCollection(name, metadata, embeddingFunc) + collection, err = db.CreateCollection(name, metadata, embeddingFunc, normalized) if err != nil { return nil, fmt.Errorf("couldn't create collection: %w", err) } diff --git a/db_test.go b/db_test.go index ad04bbe..2248451 100644 --- a/db_test.go +++ b/db_test.go @@ -18,7 +18,7 @@ func TestDB_CreateCollection(t *testing.T) { db := NewDB() t.Run("OK", func(t *testing.T) { - c, err := db.CreateCollection(name, metadata, embeddingFunc) + c, err := db.CreateCollection(name, metadata, embeddingFunc, nil) if err != nil { t.Fatal("expected no error, got", err) } @@ -70,7 +70,7 @@ func TestDB_CreateCollection(t *testing.T) { }) t.Run("NOK - Empty name", func(t *testing.T) { - _, err := db.CreateCollection("", metadata, embeddingFunc) + _, err := db.CreateCollection("", metadata, embeddingFunc, nil) if err == nil { t.Fatal("expected error, got nil") } @@ -89,7 +89,7 @@ func TestDB_ListCollections(t *testing.T) { // Create initial collection db := NewDB() // We ignore the return value. CreateCollection is tested elsewhere. - _, err := db.CreateCollection(name, metadata, embeddingFunc) + _, err := db.CreateCollection(name, metadata, embeddingFunc, nil) if err != nil { t.Fatal("expected no error, got", err) } @@ -155,7 +155,7 @@ func TestDB_GetCollection(t *testing.T) { // Create initial collection db := NewDB() // We ignore the return value. CreateCollection is tested elsewhere. - _, err := db.CreateCollection(name, metadata, embeddingFunc) + _, err := db.CreateCollection(name, metadata, embeddingFunc, nil) if err != nil { t.Fatal("expected no error, got", err) } @@ -207,7 +207,7 @@ func TestDB_GetOrCreateCollection(t *testing.T) { // Create collection so that the GetOrCreateCollection() call below only // gets it. // We ignore the return value. CreateCollection is tested elsewhere. - _, err := db.CreateCollection(name, metadata, embeddingFunc) + _, err := db.CreateCollection(name, metadata, embeddingFunc, nil) if err != nil { t.Fatal("expected no error, got", err) } @@ -215,7 +215,7 @@ func TestDB_GetOrCreateCollection(t *testing.T) { // Call GetOrCreateCollection() with the same name to only get it. We pass // nil for the metadata and embeddingFunc so we can check that the returned // collection is the original one, and not a new one. - c, err := db.GetOrCreateCollection(name, nil, nil) + c, err := db.GetOrCreateCollection(name, nil, embeddingFunc, nil) if err != nil { t.Fatal("expected no error, got", err) } @@ -257,7 +257,7 @@ func TestDB_GetOrCreateCollection(t *testing.T) { db := NewDB() // Call GetOrCreateCollection() - c, err := db.GetOrCreateCollection(name, metadata, embeddingFunc) + c, err := db.GetOrCreateCollection(name, metadata, embeddingFunc, nil) if err != nil { t.Fatal("expected no error, got", err) } @@ -307,7 +307,7 @@ func TestDB_DeleteCollection(t *testing.T) { // Create initial collection db := NewDB() // We ignore the return value. CreateCollection is tested elsewhere. - _, err := db.CreateCollection(name, metadata, embeddingFunc) + _, err := db.CreateCollection(name, metadata, embeddingFunc, nil) if err != nil { t.Fatal("expected no error, got", err) } @@ -339,7 +339,7 @@ func TestDB_Reset(t *testing.T) { // Create initial collection db := NewDB() // We ignore the return value. CreateCollection is tested elsewhere. - _, err := db.CreateCollection(name, metadata, embeddingFunc) + _, err := db.CreateCollection(name, metadata, embeddingFunc, nil) if err != nil { t.Fatal("expected no error, got", err) } diff --git a/examples/rag-wikipedia-ollama/main.go b/examples/rag-wikipedia-ollama/main.go index 4d451ae..982e908 100644 --- a/examples/rag-wikipedia-ollama/main.go +++ b/examples/rag-wikipedia-ollama/main.go @@ -19,6 +19,9 @@ const ( embeddingModel = "nomic-embed-text" ) +// The nomic-embed-text-v1.5 model doesn't return normalized embeddings +var normalized = false + func main() { ctx := context.Background() @@ -49,7 +52,7 @@ func main() { // variable to be set. // For this example we choose to use a locally running embedding model though. // It requires Ollama to serve its API at "http://localhost:11434/api". - collection, err := db.GetOrCreateCollection("Wikipedia", nil, chromem.NewEmbeddingFuncOllama(embeddingModel)) + collection, err := db.GetOrCreateCollection("Wikipedia", nil, chromem.NewEmbeddingFuncOllama(embeddingModel), &normalized) if err != nil { panic(err) } diff --git a/examples/semantic-search-arxiv-openai/main.go b/examples/semantic-search-arxiv-openai/main.go index e0d341b..dd55b18 100644 --- a/examples/semantic-search-arxiv-openai/main.go +++ b/examples/semantic-search-arxiv-openai/main.go @@ -16,6 +16,9 @@ import ( const searchTerm = "semantic search with vector databases" +// OpenAI embeddings are already normalized. +var normalized = true + func main() { ctx := context.Background() @@ -30,7 +33,7 @@ func main() { // We pass nil as embedding function to use the default (OpenAI text-embedding-3-small), // which is very good and cheap. It requires the OPENAI_API_KEY environment // variable to be set. - collection, err := db.GetOrCreateCollection("arXiv cs.CL 2023", nil, nil) + collection, err := db.GetOrCreateCollection("arXiv cs.CL 2023", nil, nil, &normalized) if err != nil { panic(err) } diff --git a/query.go b/query.go index f0544de..c8a6144 100644 --- a/query.go +++ b/query.go @@ -95,7 +95,7 @@ func documentMatchesFilters(document *Document, where, whereDocument map[string] return true } -func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Document) ([]docSim, error) { +func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Document, isNormalized *bool) ([]docSim, error) { similarities := make([]docSim, 0, len(docs)) similaritiesLock := sync.Mutex{} @@ -145,7 +145,7 @@ func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Docu return } - sim, err := cosineSimilarity(queryVectors, doc.Embedding) + sim, err := cosineSimilarity(queryVectors, doc.Embedding, isNormalized) if err != nil { setSharedErr(fmt.Errorf("couldn't calculate similarity for document '%s': %w", doc.ID, err)) return diff --git a/vector.go b/vector.go index 5caaa71..9dfba02 100644 --- a/vector.go +++ b/vector.go @@ -7,19 +7,33 @@ import ( const isNormalizedPrecisionTolerance = 1e-6 +var ( + falseVal = false + trueVal = true +) + // cosineSimilarity calculates the cosine similarity between two vectors. -// Vectors are normalized first. +// Pass isNormalized=true if the vectors are already normalized, false +// to normalize them, and nil to autodetect. // The resulting value represents the similarity, so a higher value means the // vectors are more similar. -func cosineSimilarity(a, b []float32) (float32, error) { +func cosineSimilarity(a, b []float32, isNormalized *bool) (float32, error) { // The vectors must have the same length if len(a) != len(b) { return 0, errors.New("vectors must have the same length") } - if !isNormalized(a) || !isNormalized(b) { + if isNormalized == nil { + if !checkNormalized(a) || !checkNormalized(b) { + isNormalized = &falseVal + } else { + isNormalized = &trueVal + } + } + if !*isNormalized { a, b = normalizeVector(a), normalizeVector(b) } + var dotProduct float32 for i := range a { dotProduct += a[i] * b[i] @@ -44,8 +58,8 @@ func normalizeVector(v []float32) []float32 { return res } -// isNormalized checks if the vector is normalized. -func isNormalized(v []float32) bool { +// checkNormalized checks if the vector is normalized. +func checkNormalized(v []float32) bool { var sqSum float64 for _, val := range v { sqSum += float64(val) * float64(val) From 05c4f764c4bba63c5a225892ca61ba87d38fad95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Sat, 16 Mar 2024 12:26:14 +0100 Subject: [PATCH 09/12] Revert "Add "normalized" parameter to skip check if normalization is known" This reverts commit ff28a3807c27ebf65b511d06614cc87b4e50b78b. --- README.md | 2 +- collection.go | 12 ++++------ collection_test.go | 12 +++++----- db.go | 14 ++++------- db_test.go | 18 +++++++------- examples/rag-wikipedia-ollama/main.go | 5 +--- examples/semantic-search-arxiv-openai/main.go | 5 +--- query.go | 4 ++-- vector.go | 24 ++++--------------- 9 files changed, 34 insertions(+), 62 deletions(-) diff --git a/README.md b/README.md index ded8c5f..cb6881a 100644 --- a/README.md +++ b/README.md @@ -98,7 +98,7 @@ func main() { db := chromem.NewDB() // Create collection. GetCollection, GetOrCreateCollection, DeleteCollection also available! - collection, _ := db.CreateCollection("all-my-documents", nil, nil, nil) + collection, _ := db.CreateCollection("all-my-documents", nil, nil) // Add docs to the collection. Update and delete will be added in the future. // Can be multi-threaded with AddConcurrently()! diff --git a/collection.go b/collection.go index 6e1bf4a..22a50c2 100644 --- a/collection.go +++ b/collection.go @@ -22,12 +22,11 @@ type Collection struct { documents map[string]*Document documentsLock sync.RWMutex embed EmbeddingFunc - normalized *bool } // We don't export this yet to keep the API surface to the bare minimum. // Users create collections via [Client.CreateCollection]. -func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, normalized *bool, dir string) (*Collection, error) { +func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, dir string) (*Collection, error) { // We copy the metadata to avoid data races in case the caller modifies the // map after creating the collection while we range over it. m := make(map[string]string, len(metadata)) @@ -38,10 +37,9 @@ func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, c := &Collection{ Name: name, - metadata: m, - documents: make(map[string]*Document), - embed: embed, - normalized: normalized, + metadata: m, + documents: make(map[string]*Document), + embed: embed, } // Persistence @@ -303,7 +301,7 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int, } // For the remaining documents, calculate cosine similarity. - docSim, err := calcDocSimilarity(ctx, queryVectors, filteredDocs, c.normalized) + docSim, err := calcDocSimilarity(ctx, queryVectors, filteredDocs) if err != nil { return nil, fmt.Errorf("couldn't calculate cosine similarity: %w", err) } diff --git a/collection_test.go b/collection_test.go index 2202c38..883fdab 100644 --- a/collection_test.go +++ b/collection_test.go @@ -20,7 +20,7 @@ func TestCollection_Add(t *testing.T) { // Create collection db := NewDB() - c, err := db.CreateCollection(name, metadata, embeddingFunc, nil) + c, err := db.CreateCollection(name, metadata, embeddingFunc) if err != nil { t.Fatal("expected no error, got", err) } @@ -116,7 +116,7 @@ func TestCollection_Add_Error(t *testing.T) { // Create collection db := NewDB() - c, err := db.CreateCollection(name, metadata, embeddingFunc, nil) + c, err := db.CreateCollection(name, metadata, embeddingFunc) if err != nil { t.Fatal("expected no error, got", err) } @@ -167,7 +167,7 @@ func TestCollection_AddConcurrently(t *testing.T) { // Create collection db := NewDB() - c, err := db.CreateCollection(name, metadata, embeddingFunc, nil) + c, err := db.CreateCollection(name, metadata, embeddingFunc) if err != nil { t.Fatal("expected no error, got", err) } @@ -263,7 +263,7 @@ func TestCollection_AddConcurrently_Error(t *testing.T) { // Create collection db := NewDB() - c, err := db.CreateCollection(name, metadata, embeddingFunc, nil) + c, err := db.CreateCollection(name, metadata, embeddingFunc) if err != nil { t.Fatal("expected no error, got", err) } @@ -316,7 +316,7 @@ func TestCollection_Count(t *testing.T) { embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { return []float32{-0.1, 0.1, 0.2}, nil } - c, err := db.CreateCollection(name, metadata, embeddingFunc, nil) + c, err := db.CreateCollection(name, metadata, embeddingFunc) if err != nil { t.Fatal("expected no error, got", err) } @@ -407,7 +407,7 @@ func benchmarkCollection_Query(b *testing.B, n int, withContent bool) { // Create collection db := NewDB() name := "test" - c, err := db.CreateCollection(name, nil, embeddingFunc, &trueVal) + c, err := db.CreateCollection(name, nil, embeddingFunc) if err != nil { b.Fatal("expected no error, got", err) } diff --git a/db.go b/db.go index ba39b0b..42ab698 100644 --- a/db.go +++ b/db.go @@ -142,17 +142,14 @@ func NewPersistentDB(path string) (*DB, error) { // - metadata: Optional metadata to associate with the collection. // - embeddingFunc: Optional function to use to embed documents. // Uses the default embedding function if not provided. -// - normalized: Optional flag to indicate if the embeddings of the collection -// are normalized (when you add embeddings yourself, or the embeddings created -// by the embeddingFunc). If nil it will be autodetected. -func (db *DB) CreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc, normalized *bool) (*Collection, error) { +func (db *DB) CreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc) (*Collection, error) { if name == "" { return nil, errors.New("collection name is empty") } if embeddingFunc == nil { embeddingFunc = NewEmbeddingFuncDefault() } - collection, err := newCollection(name, metadata, embeddingFunc, normalized, db.persistDirectory) + collection, err := newCollection(name, metadata, embeddingFunc, db.persistDirectory) if err != nil { return nil, fmt.Errorf("couldn't create collection: %w", err) } @@ -216,15 +213,12 @@ func (db *DB) GetCollection(name string, embeddingFunc EmbeddingFunc) *Collectio // - metadata: Optional metadata to associate with the collection. // - embeddingFunc: Optional function to use to embed documents. // Uses the default embedding function if not provided. -// - normalized: Optional flag to indicate if the embeddings of the collection -// are normalized (when you add embeddings yourself, or the embeddings created -// by the embeddingFunc). If nil it will be autodetected. -func (db *DB) GetOrCreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc, normalized *bool) (*Collection, error) { +func (db *DB) GetOrCreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc) (*Collection, error) { // No need to lock here, because the methods we call do that. collection := db.GetCollection(name, embeddingFunc) if collection == nil { var err error - collection, err = db.CreateCollection(name, metadata, embeddingFunc, normalized) + collection, err = db.CreateCollection(name, metadata, embeddingFunc) if err != nil { return nil, fmt.Errorf("couldn't create collection: %w", err) } diff --git a/db_test.go b/db_test.go index 2248451..ad04bbe 100644 --- a/db_test.go +++ b/db_test.go @@ -18,7 +18,7 @@ func TestDB_CreateCollection(t *testing.T) { db := NewDB() t.Run("OK", func(t *testing.T) { - c, err := db.CreateCollection(name, metadata, embeddingFunc, nil) + c, err := db.CreateCollection(name, metadata, embeddingFunc) if err != nil { t.Fatal("expected no error, got", err) } @@ -70,7 +70,7 @@ func TestDB_CreateCollection(t *testing.T) { }) t.Run("NOK - Empty name", func(t *testing.T) { - _, err := db.CreateCollection("", metadata, embeddingFunc, nil) + _, err := db.CreateCollection("", metadata, embeddingFunc) if err == nil { t.Fatal("expected error, got nil") } @@ -89,7 +89,7 @@ func TestDB_ListCollections(t *testing.T) { // Create initial collection db := NewDB() // We ignore the return value. CreateCollection is tested elsewhere. - _, err := db.CreateCollection(name, metadata, embeddingFunc, nil) + _, err := db.CreateCollection(name, metadata, embeddingFunc) if err != nil { t.Fatal("expected no error, got", err) } @@ -155,7 +155,7 @@ func TestDB_GetCollection(t *testing.T) { // Create initial collection db := NewDB() // We ignore the return value. CreateCollection is tested elsewhere. - _, err := db.CreateCollection(name, metadata, embeddingFunc, nil) + _, err := db.CreateCollection(name, metadata, embeddingFunc) if err != nil { t.Fatal("expected no error, got", err) } @@ -207,7 +207,7 @@ func TestDB_GetOrCreateCollection(t *testing.T) { // Create collection so that the GetOrCreateCollection() call below only // gets it. // We ignore the return value. CreateCollection is tested elsewhere. - _, err := db.CreateCollection(name, metadata, embeddingFunc, nil) + _, err := db.CreateCollection(name, metadata, embeddingFunc) if err != nil { t.Fatal("expected no error, got", err) } @@ -215,7 +215,7 @@ func TestDB_GetOrCreateCollection(t *testing.T) { // Call GetOrCreateCollection() with the same name to only get it. We pass // nil for the metadata and embeddingFunc so we can check that the returned // collection is the original one, and not a new one. - c, err := db.GetOrCreateCollection(name, nil, embeddingFunc, nil) + c, err := db.GetOrCreateCollection(name, nil, nil) if err != nil { t.Fatal("expected no error, got", err) } @@ -257,7 +257,7 @@ func TestDB_GetOrCreateCollection(t *testing.T) { db := NewDB() // Call GetOrCreateCollection() - c, err := db.GetOrCreateCollection(name, metadata, embeddingFunc, nil) + c, err := db.GetOrCreateCollection(name, metadata, embeddingFunc) if err != nil { t.Fatal("expected no error, got", err) } @@ -307,7 +307,7 @@ func TestDB_DeleteCollection(t *testing.T) { // Create initial collection db := NewDB() // We ignore the return value. CreateCollection is tested elsewhere. - _, err := db.CreateCollection(name, metadata, embeddingFunc, nil) + _, err := db.CreateCollection(name, metadata, embeddingFunc) if err != nil { t.Fatal("expected no error, got", err) } @@ -339,7 +339,7 @@ func TestDB_Reset(t *testing.T) { // Create initial collection db := NewDB() // We ignore the return value. CreateCollection is tested elsewhere. - _, err := db.CreateCollection(name, metadata, embeddingFunc, nil) + _, err := db.CreateCollection(name, metadata, embeddingFunc) if err != nil { t.Fatal("expected no error, got", err) } diff --git a/examples/rag-wikipedia-ollama/main.go b/examples/rag-wikipedia-ollama/main.go index 982e908..4d451ae 100644 --- a/examples/rag-wikipedia-ollama/main.go +++ b/examples/rag-wikipedia-ollama/main.go @@ -19,9 +19,6 @@ const ( embeddingModel = "nomic-embed-text" ) -// The nomic-embed-text-v1.5 model doesn't return normalized embeddings -var normalized = false - func main() { ctx := context.Background() @@ -52,7 +49,7 @@ func main() { // variable to be set. // For this example we choose to use a locally running embedding model though. // It requires Ollama to serve its API at "http://localhost:11434/api". - collection, err := db.GetOrCreateCollection("Wikipedia", nil, chromem.NewEmbeddingFuncOllama(embeddingModel), &normalized) + collection, err := db.GetOrCreateCollection("Wikipedia", nil, chromem.NewEmbeddingFuncOllama(embeddingModel)) if err != nil { panic(err) } diff --git a/examples/semantic-search-arxiv-openai/main.go b/examples/semantic-search-arxiv-openai/main.go index dd55b18..e0d341b 100644 --- a/examples/semantic-search-arxiv-openai/main.go +++ b/examples/semantic-search-arxiv-openai/main.go @@ -16,9 +16,6 @@ import ( const searchTerm = "semantic search with vector databases" -// OpenAI embeddings are already normalized. -var normalized = true - func main() { ctx := context.Background() @@ -33,7 +30,7 @@ func main() { // We pass nil as embedding function to use the default (OpenAI text-embedding-3-small), // which is very good and cheap. It requires the OPENAI_API_KEY environment // variable to be set. - collection, err := db.GetOrCreateCollection("arXiv cs.CL 2023", nil, nil, &normalized) + collection, err := db.GetOrCreateCollection("arXiv cs.CL 2023", nil, nil) if err != nil { panic(err) } diff --git a/query.go b/query.go index c8a6144..f0544de 100644 --- a/query.go +++ b/query.go @@ -95,7 +95,7 @@ func documentMatchesFilters(document *Document, where, whereDocument map[string] return true } -func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Document, isNormalized *bool) ([]docSim, error) { +func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Document) ([]docSim, error) { similarities := make([]docSim, 0, len(docs)) similaritiesLock := sync.Mutex{} @@ -145,7 +145,7 @@ func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Docu return } - sim, err := cosineSimilarity(queryVectors, doc.Embedding, isNormalized) + sim, err := cosineSimilarity(queryVectors, doc.Embedding) if err != nil { setSharedErr(fmt.Errorf("couldn't calculate similarity for document '%s': %w", doc.ID, err)) return diff --git a/vector.go b/vector.go index 9dfba02..5caaa71 100644 --- a/vector.go +++ b/vector.go @@ -7,33 +7,19 @@ import ( const isNormalizedPrecisionTolerance = 1e-6 -var ( - falseVal = false - trueVal = true -) - // cosineSimilarity calculates the cosine similarity between two vectors. -// Pass isNormalized=true if the vectors are already normalized, false -// to normalize them, and nil to autodetect. +// Vectors are normalized first. // The resulting value represents the similarity, so a higher value means the // vectors are more similar. -func cosineSimilarity(a, b []float32, isNormalized *bool) (float32, error) { +func cosineSimilarity(a, b []float32) (float32, error) { // The vectors must have the same length if len(a) != len(b) { return 0, errors.New("vectors must have the same length") } - if isNormalized == nil { - if !checkNormalized(a) || !checkNormalized(b) { - isNormalized = &falseVal - } else { - isNormalized = &trueVal - } - } - if !*isNormalized { + if !isNormalized(a) || !isNormalized(b) { a, b = normalizeVector(a), normalizeVector(b) } - var dotProduct float32 for i := range a { dotProduct += a[i] * b[i] @@ -58,8 +44,8 @@ func normalizeVector(v []float32) []float32 { return res } -// checkNormalized checks if the vector is normalized. -func checkNormalized(v []float32) bool { +// isNormalized checks if the vector is normalized. +func isNormalized(v []float32) bool { var sqSum float64 for _, val := range v { sqSum += float64(val) * float64(val) From 579fd46fcc9ff81eab582fc58730a3fb772ef381 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Sat, 16 Mar 2024 14:04:10 +0100 Subject: [PATCH 10/12] Normalize vectors on embedding creation instead of querying - Normalizes only once instead of each time - Embedding creation takes time anyway, while query should be as fast as possible --- collection.go | 6 +++++- collection_test.go | 11 ++++++----- db.go | 3 +++ db_test.go | 12 ++++++------ document_test.go | 2 +- embed_compat.go | 12 ++++++++---- embed_ollama.go | 18 +++++++++++++++++- embed_ollama_test.go | 2 +- embed_openai.go | 35 ++++++++++++++++++++++++++++++++--- embed_openai_test.go | 4 ++-- persistence_test.go | 2 +- query.go | 3 ++- vector.go | 22 +++++++++++++++++++++- 13 files changed, 105 insertions(+), 27 deletions(-) diff --git a/collection.go b/collection.go index 22a50c2..485660d 100644 --- a/collection.go +++ b/collection.go @@ -213,13 +213,17 @@ func (c *Collection) AddDocument(ctx context.Context, doc Document) error { m[k] = v } - // Create embedding if they don't exist + // Create embedding if they don't exist, otherwise normalize if necessary if len(doc.Embedding) == 0 { embedding, err := c.embed(ctx, doc.Content) if err != nil { return fmt.Errorf("couldn't create embedding of document: %w", err) } doc.Embedding = embedding + } else { + if !isNormalized(doc.Embedding) { + doc.Embedding = normalizeVector(doc.Embedding) + } } c.documentsLock.Lock() diff --git a/collection_test.go b/collection_test.go index 883fdab..274b271 100644 --- a/collection_test.go +++ b/collection_test.go @@ -13,7 +13,7 @@ func TestCollection_Add(t *testing.T) { ctx := context.Background() name := "test" metadata := map[string]string{"foo": "bar"} - vectors := []float32{-0.1, 0.1, 0.2} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { return vectors, nil } @@ -109,7 +109,7 @@ func TestCollection_Add_Error(t *testing.T) { ctx := context.Background() name := "test" metadata := map[string]string{"foo": "bar"} - vectors := []float32{-0.1, 0.1, 0.2} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { return vectors, nil } @@ -160,7 +160,7 @@ func TestCollection_AddConcurrently(t *testing.T) { ctx := context.Background() name := "test" metadata := map[string]string{"foo": "bar"} - vectors := []float32{-0.1, 0.1, 0.2} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { return vectors, nil } @@ -256,7 +256,7 @@ func TestCollection_AddConcurrently_Error(t *testing.T) { ctx := context.Background() name := "test" metadata := map[string]string{"foo": "bar"} - vectors := []float32{-0.1, 0.1, 0.2} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { return vectors, nil } @@ -313,8 +313,9 @@ func TestCollection_Count(t *testing.T) { db := NewDB() name := "test" metadata := map[string]string{"foo": "bar"} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { - return []float32{-0.1, 0.1, 0.2}, nil + return vectors, nil } c, err := db.CreateCollection(name, metadata, embeddingFunc) if err != nil { diff --git a/db.go b/db.go index 42ab698..7c78b39 100644 --- a/db.go +++ b/db.go @@ -13,6 +13,9 @@ import ( // EmbeddingFunc is a function that creates embeddings for a given text. // chromem-go will use OpenAI`s "text-embedding-3-small" model by default, // but you can provide your own function, using any model you like. +// The function must return a *normalized* vector, i.e. the length of the vector +// must be 1. OpenAI's and Mistral's embedding models do this by default. Some +// others like Nomic's "nomic-embed-text-v1.5" don't. type EmbeddingFunc func(ctx context.Context, text string) ([]float32, error) // DB is the chromem-go database. It holds collections, which hold documents. diff --git a/db_test.go b/db_test.go index ad04bbe..fb25373 100644 --- a/db_test.go +++ b/db_test.go @@ -10,7 +10,7 @@ func TestDB_CreateCollection(t *testing.T) { // Values in the collection name := "test" metadata := map[string]string{"foo": "bar"} - vectors := []float32{-0.1, 0.1, 0.2} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { return vectors, nil } @@ -81,7 +81,7 @@ func TestDB_ListCollections(t *testing.T) { // Values in the collection name := "test" metadata := map[string]string{"foo": "bar"} - vectors := []float32{-0.1, 0.1, 0.2} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { return vectors, nil } @@ -147,7 +147,7 @@ func TestDB_GetCollection(t *testing.T) { // Values in the collection name := "test" metadata := map[string]string{"foo": "bar"} - vectors := []float32{-0.1, 0.1, 0.2} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { return vectors, nil } @@ -196,7 +196,7 @@ func TestDB_GetOrCreateCollection(t *testing.T) { // Values in the collection name := "test" metadata := map[string]string{"foo": "bar"} - vectors := []float32{-0.1, 0.1, 0.2} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { return vectors, nil } @@ -299,7 +299,7 @@ func TestDB_DeleteCollection(t *testing.T) { // Values in the collection name := "test" metadata := map[string]string{"foo": "bar"} - vectors := []float32{-0.1, 0.1, 0.2} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { return vectors, nil } @@ -331,7 +331,7 @@ func TestDB_Reset(t *testing.T) { // Values in the collection name := "test" metadata := map[string]string{"foo": "bar"} - vectors := []float32{-0.1, 0.1, 0.2} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { return vectors, nil } diff --git a/document_test.go b/document_test.go index 668d6b0..4290c5a 100644 --- a/document_test.go +++ b/document_test.go @@ -10,7 +10,7 @@ func TestDocument_New(t *testing.T) { ctx := context.Background() id := "test" metadata := map[string]string{"foo": "bar"} - vectors := []float32{-0.1, 0.1, 0.2} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` content := "hello world" embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { return vectors, nil diff --git a/embed_compat.go b/embed_compat.go index a3d18c9..c3cf4ed 100644 --- a/embed_compat.go +++ b/embed_compat.go @@ -9,9 +9,13 @@ const ( // NewEmbeddingFuncMistral returns a function that creates embeddings for a text // using the Mistral API. func NewEmbeddingFuncMistral(apiKey string) EmbeddingFunc { + // Mistral embeddings are normalized, see section "Distance Measures" on + // https://docs.mistral.ai/guides/embeddings/. + normalized := true + // The Mistral API docs don't mention the `encoding_format` as optional, // but it seems to be, just like OpenAI. So we reuse the OpenAI function. - return NewEmbeddingFuncOpenAICompat(baseURLMistral, apiKey, embeddingModelMistral) + return NewEmbeddingFuncOpenAICompat(baseURLMistral, apiKey, embeddingModelMistral, &normalized) } const baseURLJina = "https://api.jina.ai/v1" @@ -28,7 +32,7 @@ const ( // NewEmbeddingFuncJina returns a function that creates embeddings for a text // using the Jina API. func NewEmbeddingFuncJina(apiKey string, model EmbeddingModelJina) EmbeddingFunc { - return NewEmbeddingFuncOpenAICompat(baseURLJina, apiKey, string(model)) + return NewEmbeddingFuncOpenAICompat(baseURLJina, apiKey, string(model), nil) } const baseURLMixedbread = "https://api.mixedbread.ai" @@ -49,7 +53,7 @@ const ( // NewEmbeddingFuncMixedbread returns a function that creates embeddings for a text // using the mixedbread.ai API. func NewEmbeddingFuncMixedbread(apiKey string, model EmbeddingModelMixedbread) EmbeddingFunc { - return NewEmbeddingFuncOpenAICompat(baseURLMixedbread, apiKey, string(model)) + return NewEmbeddingFuncOpenAICompat(baseURLMixedbread, apiKey, string(model), nil) } const baseURLLocalAI = "http://localhost:8080/v1" @@ -64,5 +68,5 @@ const baseURLLocalAI = "http://localhost:8080/v1" // But other embedding models are supported as well. See the LocalAI documentation // for details. func NewEmbeddingFuncLocalAI(model string) EmbeddingFunc { - return NewEmbeddingFuncOpenAICompat(baseURLLocalAI, "", model) + return NewEmbeddingFuncOpenAICompat(baseURLLocalAI, "", model, nil) } diff --git a/embed_ollama.go b/embed_ollama.go index 7ed0573..d2231f7 100644 --- a/embed_ollama.go +++ b/embed_ollama.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "net/http" + "sync" ) // TODO: Turn into const and use as default, but allow user to pass custom URL @@ -28,6 +29,9 @@ func NewEmbeddingFuncOllama(model string) EmbeddingFunc { // and it might have to be a long timeout, depending on the text length. client := &http.Client{} + var checkedNormalized bool + checkNormalized := sync.Once{} + return func(ctx context.Context, text string) ([]float32, error) { // Prepare the request body. reqBody, err := json.Marshal(map[string]string{ @@ -74,6 +78,18 @@ func NewEmbeddingFuncOllama(model string) EmbeddingFunc { return nil, errors.New("no embeddings found in the response") } - return embeddingResponse.Embedding, nil + v := embeddingResponse.Embedding + checkNormalized.Do(func() { + if isNormalized(v) { + checkedNormalized = true + } else { + checkedNormalized = false + } + }) + if !checkedNormalized { + v = normalizeVector(v) + } + + return v, nil } } diff --git a/embed_ollama_test.go b/embed_ollama_test.go index 5ce6121..a3af70a 100644 --- a/embed_ollama_test.go +++ b/embed_ollama_test.go @@ -25,7 +25,7 @@ func TestNewEmbeddingFuncOllama(t *testing.T) { if err != nil { t.Fatal("unexpected error:", err) } - wantRes := []float32{-0.1, 0.1, 0.2} + wantRes := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` // Mock server ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/embed_openai.go b/embed_openai.go index 2d301c7..1d09e01 100644 --- a/embed_openai.go +++ b/embed_openai.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "os" + "sync" ) const BaseURLOpenAI = "https://api.openai.com/v1" @@ -39,7 +40,9 @@ func NewEmbeddingFuncDefault() EmbeddingFunc { // NewEmbeddingFuncOpenAI returns a function that creates embeddings for a text // using the OpenAI API. func NewEmbeddingFuncOpenAI(apiKey string, model EmbeddingModelOpenAI) EmbeddingFunc { - return NewEmbeddingFuncOpenAICompat(BaseURLOpenAI, apiKey, string(model)) + // OpenAI embeddings are normalized + normalized := true + return NewEmbeddingFuncOpenAICompat(BaseURLOpenAI, apiKey, string(model), &normalized) } // NewEmbeddingFuncOpenAICompat returns a function that creates embeddings for a text @@ -48,12 +51,20 @@ func NewEmbeddingFuncOpenAI(apiKey string, model EmbeddingModelOpenAI) Embedding // - LitLLM: https://github.com/BerriAI/litellm // - Ollama: https://github.com/ollama/ollama/blob/main/docs/openai.md // - etc. -func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string) EmbeddingFunc { +// +// The `normalized` parameter indicates whether the vectors returned by the embedding +// model are already normalized, as is the case for OpenAI's and Mistral's models. +// The flag is optional. If it's nil, it will be autodetected on the first request +// (which bears a small risk that the vector just happens to have a length of 1). +func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *bool) EmbeddingFunc { // We don't set a default timeout here, although it's usually a good idea. // In our case though, the library user can set the timeout on the context, // and it might have to be a long timeout, depending on the text length. client := &http.Client{} + var checkedNormalized bool + checkNormalized := sync.Once{} + return func(ctx context.Context, text string) ([]float32, error) { // Prepare the request body. reqBody, err := json.Marshal(map[string]string{ @@ -101,6 +112,24 @@ func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string) EmbeddingFunc { return nil, errors.New("no embeddings found in the response") } - return embeddingResponse.Data[0].Embedding, nil + v := embeddingResponse.Data[0].Embedding + if normalized != nil { + if *normalized { + return v, nil + } + return normalizeVector(v), nil + } + checkNormalized.Do(func() { + if isNormalized(v) { + checkedNormalized = true + } else { + checkedNormalized = false + } + }) + if !checkedNormalized { + v = normalizeVector(v) + } + + return v, nil } } diff --git a/embed_openai_test.go b/embed_openai_test.go index 03049b0..5243b81 100644 --- a/embed_openai_test.go +++ b/embed_openai_test.go @@ -33,7 +33,7 @@ func TestNewEmbeddingFuncOpenAICompat(t *testing.T) { if err != nil { t.Fatal("unexpected error:", err) } - wantRes := []float32{-0.1, 0.1, 0.2} + wantRes := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` // Mock server ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -75,7 +75,7 @@ func TestNewEmbeddingFuncOpenAICompat(t *testing.T) { defer ts.Close() baseURL := ts.URL + baseURLSuffix - f := chromem.NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model) + f := chromem.NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model, nil) res, err := f(context.Background(), input) if err != nil { t.Fatal("expected nil, got", err) diff --git a/persistence_test.go b/persistence_test.go index 469126e..c5799af 100644 --- a/persistence_test.go +++ b/persistence_test.go @@ -23,7 +23,7 @@ func TestPersistence(t *testing.T) { } obj := s{ Foo: "test", - Bar: []float32{-0.1, 0.1, 0.2}, + Bar: []float32{-0.40824828, 0.40824828, 0.81649655}, // normalized version of `{-0.1, 0.1, 0.2}` } persist(tempDir, obj) diff --git a/query.go b/query.go index f0544de..9da1be2 100644 --- a/query.go +++ b/query.go @@ -145,7 +145,8 @@ func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Docu return } - sim, err := cosineSimilarity(queryVectors, doc.Embedding) + // As the vectors are normalized, the dot product is the cosine similarity. + sim, err := dotProduct(queryVectors, doc.Embedding) if err != nil { setSharedErr(fmt.Errorf("couldn't calculate similarity for document '%s': %w", doc.ID, err)) return diff --git a/vector.go b/vector.go index 5caaa71..972b6b2 100644 --- a/vector.go +++ b/vector.go @@ -2,6 +2,7 @@ package chromem import ( "errors" + "fmt" "math" ) @@ -20,11 +21,30 @@ func cosineSimilarity(a, b []float32) (float32, error) { if !isNormalized(a) || !isNormalized(b) { a, b = normalizeVector(a), normalizeVector(b) } + dotProduct, err := dotProduct(a, b) + if err != nil { + return 0, fmt.Errorf("couldn't calculate dot product: %w", err) + } + + // Vectors are already normalized, so no need to divide by magnitudes + + return dotProduct, nil +} + +// dotProduct calculates the dot product between two vectors. +// It's the same as cosine similarity for normalized vectors. +// The resulting value represents the similarity, so a higher value means the +// vectors are more similar. +func dotProduct(a, b []float32) (float32, error) { + // The vectors must have the same length + if len(a) != len(b) { + return 0, errors.New("vectors must have the same length") + } + var dotProduct float32 for i := range a { dotProduct += a[i] * b[i] } - // Vectors are already normalized, so no need to divide by magnitudes return dotProduct, nil } From 5656523436c60031aee3d89dd6691eddecdd4f3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Sat, 16 Mar 2024 18:11:58 +0100 Subject: [PATCH 11/12] Clarify query duration in examples --- examples/rag-wikipedia-ollama/README.md | 2 +- examples/rag-wikipedia-ollama/main.go | 4 ++-- examples/semantic-search-arxiv-openai/README.md | 6 +++--- examples/semantic-search-arxiv-openai/main.go | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/rag-wikipedia-ollama/README.md b/examples/rag-wikipedia-ollama/README.md index 539e90c..60612c8 100644 --- a/examples/rag-wikipedia-ollama/README.md +++ b/examples/rag-wikipedia-ollama/README.md @@ -29,7 +29,7 @@ The output can differ slightly on each run, but it's along the lines of: 2024/03/02 20:02:34 Reading JSON lines... 2024/03/02 20:02:34 Adding documents to chromem-go, including creating their embeddings via Ollama API... 2024/03/02 20:03:11 Querying chromem-go... -2024/03/02 20:03:11 Search took 231.672667ms +2024/03/02 20:03:11 Search (incl query embedding) took 231.672667ms 2024/03/02 20:03:11 Document 1 (similarity: 0.723627): "Malleable Iron Range Company was a company that existed from 1896 to 1985 and primarily produced kitchen ranges made of malleable iron but also produced a variety of other related products. The company's primary trademark was 'Monarch' and was colloquially often referred to as the Monarch Company or just Monarch." 2024/03/02 20:03:11 Document 2 (similarity: 0.550584): "The American Motor Car Company was a short-lived company in the automotive industry founded in 1906 lasting until 1913. It was based in Indianapolis Indiana United States. The American Motor Car Company pioneered the underslung design." 2024/03/02 20:03:11 Asking LLM with augmented question... diff --git a/examples/rag-wikipedia-ollama/main.go b/examples/rag-wikipedia-ollama/main.go index 4d451ae..ce1cc47 100644 --- a/examples/rag-wikipedia-ollama/main.go +++ b/examples/rag-wikipedia-ollama/main.go @@ -104,7 +104,7 @@ func main() { if err != nil { panic(err) } - log.Println("Search took", time.Since(start)) + log.Println("Search (incl query embedding) took", time.Since(start)) // Here you could filter out any documents whose similarity is below a certain threshold. // if docRes[...].Similarity < 0.5 { ... @@ -129,7 +129,7 @@ func main() { 2024/03/02 20:02:34 Reading JSON lines... 2024/03/02 20:02:34 Adding documents to chromem-go, including creating their embeddings via Ollama API... 2024/03/02 20:03:11 Querying chromem-go... - 2024/03/02 20:03:11 Search took 231.672667ms + 2024/03/02 20:03:11 Search (incl query embedding) took 231.672667ms 2024/03/02 20:03:11 Document 1 (similarity: 0.723627): "Malleable Iron Range Company was a company that existed from 1896 to 1985 and primarily produced kitchen ranges made of malleable iron but also produced a variety of other related products. The company's primary trademark was 'Monarch' and was colloquially often referred to as the Monarch Company or just Monarch." 2024/03/02 20:03:11 Document 2 (similarity: 0.550584): "The American Motor Car Company was a short-lived company in the automotive industry founded in 1906 lasting until 1913. It was based in Indianapolis Indiana United States. The American Motor Car Company pioneered the underslung design." 2024/03/02 20:03:11 Asking LLM with augmented question... diff --git a/examples/semantic-search-arxiv-openai/README.md b/examples/semantic-search-arxiv-openai/README.md index 1292fcf..829024f 100644 --- a/examples/semantic-search-arxiv-openai/README.md +++ b/examples/semantic-search-arxiv-openai/README.md @@ -12,8 +12,8 @@ This is not a retrieval augmented generation (RAG) app, because after *retrievin 1. Ensure you have [ripgrep](https://github.com/BurntSushi/ripgrep) installed, or adapt the following commands to use grep 2. Run `rg '"categories":"cs.CL"' ~/Downloads/arxiv-metadata-oai-snapshot.json | rg '"update_date":"2023' > /tmp/arxiv_cs-cl_2023.jsonl` (adapt input file path if necessary) 3. Check the data - 1. `wc -l arxiv_cs-cl_2023.jsonl` should show ~5,000 lines - 2. `du -h arxiv_cs-cl_2023.jsonl` should show ~8.8 MB + 1. `wc -l /tmp/arxiv_cs-cl_2023.jsonl` should show ~5,000 lines + 2. `du -h /tmp/arxiv_cs-cl_2023.jsonl` should show ~8.8 MB 2. Set the OpenAI API key in your env as `OPENAI_API_KEY` 3. Run the example: `go run .` @@ -27,7 +27,7 @@ The output can differ slightly on each run, but it's along the lines of: 2024/03/10 18:23:55 Read and parsed 5006 documents. 2024/03/10 18:23:55 Adding documents to chromem-go, including creating their embeddings via OpenAI API... 2024/03/10 18:28:12 Querying chromem-go... - 2024/03/10 18:28:12 Search took 529.451163ms + 2024/03/10 18:28:12 Search (incl query embedding) took 529.451163ms 2024/03/10 18:28:12 Search results: 1) Similarity 0.488895: URL: https://arxiv.org/abs/2209.15469 diff --git a/examples/semantic-search-arxiv-openai/main.go b/examples/semantic-search-arxiv-openai/main.go index e0d341b..c52c366 100644 --- a/examples/semantic-search-arxiv-openai/main.go +++ b/examples/semantic-search-arxiv-openai/main.go @@ -93,7 +93,7 @@ func main() { if err != nil { panic(err) } - log.Println("Search took", time.Since(start)) + log.Println("Search (incl query embedding) took", time.Since(start)) // Here you could filter out any documents whose similarity is below a certain threshold. // if docRes[...].Similarity < 0.5 { ... @@ -117,7 +117,7 @@ func main() { 2024/03/10 18:23:55 Read and parsed 5006 documents. 2024/03/10 18:23:55 Adding documents to chromem-go, including creating their embeddings via OpenAI API... 2024/03/10 18:28:12 Querying chromem-go... - 2024/03/10 18:28:12 Search took 529.451163ms + 2024/03/10 18:28:12 Search (incl query embedding) took 529.451163ms 2024/03/10 18:28:12 Search results: 1) Similarity 0.488895: URL: https://arxiv.org/abs/2209.15469 From 141061251a8a84e7be36d42e987a74f426a450a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Sat, 16 Mar 2024 18:38:58 +0100 Subject: [PATCH 12/12] Update README --- README.md | 59 +++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 40 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index cb6881a..3ce683f 100644 --- a/README.md +++ b/README.md @@ -11,9 +11,7 @@ Because `chromem-go` is embeddable it enables you to add retrieval augmented gen It's *not* a library to connect to Chroma and also not a reimplementation of it in Go. It's a database on its own. -The focus is not scale or number of features, but simplicity. - -Performance has not been a priority yet. Without optimizations (except some parallelization with goroutines) querying 5,000 documents takes ~500ms on a mid-range laptop CPU (11th Gen Intel i5-1135G7, like in the first generation Framework Laptop 13). +The focus is not scale (millions of documents) or number of features, but simplicity and performance for the most common use cases. On a mid-range 2020 Intel laptop CPU you can query 1,000 documents in 0.5 ms and 100,000 documents in 56 ms, both with just 44 memory allocations. See [Benchmarks](#benchmarks) for details. > ⚠️ The project is in beta, under heavy construction, and may introduce breaking changes in releases before `v1.0.0`. All changes are documented in the [`CHANGELOG`](./CHANGELOG.md). @@ -23,8 +21,9 @@ Performance has not been a priority yet. Without optimizations (except some para 2. [Interface](#interface) 3. [Features](#features) 4. [Usage](#usage) -5. [Motivation](#motivation) -6. [Related projects](#related-projects) +5. [Benchmarks](#benchmarks) +6. [Motivation](#motivation) +7. [Related projects](#related-projects) ## Use cases @@ -156,25 +155,25 @@ See the Godoc for details: