From 6bd5ae6c530a954bf0a029ecf9ce78e9b7281cee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Mon, 29 Apr 2024 23:33:11 +0200 Subject: [PATCH 1/4] Implement persistToWriter --- collection.go | 8 +++---- db.go | 8 +++---- persistence.go | 56 +++++++++++++++++++++++++++++---------------- persistence_test.go | 12 +++++----- 4 files changed, 49 insertions(+), 35 deletions(-) diff --git a/collection.go b/collection.go index e01540a..2a2c859 100644 --- a/collection.go +++ b/collection.go @@ -63,7 +63,7 @@ func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, Name: name, Metadata: m, } - err := persist(metadataPath, pc, compress, "") + err := persistToFile(metadataPath, pc, compress, "") if err != nil { return nil, fmt.Errorf("couldn't persist collection metadata: %w", err) } @@ -237,7 +237,7 @@ func (c *Collection) AddDocument(ctx context.Context, doc Document) error { // Persist the document if c.persistDirectory != "" { docPath := c.getDocPath(doc.ID) - err := persist(docPath, doc, c.compress, "") + err := persistToFile(docPath, doc, c.compress, "") if err != nil { return fmt.Errorf("couldn't persist document to %q: %w", docPath, err) } @@ -252,7 +252,6 @@ func (c *Collection) AddDocument(ctx context.Context, doc Document) error { // - whereDocument: Conditional filtering on documents. Optional. // - ids: The ids of the documents to delete. If empty, all documents are deleted. func (c *Collection) Delete(_ context.Context, where, whereDocument map[string]string, ids ...string) error { - // must have at least one of where, whereDocument or ids if len(where) == 0 && len(whereDocument) == 0 && len(ids) == 0 { return fmt.Errorf("must have at least one of where, whereDocument or ids") @@ -294,7 +293,7 @@ func (c *Collection) Delete(_ context.Context, where, whereDocument map[string]s // Remove the document from disk if c.persistDirectory != "" { docPath := c.getDocPath(docID) - err := remove(docPath) + err := removeFile(docPath) if err != nil { return fmt.Errorf("couldn't remove document at %q: %w", docPath, err) } @@ -302,7 +301,6 @@ func (c *Collection) Delete(_ context.Context, where, whereDocument map[string]s } return nil - } // Count returns the number of documents in the collection. diff --git a/db.go b/db.go index 518f1db..fe7fc80 100644 --- a/db.go +++ b/db.go @@ -142,7 +142,7 @@ func NewPersistentDB(path string, compress bool) (*DB, error) { Name string Metadata map[string]string }{} - err := read(fPath, &pc, "") + err := readFromFile(fPath, &pc, "") if err != nil { return nil, fmt.Errorf("couldn't read collection metadata: %w", err) } @@ -151,7 +151,7 @@ func NewPersistentDB(path string, compress bool) (*DB, error) { } else if strings.HasSuffix(collectionDirEntry.Name(), ext) { // Read document d := &Document{} - err := read(fPath, d, "") + err := readFromFile(fPath, d, "") if err != nil { return nil, fmt.Errorf("couldn't read document: %w", err) } @@ -223,7 +223,7 @@ func (db *DB) Import(filePath string, encryptionKey string) error { db.collectionsLock.Lock() defer db.collectionsLock.Unlock() - err = read(filePath, &persistenceDB, encryptionKey) + err = readFromFile(filePath, &persistenceDB, encryptionKey) if err != nil { return fmt.Errorf("couldn't read file: %w", err) } @@ -295,7 +295,7 @@ func (db *DB) Export(filePath string, compress bool, encryptionKey string) error } } - err := persist(filePath, persistenceDB, compress, encryptionKey) + err := persistToFile(filePath, persistenceDB, compress, encryptionKey) if err != nil { return fmt.Errorf("couldn't export DB: %w", err) } diff --git a/persistence.go b/persistence.go index 5748afd..a74aee9 100644 --- a/persistence.go +++ b/persistence.go @@ -27,11 +27,11 @@ func hash2hex(name string) string { return hex.EncodeToString(hash[:4]) } -// persist persists an object to a file at the given path. The object is serialized +// persistToFile persists an object to a file at the given path. The object is serialized // as gob, optionally compressed with flate (as gzip) and optionally encrypted with // AES-GCM. The encryption key must be 32 bytes long. If the file exists, it's // overwritten, otherwise created. -func persist(filePath string, obj any, compress bool, encryptionKey string) error { +func persistToFile(filePath string, obj any, compress bool, encryptionKey string) error { if filePath == "" { return fmt.Errorf("file path is empty") } @@ -66,25 +66,41 @@ func persist(filePath string, obj any, compress bool, encryptionKey string) erro } defer f.Close() + return persistToWriter(f, obj, compress, encryptionKey) +} + +// persistToWriter persists an object to a writer. The object is serialized +// as gob, optionally compressed with flate (as gzip) and optionally encrypted with +// AES-GCM. The encryption key must be 32 bytes long. +// If the writer has to be closed, it's the caller's responsibility. +func persistToWriter(w io.Writer, obj any, compress bool, encryptionKey string) error { + // AES 256 requires a 32 byte key + if encryptionKey != "" { + if len(encryptionKey) != 32 { + return errors.New("encryption key must be 32 bytes long") + } + } + // We want to: - // Encode as gob -> compress with flate -> encrypt with AES-GCM -> write file. + // Encode as gob -> compress with flate -> encrypt with AES-GCM -> write to + // passed writer. // To reduce memory usage we chain the writers instead of buffering, so we start // from the end. For AES GCM sealing the stdlib doesn't provide a writer though. - var w io.Writer + var chainedWriter io.Writer if encryptionKey == "" { - w = f + chainedWriter = w } else { - w = &bytes.Buffer{} + chainedWriter = &bytes.Buffer{} } var gzw *gzip.Writer var enc *gob.Encoder if compress { - gzw = gzip.NewWriter(w) + gzw = gzip.NewWriter(chainedWriter) enc = gob.NewEncoder(gzw) } else { - enc = gob.NewEncoder(w) + enc = gob.NewEncoder(chainedWriter) } // Start encoding, it will write to the chain of writers. @@ -93,22 +109,22 @@ func persist(filePath string, obj any, compress bool, encryptionKey string) erro } // If compressing, close the gzip writer. Otherwise the gzip footer won't be - // written yet. When using encryption (and w is a buffer) then we'll encrypt - // an incomplete file. Without encryption when we return here and having + // written yet. When using encryption (and chainedWriter is a buffer) then + // we'll encrypt an incomplete stream. Without encryption when we return here and having // a deferred Close(), there might be a silenced error. if compress { - err = gzw.Close() + err := gzw.Close() if err != nil { return fmt.Errorf("couldn't close gzip writer: %w", err) } } - // Without encyrption, the chain is done and the file is written. + // Without encyrption, the chain is done and the writing is finished. if encryptionKey == "" { return nil } - // Otherwise, encrypt and then write to the file + // Otherwise, encrypt and then write to the unchained target writer. block, err := aes.NewCipher([]byte(encryptionKey)) if err != nil { return fmt.Errorf("couldn't create new AES cipher: %w", err) @@ -121,10 +137,10 @@ func persist(filePath string, obj any, compress bool, encryptionKey string) erro if _, err := io.ReadFull(rand.Reader, nonce); err != nil { return fmt.Errorf("couldn't read random bytes for nonce: %w", err) } - // w is a *bytes.Buffer - buf := w.(*bytes.Buffer) + // chainedWriter is a *bytes.Buffer + buf := chainedWriter.(*bytes.Buffer) encrypted := gcm.Seal(nonce, nonce, buf.Bytes(), nil) - _, err = f.Write(encrypted) + _, err = w.Write(encrypted) if err != nil { return fmt.Errorf("couldn't write encrypted data: %w", err) } @@ -132,11 +148,11 @@ func persist(filePath string, obj any, compress bool, encryptionKey string) erro return nil } -// read reads an object from a file at the given path. The object is deserialized +// readFromFile reads an object from a file at the given path. The object is deserialized // from gob. `obj` must be a pointer to an instantiated object. The file may // optionally be compressed as gzip and/or encrypted with AES-GCM. The encryption // key must be 32 bytes long. -func read(filePath string, obj any, encryptionKey string) error { +func readFromFile(filePath string, obj any, encryptionKey string) error { if filePath == "" { return fmt.Errorf("file path is empty") } @@ -226,8 +242,8 @@ func read(filePath string, obj any, encryptionKey string) error { return nil } -// remove removes a file at the given path. If the file doesn't exist, it's a no-op. -func remove(filePath string) error { +// removeFile removes a file at the given path. If the file doesn't exist, it's a no-op. +func removeFile(filePath string) error { if filePath == "" { return fmt.Errorf("file path is empty") } diff --git a/persistence_test.go b/persistence_test.go index 09676f9..4515d58 100644 --- a/persistence_test.go +++ b/persistence_test.go @@ -28,7 +28,7 @@ func TestPersistenceWrite(t *testing.T) { t.Run("gob", func(t *testing.T) { tempFilePath := tempDir + ".gob" - persist(tempFilePath, obj, false, "") + persistToFile(tempFilePath, obj, false, "") // Check if the file exists. _, err = os.Stat(tempFilePath) @@ -57,7 +57,7 @@ func TestPersistenceWrite(t *testing.T) { t.Run("gob gzipped", func(t *testing.T) { tempFilePath := tempDir + ".gob.gz" - persist(tempFilePath, obj, true, "") + persistToFile(tempFilePath, obj, true, "") // Check if the file exists. _, err = os.Stat(tempFilePath) @@ -123,7 +123,7 @@ func TestPersistenceRead(t *testing.T) { // Read the file. var res s - err = read(tempFilePath, &res, "") + err = readFromFile(tempFilePath, &res, "") if err != nil { t.Fatal("expected nil, got", err) } @@ -157,7 +157,7 @@ func TestPersistenceRead(t *testing.T) { // Read the file. var res s - err = read(tempFilePath, &res, "") + err = readFromFile(tempFilePath, &res, "") if err != nil { t.Fatal("expected nil, got", err) } @@ -207,7 +207,7 @@ func TestPersistenceEncryption(t *testing.T) { for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { - err := persist(tc.filePath, obj, tc.compress, encryptionKey) + err := persistToFile(tc.filePath, obj, tc.compress, encryptionKey) if err != nil { t.Fatal("expected nil, got", err) } @@ -220,7 +220,7 @@ func TestPersistenceEncryption(t *testing.T) { // Read the file. var res s - err = read(tc.filePath, &res, encryptionKey) + err = readFromFile(tc.filePath, &res, encryptionKey) if err != nil { t.Fatal("expected nil, got", err) } From abff93784e33357963a11fa5e89b8017b87bc2a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Sat, 4 May 2024 16:54:17 +0200 Subject: [PATCH 2/4] Implement DB.ExportToWriter --- db.go | 50 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/db.go b/db.go index fe7fc80..87e5f22 100644 --- a/db.go +++ b/db.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "io" "io/fs" "os" "path/filepath" @@ -303,6 +304,55 @@ func (db *DB) Export(filePath string, compress bool, encryptionKey string) error return nil } +// ExportToWriter exports the DB to a writer. The stream is encoded as gob, +// optionally compressed with flate (as gzip) and optionally encrypted with AES-GCM. +// This works for both the in-memory and persistent DBs. +// If the writer has to be closed, it's the caller's responsibility. +// +// - writer: An implementation of [io.Writer] +// - compress: Optional. Compresses as gzip if true. +// - encryptionKey: Optional. Encrypts with AES-GCM if provided. Must be 32 bytes +// long if provided. +func (db *DB) ExportToWriter(writer io.Writer, compress bool, encryptionKey string) error { + if encryptionKey != "" { + // AES 256 requires a 32 byte key + if len(encryptionKey) != 32 { + return errors.New("encryption key must be 32 bytes long") + } + } + + // Create persistence structs with exported fields so that they can be encoded + // as gob. + type persistenceCollection struct { + Name string + Metadata map[string]string + Documents map[string]*Document + } + persistenceDB := struct { + Collections map[string]*persistenceCollection + }{ + Collections: make(map[string]*persistenceCollection, len(db.collections)), + } + + db.collectionsLock.RLock() + defer db.collectionsLock.RUnlock() + + for k, v := range db.collections { + persistenceDB.Collections[k] = &persistenceCollection{ + Name: v.Name, + Metadata: v.metadata, + Documents: v.documents, + } + } + + err := persistToWriter(writer, persistenceDB, compress, encryptionKey) + if err != nil { + return fmt.Errorf("couldn't export DB: %w", err) + } + + return nil +} + // CreateCollection creates a new collection with the given name and metadata. // // - name: The name of the collection to create. From 1e74ee0439ebacd14ce3b9639adb9edad53200bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Sat, 4 May 2024 16:55:08 +0200 Subject: [PATCH 3/4] Implement DB.ExportToFile and deprecate DB.Export --- db.go | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/db.go b/db.go index 87e5f22..9e77451 100644 --- a/db.go +++ b/db.go @@ -255,7 +255,22 @@ func (db *DB) Import(filePath string, encryptionKey string) error { // - compress: Optional. Compresses as gzip if true. // - encryptionKey: Optional. Encrypts with AES-GCM if provided. Must be 32 bytes // long if provided. +// +// Deprecated: Use [DB.ExportToFile] instead. func (db *DB) Export(filePath string, compress bool, encryptionKey string) error { + return db.ExportToFile(filePath, compress, encryptionKey) +} + +// ExportToFile exports the DB to a file at the given path. The file is encoded as gob, +// optionally compressed with flate (as gzip) and optionally encrypted with AES-GCM. +// This works for both the in-memory and persistent DBs. +// If the file exists, it's overwritten, otherwise created. +// +// - filePath: If empty, it defaults to "./chromem-go.gob" (+ ".gz" + ".enc") +// - compress: Optional. Compresses as gzip if true. +// - encryptionKey: Optional. Encrypts with AES-GCM if provided. Must be 32 bytes +// long if provided. +func (db *DB) ExportToFile(filePath string, compress bool, encryptionKey string) error { if filePath == "" { filePath = "./chromem-go.gob" if compress { From 0ef9693f2d3b36104f780bbe1076b2bb0ccd0284 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Sat, 4 May 2024 17:35:49 +0200 Subject: [PATCH 4/4] Use new method in unit test --- db_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/db_test.go b/db_test.go index fc0d230..e3516a0 100644 --- a/db_test.go +++ b/db_test.go @@ -139,7 +139,7 @@ func TestDB_ImportExport(t *testing.T) { } // Export - err = orig.Export(tc.filePath, tc.compress, tc.encryptionKey) + err = orig.ExportToFile(tc.filePath, tc.compress, tc.encryptionKey) if err != nil { t.Fatal("expected no error, got", err) }