diff --git a/collection.go b/collection.go index bffa921..4e45f2a 100644 --- a/collection.go +++ b/collection.go @@ -107,16 +107,19 @@ func (c *Collection) Count() int { // - where: Conditional filtering on metadata. Optional. // - whereDocument: Conditional filtering on documents. Optional. func (c *Collection) Query(ctx context.Context, queryText string, nResults int, where, whereDocument map[string]string) ([]Result, error) { + if queryText == "" { + return nil, errors.New("queryText is empty") + } + if nResults <= 0 { + return nil, errors.New("nResults must be > 0") + } + c.documentsLock.RLock() defer c.documentsLock.RUnlock() if len(c.documents) == 0 { return nil, nil } - if nResults <= 0 { - return nil, errors.New("nResults must be > 0") - } - // Validate whereDocument operators for k := range whereDocument { if !slices.Contains(supportedFilters, k) { diff --git a/db.go b/db.go index 95cf337..5cdf0d9 100644 --- a/db.go +++ b/db.go @@ -138,6 +138,9 @@ func NewPersistentDB(path string) (*DB, error) { // - embeddingFunc: Optional function to use to embed documents. // Uses the default embedding function if not provided. func (db *DB) CreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc) (*Collection, error) { + if name == "" { + return nil, errors.New("collection name is empty") + } if embeddingFunc == nil { embeddingFunc = NewEmbeddingFuncDefault() } diff --git a/db_test.go b/db_test.go index 449131a..da74bad 100644 --- a/db_test.go +++ b/db_test.go @@ -15,21 +15,30 @@ func TestDB_CreateCollection(t *testing.T) { return []float32{-0.1, 0.1, 0.2}, nil } - // Create collection db := chromem.NewDB() - c, err := db.CreateCollection(name, metadata, embeddingFunc) - if err != nil { - t.Error("expected no error, got", err) - } - if c == nil { - t.Error("expected collection, got nil") - } - // Check expectations - if c.Name != name { - t.Error("expected name", name, "got", c.Name) - } - // TODO: Check metadata etc when they become accessible + t.Run("OK", func(t *testing.T) { + c, err := db.CreateCollection(name, metadata, embeddingFunc) + if err != nil { + t.Error("expected no error, got", err) + } + if c == nil { + t.Error("expected collection, got nil") + } + + // Check expectations + if c.Name != name { + t.Error("expected name", name, "got", c.Name) + } + // TODO: Check metadata etc when they become accessible + }) + + t.Run("NOK - Empty name", func(t *testing.T) { + _, err := db.CreateCollection("", metadata, embeddingFunc) + if err == nil { + t.Error("expected error, got nil") + } + }) } func TestDB_ListCollections(t *testing.T) {