From dd313ebc99e4fefa43af87cfad12a9024977e764 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Sat, 2 Mar 2024 13:11:49 +0100 Subject: [PATCH] Add Collection.Count() method --- collection.go | 7 +++++++ collection_test.go | 31 +++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/collection.go b/collection.go index 1cd7bdc..bffa921 100644 --- a/collection.go +++ b/collection.go @@ -93,6 +93,13 @@ func (c *Collection) AddConcurrently(ctx context.Context, ids []string, embeddin return c.add(ctx, ids, documents, embeddings, metadatas, concurrency) } +// Count returns the number of documents in the collection. +func (c *Collection) Count() int { + c.documentsLock.RLock() + defer c.documentsLock.RUnlock() + return len(c.documents) +} + // Performs a nearest neighbors query on a collection specified by UUID. // // - queryText: The text to search for. diff --git a/collection_test.go b/collection_test.go index 6fc2a8e..472e119 100644 --- a/collection_test.go +++ b/collection_test.go @@ -34,3 +34,34 @@ func TestCollection_Add(t *testing.T) { // TODO: Check expectations when documents become accessible } + +func TestCollection_Count(t *testing.T) { + // Create collection + db := chromem.NewDB() + name := "test" + metadata := map[string]string{"foo": "bar"} + embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { + return []float32{-0.1, 0.1, 0.2}, nil + } + c, err := db.CreateCollection(name, metadata, embeddingFunc) + if err != nil { + t.Error("expected no error, got", err) + } + if c == nil { + t.Error("expected collection, got nil") + } + + // Add documents + ids := []string{"1", "2"} + metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}} + documents := []string{"hello world", "hallo welt"} + err = c.Add(context.Background(), ids, nil, metadatas, documents) + if err != nil { + t.Error("expected nil, got", err) + } + + // Check count + if c.Count() != 2 { + t.Error("expected 2, got", c.Count()) + } +}