Skip to content

Commit

Permalink
Merge pull request #51 from philippgille/fix-param-validation
Browse files Browse the repository at this point in the history
Fix param validation
  • Loading branch information
philippgille authored Mar 17, 2024
2 parents efb6890 + 98516b1 commit 9b47246
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 2 deletions.
4 changes: 2 additions & 2 deletions collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,8 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int,
}
c.documentsLock.RLock()
defer c.documentsLock.RUnlock()
if nResults < len(c.documents) {
return nil, errors.New("nResults must be greater than the number of documents in the collection")
if nResults > len(c.documents) {
return nil, errors.New("nResults must be <= the number of documents in the collection")
}

if len(c.documents) == 0 {
Expand Down
79 changes: 79 additions & 0 deletions collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,85 @@ func TestCollection_AddConcurrently_Error(t *testing.T) {
}
}

func TestCollection_QueryError(t *testing.T) {
// Create collection
db := NewDB()
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}
c, err := db.CreateCollection(name, metadata, embeddingFunc)
if err != nil {
t.Fatal("expected no error, got", err)
}
if c == nil {
t.Fatal("expected collection, got nil")
}
// Add a document
err = c.AddDocument(context.Background(), Document{ID: "1", Content: "hello world"})
if err != nil {
t.Fatal("expected nil, got", err)
}

tt := []struct {
name string
query func() error
expErr string
}{
{
name: "Empty query",
query: func() error {
_, err := c.Query(context.Background(), "", 1, nil, nil)
return err
},
expErr: "queryText is empty",
},
{
name: "Negative limit",
query: func() error {
_, err := c.Query(context.Background(), "foo", -1, nil, nil)
return err
},
expErr: "nResults must be > 0",
},
{
name: "Zero limit",
query: func() error {
_, err := c.Query(context.Background(), "foo", 0, nil, nil)
return err
},
expErr: "nResults must be > 0",
},
{
name: "Limit greater than number of documents",
query: func() error {
_, err := c.Query(context.Background(), "foo", 2, nil, nil)
return err
},
expErr: "nResults must be <= the number of documents in the collection",
},
{
name: "Bad content filter",
query: func() error {
_, err := c.Query(context.Background(), "foo", 1, nil, map[string]string{"invalid": "foo"})
return err
},
expErr: "unsupported operator",
},
}

for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
err := tc.query()
if err.Error() != tc.expErr {
t.Fatal("expected", tc.expErr, "got", err)
}
})
}
}

func TestCollection_Count(t *testing.T) {
// Create collection
db := NewDB()
Expand Down

0 comments on commit 9b47246

Please sign in to comment.