diff --git a/db.go b/db.go index 48f77f5..a172d7f 100644 --- a/db.go +++ b/db.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "os" + "path/filepath" "sync" ) @@ -40,18 +41,92 @@ func NewDB() *DB { func NewPersistentDB(path string) (*DB, error) { if path == "" { path = "./chromem-go" + } else { + // Clean in case the user provides something like "./db/../db" + path = filepath.Clean(path) } - // Make directory if it doesn't exist. - err := os.MkdirAll(path, 0o700) + db := &DB{ + persistDirectory: path, + collections: make(map[string]*Collection), + } + + // If the directory doesn't exist, create it and return an empty DB. + if _, err := os.Stat(path); os.IsNotExist(err) { + err := os.MkdirAll(path, 0o700) + if err != nil { + return nil, fmt.Errorf("couldn't create persistence directory: %w", err) + } + + return db, nil + } + + // Otherwise, read all collections and their documents from the directory. + dirEntries, err := os.ReadDir(path) if err != nil { - return nil, fmt.Errorf("couldn't create persistence directory: %w", err) + return nil, fmt.Errorf("couldn't read persistence directory: %w", err) + } + for _, dirEntry := range dirEntries { + // Collections are subdirectories, so skip any files (which the user might + // have placed). + if !dirEntry.IsDir() { + continue + } + // For each subdirectory, create a collection and read its name, metadata + // and documents. + // TODO: Parallelize this (e.g. chan with $numCPU buffer and $numCPU goroutines + // reading from it). + collectionPath := filepath.Join(path, dirEntry.Name()) + collectionDirEntries, err := os.ReadDir(collectionPath) + if err != nil { + return nil, fmt.Errorf("couldn't read collection directory: %w", err) + } + c := &Collection{ + // We can fill Name, persistDirectory and metadata only after reading + // the metadata. + documents: make(map[string]*document), + // We can fill embed only when the user calls DB.GetCollection() or + // DB.GetOrCreateCollection(). + } + for _, collectionDirEntry := range collectionDirEntries { + // Files should be metadata and documents; skip subdirectories which + // the user might have placed. + if collectionDirEntry.IsDir() { + continue + } + + fPath := filepath.Join(collectionPath, collectionDirEntry.Name()) + // Differentiate between collection metadata, documents and other files. + if collectionDirEntry.Name() == metadataFileName+".gob" { + // Read name and metadata + pc := struct { + Name string + Metadata map[string]string + }{} + err := read(fPath, &pc) + if err != nil { + return nil, fmt.Errorf("couldn't read collection metadata: %w", err) + } + c.Name = pc.Name + c.persistDirectory = filepath.Dir(collectionPath) + c.metadata = pc.Metadata + } else if filepath.Ext(collectionDirEntry.Name()) == ".gob" { + // Read document + d := &document{} + err := read(fPath, d) + if err != nil { + return nil, fmt.Errorf("couldn't read document: %w", err) + } + c.documents[d.ID] = d + } else { + // Might be a file that the user has placed + continue + } + } + db.collections[c.Name] = c } - return &DB{ - persistDirectory: path, - collections: make(map[string]*Collection), - }, nil + return db, nil } // CreateCollection creates a new collection with the given name and metadata. @@ -95,14 +170,30 @@ func (db *DB) ListCollections() map[string]*Collection { } // GetCollection returns the collection with the given name. -// The returned value is a reference to the original collection, so any methods +// The embeddingFunc param is only used if the DB is persistent and was just loaded +// from storage, in which case no embedding func is set yet (funcs are not (de-)serializable). +// It can be nil, in which case the default one will be used. +// The returned collection is a reference to the original collection, so any methods // on the collection like Add() will be reflected on the DB's collection. Those // operations are concurrency-safe. // If the collection doesn't exist, this returns nil. -func (db *DB) GetCollection(name string) *Collection { +func (db *DB) GetCollection(name string, embeddingFunc EmbeddingFunc) *Collection { db.collectionsLock.RLock() defer db.collectionsLock.RUnlock() - return db.collections[name] + + c, ok := db.collections[name] + if !ok { + return nil + } + + if c.embed == nil { + if embeddingFunc == nil { + c.embed = NewEmbeddingFuncDefault() + } else { + c.embed = embeddingFunc + } + } + return c } // GetOrCreateCollection returns the collection with the given name if it exists @@ -114,7 +205,7 @@ func (db *DB) GetCollection(name string) *Collection { // Uses the default embedding function if not provided. func (db *DB) GetOrCreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc) (*Collection, error) { // No need to lock here, because the methods we call do that. - collection := db.GetCollection(name) + collection := db.GetCollection(name, embeddingFunc) if collection == nil { var err error collection, err = db.CreateCollection(name, metadata, embeddingFunc) diff --git a/db_test.go b/db_test.go index d441ba5..449131a 100644 --- a/db_test.go +++ b/db_test.go @@ -96,7 +96,7 @@ func TestDB_GetCollection(t *testing.T) { } // Get collection - c := db.GetCollection(name) + c := db.GetCollection(name, nil) // Check expectations if c.Name != name { diff --git a/persistence.go b/persistence.go index f81a6f6..5462b10 100644 --- a/persistence.go +++ b/persistence.go @@ -18,7 +18,8 @@ func hash2hex(name string) string { return hex.EncodeToString(hash[:4]) } -// persist persists an object to a file at the given path. +// persist persists an object to a file at the given path. The object is serialized +// as gob. func persist(filePath string, obj any) error { filePath += ".gob" @@ -36,3 +37,21 @@ func persist(filePath string, obj any) error { return nil } + +// read reads an object from a file at the given path. The object is deserialized +// from gob. `obj` must be a pointer to an instantiated object. +func read(filePath string, obj any) error { + f, err := os.Open(filePath) + if err != nil { + return fmt.Errorf("couldn't open file '%s': %w", filePath, err) + } + defer f.Close() + + dec := gob.NewDecoder(f) + err = dec.Decode(obj) + if err != nil { + return fmt.Errorf("couldn't decode or read object: %w", err) + } + + return nil +}