Skip to content

Commit

Permalink
Merge pull request #58 from philippgille/db-export-import
Browse files Browse the repository at this point in the history
DB export and import with optional compression and encryption
  • Loading branch information
philippgille authored Mar 22, 2024
2 parents 9ab9af9 + b3b2c4b commit d33678d
Show file tree
Hide file tree
Showing 7 changed files with 639 additions and 51 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ See the Godoc for details: <https://pkg.go.dev/github.com/philippgille/chromem-g
- [X] Metadata filters: Exact matches
- Storage:
- [X] In-memory
- [X] Optional local persistence (file based, encoded as [gob](https://go.dev/blog/gob))
- [X] Backups: Export and import of the entire DB to/from a single file (optionally gzip-compressed and AES-GCM encrypted)
- [X] Optional immediate persistence (writes one file for each added collection and document, encoded as [gob](https://go.dev/blog/gob))
- Data types:
- [X] Documents (text)

Expand All @@ -168,8 +169,7 @@ See the Godoc for details: <https://pkg.go.dev/github.com/philippgille/chromem-g
- Storage:
- JSON as second encoding format
- Write-ahead log (WAL) as second file format
- Compression
- Encryption (at rest)
- Compression and encryption not only for the export, but also for each collection/document file
- Optional remote storage (S3, PostgreSQL, ...)
- Data types:
- Images
Expand Down
23 changes: 11 additions & 12 deletions collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"os"
"path/filepath"
"slices"
"sync"
Expand All @@ -21,11 +20,14 @@ type Collection struct {
documents map[string]*Document
documentsLock sync.RWMutex
embed EmbeddingFunc

// ⚠️ When adding fields here, consider adding them to the persistence struct
// versions in [DB.Export] and [DB.Import] as well!
}

// We don't export this yet to keep the API surface to the bare minimum.
// Users create collections via [Client.CreateCollection].
func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, dir string) (*Collection, error) {
func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, dbDir string) (*Collection, error) {
// We copy the metadata to avoid data races in case the caller modifies the
// map after creating the collection while we range over it.
m := make(map[string]string, len(metadata))
Expand All @@ -42,24 +44,20 @@ func newCollection(name string, metadata map[string]string, embed EmbeddingFunc,
}

// Persistence
if dir != "" {
if dbDir != "" {
safeName := hash2hex(name)
c.persistDirectory = filepath.Join(dir, safeName)
// Create dir
err := os.MkdirAll(c.persistDirectory, 0o700)
if err != nil {
return nil, fmt.Errorf("couldn't create collection directory: %w", err)
}
c.persistDirectory = filepath.Join(dbDir, safeName)
// Persist name and metadata
metadataPath := filepath.Join(c.persistDirectory, metadataFileName)
metadataPath += ".gob"
pc := struct {
Name string
Metadata map[string]string
}{
Name: name,
Metadata: m,
}
err = persist(metadataPath, pc)
err := persist(metadataPath, pc, false, "")
if err != nil {
return nil, fmt.Errorf("couldn't persist collection metadata: %w", err)
}
Expand Down Expand Up @@ -233,8 +231,9 @@ func (c *Collection) AddDocument(ctx context.Context, doc Document) error {
// Persist the document
if c.persistDirectory != "" {
safeID := hash2hex(doc.ID)
filePath := filepath.Join(c.persistDirectory, safeID)
err := persist(filePath, doc)
docPath := filepath.Join(c.persistDirectory, safeID)
docPath += ".gob"
err := persist(docPath, doc, false, "")
if err != nil {
return fmt.Errorf("couldn't persist document: %w", err)
}
Expand Down
139 changes: 137 additions & 2 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,15 @@ type DB struct {
collections map[string]*Collection
collectionsLock sync.RWMutex
persistDirectory string

// ⚠️ When adding fields here, consider adding them to the persistence struct
// versions in [DB.Export] and [DB.Import] as well!
}

// NewDB creates a new in-memory chromem-go DB.
// While it doesn't write files when you add collections and documents, you can
// still use [DB.Export] and [DB.Import] to export and import the the entire DB
// from a file.
func NewDB() *DB {
return &DB{
collections: make(map[string]*Collection),
Expand All @@ -48,6 +54,10 @@ func NewDB() *DB {
// Currently the persistence is done synchronously on each write operation, and
// each document addition leads to a new file, encoded as gob. In the future we
// will make this configurable (encoding, async writes, WAL-based writes, etc.).
//
// In addition to persistence for each added collection and document you can use
// [DB.Export] and [DB.Import] to export and import the entire DB to/from a file,
// which also works for the pure in-memory DB.
func NewPersistentDB(path string) (*DB, error) {
if path == "" {
path = "./chromem-go"
Expand Down Expand Up @@ -120,7 +130,7 @@ func NewPersistentDB(path string) (*DB, error) {
Name string
Metadata map[string]string
}{}
err := read(fPath, &pc)
err := read(fPath, &pc, "")
if err != nil {
return nil, fmt.Errorf("couldn't read collection metadata: %w", err)
}
Expand All @@ -129,7 +139,7 @@ func NewPersistentDB(path string) (*DB, error) {
} else if filepath.Ext(collectionDirEntry.Name()) == ".gob" {
// Read document
d := &Document{}
err := read(fPath, d)
err := read(fPath, d, "")
if err != nil {
return nil, fmt.Errorf("couldn't read document: %w", err)
}
Expand All @@ -155,6 +165,131 @@ func NewPersistentDB(path string) (*DB, error) {
return db, nil
}

// Import imports the DB from a file at the given path. The file must be encoded
// as gob and can optionally be compressed with flate (as gzip) and encrypted
// with AES-GCM.
// This works for both the in-memory and persistent DBs.
// Existing collections are overwritten.
//
// - filePath: Mandatory, must not be empty
// - encryptionKey: Optional, must be 32 bytes long if provided
func (db *DB) Import(filePath string, encryptionKey string) error {
if filePath == "" {
return fmt.Errorf("file path is empty")
}
if encryptionKey != "" {
// AES 256 requires a 32 byte key
if len(encryptionKey) != 32 {
return errors.New("encryption key must be 32 bytes long")
}
}

// If the file doesn't exist or is a directory, return an error.
fi, err := os.Stat(filePath)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("file doesn't exist: %s", filePath)
}
return fmt.Errorf("couldn't get info about the file: %w", err)
} else if fi.IsDir() {
return fmt.Errorf("path is a directory: %s", filePath)
}

// Create persistence structs with exported fields so that they can be decoded
// from gob.
type persistenceCollection struct {
Name string
Metadata map[string]string
Documents map[string]*Document
}
persistenceDB := struct {
Collections map[string]*persistenceCollection
}{
Collections: make(map[string]*persistenceCollection, len(db.collections)),
}

db.collectionsLock.Lock()
defer db.collectionsLock.Unlock()

err = read(filePath, &persistenceDB, encryptionKey)
if err != nil {
return fmt.Errorf("couldn't read file: %w", err)
}

for _, pc := range persistenceDB.Collections {
c := &Collection{
Name: pc.Name,

metadata: pc.Metadata,
documents: pc.Documents,
}
if db.persistDirectory != "" {
c.persistDirectory = filepath.Join(db.persistDirectory, hash2hex(pc.Name))
}
db.collections[c.Name] = c
}

return nil
}

// Export exports the DB to a file at the given path. The file is encoded as gob,
// optionally compressed with flate (as gzip) and optionally encrypted with AES-GCM.
// This works for both the in-memory and persistent DBs.
// If the file exists, it's overwritten, otherwise created.
//
// - filePath: If empty, it defaults to "./chromem-go.gob" (+ ".gz" + ".enc")
// - compress: Optional. Compresses as gzip if true.
// - encryptionKey: Optional. Encrypts with AES-GCM if provided. Must be 32 bytes
// long if provided.
func (db *DB) Export(filePath string, compress bool, encryptionKey string) error {
if filePath == "" {
filePath = "./chromem-go.gob"
if compress {
filePath += ".gz"
}
if encryptionKey != "" {
filePath += ".enc"
}
}
if encryptionKey != "" {
// AES 256 requires a 32 byte key
if len(encryptionKey) != 32 {
return errors.New("encryption key must be 32 bytes long")
}
}

// Create persistence structs with exported fields so that they can be encoded
// as gob.
type persistenceCollection struct {
Name string
Metadata map[string]string
Documents map[string]*Document
}
persistenceDB := struct {
Collections map[string]*persistenceCollection
}{
Collections: make(map[string]*persistenceCollection, len(db.collections)),
}

db.collectionsLock.RLock()
defer db.collectionsLock.RUnlock()

for k, v := range db.collections {
persistenceDB.Collections[k] = &persistenceCollection{
Name: v.Name,
Metadata: v.metadata,
Documents: v.documents,
}
}

err := persist(filePath, persistenceDB, compress, encryptionKey)
if err != nil {
return fmt.Errorf("couldn't export DB: %w", err)
}

return nil
}

// CreateCollection creates a new collection with the given name and metadata.
//
// - name: The name of the collection to create.
Expand Down
99 changes: 98 additions & 1 deletion db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ import (

func TestNewPersistentDB(t *testing.T) {
t.Run("Create directory", func(t *testing.T) {
randString := randomString(rand.New(rand.NewSource(rand.Int63())), 10)
r := rand.New(rand.NewSource(rand.Int63()))
randString := randomString(r, 10)
path := filepath.Join(os.TempDir(), randString)
defer os.RemoveAll(path)

Expand Down Expand Up @@ -66,6 +67,102 @@ func TestNewPersistentDB_Errors(t *testing.T) {
})
}

func TestDB_ImportExport(t *testing.T) {
r := rand.New(rand.NewSource(rand.Int63()))
randString := randomString(r, 10)
path := filepath.Join(os.TempDir(), randString)
defer os.RemoveAll(path)

// Values in the collection
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}

tt := []struct {
name string
filePath string
compress bool
encryptionKey string
}{
{
name: "gob",
filePath: path + ".gob",
compress: false,
encryptionKey: "",
},
{
name: "gob compressed",
filePath: path + ".gob.gz",
compress: true,
encryptionKey: "",
},
{
name: "gob compressed encrypted",
filePath: path + ".gob.gz.enc",
compress: true,
encryptionKey: randomString(r, 32),
},
{
name: "gob encrypted",
filePath: path + ".gob.enc",
compress: false,
encryptionKey: randomString(r, 32),
},
}

for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
// Create DB, can just be in-memory
orig := NewDB()

// Create collection
c, err := orig.CreateCollection(name, metadata, embeddingFunc)
if err != nil {
t.Fatal("expected no error, got", err)
}
if c == nil {
t.Fatal("expected collection, got nil")
}
// Add document
doc := Document{
ID: name,
Metadata: metadata,
Embedding: vectors,
Content: "test",
}
err = c.AddDocument(context.Background(), doc)
if err != nil {
t.Fatal("expected no error, got", err)
}

// Export
err = orig.Export(tc.filePath, tc.compress, tc.encryptionKey)
if err != nil {
t.Fatal("expected no error, got", err)
}

new := NewDB()

// Import
err = new.Import(tc.filePath, tc.encryptionKey)
if err != nil {
t.Fatal("expected no error, got", err)
}

// Check expectations
// We have to reset the embed function, but otherwise the DB objects
// should be deep equal.
c.embed = nil
if !reflect.DeepEqual(orig, new) {
t.Fatalf("expected DB %+v, got %+v", orig, new)
}
})
}
}

func TestDB_CreateCollection(t *testing.T) {
// Values in the collection
name := "test"
Expand Down
3 changes: 3 additions & 0 deletions document.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ type Document struct {
Metadata map[string]string
Embedding []float32
Content string

// ⚠️ When adding unexported fields here, consider adding a persistence struct
// version of this in [DB.Export] and [DB.Import].
}

// NewDocument creates a new document, including its embeddings.
Expand Down
Loading

0 comments on commit d33678d

Please sign in to comment.