diff --git a/db.go b/db.go index 1b52ef7..2d949d7 100644 --- a/db.go +++ b/db.go @@ -75,6 +75,22 @@ func (c *DB) GetCollection(name string) *Collection { return c.collections[name] } +// GetOrCreateCollection returns the collection with the given name if it exists +// in the DB, or otherwise creates it. When creating: +// +// - name: The name of the collection to create. +// - metadata: Optional metadata to associate with the collection. +// - embeddingFunc: Optional function to use to embed documents. +// Uses the default embedding function if not provided. +func (c *DB) GetOrCreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc) *Collection { + // No need to lock here, because the methods we call do that. + collection := c.GetCollection(name) + if collection == nil { + collection = c.CreateCollection(name, metadata, embeddingFunc) + } + return collection +} + // DeleteCollection deletes the collection with the given name. // If the collection doesn't exist, this is a no-op. func (c *DB) DeleteCollection(name string) { diff --git a/db_test.go b/db_test.go index 8e9269b..c4d9f1c 100644 --- a/db_test.go +++ b/db_test.go @@ -105,6 +105,55 @@ func TestDB_GetCollection(t *testing.T) { // TODO: Check documents map being a copy as soon as we have access to it } +func TestDB_GetOrCreateCollection(t *testing.T) { + // Values in the collection + name := "test" + metadata := map[string]string{"foo": "bar"} + embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { + return []float32{-0.1, 0.1, 0.2}, nil + } + + t.Run("Get", func(t *testing.T) { + // Create initial collection + db := chromem.NewDB() + // Create collection so that the GetOrCreateCollection() call below only + // gets it. + // We ignore the return value. CreateCollection is tested elsewhere. + _ = db.CreateCollection(name, metadata, embeddingFunc) + + // Call GetOrCreateCollection() with the same name to only get it. We pass + // nil for the metadata and embeddingFunc so we can check that the returned + // collection is the original one, and not a new one. + c := db.GetOrCreateCollection(name, nil, nil) + if c == nil { + t.Error("expected collection, got nil") + } + + // Check expectations + if c.Name != name { + t.Error("expected name", name, "got", c.Name) + } + // TODO: Check metadata when it's accessible (e.g. with GetMetadata()) + }) + + t.Run("Create", func(t *testing.T) { + // Create initial collection + db := chromem.NewDB() + + // Call GetOrCreateCollection() + c := db.GetOrCreateCollection(name, metadata, embeddingFunc) + if c == nil { + t.Error("expected collection, got nil") + } + + // Check like we check CreateCollection() + if c.Name != name { + t.Error("expected name", name, "got", c.Name) + } + // TODO: Check metadata when it's accessible (e.g. with GetMetadata()) + }) +} + func TestDB_DeleteCollection(t *testing.T) { // Values in the collection name := "test"