Skip to content

Commit

Permalink
Merge pull request #22 from philippgille/add-get-or-create-collection
Browse files Browse the repository at this point in the history
Add DB.GetOrCreateCollection()
  • Loading branch information
philippgille authored Feb 18, 2024
2 parents cf08535 + d1415e2 commit ad49c6e
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 0 deletions.
16 changes: 16 additions & 0 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,22 @@ func (c *DB) GetCollection(name string) *Collection {
return c.collections[name]
}

// GetOrCreateCollection returns the collection with the given name if it exists
// in the DB, or otherwise creates it. When creating:
//
// - name: The name of the collection to create.
// - metadata: Optional metadata to associate with the collection.
// - embeddingFunc: Optional function to use to embed documents.
// Uses the default embedding function if not provided.
func (c *DB) GetOrCreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc) *Collection {
// No need to lock here, because the methods we call do that.
collection := c.GetCollection(name)
if collection == nil {
collection = c.CreateCollection(name, metadata, embeddingFunc)
}
return collection
}

// DeleteCollection deletes the collection with the given name.
// If the collection doesn't exist, this is a no-op.
func (c *DB) DeleteCollection(name string) {
Expand Down
49 changes: 49 additions & 0 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,55 @@ func TestDB_GetCollection(t *testing.T) {
// TODO: Check documents map being a copy as soon as we have access to it
}

func TestDB_GetOrCreateCollection(t *testing.T) {
// Values in the collection
name := "test"
metadata := map[string]string{"foo": "bar"}
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return []float32{-0.1, 0.1, 0.2}, nil
}

t.Run("Get", func(t *testing.T) {
// Create initial collection
db := chromem.NewDB()
// Create collection so that the GetOrCreateCollection() call below only
// gets it.
// We ignore the return value. CreateCollection is tested elsewhere.
_ = db.CreateCollection(name, metadata, embeddingFunc)

// Call GetOrCreateCollection() with the same name to only get it. We pass
// nil for the metadata and embeddingFunc so we can check that the returned
// collection is the original one, and not a new one.
c := db.GetOrCreateCollection(name, nil, nil)
if c == nil {
t.Error("expected collection, got nil")
}

// Check expectations
if c.Name != name {
t.Error("expected name", name, "got", c.Name)
}
// TODO: Check metadata when it's accessible (e.g. with GetMetadata())
})

t.Run("Create", func(t *testing.T) {
// Create initial collection
db := chromem.NewDB()

// Call GetOrCreateCollection()
c := db.GetOrCreateCollection(name, metadata, embeddingFunc)
if c == nil {
t.Error("expected collection, got nil")
}

// Check like we check CreateCollection()
if c.Name != name {
t.Error("expected name", name, "got", c.Name)
}
// TODO: Check metadata when it's accessible (e.g. with GetMetadata())
})
}

func TestDB_DeleteCollection(t *testing.T) {
// Values in the collection
name := "test"
Expand Down

0 comments on commit ad49c6e

Please sign in to comment.