From 05c4f764c4bba63c5a225892ca61ba87d38fad95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Sat, 16 Mar 2024 12:26:14 +0100 Subject: [PATCH] Revert "Add "normalized" parameter to skip check if normalization is known" This reverts commit ff28a3807c27ebf65b511d06614cc87b4e50b78b. --- README.md | 2 +- collection.go | 12 ++++------ collection_test.go | 12 +++++----- db.go | 14 ++++------- db_test.go | 18 +++++++------- examples/rag-wikipedia-ollama/main.go | 5 +--- examples/semantic-search-arxiv-openai/main.go | 5 +--- query.go | 4 ++-- vector.go | 24 ++++--------------- 9 files changed, 34 insertions(+), 62 deletions(-) diff --git a/README.md b/README.md index ded8c5f..cb6881a 100644 --- a/README.md +++ b/README.md @@ -98,7 +98,7 @@ func main() { db := chromem.NewDB() // Create collection. GetCollection, GetOrCreateCollection, DeleteCollection also available! - collection, _ := db.CreateCollection("all-my-documents", nil, nil, nil) + collection, _ := db.CreateCollection("all-my-documents", nil, nil) // Add docs to the collection. Update and delete will be added in the future. // Can be multi-threaded with AddConcurrently()! diff --git a/collection.go b/collection.go index 6e1bf4a..22a50c2 100644 --- a/collection.go +++ b/collection.go @@ -22,12 +22,11 @@ type Collection struct { documents map[string]*Document documentsLock sync.RWMutex embed EmbeddingFunc - normalized *bool } // We don't export this yet to keep the API surface to the bare minimum. // Users create collections via [Client.CreateCollection]. -func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, normalized *bool, dir string) (*Collection, error) { +func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, dir string) (*Collection, error) { // We copy the metadata to avoid data races in case the caller modifies the // map after creating the collection while we range over it. m := make(map[string]string, len(metadata)) @@ -38,10 +37,9 @@ func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, c := &Collection{ Name: name, - metadata: m, - documents: make(map[string]*Document), - embed: embed, - normalized: normalized, + metadata: m, + documents: make(map[string]*Document), + embed: embed, } // Persistence @@ -303,7 +301,7 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int, } // For the remaining documents, calculate cosine similarity. - docSim, err := calcDocSimilarity(ctx, queryVectors, filteredDocs, c.normalized) + docSim, err := calcDocSimilarity(ctx, queryVectors, filteredDocs) if err != nil { return nil, fmt.Errorf("couldn't calculate cosine similarity: %w", err) } diff --git a/collection_test.go b/collection_test.go index 2202c38..883fdab 100644 --- a/collection_test.go +++ b/collection_test.go @@ -20,7 +20,7 @@ func TestCollection_Add(t *testing.T) { // Create collection db := NewDB() - c, err := db.CreateCollection(name, metadata, embeddingFunc, nil) + c, err := db.CreateCollection(name, metadata, embeddingFunc) if err != nil { t.Fatal("expected no error, got", err) } @@ -116,7 +116,7 @@ func TestCollection_Add_Error(t *testing.T) { // Create collection db := NewDB() - c, err := db.CreateCollection(name, metadata, embeddingFunc, nil) + c, err := db.CreateCollection(name, metadata, embeddingFunc) if err != nil { t.Fatal("expected no error, got", err) } @@ -167,7 +167,7 @@ func TestCollection_AddConcurrently(t *testing.T) { // Create collection db := NewDB() - c, err := db.CreateCollection(name, metadata, embeddingFunc, nil) + c, err := db.CreateCollection(name, metadata, embeddingFunc) if err != nil { t.Fatal("expected no error, got", err) } @@ -263,7 +263,7 @@ func TestCollection_AddConcurrently_Error(t *testing.T) { // Create collection db := NewDB() - c, err := db.CreateCollection(name, metadata, embeddingFunc, nil) + c, err := db.CreateCollection(name, metadata, embeddingFunc) if err != nil { t.Fatal("expected no error, got", err) } @@ -316,7 +316,7 @@ func TestCollection_Count(t *testing.T) { embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { return []float32{-0.1, 0.1, 0.2}, nil } - c, err := db.CreateCollection(name, metadata, embeddingFunc, nil) + c, err := db.CreateCollection(name, metadata, embeddingFunc) if err != nil { t.Fatal("expected no error, got", err) } @@ -407,7 +407,7 @@ func benchmarkCollection_Query(b *testing.B, n int, withContent bool) { // Create collection db := NewDB() name := "test" - c, err := db.CreateCollection(name, nil, embeddingFunc, &trueVal) + c, err := db.CreateCollection(name, nil, embeddingFunc) if err != nil { b.Fatal("expected no error, got", err) } diff --git a/db.go b/db.go index ba39b0b..42ab698 100644 --- a/db.go +++ b/db.go @@ -142,17 +142,14 @@ func NewPersistentDB(path string) (*DB, error) { // - metadata: Optional metadata to associate with the collection. // - embeddingFunc: Optional function to use to embed documents. // Uses the default embedding function if not provided. -// - normalized: Optional flag to indicate if the embeddings of the collection -// are normalized (when you add embeddings yourself, or the embeddings created -// by the embeddingFunc). If nil it will be autodetected. -func (db *DB) CreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc, normalized *bool) (*Collection, error) { +func (db *DB) CreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc) (*Collection, error) { if name == "" { return nil, errors.New("collection name is empty") } if embeddingFunc == nil { embeddingFunc = NewEmbeddingFuncDefault() } - collection, err := newCollection(name, metadata, embeddingFunc, normalized, db.persistDirectory) + collection, err := newCollection(name, metadata, embeddingFunc, db.persistDirectory) if err != nil { return nil, fmt.Errorf("couldn't create collection: %w", err) } @@ -216,15 +213,12 @@ func (db *DB) GetCollection(name string, embeddingFunc EmbeddingFunc) *Collectio // - metadata: Optional metadata to associate with the collection. // - embeddingFunc: Optional function to use to embed documents. // Uses the default embedding function if not provided. -// - normalized: Optional flag to indicate if the embeddings of the collection -// are normalized (when you add embeddings yourself, or the embeddings created -// by the embeddingFunc). If nil it will be autodetected. -func (db *DB) GetOrCreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc, normalized *bool) (*Collection, error) { +func (db *DB) GetOrCreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc) (*Collection, error) { // No need to lock here, because the methods we call do that. collection := db.GetCollection(name, embeddingFunc) if collection == nil { var err error - collection, err = db.CreateCollection(name, metadata, embeddingFunc, normalized) + collection, err = db.CreateCollection(name, metadata, embeddingFunc) if err != nil { return nil, fmt.Errorf("couldn't create collection: %w", err) } diff --git a/db_test.go b/db_test.go index 2248451..ad04bbe 100644 --- a/db_test.go +++ b/db_test.go @@ -18,7 +18,7 @@ func TestDB_CreateCollection(t *testing.T) { db := NewDB() t.Run("OK", func(t *testing.T) { - c, err := db.CreateCollection(name, metadata, embeddingFunc, nil) + c, err := db.CreateCollection(name, metadata, embeddingFunc) if err != nil { t.Fatal("expected no error, got", err) } @@ -70,7 +70,7 @@ func TestDB_CreateCollection(t *testing.T) { }) t.Run("NOK - Empty name", func(t *testing.T) { - _, err := db.CreateCollection("", metadata, embeddingFunc, nil) + _, err := db.CreateCollection("", metadata, embeddingFunc) if err == nil { t.Fatal("expected error, got nil") } @@ -89,7 +89,7 @@ func TestDB_ListCollections(t *testing.T) { // Create initial collection db := NewDB() // We ignore the return value. CreateCollection is tested elsewhere. - _, err := db.CreateCollection(name, metadata, embeddingFunc, nil) + _, err := db.CreateCollection(name, metadata, embeddingFunc) if err != nil { t.Fatal("expected no error, got", err) } @@ -155,7 +155,7 @@ func TestDB_GetCollection(t *testing.T) { // Create initial collection db := NewDB() // We ignore the return value. CreateCollection is tested elsewhere. - _, err := db.CreateCollection(name, metadata, embeddingFunc, nil) + _, err := db.CreateCollection(name, metadata, embeddingFunc) if err != nil { t.Fatal("expected no error, got", err) } @@ -207,7 +207,7 @@ func TestDB_GetOrCreateCollection(t *testing.T) { // Create collection so that the GetOrCreateCollection() call below only // gets it. // We ignore the return value. CreateCollection is tested elsewhere. - _, err := db.CreateCollection(name, metadata, embeddingFunc, nil) + _, err := db.CreateCollection(name, metadata, embeddingFunc) if err != nil { t.Fatal("expected no error, got", err) } @@ -215,7 +215,7 @@ func TestDB_GetOrCreateCollection(t *testing.T) { // Call GetOrCreateCollection() with the same name to only get it. We pass // nil for the metadata and embeddingFunc so we can check that the returned // collection is the original one, and not a new one. - c, err := db.GetOrCreateCollection(name, nil, embeddingFunc, nil) + c, err := db.GetOrCreateCollection(name, nil, nil) if err != nil { t.Fatal("expected no error, got", err) } @@ -257,7 +257,7 @@ func TestDB_GetOrCreateCollection(t *testing.T) { db := NewDB() // Call GetOrCreateCollection() - c, err := db.GetOrCreateCollection(name, metadata, embeddingFunc, nil) + c, err := db.GetOrCreateCollection(name, metadata, embeddingFunc) if err != nil { t.Fatal("expected no error, got", err) } @@ -307,7 +307,7 @@ func TestDB_DeleteCollection(t *testing.T) { // Create initial collection db := NewDB() // We ignore the return value. CreateCollection is tested elsewhere. - _, err := db.CreateCollection(name, metadata, embeddingFunc, nil) + _, err := db.CreateCollection(name, metadata, embeddingFunc) if err != nil { t.Fatal("expected no error, got", err) } @@ -339,7 +339,7 @@ func TestDB_Reset(t *testing.T) { // Create initial collection db := NewDB() // We ignore the return value. CreateCollection is tested elsewhere. - _, err := db.CreateCollection(name, metadata, embeddingFunc, nil) + _, err := db.CreateCollection(name, metadata, embeddingFunc) if err != nil { t.Fatal("expected no error, got", err) } diff --git a/examples/rag-wikipedia-ollama/main.go b/examples/rag-wikipedia-ollama/main.go index 982e908..4d451ae 100644 --- a/examples/rag-wikipedia-ollama/main.go +++ b/examples/rag-wikipedia-ollama/main.go @@ -19,9 +19,6 @@ const ( embeddingModel = "nomic-embed-text" ) -// The nomic-embed-text-v1.5 model doesn't return normalized embeddings -var normalized = false - func main() { ctx := context.Background() @@ -52,7 +49,7 @@ func main() { // variable to be set. // For this example we choose to use a locally running embedding model though. // It requires Ollama to serve its API at "http://localhost:11434/api". - collection, err := db.GetOrCreateCollection("Wikipedia", nil, chromem.NewEmbeddingFuncOllama(embeddingModel), &normalized) + collection, err := db.GetOrCreateCollection("Wikipedia", nil, chromem.NewEmbeddingFuncOllama(embeddingModel)) if err != nil { panic(err) } diff --git a/examples/semantic-search-arxiv-openai/main.go b/examples/semantic-search-arxiv-openai/main.go index dd55b18..e0d341b 100644 --- a/examples/semantic-search-arxiv-openai/main.go +++ b/examples/semantic-search-arxiv-openai/main.go @@ -16,9 +16,6 @@ import ( const searchTerm = "semantic search with vector databases" -// OpenAI embeddings are already normalized. -var normalized = true - func main() { ctx := context.Background() @@ -33,7 +30,7 @@ func main() { // We pass nil as embedding function to use the default (OpenAI text-embedding-3-small), // which is very good and cheap. It requires the OPENAI_API_KEY environment // variable to be set. - collection, err := db.GetOrCreateCollection("arXiv cs.CL 2023", nil, nil, &normalized) + collection, err := db.GetOrCreateCollection("arXiv cs.CL 2023", nil, nil) if err != nil { panic(err) } diff --git a/query.go b/query.go index c8a6144..f0544de 100644 --- a/query.go +++ b/query.go @@ -95,7 +95,7 @@ func documentMatchesFilters(document *Document, where, whereDocument map[string] return true } -func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Document, isNormalized *bool) ([]docSim, error) { +func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Document) ([]docSim, error) { similarities := make([]docSim, 0, len(docs)) similaritiesLock := sync.Mutex{} @@ -145,7 +145,7 @@ func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Docu return } - sim, err := cosineSimilarity(queryVectors, doc.Embedding, isNormalized) + sim, err := cosineSimilarity(queryVectors, doc.Embedding) if err != nil { setSharedErr(fmt.Errorf("couldn't calculate similarity for document '%s': %w", doc.ID, err)) return diff --git a/vector.go b/vector.go index 9dfba02..5caaa71 100644 --- a/vector.go +++ b/vector.go @@ -7,33 +7,19 @@ import ( const isNormalizedPrecisionTolerance = 1e-6 -var ( - falseVal = false - trueVal = true -) - // cosineSimilarity calculates the cosine similarity between two vectors. -// Pass isNormalized=true if the vectors are already normalized, false -// to normalize them, and nil to autodetect. +// Vectors are normalized first. // The resulting value represents the similarity, so a higher value means the // vectors are more similar. -func cosineSimilarity(a, b []float32, isNormalized *bool) (float32, error) { +func cosineSimilarity(a, b []float32) (float32, error) { // The vectors must have the same length if len(a) != len(b) { return 0, errors.New("vectors must have the same length") } - if isNormalized == nil { - if !checkNormalized(a) || !checkNormalized(b) { - isNormalized = &falseVal - } else { - isNormalized = &trueVal - } - } - if !*isNormalized { + if !isNormalized(a) || !isNormalized(b) { a, b = normalizeVector(a), normalizeVector(b) } - var dotProduct float32 for i := range a { dotProduct += a[i] * b[i] @@ -58,8 +44,8 @@ func normalizeVector(v []float32) []float32 { return res } -// checkNormalized checks if the vector is normalized. -func checkNormalized(v []float32) bool { +// isNormalized checks if the vector is normalized. +func isNormalized(v []float32) bool { var sqSum float64 for _, val := range v { sqSum += float64(val) * float64(val)