diff --git a/collection.go b/collection.go index 5c9c893..b0765cd 100644 --- a/collection.go +++ b/collection.go @@ -15,11 +15,13 @@ import ( type Collection struct { Name string + metadata map[string]string + documents map[string]*Document + documentsLock sync.RWMutex + embed EmbeddingFunc + persistDirectory string - metadata map[string]string - documents map[string]*Document - documentsLock sync.RWMutex - embed EmbeddingFunc + compress bool // ⚠️ When adding fields here, consider adding them to the persistence struct // versions in [DB.Export] and [DB.Import] as well! @@ -27,7 +29,7 @@ type Collection struct { // We don't export this yet to keep the API surface to the bare minimum. // Users create collections via [Client.CreateCollection]. -func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, dbDir string) (*Collection, error) { +func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, dbDir string, compress bool) (*Collection, error) { // We copy the metadata to avoid data races in case the caller modifies the // map after creating the collection while we range over it. m := make(map[string]string, len(metadata)) @@ -47,9 +49,13 @@ func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, if dbDir != "" { safeName := hash2hex(name) c.persistDirectory = filepath.Join(dbDir, safeName) + c.compress = compress // Persist name and metadata metadataPath := filepath.Join(c.persistDirectory, metadataFileName) metadataPath += ".gob" + if c.compress { + metadataPath += ".gz" + } pc := struct { Name string Metadata map[string]string @@ -57,7 +63,7 @@ func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, Name: name, Metadata: m, } - err := persist(metadataPath, pc, false, "") + err := persist(metadataPath, pc, compress, "") if err != nil { return nil, fmt.Errorf("couldn't persist collection metadata: %w", err) } @@ -233,7 +239,10 @@ func (c *Collection) AddDocument(ctx context.Context, doc Document) error { safeID := hash2hex(doc.ID) docPath := filepath.Join(c.persistDirectory, safeID) docPath += ".gob" - err := persist(docPath, doc, false, "") + if c.compress { + docPath += ".gz" + } + err := persist(docPath, doc, c.compress, "") if err != nil { return fmt.Errorf("couldn't persist document: %w", err) } diff --git a/db.go b/db.go index 6a6311c..518f1db 100644 --- a/db.go +++ b/db.go @@ -7,6 +7,7 @@ import ( "io/fs" "os" "path/filepath" + "strings" "sync" ) @@ -24,9 +25,11 @@ type EmbeddingFunc func(ctx context.Context, text string) ([]float32, error) // | DB |-----------| Collection |-----------| Document | // +----+ +------------+ +----------+ type DB struct { - collections map[string]*Collection - collectionsLock sync.RWMutex + collections map[string]*Collection + collectionsLock sync.RWMutex + persistDirectory string + compress bool // ⚠️ When adding fields here, consider adding them to the persistence struct // versions in [DB.Export] and [DB.Import] as well! @@ -44,6 +47,7 @@ func NewDB() *DB { // NewPersistentDB creates a new persistent chromem-go DB. // If the path is empty, it defaults to "./chromem-go". +// If compress is true, the files are compressed with gzip. // // The persistence covers the collections (including their documents) and the metadata. // However it doesn't cover the EmbeddingFunc, as functions can't be serialized. @@ -58,7 +62,7 @@ func NewDB() *DB { // In addition to persistence for each added collection and document you can use // [DB.Export] and [DB.Import] to export and import the entire DB to/from a file, // which also works for the pure in-memory DB. -func NewPersistentDB(path string) (*DB, error) { +func NewPersistentDB(path string, compress bool) (*DB, error) { if path == "" { path = "./chromem-go" } else { @@ -66,9 +70,16 @@ func NewPersistentDB(path string) (*DB, error) { path = filepath.Clean(path) } + // We check for this file extension and skip others + ext := ".gob" + if compress { + ext += ".gz" + } + db := &DB{ - persistDirectory: path, collections: make(map[string]*Collection), + persistDirectory: path, + compress: compress, } // If the directory doesn't exist, create it and return an empty DB. @@ -108,8 +119,9 @@ func NewPersistentDB(path string) (*DB, error) { return nil, fmt.Errorf("couldn't read collection directory: %w", err) } c := &Collection{ - persistDirectory: collectionPath, documents: make(map[string]*Document), + persistDirectory: collectionPath, + compress: compress, // We can fill Name and metadata only after reading // the metadata. // We can fill embed only when the user calls DB.GetCollection() or @@ -124,7 +136,7 @@ func NewPersistentDB(path string) (*DB, error) { fPath := filepath.Join(collectionPath, collectionDirEntry.Name()) // Differentiate between collection metadata, documents and other files. - if collectionDirEntry.Name() == metadataFileName+".gob" { + if collectionDirEntry.Name() == metadataFileName+ext { // Read name and metadata pc := struct { Name string @@ -136,7 +148,7 @@ func NewPersistentDB(path string) (*DB, error) { } c.Name = pc.Name c.metadata = pc.Metadata - } else if filepath.Ext(collectionDirEntry.Name()) == ".gob" { + } else if strings.HasSuffix(collectionDirEntry.Name(), ext) { // Read document d := &Document{} err := read(fPath, d, "") @@ -225,6 +237,7 @@ func (db *DB) Import(filePath string, encryptionKey string) error { } if db.persistDirectory != "" { c.persistDirectory = filepath.Join(db.persistDirectory, hash2hex(pc.Name)) + c.compress = db.compress } db.collections[c.Name] = c } @@ -303,7 +316,7 @@ func (db *DB) CreateCollection(name string, metadata map[string]string, embeddin if embeddingFunc == nil { embeddingFunc = NewEmbeddingFuncDefault() } - collection, err := newCollection(name, metadata, embeddingFunc, db.persistDirectory) + collection, err := newCollection(name, metadata, embeddingFunc, db.persistDirectory, db.compress) if err != nil { return nil, fmt.Errorf("couldn't create collection: %w", err) } diff --git a/db_test.go b/db_test.go index 3c7ba80..fc0d230 100644 --- a/db_test.go +++ b/db_test.go @@ -22,7 +22,7 @@ func TestNewPersistentDB(t *testing.T) { t.Fatal("expected path to not exist, got", err) } - db, err := NewPersistentDB(path) + db, err := NewPersistentDB(path, false) if err != nil { t.Fatal("expected no error, got", err) } @@ -42,7 +42,7 @@ func TestNewPersistentDB(t *testing.T) { } defer os.RemoveAll(path) - db, err := NewPersistentDB(path) + db, err := NewPersistentDB(path, false) if err != nil { t.Fatal("expected no error, got", err) } @@ -60,7 +60,7 @@ func TestNewPersistentDB_Errors(t *testing.T) { } defer os.RemoveAll(f.Name()) - _, err = NewPersistentDB(f.Name()) + _, err = NewPersistentDB(f.Name(), false) if err == nil { t.Fatal("expected error, got nil") } diff --git a/examples/rag-wikipedia-ollama/main.go b/examples/rag-wikipedia-ollama/main.go index ce1cc47..3b865da 100644 --- a/examples/rag-wikipedia-ollama/main.go +++ b/examples/rag-wikipedia-ollama/main.go @@ -39,7 +39,7 @@ func main() { // Set up chromem-go with persistence, so that when the program restarts, the // DB's data is still available. log.Println("Setting up chromem-go...") - db, err := chromem.NewPersistentDB("./db") + db, err := chromem.NewPersistentDB("./db", false) if err != nil { panic(err) } diff --git a/examples/semantic-search-arxiv-openai/main.go b/examples/semantic-search-arxiv-openai/main.go index c52c366..e3c9fd7 100644 --- a/examples/semantic-search-arxiv-openai/main.go +++ b/examples/semantic-search-arxiv-openai/main.go @@ -22,7 +22,7 @@ func main() { // Set up chromem-go with persistence, so that when the program restarts, the // DB's data is still available. log.Println("Setting up chromem-go...") - db, err := chromem.NewPersistentDB("./db") + db, err := chromem.NewPersistentDB("./db", false) if err != nil { panic(err) }