Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Import from io.Reader #72

Merged
merged 4 commits into from
May 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 73 additions & 3 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"io"
"io/fs"
"os"
"path/filepath"
Expand Down Expand Up @@ -142,7 +143,7 @@ func NewPersistentDB(path string, compress bool) (*DB, error) {
Name string
Metadata map[string]string
}{}
err := read(fPath, &pc, "")
err := readFromFile(fPath, &pc, "")
if err != nil {
return nil, fmt.Errorf("couldn't read collection metadata: %w", err)
}
Expand All @@ -151,7 +152,7 @@ func NewPersistentDB(path string, compress bool) (*DB, error) {
} else if strings.HasSuffix(collectionDirEntry.Name(), ext) {
// Read document
d := &Document{}
err := read(fPath, d, "")
err := readFromFile(fPath, d, "")
if err != nil {
return nil, fmt.Errorf("couldn't read document: %w", err)
}
Expand Down Expand Up @@ -185,7 +186,21 @@ func NewPersistentDB(path string, compress bool) (*DB, error) {
//
// - filePath: Mandatory, must not be empty
// - encryptionKey: Optional, must be 32 bytes long if provided
//
// Deprecated: Use [DB.ImportFromFile] instead.
func (db *DB) Import(filePath string, encryptionKey string) error {
return db.ImportFromFile(filePath, encryptionKey)
}

// ImportFromFile imports the DB from a file at the given path. The file must be
// encoded as gob and can optionally be compressed with flate (as gzip) and encrypted
// with AES-GCM.
// This works for both the in-memory and persistent DBs.
// Existing collections are overwritten.
//
// - filePath: Mandatory, must not be empty
// - encryptionKey: Optional, must be 32 bytes long if provided
func (db *DB) ImportFromFile(filePath string, encryptionKey string) error {
if filePath == "" {
return fmt.Errorf("file path is empty")
}
Expand Down Expand Up @@ -223,7 +238,7 @@ func (db *DB) Import(filePath string, encryptionKey string) error {
db.collectionsLock.Lock()
defer db.collectionsLock.Unlock()

err = read(filePath, &persistenceDB, encryptionKey)
err = readFromFile(filePath, &persistenceDB, encryptionKey)
if err != nil {
return fmt.Errorf("couldn't read file: %w", err)
}
Expand All @@ -245,6 +260,61 @@ func (db *DB) Import(filePath string, encryptionKey string) error {
return nil
}

// ImportFromReader imports the DB from a reader. The stream must be encoded as
// gob and can optionally be compressed with flate (as gzip) and encrypted with
// AES-GCM.
// This works for both the in-memory and persistent DBs.
// Existing collections are overwritten.
// If the writer has to be closed, it's the caller's responsibility.
//
// - reader: An implementation of [io.ReadSeeker]
// - encryptionKey: Optional, must be 32 bytes long if provided
func (db *DB) ImportFromReader(reader io.ReadSeeker, encryptionKey string) error {
if encryptionKey != "" {
// AES 256 requires a 32 byte key
if len(encryptionKey) != 32 {
return errors.New("encryption key must be 32 bytes long")
}
}

// Create persistence structs with exported fields so that they can be decoded
// from gob.
type persistenceCollection struct {
Name string
Metadata map[string]string
Documents map[string]*Document
}
persistenceDB := struct {
Collections map[string]*persistenceCollection
}{
Collections: make(map[string]*persistenceCollection, len(db.collections)),
}

db.collectionsLock.Lock()
defer db.collectionsLock.Unlock()

err := readFromReader(reader, &persistenceDB, encryptionKey)
if err != nil {
return fmt.Errorf("couldn't read stream: %w", err)
}

for _, pc := range persistenceDB.Collections {
c := &Collection{
Name: pc.Name,

metadata: pc.Metadata,
documents: pc.Documents,
}
if db.persistDirectory != "" {
c.persistDirectory = filepath.Join(db.persistDirectory, hash2hex(pc.Name))
c.compress = db.compress
}
db.collections[c.Name] = c
}

return nil
}

// Export exports the DB to a file at the given path. The file is encoded as gob,
// optionally compressed with flate (as gzip) and optionally encrypted with AES-GCM.
// This works for both the in-memory and persistent DBs.
Expand Down
2 changes: 1 addition & 1 deletion db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func TestDB_ImportExport(t *testing.T) {
new := NewDB()

// Import
err = new.Import(tc.filePath, tc.encryptionKey)
err = new.ImportFromFile(tc.filePath, tc.encryptionKey)
if err != nil {
t.Fatal("expected no error, got", err)
}
Expand Down
61 changes: 41 additions & 20 deletions persistence.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,11 @@ func persist(filePath string, obj any, compress bool, encryptionKey string) erro
return nil
}

// read reads an object from a file at the given path. The object is deserialized
// readFromFile reads an object from a file at the given path. The object is deserialized
// from gob. `obj` must be a pointer to an instantiated object. The file may
// optionally be compressed as gzip and/or encrypted with AES-GCM. The encryption
// key must be 32 bytes long.
func read(filePath string, obj any, encryptionKey string) error {
func readFromFile(filePath string, obj any, encryptionKey string) error {
if filePath == "" {
return fmt.Errorf("file path is empty")
}
Expand All @@ -147,18 +147,43 @@ func read(filePath string, obj any, encryptionKey string) error {
}
}

r, err := os.Open(filePath)
if err != nil {
return fmt.Errorf("couldn't open file: %w", err)
}
defer r.Close()

return readFromReader(r, obj, encryptionKey)
}

// readFromReader reads an object from a Reader. The object is deserialized from gob.
// `obj` must be a pointer to an instantiated object. The stream may optionally
// be compressed as gzip and/or encrypted with AES-GCM. The encryption key must
// be 32 bytes long.
// If the reader has to be closed, it's the caller's responsibility.
func readFromReader(r io.ReadSeeker, obj any, encryptionKey string) error {
// AES 256 requires a 32 byte key
if encryptionKey != "" {
if len(encryptionKey) != 32 {
return errors.New("encryption key must be 32 bytes long")
}
}

// We want to:
// Read file -> decrypt with AES-GCM -> decompress with flate -> decode as gob
// Read from reader -> decrypt with AES-GCM -> decompress with flate -> decode
// as gob.
// To reduce memory usage we chain the readers instead of buffering, so we start
// from the end. For the decryption there's no reader though.

var r io.Reader
// For the chainedReader we don't declare it as ReadSeeker so we can reassign
// the gzip reader to it.
var chainedReader io.Reader

// Decrypt if an encryption key is provided
if encryptionKey != "" {
encrypted, err := os.ReadFile(filePath)
encrypted, err := io.ReadAll(r)
if err != nil {
return fmt.Errorf("couldn't read file: %w", err)
return fmt.Errorf("couldn't read from reader: %w", err)
}
block, err := aes.NewCipher([]byte(encryptionKey))
if err != nil {
Expand All @@ -178,28 +203,24 @@ func read(filePath string, obj any, encryptionKey string) error {
return fmt.Errorf("couldn't decrypt data: %w", err)
}

r = bytes.NewReader(data)
chainedReader = bytes.NewReader(data)
} else {
var err error
r, err = os.Open(filePath)
if err != nil {
return fmt.Errorf("couldn't open file: %w", err)
}
chainedReader = r
}

// Determine if the file is compressed
// Determine if the stream is compressed
magicNumber := make([]byte, 2)
_, err := r.Read(magicNumber)
_, err := chainedReader.Read(magicNumber)
if err != nil {
return fmt.Errorf("couldn't read magic number to determine whether the file is compressed: %w", err)
return fmt.Errorf("couldn't read magic number to determine whether the stream is compressed: %w", err)
}
var compressed bool
if magicNumber[0] == 0x1f && magicNumber[1] == 0x8b {
compressed = true
}

// Reset reader. Both file and bytes.Reader support seeking.
if s, ok := r.(io.Seeker); !ok {
// Reset reader. Both the reader from the param and bytes.Reader support seeking.
if s, ok := chainedReader.(io.Seeker); !ok {
return fmt.Errorf("reader doesn't support seeking")
} else {
_, err := s.Seek(0, 0)
Expand All @@ -209,15 +230,15 @@ func read(filePath string, obj any, encryptionKey string) error {
}

if compressed {
gzr, err := gzip.NewReader(r)
gzr, err := gzip.NewReader(chainedReader)
if err != nil {
return fmt.Errorf("couldn't create gzip reader: %w", err)
}
defer gzr.Close()
r = gzr
chainedReader = gzr
}

dec := gob.NewDecoder(r)
dec := gob.NewDecoder(chainedReader)
err = dec.Decode(obj)
if err != nil {
return fmt.Errorf("couldn't decode object: %w", err)
Expand Down
6 changes: 3 additions & 3 deletions persistence_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func TestPersistenceRead(t *testing.T) {

// Read the file.
var res s
err = read(tempFilePath, &res, "")
err = readFromFile(tempFilePath, &res, "")
if err != nil {
t.Fatal("expected nil, got", err)
}
Expand Down Expand Up @@ -157,7 +157,7 @@ func TestPersistenceRead(t *testing.T) {

// Read the file.
var res s
err = read(tempFilePath, &res, "")
err = readFromFile(tempFilePath, &res, "")
if err != nil {
t.Fatal("expected nil, got", err)
}
Expand Down Expand Up @@ -220,7 +220,7 @@ func TestPersistenceEncryption(t *testing.T) {

// Read the file.
var res s
err = read(tc.filePath, &res, encryptionKey)
err = readFromFile(tc.filePath, &res, encryptionKey)
if err != nil {
t.Fatal("expected nil, got", err)
}
Expand Down