diff --git a/db.go b/db.go index 518f1db..50838b7 100644 --- a/db.go +++ b/db.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "io" "io/fs" "os" "path/filepath" @@ -142,7 +143,7 @@ func NewPersistentDB(path string, compress bool) (*DB, error) { Name string Metadata map[string]string }{} - err := read(fPath, &pc, "") + err := readFromFile(fPath, &pc, "") if err != nil { return nil, fmt.Errorf("couldn't read collection metadata: %w", err) } @@ -151,7 +152,7 @@ func NewPersistentDB(path string, compress bool) (*DB, error) { } else if strings.HasSuffix(collectionDirEntry.Name(), ext) { // Read document d := &Document{} - err := read(fPath, d, "") + err := readFromFile(fPath, d, "") if err != nil { return nil, fmt.Errorf("couldn't read document: %w", err) } @@ -185,7 +186,21 @@ func NewPersistentDB(path string, compress bool) (*DB, error) { // // - filePath: Mandatory, must not be empty // - encryptionKey: Optional, must be 32 bytes long if provided +// +// Deprecated: Use [DB.ImportFromFile] instead. func (db *DB) Import(filePath string, encryptionKey string) error { + return db.ImportFromFile(filePath, encryptionKey) +} + +// ImportFromFile imports the DB from a file at the given path. The file must be +// encoded as gob and can optionally be compressed with flate (as gzip) and encrypted +// with AES-GCM. +// This works for both the in-memory and persistent DBs. +// Existing collections are overwritten. +// +// - filePath: Mandatory, must not be empty +// - encryptionKey: Optional, must be 32 bytes long if provided +func (db *DB) ImportFromFile(filePath string, encryptionKey string) error { if filePath == "" { return fmt.Errorf("file path is empty") } @@ -223,7 +238,7 @@ func (db *DB) Import(filePath string, encryptionKey string) error { db.collectionsLock.Lock() defer db.collectionsLock.Unlock() - err = read(filePath, &persistenceDB, encryptionKey) + err = readFromFile(filePath, &persistenceDB, encryptionKey) if err != nil { return fmt.Errorf("couldn't read file: %w", err) } @@ -245,6 +260,61 @@ func (db *DB) Import(filePath string, encryptionKey string) error { return nil } +// ImportFromReader imports the DB from a reader. The stream must be encoded as +// gob and can optionally be compressed with flate (as gzip) and encrypted with +// AES-GCM. +// This works for both the in-memory and persistent DBs. +// Existing collections are overwritten. +// If the writer has to be closed, it's the caller's responsibility. +// +// - reader: An implementation of [io.ReadSeeker] +// - encryptionKey: Optional, must be 32 bytes long if provided +func (db *DB) ImportFromReader(reader io.ReadSeeker, encryptionKey string) error { + if encryptionKey != "" { + // AES 256 requires a 32 byte key + if len(encryptionKey) != 32 { + return errors.New("encryption key must be 32 bytes long") + } + } + + // Create persistence structs with exported fields so that they can be decoded + // from gob. + type persistenceCollection struct { + Name string + Metadata map[string]string + Documents map[string]*Document + } + persistenceDB := struct { + Collections map[string]*persistenceCollection + }{ + Collections: make(map[string]*persistenceCollection, len(db.collections)), + } + + db.collectionsLock.Lock() + defer db.collectionsLock.Unlock() + + err := readFromReader(reader, &persistenceDB, encryptionKey) + if err != nil { + return fmt.Errorf("couldn't read stream: %w", err) + } + + for _, pc := range persistenceDB.Collections { + c := &Collection{ + Name: pc.Name, + + metadata: pc.Metadata, + documents: pc.Documents, + } + if db.persistDirectory != "" { + c.persistDirectory = filepath.Join(db.persistDirectory, hash2hex(pc.Name)) + c.compress = db.compress + } + db.collections[c.Name] = c + } + + return nil +} + // Export exports the DB to a file at the given path. The file is encoded as gob, // optionally compressed with flate (as gzip) and optionally encrypted with AES-GCM. // This works for both the in-memory and persistent DBs. diff --git a/db_test.go b/db_test.go index fc0d230..227b2da 100644 --- a/db_test.go +++ b/db_test.go @@ -147,7 +147,7 @@ func TestDB_ImportExport(t *testing.T) { new := NewDB() // Import - err = new.Import(tc.filePath, tc.encryptionKey) + err = new.ImportFromFile(tc.filePath, tc.encryptionKey) if err != nil { t.Fatal("expected no error, got", err) } diff --git a/persistence.go b/persistence.go index 5748afd..b65a125 100644 --- a/persistence.go +++ b/persistence.go @@ -132,11 +132,11 @@ func persist(filePath string, obj any, compress bool, encryptionKey string) erro return nil } -// read reads an object from a file at the given path. The object is deserialized +// readFromFile reads an object from a file at the given path. The object is deserialized // from gob. `obj` must be a pointer to an instantiated object. The file may // optionally be compressed as gzip and/or encrypted with AES-GCM. The encryption // key must be 32 bytes long. -func read(filePath string, obj any, encryptionKey string) error { +func readFromFile(filePath string, obj any, encryptionKey string) error { if filePath == "" { return fmt.Errorf("file path is empty") } @@ -147,18 +147,43 @@ func read(filePath string, obj any, encryptionKey string) error { } } + r, err := os.Open(filePath) + if err != nil { + return fmt.Errorf("couldn't open file: %w", err) + } + defer r.Close() + + return readFromReader(r, obj, encryptionKey) +} + +// readFromReader reads an object from a Reader. The object is deserialized from gob. +// `obj` must be a pointer to an instantiated object. The stream may optionally +// be compressed as gzip and/or encrypted with AES-GCM. The encryption key must +// be 32 bytes long. +// If the reader has to be closed, it's the caller's responsibility. +func readFromReader(r io.ReadSeeker, obj any, encryptionKey string) error { + // AES 256 requires a 32 byte key + if encryptionKey != "" { + if len(encryptionKey) != 32 { + return errors.New("encryption key must be 32 bytes long") + } + } + // We want to: - // Read file -> decrypt with AES-GCM -> decompress with flate -> decode as gob + // Read from reader -> decrypt with AES-GCM -> decompress with flate -> decode + // as gob. // To reduce memory usage we chain the readers instead of buffering, so we start // from the end. For the decryption there's no reader though. - var r io.Reader + // For the chainedReader we don't declare it as ReadSeeker so we can reassign + // the gzip reader to it. + var chainedReader io.Reader // Decrypt if an encryption key is provided if encryptionKey != "" { - encrypted, err := os.ReadFile(filePath) + encrypted, err := io.ReadAll(r) if err != nil { - return fmt.Errorf("couldn't read file: %w", err) + return fmt.Errorf("couldn't read from reader: %w", err) } block, err := aes.NewCipher([]byte(encryptionKey)) if err != nil { @@ -178,28 +203,24 @@ func read(filePath string, obj any, encryptionKey string) error { return fmt.Errorf("couldn't decrypt data: %w", err) } - r = bytes.NewReader(data) + chainedReader = bytes.NewReader(data) } else { - var err error - r, err = os.Open(filePath) - if err != nil { - return fmt.Errorf("couldn't open file: %w", err) - } + chainedReader = r } - // Determine if the file is compressed + // Determine if the stream is compressed magicNumber := make([]byte, 2) - _, err := r.Read(magicNumber) + _, err := chainedReader.Read(magicNumber) if err != nil { - return fmt.Errorf("couldn't read magic number to determine whether the file is compressed: %w", err) + return fmt.Errorf("couldn't read magic number to determine whether the stream is compressed: %w", err) } var compressed bool if magicNumber[0] == 0x1f && magicNumber[1] == 0x8b { compressed = true } - // Reset reader. Both file and bytes.Reader support seeking. - if s, ok := r.(io.Seeker); !ok { + // Reset reader. Both the reader from the param and bytes.Reader support seeking. + if s, ok := chainedReader.(io.Seeker); !ok { return fmt.Errorf("reader doesn't support seeking") } else { _, err := s.Seek(0, 0) @@ -209,15 +230,15 @@ func read(filePath string, obj any, encryptionKey string) error { } if compressed { - gzr, err := gzip.NewReader(r) + gzr, err := gzip.NewReader(chainedReader) if err != nil { return fmt.Errorf("couldn't create gzip reader: %w", err) } defer gzr.Close() - r = gzr + chainedReader = gzr } - dec := gob.NewDecoder(r) + dec := gob.NewDecoder(chainedReader) err = dec.Decode(obj) if err != nil { return fmt.Errorf("couldn't decode object: %w", err) diff --git a/persistence_test.go b/persistence_test.go index 09676f9..64f4c68 100644 --- a/persistence_test.go +++ b/persistence_test.go @@ -123,7 +123,7 @@ func TestPersistenceRead(t *testing.T) { // Read the file. var res s - err = read(tempFilePath, &res, "") + err = readFromFile(tempFilePath, &res, "") if err != nil { t.Fatal("expected nil, got", err) } @@ -157,7 +157,7 @@ func TestPersistenceRead(t *testing.T) { // Read the file. var res s - err = read(tempFilePath, &res, "") + err = readFromFile(tempFilePath, &res, "") if err != nil { t.Fatal("expected nil, got", err) } @@ -220,7 +220,7 @@ func TestPersistenceEncryption(t *testing.T) { // Read the file. var res s - err = read(tc.filePath, &res, encryptionKey) + err = readFromFile(tc.filePath, &res, encryptionKey) if err != nil { t.Fatal("expected nil, got", err) }