diff --git a/collection_test.go b/collection_test.go index 807b4f6..fc37689 100644 --- a/collection_test.go +++ b/collection_test.go @@ -129,6 +129,7 @@ func TestCollection_Add_Error(t *testing.T) { embeddings := [][]float32{vectors, vectors} metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}} contents := []string{"hello world", "hallo welt"} + // Empty IDs err = c.Add(ctx, []string{}, embeddings, metadatas, contents) if err == nil { diff --git a/db_test.go b/db_test.go index fb25373..b5e1d57 100644 --- a/db_test.go +++ b/db_test.go @@ -2,6 +2,7 @@ package chromem import ( "context" + "reflect" "slices" "testing" ) @@ -37,29 +38,7 @@ func TestDB_CreateCollection(t *testing.T) { if !ok { t.Fatal("expected collection", name, "not found") } - if c2.Name != name { - t.Fatal("expected name", name, "got", c2.Name) - } - // The returned collection should also be the same - if c.Name != name { - t.Fatal("expected name", name, "got", c.Name) - } - // The collection's persistent dir should be empty - if c.persistDirectory != "" { - t.Fatal("expected empty persistent directory, got", c.persistDirectory) - } - // It's metadata should match - if len(c.metadata) != 1 || c.metadata["foo"] != "bar" { - t.Fatal("expected metadata", metadata, "got", c.metadata) - } - // Documents should be empty, but not nil - if c.documents == nil { - t.Fatal("expected non-nil documents, got nil") - } - if len(c.documents) != 0 { - t.Fatal("expected empty documents, got", len(c.documents)) - } - // The embedding function should be the one we passed + // Check the embedding function first, then the rest with DeepEqual gotVectors, err := c.embed(context.Background(), "test") if err != nil { t.Fatal("expected no error, got", err) @@ -67,6 +46,10 @@ func TestDB_CreateCollection(t *testing.T) { if !slices.Equal(gotVectors, vectors) { t.Fatal("expected vectors", vectors, "got", gotVectors) } + c.embed, c2.embed = nil, nil + if !reflect.DeepEqual(c, c2) { + t.Fatalf("expected collection %+v, got %+v", c, c2) + } }) t.Run("NOK - Empty name", func(t *testing.T) { @@ -88,8 +71,7 @@ func TestDB_ListCollections(t *testing.T) { // Create initial collection db := NewDB() - // We ignore the return value. CreateCollection is tested elsewhere. - _, err := db.CreateCollection(name, metadata, embeddingFunc) + orig, err := db.CreateCollection(name, metadata, embeddingFunc) if err != nil { t.Fatal("expected no error, got", err) } @@ -108,25 +90,7 @@ func TestDB_ListCollections(t *testing.T) { if !ok { t.Fatal("expected collection", name, "not found") } - if c.Name != name { - t.Fatal("expected name", name, "got", c.Name) - } - // The collection's persistent dir should be empty - if c.persistDirectory != "" { - t.Fatal("expected empty persistent directory, got", c.persistDirectory) - } - // It's metadata should match - if len(c.metadata) != 1 || c.metadata["foo"] != "bar" { - t.Fatal("expected metadata", metadata, "got", c.metadata) - } - // Documents should be empty, but not nil - if c.documents == nil { - t.Fatal("expected non-nil documents, got nil") - } - if len(c.documents) != 0 { - t.Fatal("expected empty documents, got", len(c.documents)) - } - // The embedding function should be the one we passed + // Check the embedding function first, then the rest with DeepEqual gotVectors, err := c.embed(context.Background(), "test") if err != nil { t.Fatal("expected no error, got", err) @@ -134,6 +98,10 @@ func TestDB_ListCollections(t *testing.T) { if !slices.Equal(gotVectors, vectors) { t.Fatal("expected vectors", vectors, "got", gotVectors) } + orig.embed, c.embed = nil, nil + if !reflect.DeepEqual(orig, c) { + t.Fatalf("expected collection %+v, got %+v", orig, c) + } // And it should be a copy. Adding a value here should not reflect on the DB's // collection. @@ -154,8 +122,7 @@ func TestDB_GetCollection(t *testing.T) { // Create initial collection db := NewDB() - // We ignore the return value. CreateCollection is tested elsewhere. - _, err := db.CreateCollection(name, metadata, embeddingFunc) + orig, err := db.CreateCollection(name, metadata, embeddingFunc) if err != nil { t.Fatal("expected no error, got", err) } @@ -163,26 +130,7 @@ func TestDB_GetCollection(t *testing.T) { // Get collection c := db.GetCollection(name, nil) - // Check expectations - if c.Name != name { - t.Fatal("expected name", name, "got", c.Name) - } - // The collection's persistent dir should be empty - if c.persistDirectory != "" { - t.Fatal("expected empty persistent directory, got", c.persistDirectory) - } - // It's metadata should match - if len(c.metadata) != 1 || c.metadata["foo"] != "bar" { - t.Fatal("expected metadata", metadata, "got", c.metadata) - } - // Documents should be empty, but not nil - if c.documents == nil { - t.Fatal("expected non-nil documents, got nil") - } - if len(c.documents) != 0 { - t.Fatal("expected empty documents, got", len(c.documents)) - } - // The embedding function should be the one we passed + // Check the embedding function first, then the rest with DeepEqual gotVectors, err := c.embed(context.Background(), "test") if err != nil { t.Fatal("expected no error, got", err) @@ -190,6 +138,10 @@ func TestDB_GetCollection(t *testing.T) { if !slices.Equal(gotVectors, vectors) { t.Fatal("expected vectors", vectors, "got", gotVectors) } + orig.embed, c.embed = nil, nil + if !reflect.DeepEqual(orig, c) { + t.Fatalf("expected collection %+v, got %+v", orig, c) + } } func TestDB_GetOrCreateCollection(t *testing.T) { @@ -206,8 +158,7 @@ func TestDB_GetOrCreateCollection(t *testing.T) { db := NewDB() // Create collection so that the GetOrCreateCollection() call below only // gets it. - // We ignore the return value. CreateCollection is tested elsewhere. - _, err := db.CreateCollection(name, metadata, embeddingFunc) + orig, err := db.CreateCollection(name, metadata, embeddingFunc) if err != nil { t.Fatal("expected no error, got", err) } @@ -223,26 +174,7 @@ func TestDB_GetOrCreateCollection(t *testing.T) { t.Fatal("expected collection, got nil") } - // Check expectations - if c.Name != name { - t.Fatal("expected name", name, "got", c.Name) - } - // The collection's persistent dir should be empty - if c.persistDirectory != "" { - t.Fatal("expected empty persistent directory, got", c.persistDirectory) - } - // It's metadata should match - if len(c.metadata) != 1 || c.metadata["foo"] != "bar" { - t.Fatal("expected metadata", metadata, "got", c.metadata) - } - // Documents should be empty, but not nil - if c.documents == nil { - t.Fatal("expected non-nil documents, got nil") - } - if len(c.documents) != 0 { - t.Fatal("expected empty documents, got", len(c.documents)) - } - // The embedding function should be the one we passed + // Check the embedding function first, then the rest with DeepEqual gotVectors, err := c.embed(context.Background(), "test") if err != nil { t.Fatal("expected no error, got", err) @@ -250,6 +182,10 @@ func TestDB_GetOrCreateCollection(t *testing.T) { if !slices.Equal(gotVectors, vectors) { t.Fatal("expected vectors", vectors, "got", gotVectors) } + orig.embed, c.embed = nil, nil + if !reflect.DeepEqual(orig, c) { + t.Fatalf("expected collection %+v, got %+v", orig, c) + } }) t.Run("Create", func(t *testing.T) { @@ -266,25 +202,10 @@ func TestDB_GetOrCreateCollection(t *testing.T) { } // Check like we check CreateCollection() - if c.Name != name { - t.Fatal("expected name", name, "got", c.Name) - } - // The collection's persistent dir should be empty - if c.persistDirectory != "" { - t.Fatal("expected empty persistent directory, got", c.persistDirectory) - } - // It's metadata should match - if len(c.metadata) != 1 || c.metadata["foo"] != "bar" { - t.Fatal("expected metadata", metadata, "got", c.metadata) - } - // Documents should be empty, but not nil - if c.documents == nil { - t.Fatal("expected non-nil documents, got nil") - } - if len(c.documents) != 0 { - t.Fatal("expected empty documents, got", len(c.documents)) + c2, ok := db.collections[name] + if !ok { + t.Fatal("expected collection", name, "not found") } - // The embedding function should be the one we passed gotVectors, err := c.embed(context.Background(), "test") if err != nil { t.Fatal("expected no error, got", err) @@ -292,6 +213,10 @@ func TestDB_GetOrCreateCollection(t *testing.T) { if !slices.Equal(gotVectors, vectors) { t.Fatal("expected vectors", vectors, "got", gotVectors) } + c.embed, c2.embed = nil, nil + if !reflect.DeepEqual(c, c2) { + t.Fatalf("expected collection %+v, got %+v", c, c2) + } }) } diff --git a/document_test.go b/document_test.go index 4290c5a..5412bd6 100644 --- a/document_test.go +++ b/document_test.go @@ -2,7 +2,7 @@ package chromem import ( "context" - "slices" + "reflect" "testing" ) @@ -49,17 +49,15 @@ func TestDocument_New(t *testing.T) { if err != nil { t.Fatal("expected no error, got", err) } - if d.ID != id { - t.Fatal("expected id", id, "got", d.ID) + // We can compare with DeepEqual after removing the embedding function + d.Embedding = nil + exp := Document{ + ID: id, + Metadata: metadata, + Content: content, } - if d.Metadata["foo"] != metadata["foo"] { - t.Fatal("expected metadata", metadata, "got", d.Metadata) - } - if !slices.Equal(d.Embedding, vectors) { - t.Fatal("expected vectors", vectors, "got", d.Embedding) - } - if d.Content != content { - t.Fatal("expected content", content, "got", d.Content) + if !reflect.DeepEqual(exp, d) { + t.Fatalf("expected %+v, got %+v", exp, d) } }) } diff --git a/persistence_test.go b/persistence_test.go index c5799af..06d0cdd 100644 --- a/persistence_test.go +++ b/persistence_test.go @@ -4,7 +4,7 @@ import ( "bytes" "encoding/gob" "os" - "slices" + "reflect" "testing" ) @@ -44,10 +44,7 @@ func TestPersistence(t *testing.T) { if err != nil { t.Fatal("expected nil, got", err) } - if res.Foo != obj.Foo { - t.Fatal("expected", obj.Foo, "got", res.Foo) - } - if slices.Compare[[]float32](res.Bar, obj.Bar) != 0 { - t.Fatal("expected", obj.Bar, "got", res.Bar) + if !reflect.DeepEqual(obj, res) { + t.Fatalf("expected %+v, got %+v", obj, res) } }