From bb6a18e97db1a7f3794536c60bbf2d0a1b702610 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Sun, 17 Mar 2024 20:00:50 +0100 Subject: [PATCH 01/20] Implement DB export With optional compression and encryption --- db.go | 107 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) diff --git a/db.go b/db.go index 74e550e..7fc4383 100644 --- a/db.go +++ b/db.go @@ -1,9 +1,16 @@ package chromem import ( + "bytes" + "compress/gzip" "context" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/gob" "errors" "fmt" + "io" "io/fs" "os" "path/filepath" @@ -155,6 +162,106 @@ func NewPersistentDB(path string) (*DB, error) { return db, nil } +// TODO: Godoc +func (db *DB) Import(filePath string, decryptionKey string) error { + return errors.New("not implemented") // TODO: implement +} + +// TODO: Godoc +func (db *DB) Export(filePath string, compress bool, encryptionKey string) error { + if filePath == "" { + filePath = "./chromem-go.gob" + if encryptionKey != "" { + filePath += ".enc" + } + if compress { + filePath += ".gz" + } + } + + // AES 256 requires a 32 byte key + if encryptionKey != "" { + if len(encryptionKey) != 32 { + return errors.New("encryption key must be 32 bytes long") + } + } + + // If path doesn't exist, create the parent path. + // If path exists and it's a directory, return an error. + fi, err := os.Stat(filePath) + if err != nil { + if !errors.Is(err, fs.ErrNotExist) { + return fmt.Errorf("couldn't get info about export path: %w", err) + } else { + // If the file doesn't exist, create the parent path + err := os.MkdirAll(filepath.Dir(filePath), 0o700) + if err != nil { + return fmt.Errorf("couldn't create export directory: %w", err) + } + } + } else if fi.IsDir() { + return fmt.Errorf("path is a directory: %s", filePath) + } + + // Open file for writing + f, err := os.Create(filePath) + if err != nil { + return fmt.Errorf("couldn't create file: %w", err) + } + defer f.Close() + + // We want to: + // Encode as gob -> compress with flate -> encrypt with AES-GCM -> write file. + // To reduce memory usage we chain the writers instead of buffering, so we start + // from the end. For AES GCM sealing the stdlib doesn't provide a writer though. + + var w io.Writer + if encryptionKey == "" { + w = f + } else { + w = &bytes.Buffer{} + } + if compress { + gzw := gzip.NewWriter(w) + defer gzw.Close() + w = gzw + } + enc := gob.NewEncoder(w) + + // Start encoding, it will write to the chain of writers. + if err := enc.Encode(db); err != nil { + return fmt.Errorf("couldn't encode DB as gob: %w", err) + } + + // Without encyrption, the chain is done and the file is written. + if encryptionKey == "" { + return nil + } + + // Otherwise, encrypt and then write to the file + block, err := aes.NewCipher([]byte(encryptionKey)) + if err != nil { + return fmt.Errorf("couldn't create new AES cipher: %w", err) + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return fmt.Errorf("couldn't create GCM wrapper: %w", err) + } + nonce := make([]byte, gcm.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return fmt.Errorf("couldn't read random bytes for nonce: %w", err) + } + // w is a *bytes.Buffer + buf := w.(*bytes.Buffer) + encrypted := gcm.Seal(nonce, nonce, buf.Bytes(), nil) + _, err = f.Write(encrypted) + if err != nil { + return fmt.Errorf("couldn't write encrypted data: %w", err) + } + + return nil +} + // CreateCollection creates a new collection with the given name and metadata. // // - name: The name of the collection to create. From 10739d0d9bd24dfa4dd1351ea9b480e991b4fa96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Mon, 18 Mar 2024 22:35:05 +0100 Subject: [PATCH 02/20] Move DB export persistence code to persistence.go This requires the ".gob" suffix to be added in the callers. --- collection.go | 14 ++++--- db.go | 89 +----------------------------------------- persistence.go | 95 +++++++++++++++++++++++++++++++++++++++++---- persistence_test.go | 3 +- 4 files changed, 99 insertions(+), 102 deletions(-) diff --git a/collection.go b/collection.go index 52d8844..7d53625 100644 --- a/collection.go +++ b/collection.go @@ -25,7 +25,7 @@ type Collection struct { // We don't export this yet to keep the API surface to the bare minimum. // Users create collections via [Client.CreateCollection]. -func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, dir string) (*Collection, error) { +func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, dbDir string) (*Collection, error) { // We copy the metadata to avoid data races in case the caller modifies the // map after creating the collection while we range over it. m := make(map[string]string, len(metadata)) @@ -42,9 +42,9 @@ func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, } // Persistence - if dir != "" { + if dbDir != "" { safeName := hash2hex(name) - c.persistDirectory = filepath.Join(dir, safeName) + c.persistDirectory = filepath.Join(dbDir, safeName) // Create dir err := os.MkdirAll(c.persistDirectory, 0o700) if err != nil { @@ -52,6 +52,7 @@ func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, } // Persist name and metadata metadataPath := filepath.Join(c.persistDirectory, metadataFileName) + metadataPath += ".gob" pc := struct { Name string Metadata map[string]string @@ -59,7 +60,7 @@ func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, Name: name, Metadata: m, } - err = persist(metadataPath, pc) + err = persist(metadataPath, pc, false, "") if err != nil { return nil, fmt.Errorf("couldn't persist collection metadata: %w", err) } @@ -233,8 +234,9 @@ func (c *Collection) AddDocument(ctx context.Context, doc Document) error { // Persist the document if c.persistDirectory != "" { safeID := hash2hex(doc.ID) - filePath := filepath.Join(c.persistDirectory, safeID) - err := persist(filePath, doc) + docPath := filepath.Join(c.persistDirectory, safeID) + docPath += ".gob" + err := persist(docPath, doc, false, "") if err != nil { return fmt.Errorf("couldn't persist document: %w", err) } diff --git a/db.go b/db.go index 7fc4383..3afc8f8 100644 --- a/db.go +++ b/db.go @@ -1,16 +1,9 @@ package chromem import ( - "bytes" - "compress/gzip" "context" - "crypto/aes" - "crypto/cipher" - "crypto/rand" - "encoding/gob" "errors" "fmt" - "io" "io/fs" "os" "path/filepath" @@ -179,87 +172,7 @@ func (db *DB) Export(filePath string, compress bool, encryptionKey string) error } } - // AES 256 requires a 32 byte key - if encryptionKey != "" { - if len(encryptionKey) != 32 { - return errors.New("encryption key must be 32 bytes long") - } - } - - // If path doesn't exist, create the parent path. - // If path exists and it's a directory, return an error. - fi, err := os.Stat(filePath) - if err != nil { - if !errors.Is(err, fs.ErrNotExist) { - return fmt.Errorf("couldn't get info about export path: %w", err) - } else { - // If the file doesn't exist, create the parent path - err := os.MkdirAll(filepath.Dir(filePath), 0o700) - if err != nil { - return fmt.Errorf("couldn't create export directory: %w", err) - } - } - } else if fi.IsDir() { - return fmt.Errorf("path is a directory: %s", filePath) - } - - // Open file for writing - f, err := os.Create(filePath) - if err != nil { - return fmt.Errorf("couldn't create file: %w", err) - } - defer f.Close() - - // We want to: - // Encode as gob -> compress with flate -> encrypt with AES-GCM -> write file. - // To reduce memory usage we chain the writers instead of buffering, so we start - // from the end. For AES GCM sealing the stdlib doesn't provide a writer though. - - var w io.Writer - if encryptionKey == "" { - w = f - } else { - w = &bytes.Buffer{} - } - if compress { - gzw := gzip.NewWriter(w) - defer gzw.Close() - w = gzw - } - enc := gob.NewEncoder(w) - - // Start encoding, it will write to the chain of writers. - if err := enc.Encode(db); err != nil { - return fmt.Errorf("couldn't encode DB as gob: %w", err) - } - - // Without encyrption, the chain is done and the file is written. - if encryptionKey == "" { - return nil - } - - // Otherwise, encrypt and then write to the file - block, err := aes.NewCipher([]byte(encryptionKey)) - if err != nil { - return fmt.Errorf("couldn't create new AES cipher: %w", err) - } - gcm, err := cipher.NewGCM(block) - if err != nil { - return fmt.Errorf("couldn't create GCM wrapper: %w", err) - } - nonce := make([]byte, gcm.NonceSize()) - if _, err := io.ReadFull(rand.Reader, nonce); err != nil { - return fmt.Errorf("couldn't read random bytes for nonce: %w", err) - } - // w is a *bytes.Buffer - buf := w.(*bytes.Buffer) - encrypted := gcm.Seal(nonce, nonce, buf.Bytes(), nil) - _, err = f.Write(encrypted) - if err != nil { - return fmt.Errorf("couldn't write encrypted data: %w", err) - } - - return nil + return persist(filePath, db, compress, encryptionKey) } // CreateCollection creates a new collection with the given name and metadata. diff --git a/persistence.go b/persistence.go index 5462b10..7cfa852 100644 --- a/persistence.go +++ b/persistence.go @@ -1,11 +1,20 @@ package chromem import ( + "bytes" + "compress/gzip" + "crypto/aes" + "crypto/cipher" + "crypto/rand" "crypto/sha256" "encoding/gob" "encoding/hex" + "errors" "fmt" + "io" + "io/fs" "os" + "path/filepath" ) const metadataFileName = "00000000" @@ -19,22 +28,94 @@ func hash2hex(name string) string { } // persist persists an object to a file at the given path. The object is serialized -// as gob. -func persist(filePath string, obj any) error { - filePath += ".gob" +// as gob, optionally compressed with flate (as gzip) and optionally encrypted with +// AES-GCM. The encryption key must be 32 bytes long. If the file exists, it's +// overwritten, otherwise created. +func persist(filePath string, obj any, compress bool, encryptionKey string) error { + if filePath == "" { + return fmt.Errorf("file path is empty") + } + + // AES 256 requires a 32 byte key + if encryptionKey != "" { + if len(encryptionKey) != 32 { + return errors.New("encryption key must be 32 bytes long") + } + } + + // If path doesn't exist, create the parent path. + // If path exists and it's a directory, return an error. + fi, err := os.Stat(filePath) + if err != nil { + if !errors.Is(err, fs.ErrNotExist) { + return fmt.Errorf("couldn't get info about the path: %w", err) + } else { + // If the file doesn't exist, create the parent path + err := os.MkdirAll(filepath.Dir(filePath), 0o700) + if err != nil { + return fmt.Errorf("couldn't create parent directories to path: %w", err) + } + } + } else if fi.IsDir() { + return fmt.Errorf("path is a directory: %s", filePath) + } + // Open file for writing f, err := os.Create(filePath) if err != nil { - return fmt.Errorf("couldn't create file '%s': %w", filePath, err) + return fmt.Errorf("couldn't create file: %w", err) } defer f.Close() - enc := gob.NewEncoder(f) - err = enc.Encode(obj) - if err != nil { + // We want to: + // Encode as gob -> compress with flate -> encrypt with AES-GCM -> write file. + // To reduce memory usage we chain the writers instead of buffering, so we start + // from the end. For AES GCM sealing the stdlib doesn't provide a writer though. + + var w io.Writer + if encryptionKey == "" { + w = f + } else { + w = &bytes.Buffer{} + } + if compress { + gzw := gzip.NewWriter(w) + defer gzw.Close() + w = gzw + } + enc := gob.NewEncoder(w) + + // Start encoding, it will write to the chain of writers. + if err := enc.Encode(obj); err != nil { return fmt.Errorf("couldn't encode or write object: %w", err) } + // Without encyrption, the chain is done and the file is written. + if encryptionKey == "" { + return nil + } + + // Otherwise, encrypt and then write to the file + block, err := aes.NewCipher([]byte(encryptionKey)) + if err != nil { + return fmt.Errorf("couldn't create new AES cipher: %w", err) + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return fmt.Errorf("couldn't create GCM wrapper: %w", err) + } + nonce := make([]byte, gcm.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return fmt.Errorf("couldn't read random bytes for nonce: %w", err) + } + // w is a *bytes.Buffer + buf := w.(*bytes.Buffer) + encrypted := gcm.Seal(nonce, nonce, buf.Bytes(), nil) + _, err = f.Write(encrypted) + if err != nil { + return fmt.Errorf("couldn't write encrypted data: %w", err) + } + return nil } diff --git a/persistence_test.go b/persistence_test.go index 06d0cdd..e216e01 100644 --- a/persistence_test.go +++ b/persistence_test.go @@ -26,7 +26,8 @@ func TestPersistence(t *testing.T) { Bar: []float32{-0.40824828, 0.40824828, 0.81649655}, // normalized version of `{-0.1, 0.1, 0.2}` } - persist(tempDir, obj) + tempFilePath := tempDir + ".gob" + persist(tempFilePath, obj, false, "") // Check if the file exists. _, err = os.Stat(tempDir + ".gob") From 5d7cb4a8ea9b11bdc8527cec1e24d0cda7721648 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Mon, 18 Mar 2024 22:54:36 +0100 Subject: [PATCH 03/20] Fix export only contains exported fields --- collection.go | 3 +++ db.go | 29 ++++++++++++++++++++++++++++- document.go | 3 +++ 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/collection.go b/collection.go index 7d53625..fde2740 100644 --- a/collection.go +++ b/collection.go @@ -21,6 +21,9 @@ type Collection struct { documents map[string]*Document documentsLock sync.RWMutex embed EmbeddingFunc + + // ⚠️ When adding fields here, consider adding them to the persistence struct + // version in [DB.Export] as well! } // We don't export this yet to keep the API surface to the bare minimum. diff --git a/db.go b/db.go index 3afc8f8..bbd296a 100644 --- a/db.go +++ b/db.go @@ -27,6 +27,9 @@ type DB struct { collections map[string]*Collection collectionsLock sync.RWMutex persistDirectory string + + // ⚠️ When adding fields here, consider adding them to the persistence struct + // version in [DB.Export] as well! } // NewDB creates a new in-memory chromem-go DB. @@ -172,7 +175,31 @@ func (db *DB) Export(filePath string, compress bool, encryptionKey string) error } } - return persist(filePath, db, compress, encryptionKey) + // Create persistence structs with exported fields so that they can be encoded + // as gob. + type persistenceCollection struct { + Name string + Metadata map[string]string + Documents map[string]*Document + } + persistenceDB := struct { + Collections map[string]*persistenceCollection + }{ + Collections: make(map[string]*persistenceCollection, len(db.collections)), + } + + db.collectionsLock.RLock() + defer db.collectionsLock.RUnlock() + + for k, v := range db.collections { + persistenceDB.Collections[k] = &persistenceCollection{ + Name: v.Name, + Metadata: v.metadata, + Documents: v.documents, + } + } + + return persist(filePath, persistenceDB, compress, encryptionKey) } // CreateCollection creates a new collection with the given name and metadata. diff --git a/document.go b/document.go index 4fe1f54..120dada 100644 --- a/document.go +++ b/document.go @@ -11,6 +11,9 @@ type Document struct { Metadata map[string]string Embedding []float32 Content string + + // ⚠️ When adding unexported fields here, consider adding a persistence struct + // version of this in [DB.Export]. } // NewDocument creates a new document, including its embeddings. From cb22a70fd2f0ad6656eb4b7d227b35de68873d7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Mon, 18 Mar 2024 23:02:19 +0100 Subject: [PATCH 04/20] Fix order of filename suffixes --- db.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/db.go b/db.go index bbd296a..8e09c25 100644 --- a/db.go +++ b/db.go @@ -167,12 +167,12 @@ func (db *DB) Import(filePath string, decryptionKey string) error { func (db *DB) Export(filePath string, compress bool, encryptionKey string) error { if filePath == "" { filePath = "./chromem-go.gob" - if encryptionKey != "" { - filePath += ".enc" - } if compress { filePath += ".gz" } + if encryptionKey != "" { + filePath += ".enc" + } } // Create persistence structs with exported fields so that they can be encoded From 7cfbe947a2fec8ef3305800c9ed54c757ddc9b74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Mon, 18 Mar 2024 23:02:40 +0100 Subject: [PATCH 05/20] Add Godoc for DB.Export() --- db.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/db.go b/db.go index 8e09c25..144f8f1 100644 --- a/db.go +++ b/db.go @@ -163,7 +163,13 @@ func (db *DB) Import(filePath string, decryptionKey string) error { return errors.New("not implemented") // TODO: implement } -// TODO: Godoc +// Export exports the DB to a file at the given path. The file is encoded as gob, +// optionally compressed with flate (as gzip) and optionally encrypted with AES-GCM. +// This works for both the in-memory and persistent DBs. +// +// If filePath is empty, it defaults to "./chromem-go.gob" (+ ".gz" + ".enc"). +// If the file exists, it's overwritten, otherwise created. +// For encryption you must provide a 32 bytes long key. func (db *DB) Export(filePath string, compress bool, encryptionKey string) error { if filePath == "" { filePath = "./chromem-go.gob" From 59bd365c001675f60b0edd38bd5f7623e259feb3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Tue, 19 Mar 2024 00:09:50 +0100 Subject: [PATCH 06/20] Implement basic DB import handling But no decryption or decompression yet --- collection.go | 2 +- db.go | 52 +++++++++++++++++++++++++++++++++++++++++++++++++-- document.go | 2 +- 3 files changed, 52 insertions(+), 4 deletions(-) diff --git a/collection.go b/collection.go index fde2740..11bdc73 100644 --- a/collection.go +++ b/collection.go @@ -23,7 +23,7 @@ type Collection struct { embed EmbeddingFunc // ⚠️ When adding fields here, consider adding them to the persistence struct - // version in [DB.Export] as well! + // versions in [DB.Export] and [DB.Import] as well! } // We don't export this yet to keep the API surface to the bare minimum. diff --git a/db.go b/db.go index 144f8f1..9ec9795 100644 --- a/db.go +++ b/db.go @@ -29,7 +29,7 @@ type DB struct { persistDirectory string // ⚠️ When adding fields here, consider adding them to the persistence struct - // version in [DB.Export] as well! + // versions in [DB.Export] and [DB.Import] as well! } // NewDB creates a new in-memory chromem-go DB. @@ -160,7 +160,55 @@ func NewPersistentDB(path string) (*DB, error) { // TODO: Godoc func (db *DB) Import(filePath string, decryptionKey string) error { - return errors.New("not implemented") // TODO: implement + if filePath == "" { + return fmt.Errorf("file path is empty") + } + + // If the file doesn't exist or is a directory, return an error. + fi, err := os.Stat(filePath) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + return fmt.Errorf("file doesn't exist: %s", filePath) + } + return fmt.Errorf("couldn't get info about the file: %w", err) + } else if fi.IsDir() { + return fmt.Errorf("path is a directory: %s", filePath) + } + + // Create persistence structs with exported fields so that they can be decoded + // from gob. + type persistenceCollection struct { + Name string + Metadata map[string]string + Documents map[string]*Document + } + persistenceDB := struct { + Collections map[string]*persistenceCollection + }{ + Collections: make(map[string]*persistenceCollection, len(db.collections)), + } + + db.collectionsLock.Lock() + defer db.collectionsLock.Unlock() + + // TODO: Implement decryption and decompression + err = read(filePath, &persistenceDB) + if err != nil { + return fmt.Errorf("couldn't read file: %w", err) + } + + for _, pc := range persistenceDB.Collections { + c := &Collection{ + Name: pc.Name, + + persistDirectory: filepath.Join(db.persistDirectory, hash2hex(pc.Name)), + metadata: pc.Metadata, + documents: pc.Documents, + } + db.collections[c.Name] = c + } + + return nil } // Export exports the DB to a file at the given path. The file is encoded as gob, diff --git a/document.go b/document.go index 120dada..5c42ca8 100644 --- a/document.go +++ b/document.go @@ -13,7 +13,7 @@ type Document struct { Content string // ⚠️ When adding unexported fields here, consider adding a persistence struct - // version of this in [DB.Export]. + // version of this in [DB.Export] and [DB.Import]. } // NewDocument creates a new document, including its embeddings. From 3632003c9965615b1264861fc7fab2118ff907ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Tue, 19 Mar 2024 00:12:37 +0100 Subject: [PATCH 07/20] Add Godoc for DB.Import() --- db.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/db.go b/db.go index 9ec9795..8bedbd2 100644 --- a/db.go +++ b/db.go @@ -158,7 +158,14 @@ func NewPersistentDB(path string) (*DB, error) { return db, nil } -// TODO: Godoc +// Import imports the DB from a file at the given path. The file must be encoded +// as gob and can optionally be compressed with flate (as gzip) and encrypted +// with AES-GCM. +// This works for both the in-memory and persistent DBs. +// Existing collections are overwritten. +// +// - filePath: Mandatory, must not be empty +// - decryptionKey: Optional, must be 32 bytes long if provided func (db *DB) Import(filePath string, decryptionKey string) error { if filePath == "" { return fmt.Errorf("file path is empty") From e00d0479cb825d74a1485043f745fde4b4f96f73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Tue, 19 Mar 2024 00:51:36 +0100 Subject: [PATCH 08/20] Implement decryption and decompression for file reading --- db.go | 7 ++-- persistence.go | 91 +++++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 86 insertions(+), 12 deletions(-) diff --git a/db.go b/db.go index 8bedbd2..f438178 100644 --- a/db.go +++ b/db.go @@ -123,7 +123,7 @@ func NewPersistentDB(path string) (*DB, error) { Name string Metadata map[string]string }{} - err := read(fPath, &pc) + err := read(fPath, &pc, "") if err != nil { return nil, fmt.Errorf("couldn't read collection metadata: %w", err) } @@ -132,7 +132,7 @@ func NewPersistentDB(path string) (*DB, error) { } else if filepath.Ext(collectionDirEntry.Name()) == ".gob" { // Read document d := &Document{} - err := read(fPath, d) + err := read(fPath, d, "") if err != nil { return nil, fmt.Errorf("couldn't read document: %w", err) } @@ -198,8 +198,7 @@ func (db *DB) Import(filePath string, decryptionKey string) error { db.collectionsLock.Lock() defer db.collectionsLock.Unlock() - // TODO: Implement decryption and decompression - err = read(filePath, &persistenceDB) + err = read(filePath, &persistenceDB, decryptionKey) if err != nil { return fmt.Errorf("couldn't read file: %w", err) } diff --git a/persistence.go b/persistence.go index 7cfa852..de2a10a 100644 --- a/persistence.go +++ b/persistence.go @@ -35,7 +35,6 @@ func persist(filePath string, obj any, compress bool, encryptionKey string) erro if filePath == "" { return fmt.Errorf("file path is empty") } - // AES 256 requires a 32 byte key if encryptionKey != "" { if len(encryptionKey) != 32 { @@ -120,18 +119,94 @@ func persist(filePath string, obj any, compress bool, encryptionKey string) erro } // read reads an object from a file at the given path. The object is deserialized -// from gob. `obj` must be a pointer to an instantiated object. -func read(filePath string, obj any) error { - f, err := os.Open(filePath) +// from gob. `obj` must be a pointer to an instantiated object. The file may +// optionally be compressed as gzip and/or encrypted with AES-GCM. The decryption +// key must be 32 bytes long. +func read(filePath string, obj any, decryptionKey string) error { + if filePath == "" { + return fmt.Errorf("file path is empty") + } + // AES 256 requires a 32 byte key + if decryptionKey != "" { + if len(decryptionKey) != 32 { + return errors.New("decryption key must be 32 bytes long") + } + } + + // We want to: + // Read file -> decrypt with AES-GCM -> decompress with flate -> decode as gob + // To reduce memory usage we chain the readers instead of buffering, so we start + // from the end. For the decryption there's no reader though. + + var r io.Reader + + // Decrypt if an encryption key is provided + if decryptionKey != "" { + encrypted, err := os.ReadFile(filePath) + if err != nil { + return fmt.Errorf("couldn't read file: %w", err) + } + block, err := aes.NewCipher([]byte(decryptionKey)) + if err != nil { + return fmt.Errorf("couldn't create AES cipher: %w", err) + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return fmt.Errorf("couldn't create GCM wrapper: %w", err) + } + nonceSize := gcm.NonceSize() + if len(encrypted) < nonceSize { + return fmt.Errorf("encrypted data too short") + } + nonce, ciphertext := encrypted[:nonceSize], encrypted[nonceSize:] + data, err := gcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + return fmt.Errorf("couldn't decrypt data: %w", err) + } + + r = bytes.NewReader(data) + } else { + var err error + r, err = os.Open(filePath) + if err != nil { + return fmt.Errorf("couldn't open file: %w", err) + } + } + + // Determine if the file is compressed + magicNumber := make([]byte, 2) + _, err := r.Read(magicNumber) if err != nil { - return fmt.Errorf("couldn't open file '%s': %w", filePath, err) + return fmt.Errorf("couldn't read magic number to determine whether the file is compressed: %w", err) + } + var compressed bool + if magicNumber[0] == 0x1f && magicNumber[1] == 0x8b { + compressed = true + } + + // Reset reader. Both file and bytes.Reader support seeking. + if s, ok := r.(io.Seeker); !ok { + return fmt.Errorf("reader doesn't support seeking") + } else { + _, err := s.Seek(0, 0) + if err != nil { + return fmt.Errorf("couldn't reset reader: %w", err) + } + } + + if compressed { + gzr, err := gzip.NewReader(r) + if err != nil { + return fmt.Errorf("couldn't create gzip reader: %w", err) + } + defer gzr.Close() + r = gzr } - defer f.Close() - dec := gob.NewDecoder(f) + dec := gob.NewDecoder(r) err = dec.Decode(obj) if err != nil { - return fmt.Errorf("couldn't decode or read object: %w", err) + return fmt.Errorf("couldn't decode object: %w", err) } return nil From 778e21099bd9f4e6048a67e1fbf4a1503121283d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Tue, 19 Mar 2024 21:28:44 +0100 Subject: [PATCH 09/20] Add unit tests for persistence --- persistence_test.go | 185 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 172 insertions(+), 13 deletions(-) diff --git a/persistence_test.go b/persistence_test.go index e216e01..041540b 100644 --- a/persistence_test.go +++ b/persistence_test.go @@ -1,22 +1,101 @@ package chromem import ( - "bytes" + "compress/gzip" "encoding/gob" + "math/rand" "os" "reflect" "testing" + "time" ) -func TestPersistence(t *testing.T) { +func TestPersistenceWrite(t *testing.T) { tempDir, err := os.MkdirTemp("", "chromem-go") if err != nil { t.Fatal("expected nil, got", err) } - t.Cleanup(func() { - _ = os.RemoveAll(tempDir) + defer os.RemoveAll(tempDir) + + type s struct { + Foo string + Bar []float32 + } + obj := s{ + Foo: "test", + Bar: []float32{-0.40824828, 0.40824828, 0.81649655}, // normalized version of `{-0.1, 0.1, 0.2}` + } + + t.Run("gob", func(t *testing.T) { + tempFilePath := tempDir + ".gob" + persist(tempFilePath, obj, false, "") + + // Check if the file exists. + _, err = os.Stat(tempFilePath) + if err != nil { + t.Fatal("expected nil, got", err) + } + + // Read file and decode + f, err := os.Open(tempFilePath) + if err != nil { + t.Fatal("expected nil, got", err) + } + defer f.Close() + d := gob.NewDecoder(f) + res := s{} + err = d.Decode(&res) + if err != nil { + t.Fatal("expected nil, got", err) + } + + // Compare + if !reflect.DeepEqual(obj, res) { + t.Fatalf("expected %+v, got %+v", obj, res) + } }) + t.Run("gob gzipped", func(t *testing.T) { + tempFilePath := tempDir + ".gob.gz" + persist(tempFilePath, obj, true, "") + + // Check if the file exists. + _, err = os.Stat(tempFilePath) + if err != nil { + t.Fatal("expected nil, got", err) + } + + // Read file, decompress and decode + f, err := os.Open(tempFilePath) + if err != nil { + t.Fatal("expected nil, got", err) + } + defer f.Close() + gzr, err := gzip.NewReader(f) + if err != nil { + t.Fatal("expected nil, got", err) + } + d := gob.NewDecoder(gzr) + res := s{} + err = d.Decode(&res) + if err != nil { + t.Fatal("expected nil, got", err) + } + + // Compare + if !reflect.DeepEqual(obj, res) { + t.Fatalf("expected %+v, got %+v", obj, res) + } + }) +} + +func TestPersistenceRead(t *testing.T) { + tempDir, err := os.MkdirTemp("", "chromem-go") + if err != nil { + t.Fatal("expected nil, got", err) + } + defer os.RemoveAll(tempDir) + type s struct { Foo string Bar []float32 @@ -26,25 +105,105 @@ func TestPersistence(t *testing.T) { Bar: []float32{-0.40824828, 0.40824828, 0.81649655}, // normalized version of `{-0.1, 0.1, 0.2}` } - tempFilePath := tempDir + ".gob" - persist(tempFilePath, obj, false, "") + t.Run("gob", func(t *testing.T) { + tempFilePath := tempDir + ".gob" + f, err := os.Create(tempFilePath) + if err != nil { + t.Fatal("expected nil, got", err) + } + defer f.Close() + enc := gob.NewEncoder(f) + err = enc.Encode(obj) + if err != nil { + t.Fatal("expected nil, got", err) + } - // Check if the file exists. - _, err = os.Stat(tempDir + ".gob") + // Read the file. + var res s + err = read(tempFilePath, &res, "") + if err != nil { + t.Fatal("expected nil, got", err) + } + + // Compare + if !reflect.DeepEqual(obj, res) { + t.Fatalf("expected %+v, got %+v", obj, res) + } + }) + + t.Run("gob gzipped", func(t *testing.T) { + tempFilePath := tempDir + ".gob.gz" + f, err := os.Create(tempFilePath) + if err != nil { + t.Fatal("expected nil, got", err) + } + defer f.Close() + gzw := gzip.NewWriter(f) + enc := gob.NewEncoder(gzw) + err = enc.Encode(obj) + if err != nil { + t.Fatal("expected nil, got", err) + } + err = gzw.Close() + if err != nil { + t.Fatal("expected nil, got", err) + } + + // Read the file. + var res s + err = read(tempFilePath, &res, "") + if err != nil { + t.Fatal("expected nil, got", err) + } + + // Compare + if !reflect.DeepEqual(obj, res) { + t.Fatalf("expected %+v, got %+v", obj, res) + } + }) +} + +func TestPersistenceEncryption(t *testing.T) { + // Instead of copy pasting encryption/decryption code, we resort to using both + // functions under test, instead of one combined with an independent implementation. + + tempDir, err := os.MkdirTemp("", "chromem-go") + if err != nil { + t.Fatal("expected nil, got", err) + } + defer os.RemoveAll(tempDir) + + type s struct { + Foo string + Bar []float32 + } + obj := s{ + Foo: "test", + Bar: []float32{-0.40824828, 0.40824828, 0.81649655}, // normalized version of `{-0.1, 0.1, 0.2}` + } + + tempFilePath := tempDir + ".gob.enc" + r := rand.New(rand.NewSource(time.Now().Unix())) + encryptionKey := randomString(r, 32) + err = persist(tempFilePath, obj, false, encryptionKey) if err != nil { t.Fatal("expected nil, got", err) } - // Check if the file contains the expected data. - b, err := os.ReadFile(tempDir + ".gob") + + // Check if the file exists. + _, err = os.Stat(tempFilePath) if err != nil { t.Fatal("expected nil, got", err) } - d := gob.NewDecoder(bytes.NewReader(b)) - res := s{} - err = d.Decode(&res) + + // Read the file. + var res s + err = read(tempFilePath, &res, encryptionKey) if err != nil { t.Fatal("expected nil, got", err) } + + // Compare if !reflect.DeepEqual(obj, res) { t.Fatalf("expected %+v, got %+v", obj, res) } From 51335fb6cdc2a87c35976fab805722e0d08a4f6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Tue, 19 Mar 2024 21:39:04 +0100 Subject: [PATCH 10/20] Check encryption key early So far only in persist() and read(), now already in Export() and Import() --- db.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/db.go b/db.go index f438178..81bf8ee 100644 --- a/db.go +++ b/db.go @@ -170,6 +170,12 @@ func (db *DB) Import(filePath string, decryptionKey string) error { if filePath == "" { return fmt.Errorf("file path is empty") } + if decryptionKey != "" { + // AES 256 requires a 32 byte key + if len(decryptionKey) != 32 { + return errors.New("decryption key must be 32 bytes long") + } + } // If the file doesn't exist or is a directory, return an error. fi, err := os.Stat(filePath) @@ -234,6 +240,12 @@ func (db *DB) Export(filePath string, compress bool, encryptionKey string) error filePath += ".enc" } } + if encryptionKey != "" { + // AES 256 requires a 32 byte key + if len(encryptionKey) != 32 { + return errors.New("encryption key must be 32 bytes long") + } + } // Create persistence structs with exported fields so that they can be encoded // as gob. From 6ba14cbdcae876e1c823f3d11d81ccaa95becc87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Tue, 19 Mar 2024 21:39:32 +0100 Subject: [PATCH 11/20] Improve returned error Now wrapped like all others --- db.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/db.go b/db.go index 81bf8ee..e9aa231 100644 --- a/db.go +++ b/db.go @@ -271,7 +271,12 @@ func (db *DB) Export(filePath string, compress bool, encryptionKey string) error } } - return persist(filePath, persistenceDB, compress, encryptionKey) + err = persist(filePath, persistenceDB, compress, encryptionKey) + if err != nil { + return fmt.Errorf("couldn't export DB: %w", err) + } + + return nil } // CreateCollection creates a new collection with the given name and metadata. From bc68cf44063fac71175a6a476ff6d6135243d131 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Tue, 19 Mar 2024 21:39:55 +0100 Subject: [PATCH 12/20] Create parent directory of target file if it doesn't exist --- db.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/db.go b/db.go index e9aa231..a1e1797 100644 --- a/db.go +++ b/db.go @@ -247,6 +247,13 @@ func (db *DB) Export(filePath string, compress bool, encryptionKey string) error } } + // Create parent dir if it doesn't exist + parentDir := filepath.Dir(filePath) + err := os.MkdirAll(parentDir, 0o700) + if err != nil { + return fmt.Errorf("couldn't create parent directory: %w", err) + } + // Create persistence structs with exported fields so that they can be encoded // as gob. type persistenceCollection struct { From 0bdcc5f25a9a207b663281f95779e01769853377 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Tue, 19 Mar 2024 21:40:01 +0100 Subject: [PATCH 13/20] Improve Godoc --- db.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/db.go b/db.go index a1e1797..7933b05 100644 --- a/db.go +++ b/db.go @@ -226,10 +226,12 @@ func (db *DB) Import(filePath string, decryptionKey string) error { // Export exports the DB to a file at the given path. The file is encoded as gob, // optionally compressed with flate (as gzip) and optionally encrypted with AES-GCM. // This works for both the in-memory and persistent DBs. -// -// If filePath is empty, it defaults to "./chromem-go.gob" (+ ".gz" + ".enc"). // If the file exists, it's overwritten, otherwise created. -// For encryption you must provide a 32 bytes long key. +// +// - filePath: If empty, it defaults to "./chromem-go.gob" (+ ".gz" + ".enc") +// - compress: Optional. Compresses as gzip if true. +// - encryptionKey: Optional. Encrypts with AES-GCM if provided. Must be 32 bytes +// long if provided. func (db *DB) Export(filePath string, compress bool, encryptionKey string) error { if filePath == "" { filePath = "./chromem-go.gob" From ad1e14345e9d584213e7ded5f85a1653a2d27759 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Thu, 21 Mar 2024 22:54:11 +0100 Subject: [PATCH 14/20] Fix persistence issue (compress + encrypt) When compression and encryption were both used. Add unit test for it. --- persistence.go | 22 +++++++++--- persistence_test.go | 82 +++++++++++++++++++++++++++++---------------- 2 files changed, 71 insertions(+), 33 deletions(-) diff --git a/persistence.go b/persistence.go index de2a10a..75dd1e9 100644 --- a/persistence.go +++ b/persistence.go @@ -77,18 +77,32 @@ func persist(filePath string, obj any, compress bool, encryptionKey string) erro } else { w = &bytes.Buffer{} } + + var gzw *gzip.Writer + var enc *gob.Encoder if compress { - gzw := gzip.NewWriter(w) - defer gzw.Close() - w = gzw + gzw = gzip.NewWriter(w) + enc = gob.NewEncoder(gzw) + } else { + enc = gob.NewEncoder(w) } - enc := gob.NewEncoder(w) // Start encoding, it will write to the chain of writers. if err := enc.Encode(obj); err != nil { return fmt.Errorf("couldn't encode or write object: %w", err) } + // If compressing, close the gzip writer. Otherwise the gzip footer won't be + // written yet. When using encryption (and w is a buffer) then we'll encrypt + // an incomplete file. Without encryption when we return here and having + // a deferred Close(), there might be a silenced error. + if compress { + err = gzw.Close() + if err != nil { + return fmt.Errorf("couldn't close gzip writer: %w", err) + } + } + // Without encyrption, the chain is done and the file is written. if encryptionKey == "" { return nil diff --git a/persistence_test.go b/persistence_test.go index 041540b..09676f9 100644 --- a/persistence_test.go +++ b/persistence_test.go @@ -5,9 +5,9 @@ import ( "encoding/gob" "math/rand" "os" + "path/filepath" "reflect" "testing" - "time" ) func TestPersistenceWrite(t *testing.T) { @@ -111,12 +111,15 @@ func TestPersistenceRead(t *testing.T) { if err != nil { t.Fatal("expected nil, got", err) } - defer f.Close() enc := gob.NewEncoder(f) err = enc.Encode(obj) if err != nil { t.Fatal("expected nil, got", err) } + err = f.Close() + if err != nil { + t.Fatal("expected nil, got", err) + } // Read the file. var res s @@ -137,7 +140,6 @@ func TestPersistenceRead(t *testing.T) { if err != nil { t.Fatal("expected nil, got", err) } - defer f.Close() gzw := gzip.NewWriter(f) enc := gob.NewEncoder(gzw) err = enc.Encode(obj) @@ -148,6 +150,10 @@ func TestPersistenceRead(t *testing.T) { if err != nil { t.Fatal("expected nil, got", err) } + err = f.Close() + if err != nil { + t.Fatal("expected nil, got", err) + } // Read the file. var res s @@ -167,11 +173,10 @@ func TestPersistenceEncryption(t *testing.T) { // Instead of copy pasting encryption/decryption code, we resort to using both // functions under test, instead of one combined with an independent implementation. - tempDir, err := os.MkdirTemp("", "chromem-go") - if err != nil { - t.Fatal("expected nil, got", err) - } - defer os.RemoveAll(tempDir) + r := rand.New(rand.NewSource(rand.Int63())) + // randString := randomString(r, 10) + path := filepath.Join(os.TempDir(), "a", "chromem-go") + // defer os.RemoveAll(path) type s struct { Foo string @@ -181,30 +186,49 @@ func TestPersistenceEncryption(t *testing.T) { Foo: "test", Bar: []float32{-0.40824828, 0.40824828, 0.81649655}, // normalized version of `{-0.1, 0.1, 0.2}` } - - tempFilePath := tempDir + ".gob.enc" - r := rand.New(rand.NewSource(time.Now().Unix())) encryptionKey := randomString(r, 32) - err = persist(tempFilePath, obj, false, encryptionKey) - if err != nil { - t.Fatal("expected nil, got", err) - } - // Check if the file exists. - _, err = os.Stat(tempFilePath) - if err != nil { - t.Fatal("expected nil, got", err) - } - - // Read the file. - var res s - err = read(tempFilePath, &res, encryptionKey) - if err != nil { - t.Fatal("expected nil, got", err) + tt := []struct { + name string + filePath string + compress bool + }{ + { + name: "compress false", + filePath: path + ".gob.enc", + compress: false, + }, + { + name: "compress true", + filePath: path + ".gob.gz.enc", + compress: true, + }, } - // Compare - if !reflect.DeepEqual(obj, res) { - t.Fatalf("expected %+v, got %+v", obj, res) + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + err := persist(tc.filePath, obj, tc.compress, encryptionKey) + if err != nil { + t.Fatal("expected nil, got", err) + } + + // Check if the file exists. + _, err = os.Stat(tc.filePath) + if err != nil { + t.Fatal("expected nil, got", err) + } + + // Read the file. + var res s + err = read(tc.filePath, &res, encryptionKey) + if err != nil { + t.Fatal("expected nil, got", err) + } + + // Compare + if !reflect.DeepEqual(obj, res) { + t.Fatalf("expected %+v, got %+v", obj, res) + } + }) } } From 6af84036bfcf5934ecce007d8faed93090e8ebd2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Thu, 21 Mar 2024 23:01:26 +0100 Subject: [PATCH 15/20] Fix DB is wrongfully set to persistence after import --- db.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/db.go b/db.go index 7933b05..46679ce 100644 --- a/db.go +++ b/db.go @@ -213,9 +213,11 @@ func (db *DB) Import(filePath string, decryptionKey string) error { c := &Collection{ Name: pc.Name, - persistDirectory: filepath.Join(db.persistDirectory, hash2hex(pc.Name)), - metadata: pc.Metadata, - documents: pc.Documents, + metadata: pc.Metadata, + documents: pc.Documents, + } + if db.persistDirectory != "" { + c.persistDirectory = filepath.Join(db.persistDirectory, hash2hex(pc.Name)) } db.collections[c.Name] = c } From e1d641bec8480284a6ae6f7869ad949190a03bbd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Thu, 21 Mar 2024 23:01:42 +0100 Subject: [PATCH 16/20] Add unit tests for DB import/export --- db_test.go | 99 +++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 98 insertions(+), 1 deletion(-) diff --git a/db_test.go b/db_test.go index a503fbc..3c7ba80 100644 --- a/db_test.go +++ b/db_test.go @@ -12,7 +12,8 @@ import ( func TestNewPersistentDB(t *testing.T) { t.Run("Create directory", func(t *testing.T) { - randString := randomString(rand.New(rand.NewSource(rand.Int63())), 10) + r := rand.New(rand.NewSource(rand.Int63())) + randString := randomString(r, 10) path := filepath.Join(os.TempDir(), randString) defer os.RemoveAll(path) @@ -66,6 +67,102 @@ func TestNewPersistentDB_Errors(t *testing.T) { }) } +func TestDB_ImportExport(t *testing.T) { + r := rand.New(rand.NewSource(rand.Int63())) + randString := randomString(r, 10) + path := filepath.Join(os.TempDir(), randString) + defer os.RemoveAll(path) + + // Values in the collection + name := "test" + metadata := map[string]string{"foo": "bar"} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` + embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { + return vectors, nil + } + + tt := []struct { + name string + filePath string + compress bool + encryptionKey string + }{ + { + name: "gob", + filePath: path + ".gob", + compress: false, + encryptionKey: "", + }, + { + name: "gob compressed", + filePath: path + ".gob.gz", + compress: true, + encryptionKey: "", + }, + { + name: "gob compressed encrypted", + filePath: path + ".gob.gz.enc", + compress: true, + encryptionKey: randomString(r, 32), + }, + { + name: "gob encrypted", + filePath: path + ".gob.enc", + compress: false, + encryptionKey: randomString(r, 32), + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + // Create DB, can just be in-memory + orig := NewDB() + + // Create collection + c, err := orig.CreateCollection(name, metadata, embeddingFunc) + if err != nil { + t.Fatal("expected no error, got", err) + } + if c == nil { + t.Fatal("expected collection, got nil") + } + // Add document + doc := Document{ + ID: name, + Metadata: metadata, + Embedding: vectors, + Content: "test", + } + err = c.AddDocument(context.Background(), doc) + if err != nil { + t.Fatal("expected no error, got", err) + } + + // Export + err = orig.Export(tc.filePath, tc.compress, tc.encryptionKey) + if err != nil { + t.Fatal("expected no error, got", err) + } + + new := NewDB() + + // Import + err = new.Import(tc.filePath, tc.encryptionKey) + if err != nil { + t.Fatal("expected no error, got", err) + } + + // Check expectations + // We have to reset the embed function, but otherwise the DB objects + // should be deep equal. + c.embed = nil + if !reflect.DeepEqual(orig, new) { + t.Fatalf("expected DB %+v, got %+v", orig, new) + } + }) + } +} + func TestDB_CreateCollection(t *testing.T) { // Values in the collection name := "test" From 9c180220f3c3c0352acb61ed779df015d9c827d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Thu, 21 Mar 2024 23:48:59 +0100 Subject: [PATCH 17/20] Stop creating parent dirs ahead of time Now that persist() includes it --- collection.go | 8 +------- db.go | 9 +-------- 2 files changed, 2 insertions(+), 15 deletions(-) diff --git a/collection.go b/collection.go index 11bdc73..5c9c893 100644 --- a/collection.go +++ b/collection.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "os" "path/filepath" "slices" "sync" @@ -48,11 +47,6 @@ func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, if dbDir != "" { safeName := hash2hex(name) c.persistDirectory = filepath.Join(dbDir, safeName) - // Create dir - err := os.MkdirAll(c.persistDirectory, 0o700) - if err != nil { - return nil, fmt.Errorf("couldn't create collection directory: %w", err) - } // Persist name and metadata metadataPath := filepath.Join(c.persistDirectory, metadataFileName) metadataPath += ".gob" @@ -63,7 +57,7 @@ func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, Name: name, Metadata: m, } - err = persist(metadataPath, pc, false, "") + err := persist(metadataPath, pc, false, "") if err != nil { return nil, fmt.Errorf("couldn't persist collection metadata: %w", err) } diff --git a/db.go b/db.go index 46679ce..99910bf 100644 --- a/db.go +++ b/db.go @@ -251,13 +251,6 @@ func (db *DB) Export(filePath string, compress bool, encryptionKey string) error } } - // Create parent dir if it doesn't exist - parentDir := filepath.Dir(filePath) - err := os.MkdirAll(parentDir, 0o700) - if err != nil { - return fmt.Errorf("couldn't create parent directory: %w", err) - } - // Create persistence structs with exported fields so that they can be encoded // as gob. type persistenceCollection struct { @@ -282,7 +275,7 @@ func (db *DB) Export(filePath string, compress bool, encryptionKey string) error } } - err = persist(filePath, persistenceDB, compress, encryptionKey) + err := persist(filePath, persistenceDB, compress, encryptionKey) if err != nil { return fmt.Errorf("couldn't export DB: %w", err) } From cc361c365fb3c6f1934e3b1ff14f8960ffbfe5fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Fri, 22 Mar 2024 20:12:03 +0100 Subject: [PATCH 18/20] Extend Godoc --- db.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/db.go b/db.go index 99910bf..5f0d6e2 100644 --- a/db.go +++ b/db.go @@ -33,6 +33,9 @@ type DB struct { } // NewDB creates a new in-memory chromem-go DB. +// While it doesn't write files when you add collections and documents, you can +// still use [DB.Export] and [DB.Import] to export and import the the entire DB +// from a file. func NewDB() *DB { return &DB{ collections: make(map[string]*Collection), @@ -51,6 +54,10 @@ func NewDB() *DB { // Currently the persistence is done synchronously on each write operation, and // each document addition leads to a new file, encoded as gob. In the future we // will make this configurable (encoding, async writes, WAL-based writes, etc.). +// +// In addition to persistence for each added collection and document you can use +// [DB.Export] and [DB.Import] to export and import the entire DB to/from a file, +// which also works for the pure in-memory DB. func NewPersistentDB(path string) (*DB, error) { if path == "" { path = "./chromem-go" From 3ee41fc06b4003267db21998f2c36a291454f9e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Fri, 22 Mar 2024 20:21:38 +0100 Subject: [PATCH 19/20] Update README --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 5b817ef..0cd17b8 100644 --- a/README.md +++ b/README.md @@ -148,7 +148,8 @@ See the Godoc for details: Date: Fri, 22 Mar 2024 20:36:38 +0100 Subject: [PATCH 20/20] Rename decryptionKey to encryptionKey The separate naming is uncommon as the key is the same in symmetric encryption --- db.go | 12 ++++++------ persistence.go | 14 +++++++------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/db.go b/db.go index 5f0d6e2..6a6311c 100644 --- a/db.go +++ b/db.go @@ -172,15 +172,15 @@ func NewPersistentDB(path string) (*DB, error) { // Existing collections are overwritten. // // - filePath: Mandatory, must not be empty -// - decryptionKey: Optional, must be 32 bytes long if provided -func (db *DB) Import(filePath string, decryptionKey string) error { +// - encryptionKey: Optional, must be 32 bytes long if provided +func (db *DB) Import(filePath string, encryptionKey string) error { if filePath == "" { return fmt.Errorf("file path is empty") } - if decryptionKey != "" { + if encryptionKey != "" { // AES 256 requires a 32 byte key - if len(decryptionKey) != 32 { - return errors.New("decryption key must be 32 bytes long") + if len(encryptionKey) != 32 { + return errors.New("encryption key must be 32 bytes long") } } @@ -211,7 +211,7 @@ func (db *DB) Import(filePath string, decryptionKey string) error { db.collectionsLock.Lock() defer db.collectionsLock.Unlock() - err = read(filePath, &persistenceDB, decryptionKey) + err = read(filePath, &persistenceDB, encryptionKey) if err != nil { return fmt.Errorf("couldn't read file: %w", err) } diff --git a/persistence.go b/persistence.go index 75dd1e9..4a385d4 100644 --- a/persistence.go +++ b/persistence.go @@ -134,16 +134,16 @@ func persist(filePath string, obj any, compress bool, encryptionKey string) erro // read reads an object from a file at the given path. The object is deserialized // from gob. `obj` must be a pointer to an instantiated object. The file may -// optionally be compressed as gzip and/or encrypted with AES-GCM. The decryption +// optionally be compressed as gzip and/or encrypted with AES-GCM. The encryption // key must be 32 bytes long. -func read(filePath string, obj any, decryptionKey string) error { +func read(filePath string, obj any, encryptionKey string) error { if filePath == "" { return fmt.Errorf("file path is empty") } // AES 256 requires a 32 byte key - if decryptionKey != "" { - if len(decryptionKey) != 32 { - return errors.New("decryption key must be 32 bytes long") + if encryptionKey != "" { + if len(encryptionKey) != 32 { + return errors.New("encryption key must be 32 bytes long") } } @@ -155,12 +155,12 @@ func read(filePath string, obj any, decryptionKey string) error { var r io.Reader // Decrypt if an encryption key is provided - if decryptionKey != "" { + if encryptionKey != "" { encrypted, err := os.ReadFile(filePath) if err != nil { return fmt.Errorf("couldn't read file: %w", err) } - block, err := aes.NewCipher([]byte(decryptionKey)) + block, err := aes.NewCipher([]byte(encryptionKey)) if err != nil { return fmt.Errorf("couldn't create AES cipher: %w", err) }