From 015604c5c2aae903900d7d42d317cc1573a07902 Mon Sep 17 00:00:00 2001 From: Thorsten Klein Date: Thu, 27 Jun 2024 15:18:50 +0200 Subject: [PATCH 1/4] add: option to only import/export selected collections to/from a DB --- db.go | 51 +++++++++++++++++++++++++++++++++++---------------- 1 file changed, 35 insertions(+), 16 deletions(-) diff --git a/db.go b/db.go index 7253e94..7b08439 100644 --- a/db.go +++ b/db.go @@ -8,6 +8,7 @@ import ( "io/fs" "os" "path/filepath" + "slices" "strings" "sync" ) @@ -198,9 +199,11 @@ func (db *DB) Import(filePath string, encryptionKey string) error { // This works for both the in-memory and persistent DBs. // Existing collections are overwritten. // -// - filePath: Mandatory, must not be empty -// - encryptionKey: Optional, must be 32 bytes long if provided -func (db *DB) ImportFromFile(filePath string, encryptionKey string) error { +// - filePath: Mandatory, must not be empty +// - encryptionKey: Optional, must be 32 bytes long if provided +// - collections: Optional. If provided, only the collections with the given names +// are imported. If not provided, all collections are imported. +func (db *DB) ImportFromFile(filePath string, encryptionKey string, collections ...string) error { if filePath == "" { return fmt.Errorf("file path is empty") } @@ -244,6 +247,9 @@ func (db *DB) ImportFromFile(filePath string, encryptionKey string) error { } for _, pc := range persistenceDB.Collections { + if len(collections) > 0 && !slices.Contains(collections, pc.Name) { + continue + } c := &Collection{ Name: pc.Name, @@ -267,9 +273,11 @@ func (db *DB) ImportFromFile(filePath string, encryptionKey string) error { // Existing collections are overwritten. // If the writer has to be closed, it's the caller's responsibility. // -// - reader: An implementation of [io.ReadSeeker] -// - encryptionKey: Optional, must be 32 bytes long if provided -func (db *DB) ImportFromReader(reader io.ReadSeeker, encryptionKey string) error { +// - reader: An implementation of [io.ReadSeeker] +// - encryptionKey: Optional, must be 32 bytes long if provided +// - collections: Optional. If provided, only the collections with the given names +// are imported. If not provided, all collections are imported. +func (db *DB) ImportFromReader(reader io.ReadSeeker, encryptionKey string, collections ...string) error { if encryptionKey != "" { // AES 256 requires a 32 byte key if len(encryptionKey) != 32 { @@ -299,6 +307,9 @@ func (db *DB) ImportFromReader(reader io.ReadSeeker, encryptionKey string) error } for _, pc := range persistenceDB.Collections { + if len(collections) > 0 && !slices.Contains(collections, pc.Name) { + continue + } c := &Collection{ Name: pc.Name, @@ -339,7 +350,9 @@ func (db *DB) Export(filePath string, compress bool, encryptionKey string) error // - compress: Optional. Compresses as gzip if true. // - encryptionKey: Optional. Encrypts with AES-GCM if provided. Must be 32 bytes // long if provided. -func (db *DB) ExportToFile(filePath string, compress bool, encryptionKey string) error { +// - collections: Optional. If provided, only the collections with the given names +// are exported. If not provided, all collections are exported. +func (db *DB) ExportToFile(filePath string, compress bool, encryptionKey string, collections ...string) error { if filePath == "" { filePath = "./chromem-go.gob" if compress { @@ -373,10 +386,12 @@ func (db *DB) ExportToFile(filePath string, compress bool, encryptionKey string) defer db.collectionsLock.RUnlock() for k, v := range db.collections { - persistenceDB.Collections[k] = &persistenceCollection{ - Name: v.Name, - Metadata: v.metadata, - Documents: v.documents, + if len(collections) == 0 || slices.Contains(collections, k) { + persistenceDB.Collections[k] = &persistenceCollection{ + Name: v.Name, + Metadata: v.metadata, + Documents: v.documents, + } } } @@ -397,7 +412,9 @@ func (db *DB) ExportToFile(filePath string, compress bool, encryptionKey string) // - compress: Optional. Compresses as gzip if true. // - encryptionKey: Optional. Encrypts with AES-GCM if provided. Must be 32 bytes // long if provided. -func (db *DB) ExportToWriter(writer io.Writer, compress bool, encryptionKey string) error { +// - collections: Optional. If provided, only the collections with the given names +// are exported. If not provided, all collections are exported. +func (db *DB) ExportToWriter(writer io.Writer, compress bool, encryptionKey string, collections ...string) error { if encryptionKey != "" { // AES 256 requires a 32 byte key if len(encryptionKey) != 32 { @@ -422,10 +439,12 @@ func (db *DB) ExportToWriter(writer io.Writer, compress bool, encryptionKey stri defer db.collectionsLock.RUnlock() for k, v := range db.collections { - persistenceDB.Collections[k] = &persistenceCollection{ - Name: v.Name, - Metadata: v.metadata, - Documents: v.documents, + if len(collections) == 0 || slices.Contains(collections, k) { + persistenceDB.Collections[k] = &persistenceCollection{ + Name: v.Name, + Metadata: v.metadata, + Documents: v.documents, + } } } From fde4a02fd7af7f4316d05f41b138b4e97c49546d Mon Sep 17 00:00:00 2001 From: Thorsten Klein Date: Mon, 1 Jul 2024 15:38:10 +0200 Subject: [PATCH 2/4] fix: persist imported documents when importing to a persistent database --- db.go | 14 ++++++ db_test.go | 123 +++++++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 133 insertions(+), 4 deletions(-) diff --git a/db.go b/db.go index 7b08439..3db30d1 100644 --- a/db.go +++ b/db.go @@ -259,6 +259,13 @@ func (db *DB) ImportFromFile(filePath string, encryptionKey string, collections if db.persistDirectory != "" { c.persistDirectory = filepath.Join(db.persistDirectory, hash2hex(pc.Name)) c.compress = db.compress + for _, doc := range c.documents { + docPath := c.getDocPath(doc.ID) + err := persistToFile(docPath, doc, c.compress, "") + if err != nil { + return fmt.Errorf("couldn't persist document to %q: %w", docPath, err) + } + } } db.collections[c.Name] = c } @@ -319,6 +326,13 @@ func (db *DB) ImportFromReader(reader io.ReadSeeker, encryptionKey string, colle if db.persistDirectory != "" { c.persistDirectory = filepath.Join(db.persistDirectory, hash2hex(pc.Name)) c.compress = db.compress + for _, doc := range c.documents { + docPath := c.getDocPath(doc.ID) + err := persistToFile(docPath, doc, c.compress, "") + if err != nil { + return fmt.Errorf("couldn't persist document to %q: %w", docPath, err) + } + } } db.collections[c.Name] = c } diff --git a/db_test.go b/db_test.go index 791d788..93f63d3 100644 --- a/db_test.go +++ b/db_test.go @@ -144,10 +144,10 @@ func TestDB_ImportExport(t *testing.T) { t.Fatal("expected no error, got", err) } - new := NewDB() + newDB := NewDB() // Import - err = new.ImportFromFile(tc.filePath, tc.encryptionKey) + err = newDB.ImportFromFile(tc.filePath, tc.encryptionKey) if err != nil { t.Fatal("expected no error, got", err) } @@ -156,13 +156,128 @@ func TestDB_ImportExport(t *testing.T) { // We have to reset the embed function, but otherwise the DB objects // should be deep equal. c.embed = nil - if !reflect.DeepEqual(orig, new) { - t.Fatalf("expected DB %+v, got %+v", orig, new) + if !reflect.DeepEqual(orig, newDB) { + t.Fatalf("expected DB %+v, got %+v", orig, newDB) } }) } } +func TestDB_ImportExportSpecificCollections(t *testing.T) { + r := rand.New(rand.NewSource(rand.Int63())) + randString := randomString(r, 10) + path := filepath.Join(os.TempDir(), randString) + filePath := path + ".gob" + defer os.RemoveAll(path) + + // Values in the collection + name := "test" + name2 := "test2" + metadata := map[string]string{"foo": "bar"} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` + embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { + return vectors, nil + } + + // Create DB, can just be in-memory + orig := NewDB() + + // Create collections + c, err := orig.CreateCollection(name, metadata, embeddingFunc) + if err != nil { + t.Fatal("expected no error, got", err) + } + + c2, err := orig.CreateCollection(name2, metadata, embeddingFunc) + if err != nil { + t.Fatal("expected no error, got", err) + } + + // Add documents + doc := Document{ + ID: name, + Metadata: metadata, + Embedding: vectors, + Content: "test", + } + + doc2 := Document{ + ID: name2, + Metadata: metadata, + Embedding: vectors, + Content: "test2", + } + + err = c.AddDocument(context.Background(), doc) + if err != nil { + t.Fatal("expected no error, got", err) + } + + err = c2.AddDocument(context.Background(), doc2) + if err != nil { + t.Fatal("expected no error, got", err) + } + + // Export + err = orig.ExportToFile(filePath, false, "", name2) + if err != nil { + t.Fatal("expected no error, got", err) + } + + dir := filepath.Join(path, randomString(r, 10)) + defer os.RemoveAll(dir) + + newPDB, err := NewPersistentDB(dir, false) + if err != nil { + t.Fatal("expected no error, got", err) + } + + err = newPDB.ImportFromFile(filePath, "") + if err != nil { + t.Fatal("expected no error, got", err) + } + + if len(newPDB.collections) != 1 { + t.Fatalf("expected 1 collection, got %d", len(newPDB.collections)) + } + + // Make sure that the imported documents are actually persisted on disk + for _, col := range newPDB.collections { + for _, d := range col.documents { + _, err = os.Stat(col.getDocPath(d.ID)) + if err != nil { + t.Fatalf("expected no error when looking up persistent file for doc %q, got %v", d.ID, err) + } + } + } + + // Now export both collections and import them into the same persistent DB (overwriting the one we just imported) + filePath2 := path + "2.gob" + err = orig.ExportToFile(filePath2, false, "") + if err != nil { + t.Fatal("expected no error, got", err) + } + + err = newPDB.ImportFromFile(filePath2, "") + if err != nil { + t.Fatal("expected no error, got", err) + } + + if len(newPDB.collections) != 2 { + t.Fatalf("expected 2 collections, got %d", len(newPDB.collections)) + } + + // Make sure that the imported documents are actually persisted on disk + for _, col := range newPDB.collections { + for _, d := range col.documents { + _, err = os.Stat(col.getDocPath(d.ID)) + if err != nil { + t.Fatalf("expected no error when looking up persistent file for doc %q, got %v", d.ID, err) + } + } + } +} + func TestDB_CreateCollection(t *testing.T) { // Values in the collection name := "test" From 49eb4988eab12a8c23c01e638c4efd232168c34e Mon Sep 17 00:00:00 2001 From: Thorsten Klein Date: Mon, 1 Jul 2024 15:59:46 +0200 Subject: [PATCH 3/4] fix: also persist metadata files on import to persistent db --- collection.go | 41 ++++++++++++++++++++++++----------------- db.go | 10 +++++++++- 2 files changed, 33 insertions(+), 18 deletions(-) diff --git a/collection.go b/collection.go index 7c17c93..5a00f4a 100644 --- a/collection.go +++ b/collection.go @@ -111,23 +111,7 @@ func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, safeName := hash2hex(name) c.persistDirectory = filepath.Join(dbDir, safeName) c.compress = compress - // Persist name and metadata - metadataPath := filepath.Join(c.persistDirectory, metadataFileName) - metadataPath += ".gob" - if c.compress { - metadataPath += ".gz" - } - pc := struct { - Name string - Metadata map[string]string - }{ - Name: name, - Metadata: m, - } - err := persistToFile(metadataPath, pc, compress, "") - if err != nil { - return nil, fmt.Errorf("couldn't persist collection metadata: %w", err) - } + return c, c.persistMetadata() } return c, nil @@ -545,3 +529,26 @@ func (c *Collection) getDocPath(docID string) string { } return docPath } + +// persistMetadata persists the collection metadata to disk +func (c *Collection) persistMetadata() error { + // Persist name and metadata + metadataPath := filepath.Join(c.persistDirectory, metadataFileName) + metadataPath += ".gob" + if c.compress { + metadataPath += ".gz" + } + pc := struct { + Name string + Metadata map[string]string + }{ + Name: c.Name, + Metadata: c.metadata, + } + err := persistToFile(metadataPath, pc, c.compress, "") + if err != nil { + return err + } + + return nil +} diff --git a/db.go b/db.go index 3db30d1..7b89686 100644 --- a/db.go +++ b/db.go @@ -259,9 +259,13 @@ func (db *DB) ImportFromFile(filePath string, encryptionKey string, collections if db.persistDirectory != "" { c.persistDirectory = filepath.Join(db.persistDirectory, hash2hex(pc.Name)) c.compress = db.compress + err = c.persistMetadata() + if err != nil { + return fmt.Errorf("couldn't persist collection metadata: %w", err) + } for _, doc := range c.documents { docPath := c.getDocPath(doc.ID) - err := persistToFile(docPath, doc, c.compress, "") + err = persistToFile(docPath, doc, c.compress, "") if err != nil { return fmt.Errorf("couldn't persist document to %q: %w", docPath, err) } @@ -326,6 +330,10 @@ func (db *DB) ImportFromReader(reader io.ReadSeeker, encryptionKey string, colle if db.persistDirectory != "" { c.persistDirectory = filepath.Join(db.persistDirectory, hash2hex(pc.Name)) c.compress = db.compress + err = c.persistMetadata() + if err != nil { + return fmt.Errorf("couldn't persist collection metadata: %w", err) + } for _, doc := range c.documents { docPath := c.getDocPath(doc.ID) err := persistToFile(docPath, doc, c.compress, "") From a4a56530ad1760598e9e2f7def5483dedc67a99c Mon Sep 17 00:00:00 2001 From: Thorsten Klein Date: Wed, 3 Jul 2024 10:20:46 +0200 Subject: [PATCH 4/4] fix: pr change requests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Philipp Gillé --- db.go | 12 ++++++++---- db_test.go | 25 +++++++++++++------------ 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/db.go b/db.go index 7b89686..001271e 100644 --- a/db.go +++ b/db.go @@ -202,7 +202,8 @@ func (db *DB) Import(filePath string, encryptionKey string) error { // - filePath: Mandatory, must not be empty // - encryptionKey: Optional, must be 32 bytes long if provided // - collections: Optional. If provided, only the collections with the given names -// are imported. If not provided, all collections are imported. +// are imported. Non-existing collections are ignored. +// If not provided, all collections are imported. func (db *DB) ImportFromFile(filePath string, encryptionKey string, collections ...string) error { if filePath == "" { return fmt.Errorf("file path is empty") @@ -287,7 +288,8 @@ func (db *DB) ImportFromFile(filePath string, encryptionKey string, collections // - reader: An implementation of [io.ReadSeeker] // - encryptionKey: Optional, must be 32 bytes long if provided // - collections: Optional. If provided, only the collections with the given names -// are imported. If not provided, all collections are imported. +// are imported. Non-existing collections are ignored. +// If not provided, all collections are imported. func (db *DB) ImportFromReader(reader io.ReadSeeker, encryptionKey string, collections ...string) error { if encryptionKey != "" { // AES 256 requires a 32 byte key @@ -373,7 +375,8 @@ func (db *DB) Export(filePath string, compress bool, encryptionKey string) error // - encryptionKey: Optional. Encrypts with AES-GCM if provided. Must be 32 bytes // long if provided. // - collections: Optional. If provided, only the collections with the given names -// are exported. If not provided, all collections are exported. +// are exported. Non-existing collections are ignored. +// If not provided, all collections are exported. func (db *DB) ExportToFile(filePath string, compress bool, encryptionKey string, collections ...string) error { if filePath == "" { filePath = "./chromem-go.gob" @@ -435,7 +438,8 @@ func (db *DB) ExportToFile(filePath string, compress bool, encryptionKey string, // - encryptionKey: Optional. Encrypts with AES-GCM if provided. Must be 32 bytes // long if provided. // - collections: Optional. If provided, only the collections with the given names -// are exported. If not provided, all collections are exported. +// are exported. Non-existing collections are ignored. +// If not provided, all collections are exported. func (db *DB) ExportToWriter(writer io.Writer, compress bool, encryptionKey string, collections ...string) error { if encryptionKey != "" { // AES 256 requires a 32 byte key diff --git a/db_test.go b/db_test.go index 93f63d3..ac82d6a 100644 --- a/db_test.go +++ b/db_test.go @@ -116,10 +116,10 @@ func TestDB_ImportExport(t *testing.T) { for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { // Create DB, can just be in-memory - orig := NewDB() + origDB := NewDB() // Create collection - c, err := orig.CreateCollection(name, metadata, embeddingFunc) + c, err := origDB.CreateCollection(name, metadata, embeddingFunc) if err != nil { t.Fatal("expected no error, got", err) } @@ -139,7 +139,7 @@ func TestDB_ImportExport(t *testing.T) { } // Export - err = orig.ExportToFile(tc.filePath, tc.compress, tc.encryptionKey) + err = origDB.ExportToFile(tc.filePath, tc.compress, tc.encryptionKey) if err != nil { t.Fatal("expected no error, got", err) } @@ -156,8 +156,8 @@ func TestDB_ImportExport(t *testing.T) { // We have to reset the embed function, but otherwise the DB objects // should be deep equal. c.embed = nil - if !reflect.DeepEqual(orig, newDB) { - t.Fatalf("expected DB %+v, got %+v", orig, newDB) + if !reflect.DeepEqual(origDB, newDB) { + t.Fatalf("expected DB %+v, got %+v", origDB, newDB) } }) } @@ -180,15 +180,15 @@ func TestDB_ImportExportSpecificCollections(t *testing.T) { } // Create DB, can just be in-memory - orig := NewDB() + origDB := NewDB() // Create collections - c, err := orig.CreateCollection(name, metadata, embeddingFunc) + c, err := origDB.CreateCollection(name, metadata, embeddingFunc) if err != nil { t.Fatal("expected no error, got", err) } - c2, err := orig.CreateCollection(name2, metadata, embeddingFunc) + c2, err := origDB.CreateCollection(name2, metadata, embeddingFunc) if err != nil { t.Fatal("expected no error, got", err) } @@ -218,8 +218,8 @@ func TestDB_ImportExportSpecificCollections(t *testing.T) { t.Fatal("expected no error, got", err) } - // Export - err = orig.ExportToFile(filePath, false, "", name2) + // Export only one of the two collections + err = origDB.ExportToFile(filePath, false, "", name2) if err != nil { t.Fatal("expected no error, got", err) } @@ -227,6 +227,7 @@ func TestDB_ImportExportSpecificCollections(t *testing.T) { dir := filepath.Join(path, randomString(r, 10)) defer os.RemoveAll(dir) + // Instead of importing to an in-memory DB we use a persistent one to cover the behavior of immediate persistent files being created for the imported data newPDB, err := NewPersistentDB(dir, false) if err != nil { t.Fatal("expected no error, got", err) @@ -252,8 +253,8 @@ func TestDB_ImportExportSpecificCollections(t *testing.T) { } // Now export both collections and import them into the same persistent DB (overwriting the one we just imported) - filePath2 := path + "2.gob" - err = orig.ExportToFile(filePath2, false, "") + filePath2 := filepath.Join(path, "2.gob") + err = origDB.ExportToFile(filePath2, false, "") if err != nil { t.Fatal("expected no error, got", err) }