Skip to content

Commit

Permalink
Merge pull request #71 from philippgille/export-to-writer
Browse files Browse the repository at this point in the history
Export to io.Writer
  • Loading branch information
philippgille authored May 5, 2024
2 parents c33c572 + 0ef9693 commit 82f4efe
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 36 deletions.
8 changes: 3 additions & 5 deletions collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func newCollection(name string, metadata map[string]string, embed EmbeddingFunc,
Name: name,
Metadata: m,
}
err := persist(metadataPath, pc, compress, "")
err := persistToFile(metadataPath, pc, compress, "")
if err != nil {
return nil, fmt.Errorf("couldn't persist collection metadata: %w", err)
}
Expand Down Expand Up @@ -237,7 +237,7 @@ func (c *Collection) AddDocument(ctx context.Context, doc Document) error {
// Persist the document
if c.persistDirectory != "" {
docPath := c.getDocPath(doc.ID)
err := persist(docPath, doc, c.compress, "")
err := persistToFile(docPath, doc, c.compress, "")
if err != nil {
return fmt.Errorf("couldn't persist document to %q: %w", docPath, err)
}
Expand All @@ -252,7 +252,6 @@ func (c *Collection) AddDocument(ctx context.Context, doc Document) error {
// - whereDocument: Conditional filtering on documents. Optional.
// - ids: The ids of the documents to delete. If empty, all documents are deleted.
func (c *Collection) Delete(_ context.Context, where, whereDocument map[string]string, ids ...string) error {

// must have at least one of where, whereDocument or ids
if len(where) == 0 && len(whereDocument) == 0 && len(ids) == 0 {
return fmt.Errorf("must have at least one of where, whereDocument or ids")
Expand Down Expand Up @@ -294,15 +293,14 @@ func (c *Collection) Delete(_ context.Context, where, whereDocument map[string]s
// Remove the document from disk
if c.persistDirectory != "" {
docPath := c.getDocPath(docID)
err := remove(docPath)
err := removeFile(docPath)
if err != nil {
return fmt.Errorf("couldn't remove document at %q: %w", docPath, err)
}
}
}

return nil

}

// Count returns the number of documents in the collection.
Expand Down
73 changes: 69 additions & 4 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"io"
"io/fs"
"os"
"path/filepath"
Expand Down Expand Up @@ -142,7 +143,7 @@ func NewPersistentDB(path string, compress bool) (*DB, error) {
Name string
Metadata map[string]string
}{}
err := read(fPath, &pc, "")
err := readFromFile(fPath, &pc, "")
if err != nil {
return nil, fmt.Errorf("couldn't read collection metadata: %w", err)
}
Expand All @@ -151,7 +152,7 @@ func NewPersistentDB(path string, compress bool) (*DB, error) {
} else if strings.HasSuffix(collectionDirEntry.Name(), ext) {
// Read document
d := &Document{}
err := read(fPath, d, "")
err := readFromFile(fPath, d, "")
if err != nil {
return nil, fmt.Errorf("couldn't read document: %w", err)
}
Expand Down Expand Up @@ -223,7 +224,7 @@ func (db *DB) Import(filePath string, encryptionKey string) error {
db.collectionsLock.Lock()
defer db.collectionsLock.Unlock()

err = read(filePath, &persistenceDB, encryptionKey)
err = readFromFile(filePath, &persistenceDB, encryptionKey)
if err != nil {
return fmt.Errorf("couldn't read file: %w", err)
}
Expand Down Expand Up @@ -254,7 +255,22 @@ func (db *DB) Import(filePath string, encryptionKey string) error {
// - compress: Optional. Compresses as gzip if true.
// - encryptionKey: Optional. Encrypts with AES-GCM if provided. Must be 32 bytes
// long if provided.
//
// Deprecated: Use [DB.ExportToFile] instead.
func (db *DB) Export(filePath string, compress bool, encryptionKey string) error {
return db.ExportToFile(filePath, compress, encryptionKey)
}

// ExportToFile exports the DB to a file at the given path. The file is encoded as gob,
// optionally compressed with flate (as gzip) and optionally encrypted with AES-GCM.
// This works for both the in-memory and persistent DBs.
// If the file exists, it's overwritten, otherwise created.
//
// - filePath: If empty, it defaults to "./chromem-go.gob" (+ ".gz" + ".enc")
// - compress: Optional. Compresses as gzip if true.
// - encryptionKey: Optional. Encrypts with AES-GCM if provided. Must be 32 bytes
// long if provided.
func (db *DB) ExportToFile(filePath string, compress bool, encryptionKey string) error {
if filePath == "" {
filePath = "./chromem-go.gob"
if compress {
Expand Down Expand Up @@ -295,7 +311,56 @@ func (db *DB) Export(filePath string, compress bool, encryptionKey string) error
}
}

err := persist(filePath, persistenceDB, compress, encryptionKey)
err := persistToFile(filePath, persistenceDB, compress, encryptionKey)
if err != nil {
return fmt.Errorf("couldn't export DB: %w", err)
}

return nil
}

// ExportToWriter exports the DB to a writer. The stream is encoded as gob,
// optionally compressed with flate (as gzip) and optionally encrypted with AES-GCM.
// This works for both the in-memory and persistent DBs.
// If the writer has to be closed, it's the caller's responsibility.
//
// - writer: An implementation of [io.Writer]
// - compress: Optional. Compresses as gzip if true.
// - encryptionKey: Optional. Encrypts with AES-GCM if provided. Must be 32 bytes
// long if provided.
func (db *DB) ExportToWriter(writer io.Writer, compress bool, encryptionKey string) error {
if encryptionKey != "" {
// AES 256 requires a 32 byte key
if len(encryptionKey) != 32 {
return errors.New("encryption key must be 32 bytes long")
}
}

// Create persistence structs with exported fields so that they can be encoded
// as gob.
type persistenceCollection struct {
Name string
Metadata map[string]string
Documents map[string]*Document
}
persistenceDB := struct {
Collections map[string]*persistenceCollection
}{
Collections: make(map[string]*persistenceCollection, len(db.collections)),
}

db.collectionsLock.RLock()
defer db.collectionsLock.RUnlock()

for k, v := range db.collections {
persistenceDB.Collections[k] = &persistenceCollection{
Name: v.Name,
Metadata: v.metadata,
Documents: v.documents,
}
}

err := persistToWriter(writer, persistenceDB, compress, encryptionKey)
if err != nil {
return fmt.Errorf("couldn't export DB: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func TestDB_ImportExport(t *testing.T) {
}

// Export
err = orig.Export(tc.filePath, tc.compress, tc.encryptionKey)
err = orig.ExportToFile(tc.filePath, tc.compress, tc.encryptionKey)
if err != nil {
t.Fatal("expected no error, got", err)
}
Expand Down
56 changes: 36 additions & 20 deletions persistence.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ func hash2hex(name string) string {
return hex.EncodeToString(hash[:4])
}

// persist persists an object to a file at the given path. The object is serialized
// persistToFile persists an object to a file at the given path. The object is serialized
// as gob, optionally compressed with flate (as gzip) and optionally encrypted with
// AES-GCM. The encryption key must be 32 bytes long. If the file exists, it's
// overwritten, otherwise created.
func persist(filePath string, obj any, compress bool, encryptionKey string) error {
func persistToFile(filePath string, obj any, compress bool, encryptionKey string) error {
if filePath == "" {
return fmt.Errorf("file path is empty")
}
Expand Down Expand Up @@ -66,25 +66,41 @@ func persist(filePath string, obj any, compress bool, encryptionKey string) erro
}
defer f.Close()

return persistToWriter(f, obj, compress, encryptionKey)
}

// persistToWriter persists an object to a writer. The object is serialized
// as gob, optionally compressed with flate (as gzip) and optionally encrypted with
// AES-GCM. The encryption key must be 32 bytes long.
// If the writer has to be closed, it's the caller's responsibility.
func persistToWriter(w io.Writer, obj any, compress bool, encryptionKey string) error {
// AES 256 requires a 32 byte key
if encryptionKey != "" {
if len(encryptionKey) != 32 {
return errors.New("encryption key must be 32 bytes long")
}
}

// We want to:
// Encode as gob -> compress with flate -> encrypt with AES-GCM -> write file.
// Encode as gob -> compress with flate -> encrypt with AES-GCM -> write to
// passed writer.
// To reduce memory usage we chain the writers instead of buffering, so we start
// from the end. For AES GCM sealing the stdlib doesn't provide a writer though.

var w io.Writer
var chainedWriter io.Writer
if encryptionKey == "" {
w = f
chainedWriter = w
} else {
w = &bytes.Buffer{}
chainedWriter = &bytes.Buffer{}
}

var gzw *gzip.Writer
var enc *gob.Encoder
if compress {
gzw = gzip.NewWriter(w)
gzw = gzip.NewWriter(chainedWriter)
enc = gob.NewEncoder(gzw)
} else {
enc = gob.NewEncoder(w)
enc = gob.NewEncoder(chainedWriter)
}

// Start encoding, it will write to the chain of writers.
Expand All @@ -93,22 +109,22 @@ func persist(filePath string, obj any, compress bool, encryptionKey string) erro
}

// If compressing, close the gzip writer. Otherwise the gzip footer won't be
// written yet. When using encryption (and w is a buffer) then we'll encrypt
// an incomplete file. Without encryption when we return here and having
// written yet. When using encryption (and chainedWriter is a buffer) then
// we'll encrypt an incomplete stream. Without encryption when we return here and having
// a deferred Close(), there might be a silenced error.
if compress {
err = gzw.Close()
err := gzw.Close()
if err != nil {
return fmt.Errorf("couldn't close gzip writer: %w", err)
}
}

// Without encyrption, the chain is done and the file is written.
// Without encyrption, the chain is done and the writing is finished.
if encryptionKey == "" {
return nil
}

// Otherwise, encrypt and then write to the file
// Otherwise, encrypt and then write to the unchained target writer.
block, err := aes.NewCipher([]byte(encryptionKey))
if err != nil {
return fmt.Errorf("couldn't create new AES cipher: %w", err)
Expand All @@ -121,22 +137,22 @@ func persist(filePath string, obj any, compress bool, encryptionKey string) erro
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return fmt.Errorf("couldn't read random bytes for nonce: %w", err)
}
// w is a *bytes.Buffer
buf := w.(*bytes.Buffer)
// chainedWriter is a *bytes.Buffer
buf := chainedWriter.(*bytes.Buffer)
encrypted := gcm.Seal(nonce, nonce, buf.Bytes(), nil)
_, err = f.Write(encrypted)
_, err = w.Write(encrypted)
if err != nil {
return fmt.Errorf("couldn't write encrypted data: %w", err)
}

return nil
}

// read reads an object from a file at the given path. The object is deserialized
// readFromFile reads an object from a file at the given path. The object is deserialized
// from gob. `obj` must be a pointer to an instantiated object. The file may
// optionally be compressed as gzip and/or encrypted with AES-GCM. The encryption
// key must be 32 bytes long.
func read(filePath string, obj any, encryptionKey string) error {
func readFromFile(filePath string, obj any, encryptionKey string) error {
if filePath == "" {
return fmt.Errorf("file path is empty")
}
Expand Down Expand Up @@ -226,8 +242,8 @@ func read(filePath string, obj any, encryptionKey string) error {
return nil
}

// remove removes a file at the given path. If the file doesn't exist, it's a no-op.
func remove(filePath string) error {
// removeFile removes a file at the given path. If the file doesn't exist, it's a no-op.
func removeFile(filePath string) error {
if filePath == "" {
return fmt.Errorf("file path is empty")
}
Expand Down
12 changes: 6 additions & 6 deletions persistence_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func TestPersistenceWrite(t *testing.T) {

t.Run("gob", func(t *testing.T) {
tempFilePath := tempDir + ".gob"
persist(tempFilePath, obj, false, "")
persistToFile(tempFilePath, obj, false, "")

// Check if the file exists.
_, err = os.Stat(tempFilePath)
Expand Down Expand Up @@ -57,7 +57,7 @@ func TestPersistenceWrite(t *testing.T) {

t.Run("gob gzipped", func(t *testing.T) {
tempFilePath := tempDir + ".gob.gz"
persist(tempFilePath, obj, true, "")
persistToFile(tempFilePath, obj, true, "")

// Check if the file exists.
_, err = os.Stat(tempFilePath)
Expand Down Expand Up @@ -123,7 +123,7 @@ func TestPersistenceRead(t *testing.T) {

// Read the file.
var res s
err = read(tempFilePath, &res, "")
err = readFromFile(tempFilePath, &res, "")
if err != nil {
t.Fatal("expected nil, got", err)
}
Expand Down Expand Up @@ -157,7 +157,7 @@ func TestPersistenceRead(t *testing.T) {

// Read the file.
var res s
err = read(tempFilePath, &res, "")
err = readFromFile(tempFilePath, &res, "")
if err != nil {
t.Fatal("expected nil, got", err)
}
Expand Down Expand Up @@ -207,7 +207,7 @@ func TestPersistenceEncryption(t *testing.T) {

for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
err := persist(tc.filePath, obj, tc.compress, encryptionKey)
err := persistToFile(tc.filePath, obj, tc.compress, encryptionKey)
if err != nil {
t.Fatal("expected nil, got", err)
}
Expand All @@ -220,7 +220,7 @@ func TestPersistenceEncryption(t *testing.T) {

// Read the file.
var res s
err = read(tc.filePath, &res, encryptionKey)
err = readFromFile(tc.filePath, &res, encryptionKey)
if err != nil {
t.Fatal("expected nil, got", err)
}
Expand Down

0 comments on commit 82f4efe

Please sign in to comment.