From 797931a33868579b695f8a6b006043a8c8bbbc0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Sun, 25 Feb 2024 16:22:38 +0100 Subject: [PATCH 1/5] Implement reading of persisted DB --- db.go | 85 +++++++++++++++++++++++++++++++++++++++++++++----- persistence.go | 21 ++++++++++++- 2 files changed, 98 insertions(+), 8 deletions(-) diff --git a/db.go b/db.go index 48f77f5..abc4471 100644 --- a/db.go +++ b/db.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "os" + "path/filepath" "sync" ) @@ -42,16 +43,86 @@ func NewPersistentDB(path string) (*DB, error) { path = "./chromem-go" } - // Make directory if it doesn't exist. - err := os.MkdirAll(path, 0o700) + db := &DB{ + persistDirectory: path, + collections: make(map[string]*Collection), + } + + // If the directory doesn't exist, create it and return an empty DB. + if _, err := os.Stat(path); os.IsNotExist(err) { + err := os.MkdirAll(path, 0o700) + if err != nil { + return nil, fmt.Errorf("couldn't create persistence directory: %w", err) + } + + return db, nil + } + + // Otherwise, read all collections and their documents from the directory. + err := filepath.WalkDir(path, func(p string, info os.DirEntry, err error) error { + if err != nil { + return fmt.Errorf("couldn't walk DB directory: %w", err) + } + // First level is the subdirectories for the collections, so skip any files. + if !info.IsDir() { + return nil + } + // For each subdirectory, create a collection and read its name, metadata + // and documents. + // TODO: Parallelize this (e.g. chan with $numCPU buffer and $numCPU goroutines + // reading from it). + c := &Collection{ + // We can fill Name, persistDirectory and metadata only after reading + // the metadata. + documents: make(map[string]*document), + // We can fill embed only when the user calls DB.GetCollection() or + // DB.GetOrCreateCollection(). + } + err = filepath.WalkDir(filepath.Join(path, info.Name()), func(p string, info os.DirEntry, err error) error { + if err != nil { + return fmt.Errorf("couldn't walk collection directory: %w", err) + } + // Files should be metadata and documents; skip subdirectories. + if info.IsDir() { + return nil + } + + if info.Name() == metadataFileName+".gob" { + pc := struct { + Name string + Metadata map[string]string + }{} + err := read(p, &pc) + if err != nil { + return fmt.Errorf("couldn't read collection metadata: %w", err) + } + c.Name = pc.Name + c.persistDirectory = filepath.Dir(p) + c.metadata = pc.Metadata + } else { + // Read document + d := &document{} + err := read(p, d) + if err != nil { + return fmt.Errorf("couldn't read document: %w", err) + } + c.documents[d.ID] = d + } + + return nil + }) + if err != nil { + return fmt.Errorf("couldn't read collection directory: %w", err) + } + db.collections[c.Name] = c + + return nil + }) if err != nil { - return nil, fmt.Errorf("couldn't create persistence directory: %w", err) + return nil, fmt.Errorf("couldn't read persisted database: %w", err) } - return &DB{ - persistDirectory: path, - collections: make(map[string]*Collection), - }, nil + return db, nil } // CreateCollection creates a new collection with the given name and metadata. diff --git a/persistence.go b/persistence.go index f81a6f6..5462b10 100644 --- a/persistence.go +++ b/persistence.go @@ -18,7 +18,8 @@ func hash2hex(name string) string { return hex.EncodeToString(hash[:4]) } -// persist persists an object to a file at the given path. +// persist persists an object to a file at the given path. The object is serialized +// as gob. func persist(filePath string, obj any) error { filePath += ".gob" @@ -36,3 +37,21 @@ func persist(filePath string, obj any) error { return nil } + +// read reads an object from a file at the given path. The object is deserialized +// from gob. `obj` must be a pointer to an instantiated object. +func read(filePath string, obj any) error { + f, err := os.Open(filePath) + if err != nil { + return fmt.Errorf("couldn't open file '%s': %w", filePath, err) + } + defer f.Close() + + dec := gob.NewDecoder(f) + err = dec.Decode(obj) + if err != nil { + return fmt.Errorf("couldn't decode or read object: %w", err) + } + + return nil +} From c24bb3799bb18c5fa0ac9668b0f4aae6933add89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Sun, 25 Feb 2024 18:48:06 +0100 Subject: [PATCH 2/5] Fix root dir is read as collection dir --- db.go | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/db.go b/db.go index abc4471..a0c626a 100644 --- a/db.go +++ b/db.go @@ -41,6 +41,9 @@ func NewDB() *DB { func NewPersistentDB(path string) (*DB, error) { if path == "" { path = "./chromem-go" + } else { + // Clean in case the user provides something like "./db/../db" + path = filepath.Clean(path) } db := &DB{ @@ -63,6 +66,10 @@ func NewPersistentDB(path string) (*DB, error) { if err != nil { return fmt.Errorf("couldn't walk DB directory: %w", err) } + // WalkDir reads root, which we can skip. + if path == p { + return nil + } // First level is the subdirectories for the collections, so skip any files. if !info.IsDir() { return nil @@ -78,7 +85,8 @@ func NewPersistentDB(path string) (*DB, error) { // We can fill embed only when the user calls DB.GetCollection() or // DB.GetOrCreateCollection(). } - err = filepath.WalkDir(filepath.Join(path, info.Name()), func(p string, info os.DirEntry, err error) error { + collectionPath := filepath.Join(path, info.Name()) + err = filepath.WalkDir(collectionPath, func(p string, info os.DirEntry, err error) error { if err != nil { return fmt.Errorf("couldn't walk collection directory: %w", err) } From acccec02090c391117949e6caf9da832faf666cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Sun, 25 Feb 2024 20:02:53 +0100 Subject: [PATCH 3/5] Use os.ReadDir() instead of filepath.WalkDir() Simpler to handle, more efficient --- db.go | 68 ++++++++++++++++++++++++++++------------------------------- 1 file changed, 32 insertions(+), 36 deletions(-) diff --git a/db.go b/db.go index a0c626a..855806f 100644 --- a/db.go +++ b/db.go @@ -62,17 +62,15 @@ func NewPersistentDB(path string) (*DB, error) { } // Otherwise, read all collections and their documents from the directory. - err := filepath.WalkDir(path, func(p string, info os.DirEntry, err error) error { - if err != nil { - return fmt.Errorf("couldn't walk DB directory: %w", err) - } - // WalkDir reads root, which we can skip. - if path == p { - return nil - } - // First level is the subdirectories for the collections, so skip any files. - if !info.IsDir() { - return nil + dirEntries, err := os.ReadDir(path) + if err != nil { + return nil, fmt.Errorf("couldn't read persistence directory: %w", err) + } + for _, dirEntry := range dirEntries { + // Collections are subdirectories, so skip any files (which the user might + // have placed). + if !dirEntry.IsDir() { + continue } // For each subdirectory, create a collection and read its name, metadata // and documents. @@ -85,49 +83,47 @@ func NewPersistentDB(path string) (*DB, error) { // We can fill embed only when the user calls DB.GetCollection() or // DB.GetOrCreateCollection(). } - collectionPath := filepath.Join(path, info.Name()) - err = filepath.WalkDir(collectionPath, func(p string, info os.DirEntry, err error) error { - if err != nil { - return fmt.Errorf("couldn't walk collection directory: %w", err) - } - // Files should be metadata and documents; skip subdirectories. - if info.IsDir() { - return nil + collectionPath := filepath.Join(path, dirEntry.Name()) + collectionDirEntries, err := os.ReadDir(collectionPath) + if err != nil { + return nil, fmt.Errorf("couldn't read collection directory: %w", err) + } + for _, collectionDirEntry := range collectionDirEntries { + // Files should be metadata and documents; skip subdirectories which + // the user might have placed. + if collectionDirEntry.IsDir() { + continue } - if info.Name() == metadataFileName+".gob" { + fPath := filepath.Join(collectionPath, collectionDirEntry.Name()) + // Differentiate between collection metadata, documents and other files. + if collectionDirEntry.Name() == metadataFileName+".gob" { + // Read name and metadata pc := struct { Name string Metadata map[string]string }{} - err := read(p, &pc) + err := read(fPath, &pc) if err != nil { - return fmt.Errorf("couldn't read collection metadata: %w", err) + return nil, fmt.Errorf("couldn't read collection metadata: %w", err) } c.Name = pc.Name - c.persistDirectory = filepath.Dir(p) + c.persistDirectory = filepath.Dir(collectionPath) c.metadata = pc.Metadata - } else { + } else if filepath.Ext(collectionDirEntry.Name()) == ".gob" { // Read document d := &document{} - err := read(p, d) + err := read(fPath, d) if err != nil { - return fmt.Errorf("couldn't read document: %w", err) + return nil, fmt.Errorf("couldn't read document: %w", err) } c.documents[d.ID] = d + } else { + // Might be a file that the user has placed + continue } - - return nil - }) - if err != nil { - return fmt.Errorf("couldn't read collection directory: %w", err) } db.collections[c.Name] = c - - return nil - }) - if err != nil { - return nil, fmt.Errorf("couldn't read persisted database: %w", err) } return db, nil From 407db963f19b52e5ec2fb457d7a400913a982963 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Sun, 25 Feb 2024 20:06:36 +0100 Subject: [PATCH 4/5] Only instantiate collection object when necessary --- db.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/db.go b/db.go index 855806f..ad4103a 100644 --- a/db.go +++ b/db.go @@ -76,6 +76,11 @@ func NewPersistentDB(path string) (*DB, error) { // and documents. // TODO: Parallelize this (e.g. chan with $numCPU buffer and $numCPU goroutines // reading from it). + collectionPath := filepath.Join(path, dirEntry.Name()) + collectionDirEntries, err := os.ReadDir(collectionPath) + if err != nil { + return nil, fmt.Errorf("couldn't read collection directory: %w", err) + } c := &Collection{ // We can fill Name, persistDirectory and metadata only after reading // the metadata. @@ -83,11 +88,6 @@ func NewPersistentDB(path string) (*DB, error) { // We can fill embed only when the user calls DB.GetCollection() or // DB.GetOrCreateCollection(). } - collectionPath := filepath.Join(path, dirEntry.Name()) - collectionDirEntries, err := os.ReadDir(collectionPath) - if err != nil { - return nil, fmt.Errorf("couldn't read collection directory: %w", err) - } for _, collectionDirEntry := range collectionDirEntries { // Files should be metadata and documents; skip subdirectories which // the user might have placed. From 727a5f1b14edcba456e20f1a13b8745795280ae8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Sun, 25 Feb 2024 20:30:23 +0100 Subject: [PATCH 5/5] Add EmbeddingFunc param to DB.GetCollection() Required for when the DB was just loaded from persistant storage, as funcs can't be (de-)serialized. --- db.go | 24 ++++++++++++++++++++---- db_test.go | 2 +- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/db.go b/db.go index ad4103a..a172d7f 100644 --- a/db.go +++ b/db.go @@ -170,14 +170,30 @@ func (db *DB) ListCollections() map[string]*Collection { } // GetCollection returns the collection with the given name. -// The returned value is a reference to the original collection, so any methods +// The embeddingFunc param is only used if the DB is persistent and was just loaded +// from storage, in which case no embedding func is set yet (funcs are not (de-)serializable). +// It can be nil, in which case the default one will be used. +// The returned collection is a reference to the original collection, so any methods // on the collection like Add() will be reflected on the DB's collection. Those // operations are concurrency-safe. // If the collection doesn't exist, this returns nil. -func (db *DB) GetCollection(name string) *Collection { +func (db *DB) GetCollection(name string, embeddingFunc EmbeddingFunc) *Collection { db.collectionsLock.RLock() defer db.collectionsLock.RUnlock() - return db.collections[name] + + c, ok := db.collections[name] + if !ok { + return nil + } + + if c.embed == nil { + if embeddingFunc == nil { + c.embed = NewEmbeddingFuncDefault() + } else { + c.embed = embeddingFunc + } + } + return c } // GetOrCreateCollection returns the collection with the given name if it exists @@ -189,7 +205,7 @@ func (db *DB) GetCollection(name string) *Collection { // Uses the default embedding function if not provided. func (db *DB) GetOrCreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc) (*Collection, error) { // No need to lock here, because the methods we call do that. - collection := db.GetCollection(name) + collection := db.GetCollection(name, embeddingFunc) if collection == nil { var err error collection, err = db.CreateCollection(name, metadata, embeddingFunc) diff --git a/db_test.go b/db_test.go index d441ba5..449131a 100644 --- a/db_test.go +++ b/db_test.go @@ -96,7 +96,7 @@ func TestDB_GetCollection(t *testing.T) { } // Get collection - c := db.GetCollection(name) + c := db.GetCollection(name, nil) // Check expectations if c.Name != name {