From 727a5f1b14edcba456e20f1a13b8745795280ae8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Sun, 25 Feb 2024 20:30:23 +0100 Subject: [PATCH] Add EmbeddingFunc param to DB.GetCollection() Required for when the DB was just loaded from persistant storage, as funcs can't be (de-)serialized. --- db.go | 24 ++++++++++++++++++++---- db_test.go | 2 +- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/db.go b/db.go index ad4103a..a172d7f 100644 --- a/db.go +++ b/db.go @@ -170,14 +170,30 @@ func (db *DB) ListCollections() map[string]*Collection { } // GetCollection returns the collection with the given name. -// The returned value is a reference to the original collection, so any methods +// The embeddingFunc param is only used if the DB is persistent and was just loaded +// from storage, in which case no embedding func is set yet (funcs are not (de-)serializable). +// It can be nil, in which case the default one will be used. +// The returned collection is a reference to the original collection, so any methods // on the collection like Add() will be reflected on the DB's collection. Those // operations are concurrency-safe. // If the collection doesn't exist, this returns nil. -func (db *DB) GetCollection(name string) *Collection { +func (db *DB) GetCollection(name string, embeddingFunc EmbeddingFunc) *Collection { db.collectionsLock.RLock() defer db.collectionsLock.RUnlock() - return db.collections[name] + + c, ok := db.collections[name] + if !ok { + return nil + } + + if c.embed == nil { + if embeddingFunc == nil { + c.embed = NewEmbeddingFuncDefault() + } else { + c.embed = embeddingFunc + } + } + return c } // GetOrCreateCollection returns the collection with the given name if it exists @@ -189,7 +205,7 @@ func (db *DB) GetCollection(name string) *Collection { // Uses the default embedding function if not provided. func (db *DB) GetOrCreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc) (*Collection, error) { // No need to lock here, because the methods we call do that. - collection := db.GetCollection(name) + collection := db.GetCollection(name, embeddingFunc) if collection == nil { var err error collection, err = db.CreateCollection(name, metadata, embeddingFunc) diff --git a/db_test.go b/db_test.go index d441ba5..449131a 100644 --- a/db_test.go +++ b/db_test.go @@ -96,7 +96,7 @@ func TestDB_GetCollection(t *testing.T) { } // Get collection - c := db.GetCollection(name) + c := db.GetCollection(name, nil) // Check expectations if c.Name != name {