Skip to content

Commit

Permalink
Add tests for Collection.GetByID()
Browse files Browse the repository at this point in the history
  • Loading branch information
philippgille committed Sep 1, 2024
1 parent 88dad4f commit 366cbca
Showing 1 changed file with 54 additions and 0 deletions.
54 changes: 54 additions & 0 deletions collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,60 @@ func TestCollection_QueryError(t *testing.T) {
}
}

func TestCollection_Get(t *testing.T) {
ctx := context.Background()

// Create collection
db := NewDB()
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}
c, err := db.CreateCollection(name, metadata, embeddingFunc)
if err != nil {
t.Fatal("expected no error, got", err)
}
if c == nil {
t.Fatal("expected collection, got nil")
}

// Add documents
ids := []string{"1", "2"}
metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}}
contents := []string{"hello world", "hallo welt"}
err = c.Add(context.Background(), ids, nil, metadatas, contents)
if err != nil {
t.Fatal("expected nil, got", err)
}

// Get by ID
doc, err := c.GetByID(ctx, ids[0])
if err != nil {
t.Fatal("expected nil, got", err)
}
// Check fields
if doc.ID != ids[0] {
t.Fatal("expected", ids[0], "got", doc.ID)
}
if len(doc.Metadata) != 1 {
t.Fatal("expected 1, got", len(doc.Metadata))
}
if !slices.Equal(doc.Embedding, vectors) {
t.Fatal("expected", vectors, "got", doc.Embedding)
}
if doc.Content != contents[0] {
t.Fatal("expected", contents[0], "got", doc.Content)
}

// Check error
_, err = c.GetByID(ctx, "3")
if err == nil {
t.Fatal("expected error, got nil")
}
}

func TestCollection_Count(t *testing.T) {
// Create collection
db := NewDB()
Expand Down

0 comments on commit 366cbca

Please sign in to comment.