Skip to content

Commit

Permalink
Add EmbeddingFunc param to DB.GetCollection()
Browse files Browse the repository at this point in the history
Required for when the DB was just loaded from persistant storage,
as funcs can't be (de-)serialized.
  • Loading branch information
philippgille committed Feb 25, 2024
1 parent 407db96 commit 727a5f1
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
24 changes: 20 additions & 4 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,14 +170,30 @@ func (db *DB) ListCollections() map[string]*Collection {
}

// GetCollection returns the collection with the given name.
// The returned value is a reference to the original collection, so any methods
// The embeddingFunc param is only used if the DB is persistent and was just loaded
// from storage, in which case no embedding func is set yet (funcs are not (de-)serializable).
// It can be nil, in which case the default one will be used.
// The returned collection is a reference to the original collection, so any methods
// on the collection like Add() will be reflected on the DB's collection. Those
// operations are concurrency-safe.
// If the collection doesn't exist, this returns nil.
func (db *DB) GetCollection(name string) *Collection {
func (db *DB) GetCollection(name string, embeddingFunc EmbeddingFunc) *Collection {
db.collectionsLock.RLock()
defer db.collectionsLock.RUnlock()
return db.collections[name]

c, ok := db.collections[name]
if !ok {
return nil
}

if c.embed == nil {
if embeddingFunc == nil {
c.embed = NewEmbeddingFuncDefault()
} else {
c.embed = embeddingFunc
}
}
return c
}

// GetOrCreateCollection returns the collection with the given name if it exists
Expand All @@ -189,7 +205,7 @@ func (db *DB) GetCollection(name string) *Collection {
// Uses the default embedding function if not provided.
func (db *DB) GetOrCreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc) (*Collection, error) {
// No need to lock here, because the methods we call do that.
collection := db.GetCollection(name)
collection := db.GetCollection(name, embeddingFunc)
if collection == nil {
var err error
collection, err = db.CreateCollection(name, metadata, embeddingFunc)
Expand Down
2 changes: 1 addition & 1 deletion db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func TestDB_GetCollection(t *testing.T) {
}

// Get collection
c := db.GetCollection(name)
c := db.GetCollection(name, nil)

// Check expectations
if c.Name != name {
Expand Down

0 comments on commit 727a5f1

Please sign in to comment.