Skip to content

Commit

Permalink
Merge pull request #25 from philippgille/persistence-read
Browse files Browse the repository at this point in the history
Add persistence (read)
  • Loading branch information
philippgille authored Feb 25, 2024
2 parents f558425 + 727a5f1 commit fb2b325
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 13 deletions.
113 changes: 102 additions & 11 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"os"
"path/filepath"
"sync"
)

Expand Down Expand Up @@ -40,18 +41,92 @@ func NewDB() *DB {
func NewPersistentDB(path string) (*DB, error) {
if path == "" {
path = "./chromem-go"
} else {
// Clean in case the user provides something like "./db/../db"
path = filepath.Clean(path)
}

// Make directory if it doesn't exist.
err := os.MkdirAll(path, 0o700)
db := &DB{
persistDirectory: path,
collections: make(map[string]*Collection),
}

// If the directory doesn't exist, create it and return an empty DB.
if _, err := os.Stat(path); os.IsNotExist(err) {
err := os.MkdirAll(path, 0o700)
if err != nil {
return nil, fmt.Errorf("couldn't create persistence directory: %w", err)
}

return db, nil
}

// Otherwise, read all collections and their documents from the directory.
dirEntries, err := os.ReadDir(path)
if err != nil {
return nil, fmt.Errorf("couldn't create persistence directory: %w", err)
return nil, fmt.Errorf("couldn't read persistence directory: %w", err)
}
for _, dirEntry := range dirEntries {
// Collections are subdirectories, so skip any files (which the user might
// have placed).
if !dirEntry.IsDir() {
continue
}
// For each subdirectory, create a collection and read its name, metadata
// and documents.
// TODO: Parallelize this (e.g. chan with $numCPU buffer and $numCPU goroutines
// reading from it).
collectionPath := filepath.Join(path, dirEntry.Name())
collectionDirEntries, err := os.ReadDir(collectionPath)
if err != nil {
return nil, fmt.Errorf("couldn't read collection directory: %w", err)
}
c := &Collection{
// We can fill Name, persistDirectory and metadata only after reading
// the metadata.
documents: make(map[string]*document),
// We can fill embed only when the user calls DB.GetCollection() or
// DB.GetOrCreateCollection().
}
for _, collectionDirEntry := range collectionDirEntries {
// Files should be metadata and documents; skip subdirectories which
// the user might have placed.
if collectionDirEntry.IsDir() {
continue
}

fPath := filepath.Join(collectionPath, collectionDirEntry.Name())
// Differentiate between collection metadata, documents and other files.
if collectionDirEntry.Name() == metadataFileName+".gob" {
// Read name and metadata
pc := struct {
Name string
Metadata map[string]string
}{}
err := read(fPath, &pc)
if err != nil {
return nil, fmt.Errorf("couldn't read collection metadata: %w", err)
}
c.Name = pc.Name
c.persistDirectory = filepath.Dir(collectionPath)
c.metadata = pc.Metadata
} else if filepath.Ext(collectionDirEntry.Name()) == ".gob" {
// Read document
d := &document{}
err := read(fPath, d)
if err != nil {
return nil, fmt.Errorf("couldn't read document: %w", err)
}
c.documents[d.ID] = d
} else {
// Might be a file that the user has placed
continue
}
}
db.collections[c.Name] = c
}

return &DB{
persistDirectory: path,
collections: make(map[string]*Collection),
}, nil
return db, nil
}

// CreateCollection creates a new collection with the given name and metadata.
Expand Down Expand Up @@ -95,14 +170,30 @@ func (db *DB) ListCollections() map[string]*Collection {
}

// GetCollection returns the collection with the given name.
// The returned value is a reference to the original collection, so any methods
// The embeddingFunc param is only used if the DB is persistent and was just loaded
// from storage, in which case no embedding func is set yet (funcs are not (de-)serializable).
// It can be nil, in which case the default one will be used.
// The returned collection is a reference to the original collection, so any methods
// on the collection like Add() will be reflected on the DB's collection. Those
// operations are concurrency-safe.
// If the collection doesn't exist, this returns nil.
func (db *DB) GetCollection(name string) *Collection {
func (db *DB) GetCollection(name string, embeddingFunc EmbeddingFunc) *Collection {
db.collectionsLock.RLock()
defer db.collectionsLock.RUnlock()
return db.collections[name]

c, ok := db.collections[name]
if !ok {
return nil
}

if c.embed == nil {
if embeddingFunc == nil {
c.embed = NewEmbeddingFuncDefault()
} else {
c.embed = embeddingFunc
}
}
return c
}

// GetOrCreateCollection returns the collection with the given name if it exists
Expand All @@ -114,7 +205,7 @@ func (db *DB) GetCollection(name string) *Collection {
// Uses the default embedding function if not provided.
func (db *DB) GetOrCreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc) (*Collection, error) {
// No need to lock here, because the methods we call do that.
collection := db.GetCollection(name)
collection := db.GetCollection(name, embeddingFunc)
if collection == nil {
var err error
collection, err = db.CreateCollection(name, metadata, embeddingFunc)
Expand Down
2 changes: 1 addition & 1 deletion db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func TestDB_GetCollection(t *testing.T) {
}

// Get collection
c := db.GetCollection(name)
c := db.GetCollection(name, nil)

// Check expectations
if c.Name != name {
Expand Down
21 changes: 20 additions & 1 deletion persistence.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ func hash2hex(name string) string {
return hex.EncodeToString(hash[:4])
}

// persist persists an object to a file at the given path.
// persist persists an object to a file at the given path. The object is serialized
// as gob.
func persist(filePath string, obj any) error {
filePath += ".gob"

Expand All @@ -36,3 +37,21 @@ func persist(filePath string, obj any) error {

return nil
}

// read reads an object from a file at the given path. The object is deserialized
// from gob. `obj` must be a pointer to an instantiated object.
func read(filePath string, obj any) error {
f, err := os.Open(filePath)
if err != nil {
return fmt.Errorf("couldn't open file '%s': %w", filePath, err)
}
defer f.Close()

dec := gob.NewDecoder(f)
err = dec.Decode(obj)
if err != nil {
return fmt.Errorf("couldn't decode or read object: %w", err)
}

return nil
}

0 comments on commit fb2b325

Please sign in to comment.