From a8b7e80e6aa839d425fa5e1a2e0539b5eeef5948 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Sun, 17 Mar 2024 12:47:08 +0100 Subject: [PATCH 1/2] Fix nResults validation --- collection.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/collection.go b/collection.go index 732cb37..92454ff 100644 --- a/collection.go +++ b/collection.go @@ -279,8 +279,8 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int, } c.documentsLock.RLock() defer c.documentsLock.RUnlock() - if nResults < len(c.documents) { - return nil, errors.New("nResults must be greater than the number of documents in the collection") + if nResults > len(c.documents) { + return nil, errors.New("nResults must be <= the number of documents in the collection") } if len(c.documents) == 0 { From 98516b1a6818edd05aa8967ee14d2b4ff3d2bfc4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Sun, 17 Mar 2024 13:24:17 +0100 Subject: [PATCH 2/2] Add unit test for query errors --- collection_test.go | 79 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/collection_test.go b/collection_test.go index 274b271..89ffd74 100644 --- a/collection_test.go +++ b/collection_test.go @@ -308,6 +308,85 @@ func TestCollection_AddConcurrently_Error(t *testing.T) { } } +func TestCollection_QueryError(t *testing.T) { + // Create collection + db := NewDB() + name := "test" + metadata := map[string]string{"foo": "bar"} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` + embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { + return vectors, nil + } + c, err := db.CreateCollection(name, metadata, embeddingFunc) + if err != nil { + t.Fatal("expected no error, got", err) + } + if c == nil { + t.Fatal("expected collection, got nil") + } + // Add a document + err = c.AddDocument(context.Background(), Document{ID: "1", Content: "hello world"}) + if err != nil { + t.Fatal("expected nil, got", err) + } + + tt := []struct { + name string + query func() error + expErr string + }{ + { + name: "Empty query", + query: func() error { + _, err := c.Query(context.Background(), "", 1, nil, nil) + return err + }, + expErr: "queryText is empty", + }, + { + name: "Negative limit", + query: func() error { + _, err := c.Query(context.Background(), "foo", -1, nil, nil) + return err + }, + expErr: "nResults must be > 0", + }, + { + name: "Zero limit", + query: func() error { + _, err := c.Query(context.Background(), "foo", 0, nil, nil) + return err + }, + expErr: "nResults must be > 0", + }, + { + name: "Limit greater than number of documents", + query: func() error { + _, err := c.Query(context.Background(), "foo", 2, nil, nil) + return err + }, + expErr: "nResults must be <= the number of documents in the collection", + }, + { + name: "Bad content filter", + query: func() error { + _, err := c.Query(context.Background(), "foo", 1, nil, map[string]string{"invalid": "foo"}) + return err + }, + expErr: "unsupported operator", + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + err := tc.query() + if err.Error() != tc.expErr { + t.Fatal("expected", tc.expErr, "got", err) + } + }) + } +} + func TestCollection_Count(t *testing.T) { // Create collection db := NewDB()