From e162b623cb3b903f7835a4cc471243289033e6dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Tue, 5 Mar 2024 20:10:09 +0100 Subject: [PATCH 1/5] Extend DB tests as whitebox --- db_test.go | 206 ++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 165 insertions(+), 41 deletions(-) diff --git a/db_test.go b/db_test.go index da74bad..97134ce 100644 --- a/db_test.go +++ b/db_test.go @@ -1,21 +1,21 @@ -package chromem_test +package chromem import ( "context" + "slices" "testing" - - "github.com/philippgille/chromem-go" ) func TestDB_CreateCollection(t *testing.T) { // Values in the collection name := "test" metadata := map[string]string{"foo": "bar"} + vectors := []float32{-0.1, 0.1, 0.2} embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { - return []float32{-0.1, 0.1, 0.2}, nil + return vectors, nil } - db := chromem.NewDB() + db := NewDB() t.Run("OK", func(t *testing.T) { c, err := db.CreateCollection(name, metadata, embeddingFunc) @@ -23,14 +23,50 @@ func TestDB_CreateCollection(t *testing.T) { t.Error("expected no error, got", err) } if c == nil { - t.Error("expected collection, got nil") + t.Fatal("expected collection, got nil") } // Check expectations + + // DB should have one collection now + if len(db.collections) != 1 { + t.Error("expected 1 collection, got", len(db.collections)) + } + // The collection should be the one we just created + c2, ok := db.collections[name] + if !ok { + t.Error("expected collection", name, "not found") + } + if c2.Name != name { + t.Error("expected name", name, "got", c2.Name) + } + // The returned collection should also be the same if c.Name != name { t.Error("expected name", name, "got", c.Name) } - // TODO: Check metadata etc when they become accessible + // The collection's persistent dir should be empty + if c.persistDirectory != "" { + t.Error("expected empty persistent directory, got", c.persistDirectory) + } + // It's metadata should match + if len(c.metadata) != 1 || c.metadata["foo"] != "bar" { + t.Error("expected metadata", metadata, "got", c.metadata) + } + // Documents should be empty, but not nil + if c.documents == nil { + t.Error("expected non-nil documents, got nil") + } + if len(c.documents) != 0 { + t.Error("expected empty documents, got", len(c.documents)) + } + // The embedding function should be the one we passed + gotVectors, err := c.embed(context.Background(), "test") + if err != nil { + t.Error("expected no error, got", err) + } + if !slices.Equal(gotVectors, vectors) { + t.Error("expected vectors", vectors, "got", gotVectors) + } }) t.Run("NOK - Empty name", func(t *testing.T) { @@ -45,12 +81,13 @@ func TestDB_ListCollections(t *testing.T) { // Values in the collection name := "test" metadata := map[string]string{"foo": "bar"} + vectors := []float32{-0.1, 0.1, 0.2} embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { - return []float32{-0.1, 0.1, 0.2}, nil + return vectors, nil } // Create initial collection - db := chromem.NewDB() + db := NewDB() // We ignore the return value. CreateCollection is tested elsewhere. _, err := db.CreateCollection(name, metadata, embeddingFunc) if err != nil { @@ -61,9 +98,12 @@ func TestDB_ListCollections(t *testing.T) { res := db.ListCollections() // Check expectations + + // Should've returned a map with one collection if len(res) != 1 { t.Error("expected 1 collection, got", len(res)) } + // The collection should be the one we just created c, ok := res[name] if !ok { t.Error("expected collection", name, "not found") @@ -71,18 +111,33 @@ func TestDB_ListCollections(t *testing.T) { if c.Name != name { t.Error("expected name", name, "got", c.Name) } - // TODO: Check metadata when it's accessible (e.g. with GetMetadata()) - // if len(c.Metadata) != 1 { - // t.Error("expected 1 metadata, got", len(c.Metadata)) - // } - // if c.Metadata["foo"] != "bar" { - // t.Error("expected metadata", metadata, "got", c.Metadata) - // } - // TODO: Same for documents and EmbeddingFunc + // The collection's persistent dir should be empty + if c.persistDirectory != "" { + t.Error("expected empty persistent directory, got", c.persistDirectory) + } + // It's metadata should match + if len(c.metadata) != 1 || c.metadata["foo"] != "bar" { + t.Error("expected metadata", metadata, "got", c.metadata) + } + // Documents should be empty, but not nil + if c.documents == nil { + t.Error("expected non-nil documents, got nil") + } + if len(c.documents) != 0 { + t.Error("expected empty documents, got", len(c.documents)) + } + // The embedding function should be the one we passed + gotVectors, err := c.embed(context.Background(), "test") + if err != nil { + t.Error("expected no error, got", err) + } + if !slices.Equal(gotVectors, vectors) { + t.Error("expected vectors", vectors, "got", gotVectors) + } // And it should be a copy. Adding a value here should not reflect on the DB's // collection. - res["foo"] = &chromem.Collection{} + res["foo"] = &Collection{} if len(db.ListCollections()) != 1 { t.Error("expected 1 collection, got", len(db.ListCollections())) } @@ -92,12 +147,13 @@ func TestDB_GetCollection(t *testing.T) { // Values in the collection name := "test" metadata := map[string]string{"foo": "bar"} + vectors := []float32{-0.1, 0.1, 0.2} embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { - return []float32{-0.1, 0.1, 0.2}, nil + return vectors, nil } // Create initial collection - db := chromem.NewDB() + db := NewDB() // We ignore the return value. CreateCollection is tested elsewhere. _, err := db.CreateCollection(name, metadata, embeddingFunc) if err != nil { @@ -111,29 +167,43 @@ func TestDB_GetCollection(t *testing.T) { if c.Name != name { t.Error("expected name", name, "got", c.Name) } - // TODO: Check metadata when it's accessible (e.g. with GetMetadata()) - // if len(c.Metadata) != 1 { - // t.Error("expected 1 metadata, got", len(c.Metadata)) - // } - // if c.Metadata["foo"] != "bar" { - // t.Error("expected metadata", metadata, "got", c.Metadata) - // } - // TODO: Check documents content as soon as we have access to them - // TODO: Same for the EmbeddingFunc - // TODO: Check documents map being a copy as soon as we have access to it + // The collection's persistent dir should be empty + if c.persistDirectory != "" { + t.Error("expected empty persistent directory, got", c.persistDirectory) + } + // It's metadata should match + if len(c.metadata) != 1 || c.metadata["foo"] != "bar" { + t.Error("expected metadata", metadata, "got", c.metadata) + } + // Documents should be empty, but not nil + if c.documents == nil { + t.Error("expected non-nil documents, got nil") + } + if len(c.documents) != 0 { + t.Error("expected empty documents, got", len(c.documents)) + } + // The embedding function should be the one we passed + gotVectors, err := c.embed(context.Background(), "test") + if err != nil { + t.Error("expected no error, got", err) + } + if !slices.Equal(gotVectors, vectors) { + t.Error("expected vectors", vectors, "got", gotVectors) + } } func TestDB_GetOrCreateCollection(t *testing.T) { // Values in the collection name := "test" metadata := map[string]string{"foo": "bar"} + vectors := []float32{-0.1, 0.1, 0.2} embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { - return []float32{-0.1, 0.1, 0.2}, nil + return vectors, nil } t.Run("Get", func(t *testing.T) { // Create initial collection - db := chromem.NewDB() + db := NewDB() // Create collection so that the GetOrCreateCollection() call below only // gets it. // We ignore the return value. CreateCollection is tested elsewhere. @@ -150,19 +220,41 @@ func TestDB_GetOrCreateCollection(t *testing.T) { t.Error("expected no error, got", err) } if c == nil { - t.Error("expected collection, got nil") + t.Fatal("expected collection, got nil") } // Check expectations if c.Name != name { t.Error("expected name", name, "got", c.Name) } - // TODO: Check metadata when it's accessible (e.g. with GetMetadata()) + // The collection's persistent dir should be empty + if c.persistDirectory != "" { + t.Error("expected empty persistent directory, got", c.persistDirectory) + } + // It's metadata should match + if len(c.metadata) != 1 || c.metadata["foo"] != "bar" { + t.Error("expected metadata", metadata, "got", c.metadata) + } + // Documents should be empty, but not nil + if c.documents == nil { + t.Error("expected non-nil documents, got nil") + } + if len(c.documents) != 0 { + t.Error("expected empty documents, got", len(c.documents)) + } + // The embedding function should be the one we passed + gotVectors, err := c.embed(context.Background(), "test") + if err != nil { + t.Error("expected no error, got", err) + } + if !slices.Equal(gotVectors, vectors) { + t.Error("expected vectors", vectors, "got", gotVectors) + } }) t.Run("Create", func(t *testing.T) { // Create initial collection - db := chromem.NewDB() + db := NewDB() // Call GetOrCreateCollection() c, err := db.GetOrCreateCollection(name, metadata, embeddingFunc) @@ -170,14 +262,36 @@ func TestDB_GetOrCreateCollection(t *testing.T) { t.Error("expected no error, got", err) } if c == nil { - t.Error("expected collection, got nil") + t.Fatal("expected collection, got nil") } // Check like we check CreateCollection() if c.Name != name { t.Error("expected name", name, "got", c.Name) } - // TODO: Check metadata when it's accessible (e.g. with GetMetadata()) + // The collection's persistent dir should be empty + if c.persistDirectory != "" { + t.Error("expected empty persistent directory, got", c.persistDirectory) + } + // It's metadata should match + if len(c.metadata) != 1 || c.metadata["foo"] != "bar" { + t.Error("expected metadata", metadata, "got", c.metadata) + } + // Documents should be empty, but not nil + if c.documents == nil { + t.Error("expected non-nil documents, got nil") + } + if len(c.documents) != 0 { + t.Error("expected empty documents, got", len(c.documents)) + } + // The embedding function should be the one we passed + gotVectors, err := c.embed(context.Background(), "test") + if err != nil { + t.Error("expected no error, got", err) + } + if !slices.Equal(gotVectors, vectors) { + t.Error("expected vectors", vectors, "got", gotVectors) + } }) } @@ -185,12 +299,13 @@ func TestDB_DeleteCollection(t *testing.T) { // Values in the collection name := "test" metadata := map[string]string{"foo": "bar"} + vectors := []float32{-0.1, 0.1, 0.2} embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { - return []float32{-0.1, 0.1, 0.2}, nil + return vectors, nil } // Create initial collection - db := chromem.NewDB() + db := NewDB() // We ignore the return value. CreateCollection is tested elsewhere. _, err := db.CreateCollection(name, metadata, embeddingFunc) if err != nil { @@ -206,18 +321,23 @@ func TestDB_DeleteCollection(t *testing.T) { if len(db.ListCollections()) != 0 { t.Error("expected 0 collections, got", len(db.ListCollections())) } + // Also check internally + if len(db.collections) != 0 { + t.Error("expected 0 collections, got", len(db.collections)) + } } func TestDB_Reset(t *testing.T) { // Values in the collection name := "test" metadata := map[string]string{"foo": "bar"} + vectors := []float32{-0.1, 0.1, 0.2} embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { - return []float32{-0.1, 0.1, 0.2}, nil + return vectors, nil } // Create initial collection - db := chromem.NewDB() + db := NewDB() // We ignore the return value. CreateCollection is tested elsewhere. _, err := db.CreateCollection(name, metadata, embeddingFunc) if err != nil { @@ -233,4 +353,8 @@ func TestDB_Reset(t *testing.T) { if len(db.ListCollections()) != 0 { t.Error("expected 0 collections, got", len(db.ListCollections())) } + // Also check internally + if len(db.collections) != 0 { + t.Error("expected 0 collections, got", len(db.collections)) + } } From 9d971a9e7b2312856b7572233a7ff993e63682d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Tue, 5 Mar 2024 20:49:10 +0100 Subject: [PATCH 2/5] Extend collection tests --- collection_test.go | 290 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 280 insertions(+), 10 deletions(-) diff --git a/collection_test.go b/collection_test.go index 8b1059a..795a0c2 100644 --- a/collection_test.go +++ b/collection_test.go @@ -1,20 +1,118 @@ -package chromem_test +package chromem import ( "context" + "slices" "testing" - - "github.com/philippgille/chromem-go" ) func TestCollection_Add(t *testing.T) { + ctx := context.Background() + name := "test" + metadata := map[string]string{"foo": "bar"} + vectors := []float32{-0.1, 0.1, 0.2} + embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { + return vectors, nil + } + // Create collection - db := chromem.NewDB() + db := NewDB() + c, err := db.CreateCollection(name, metadata, embeddingFunc) + if err != nil { + t.Error("expected no error, got", err) + } + if c == nil { + t.Error("expected collection, got nil") + } + + // Add documents + + ids := []string{"1", "2"} + embeddings := [][]float32{vectors, vectors} + metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}} + contents := []string{"hello world", "hallo welt"} + + tt := []struct { + name string + ids []string + embeddings [][]float32 + metadatas []map[string]string + contents []string + }{ + { + name: "No embeddings", + ids: ids, + embeddings: nil, + metadatas: metadatas, + contents: contents, + }, + { + name: "With embeddings", + ids: ids, + embeddings: embeddings, + metadatas: metadatas, + contents: contents, + }, + { + name: "With embeddings but no contents", + ids: ids, + embeddings: embeddings, + metadatas: metadatas, + contents: nil, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + err = c.Add(ctx, ids, nil, metadatas, contents) + if err != nil { + t.Error("expected nil, got", err) + } + + // Check documents + if len(c.documents) != 2 { + t.Error("expected 2, got", len(c.documents)) + } + for i, id := range ids { + doc, ok := c.documents[id] + if !ok { + t.Error("expected document, got nil") + } + if doc.ID != id { + t.Error("expected", id, "got", doc.ID) + } + if len(doc.Metadata) != 1 { + t.Error("expected 1, got", len(doc.Metadata)) + } + if !slices.Equal(doc.Embedding, vectors) { + t.Error("expected", vectors, "got", doc.Embedding) + } + if doc.Content != contents[i] { + t.Error("expected", contents[i], "got", doc.Content) + } + } + // Metadata can't be accessed with the loop's i + if c.documents[ids[0]].Metadata["foo"] != "bar" { + t.Error("expected bar, got", c.documents[ids[0]].Metadata["foo"]) + } + if c.documents[ids[1]].Metadata["a"] != "b" { + t.Error("expected b, got", c.documents[ids[1]].Metadata["a"]) + } + }) + } +} + +func TestCollection_Add_Error(t *testing.T) { + ctx := context.Background() name := "test" metadata := map[string]string{"foo": "bar"} + vectors := []float32{-0.1, 0.1, 0.2} embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { - return []float32{-0.1, 0.1, 0.2}, nil + return vectors, nil } + + // Create collection + db := NewDB() c, err := db.CreateCollection(name, metadata, embeddingFunc) if err != nil { t.Error("expected no error, got", err) @@ -23,21 +121,193 @@ func TestCollection_Add(t *testing.T) { t.Error("expected collection, got nil") } - // Add document + // Add documents, provoking errors ids := []string{"1", "2"} + embeddings := [][]float32{vectors, vectors} metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}} contents := []string{"hello world", "hallo welt"} - err = c.Add(context.Background(), ids, nil, metadatas, contents) + // Empty IDs + err = c.Add(ctx, []string{}, embeddings, metadatas, contents) + if err == nil { + t.Error("expected error, got nil") + } + // Empty embeddings and contents (both at the same time!) + err = c.Add(ctx, ids, [][]float32{}, metadatas, []string{}) + if err == nil { + t.Error("expected error, got nil") + } + // Bad embeddings length + err = c.Add(ctx, ids, [][]float32{vectors}, metadatas, contents) + if err == nil { + t.Error("expected error, got nil") + } + // Bad metadatas length + err = c.Add(ctx, ids, embeddings, []map[string]string{{"foo": "bar"}}, contents) + if err == nil { + t.Error("expected error, got nil") + } + // Bad contents length + err = c.Add(ctx, ids, embeddings, metadatas, []string{"hello world"}) + if err == nil { + t.Error("expected error, got nil") + } +} + +func TestCollection_AddConcurrently(t *testing.T) { + ctx := context.Background() + name := "test" + metadata := map[string]string{"foo": "bar"} + vectors := []float32{-0.1, 0.1, 0.2} + embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { + return vectors, nil + } + + // Create collection + db := NewDB() + c, err := db.CreateCollection(name, metadata, embeddingFunc) if err != nil { - t.Error("expected nil, got", err) + t.Error("expected no error, got", err) + } + if c == nil { + t.Error("expected collection, got nil") + } + + // Add documents + + ids := []string{"1", "2"} + embeddings := [][]float32{vectors, vectors} + metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}} + contents := []string{"hello world", "hallo welt"} + + tt := []struct { + name string + ids []string + embeddings [][]float32 + metadatas []map[string]string + contents []string + }{ + { + name: "No embeddings", + ids: ids, + embeddings: nil, + metadatas: metadatas, + contents: contents, + }, + { + name: "With embeddings", + ids: ids, + embeddings: embeddings, + metadatas: metadatas, + contents: contents, + }, + { + name: "With embeddings but no contents", + ids: ids, + embeddings: embeddings, + metadatas: metadatas, + contents: nil, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + err = c.AddConcurrently(ctx, ids, nil, metadatas, contents, 2) + if err != nil { + t.Error("expected nil, got", err) + } + + // Check documents + if len(c.documents) != 2 { + t.Error("expected 2, got", len(c.documents)) + } + for i, id := range ids { + doc, ok := c.documents[id] + if !ok { + t.Error("expected document, got nil") + } + if doc.ID != id { + t.Error("expected", id, "got", doc.ID) + } + if len(doc.Metadata) != 1 { + t.Error("expected 1, got", len(doc.Metadata)) + } + if !slices.Equal(doc.Embedding, vectors) { + t.Error("expected", vectors, "got", doc.Embedding) + } + if doc.Content != contents[i] { + t.Error("expected", contents[i], "got", doc.Content) + } + } + // Metadata can't be accessed with the loop's i + if c.documents[ids[0]].Metadata["foo"] != "bar" { + t.Error("expected bar, got", c.documents[ids[0]].Metadata["foo"]) + } + if c.documents[ids[1]].Metadata["a"] != "b" { + t.Error("expected b, got", c.documents[ids[1]].Metadata["a"]) + } + }) + } +} + +func TestCollection_AddConcurrently_Error(t *testing.T) { + ctx := context.Background() + name := "test" + metadata := map[string]string{"foo": "bar"} + vectors := []float32{-0.1, 0.1, 0.2} + embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { + return vectors, nil } - // TODO: Check expectations when documents become accessible + // Create collection + db := NewDB() + c, err := db.CreateCollection(name, metadata, embeddingFunc) + if err != nil { + t.Error("expected no error, got", err) + } + if c == nil { + t.Error("expected collection, got nil") + } + + // Add documents, provoking errors + ids := []string{"1", "2"} + embeddings := [][]float32{vectors, vectors} + metadatas := []map[string]string{{"foo": "bar"}, {"a": "b"}} + contents := []string{"hello world", "hallo welt"} + // Empty IDs + err = c.AddConcurrently(ctx, []string{}, embeddings, metadatas, contents, 2) + if err == nil { + t.Error("expected error, got nil") + } + // Empty embeddings and contents (both at the same time!) + err = c.AddConcurrently(ctx, ids, [][]float32{}, metadatas, []string{}, 2) + if err == nil { + t.Error("expected error, got nil") + } + // Bad embeddings length + err = c.AddConcurrently(ctx, ids, [][]float32{vectors}, metadatas, contents, 2) + if err == nil { + t.Error("expected error, got nil") + } + // Bad metadatas length + err = c.AddConcurrently(ctx, ids, embeddings, []map[string]string{{"foo": "bar"}}, contents, 2) + if err == nil { + t.Error("expected error, got nil") + } + // Bad contents length + err = c.AddConcurrently(ctx, ids, embeddings, metadatas, []string{"hello world"}, 2) + if err == nil { + t.Error("expected error, got nil") + } + // Bad concurrency + err = c.AddConcurrently(ctx, ids, embeddings, metadatas, contents, 0) + if err == nil { + t.Error("expected error, got nil") + } } func TestCollection_Count(t *testing.T) { // Create collection - db := chromem.NewDB() + db := NewDB() name := "test" metadata := map[string]string{"foo": "bar"} embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { From 61b07f220aeced3e86c301f8c2f3f6001df4f187 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Tue, 5 Mar 2024 20:49:14 +0100 Subject: [PATCH 3/5] Add document test --- document_test.go | 66 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 document_test.go diff --git a/document_test.go b/document_test.go new file mode 100644 index 0000000..4924a77 --- /dev/null +++ b/document_test.go @@ -0,0 +1,66 @@ +package chromem + +import ( + "context" + "slices" + "testing" +) + +func TestDocument_New(t *testing.T) { + ctx := context.Background() + id := "test" + metadata := map[string]string{"foo": "bar"} + vectors := []float32{-0.1, 0.1, 0.2} + content := "hello world" + embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { + return vectors, nil + } + + tt := []struct { + name string + id string + metadata map[string]string + vectors []float32 + content string + embeddingFunc EmbeddingFunc + }{ + { + name: "No embedding", + id: id, + metadata: metadata, + vectors: nil, + content: content, + embeddingFunc: embeddingFunc, + }, + { + name: "With embedding", + id: id, + metadata: metadata, + vectors: vectors, + content: content, + embeddingFunc: embeddingFunc, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + // Create document + d, err := NewDocument(ctx, id, metadata, vectors, content, embeddingFunc) + if err != nil { + t.Error("expected no error, got", err) + } + if d.ID != id { + t.Error("expected id", id, "got", d.ID) + } + if d.Metadata["foo"] != metadata["foo"] { + t.Error("expected metadata", metadata, "got", d.Metadata) + } + if !slices.Equal(d.Embedding, vectors) { + t.Error("expected vectors", vectors, "got", d.Embedding) + } + if d.Content != content { + t.Error("expected content", content, "got", d.Content) + } + }) + } +} From f47e78256460ec75fb15531b44f5a2991efdf637 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Tue, 5 Mar 2024 21:06:55 +0100 Subject: [PATCH 4/5] Add Ollama test --- embed_ollama.go | 4 ++- embed_ollama_test.go | 79 ++++++++++++++++++++++++++++++++++++++++++++ embed_openai_test.go | 4 +-- 3 files changed, 84 insertions(+), 3 deletions(-) create mode 100644 embed_ollama_test.go diff --git a/embed_ollama.go b/embed_ollama.go index 2e5f718..7ed0573 100644 --- a/embed_ollama.go +++ b/embed_ollama.go @@ -10,7 +10,9 @@ import ( "net/http" ) -const baseURLOllama = "http://localhost:11434/api" +// TODO: Turn into const and use as default, but allow user to pass custom URL +// as well as custom API key, in case Ollama runs on a remote (secured) server. +var baseURLOllama = "http://localhost:11434/api" type ollamaResponse struct { Embedding []float32 `json:"embedding"` diff --git a/embed_ollama_test.go b/embed_ollama_test.go new file mode 100644 index 0000000..b2f83df --- /dev/null +++ b/embed_ollama_test.go @@ -0,0 +1,79 @@ +package chromem + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "net/url" + "slices" + "strings" + "testing" +) + +func TestNewEmbeddingFuncOllama(t *testing.T) { + model := "model-small" + baseURLSuffix := "/api" + prompt := "hello world" + + wantBody, err := json.Marshal(map[string]string{ + "model": model, + "prompt": prompt, + }) + if err != nil { + t.Error("unexpected error:", err) + } + wantRes := []float32{-0.1, 0.1, 0.2} + + // Mock server + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check URL + if !strings.HasSuffix(r.URL.Path, baseURLSuffix+"/embeddings") { + t.Error("expected URL", baseURLSuffix+"/embeddings", "got", r.URL.Path) + } + // Check method + if r.Method != "POST" { + t.Error("expected method POST, got", r.Method) + } + // Check headers + if r.Header.Get("Content-Type") != "application/json" { + t.Error("expected Content-Type header", "application/json", "got", r.Header.Get("Content-Type")) + } + // Check body + body, err := io.ReadAll(r.Body) + if err != nil { + t.Error("unexpected error:", err) + } + if !bytes.Equal(body, wantBody) { + t.Error("expected body", wantBody, "got", body) + } + + // Write response + resp := ollamaResponse{ + Embedding: wantRes, + } + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(resp) + })) + defer ts.Close() + + // Get port from URL + u, err := url.Parse(ts.URL) + if err != nil { + t.Error("unexpected error:", err) + } + // TODO: It's bad to overwrite a global var for testing. Follow-up with a change + // to allow passing custom URLs to the function. + baseURLOllama = strings.Replace(baseURLOllama, "11434", u.Port(), 1) + + f := NewEmbeddingFuncOllama(model) + res, err := f(context.Background(), prompt) + if err != nil { + t.Error("expected nil, got", err) + } + if slices.Compare(wantRes, res) != 0 { + t.Error("expected res", wantRes, "got", res) + } +} diff --git a/embed_openai_test.go b/embed_openai_test.go index 70c62f9..07cbe94 100644 --- a/embed_openai_test.go +++ b/embed_openai_test.go @@ -39,7 +39,7 @@ func TestNewEmbeddingFuncOpenAICompat(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Check URL if !strings.HasSuffix(r.URL.Path, baseURLSuffix+"/embeddings") { - t.Error("expected URL", baseURLSuffix+"/embedding", "got", r.URL.Path) + t.Error("expected URL", baseURLSuffix+"/embeddings", "got", r.URL.Path) } // Check method if r.Method != "POST" { @@ -80,7 +80,7 @@ func TestNewEmbeddingFuncOpenAICompat(t *testing.T) { if err != nil { t.Error("expected nil, got", err) } - if slices.Compare[[]float32](wantRes, res) != 0 { + if slices.Compare(wantRes, res) != 0 { t.Error("expected res", wantRes, "got", res) } } From 8373b11058eb77b937d98c1f42ad9b2e13bc533d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philipp=20Gill=C3=A9?= Date: Tue, 5 Mar 2024 21:21:14 +0100 Subject: [PATCH 5/5] Add filter test --- query.go | 5 +++ query_test.go | 99 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 104 insertions(+) create mode 100644 query_test.go diff --git a/query.go b/query.go index 3e34573..e7e5cd3 100644 --- a/query.go +++ b/query.go @@ -61,6 +61,11 @@ func filterDocs(docs map[string]*Document, where, whereDocument map[string]strin wg.Wait() + // With filteredDocs being initialized as potentially large slice, let's return + // nil instead of the empty slice. + if len(filteredDocs) == 0 { + filteredDocs = nil + } return filteredDocs } diff --git a/query_test.go b/query_test.go new file mode 100644 index 0000000..6529dc7 --- /dev/null +++ b/query_test.go @@ -0,0 +1,99 @@ +package chromem + +import ( + "reflect" + "testing" +) + +func TestFilterDocs(t *testing.T) { + docs := map[string]*Document{ + "1": { + ID: "1", + Metadata: map[string]string{ + "language": "en", + }, + Embedding: []float32{0.1, 0.2, 0.3}, + Content: "hello world", + }, + "2": { + ID: "2", + Metadata: map[string]string{ + "language": "de", + }, + Embedding: []float32{0.2, 0.3, 0.4}, + Content: "hallo welt", + }, + } + + tt := []struct { + name string + where map[string]string + whereDocument map[string]string + want []*Document + }{ + { + name: "meta match", + where: map[string]string{"language": "de"}, + whereDocument: nil, + want: []*Document{docs["2"]}, + }, + { + name: "meta no match", + where: map[string]string{"language": "fr"}, + whereDocument: nil, + want: nil, + }, + { + name: "content contains all", + where: nil, + whereDocument: map[string]string{"$contains": "llo"}, + want: []*Document{docs["1"], docs["2"]}, + }, + { + name: "content contains one", + where: nil, + whereDocument: map[string]string{"$contains": "hallo"}, + want: []*Document{docs["2"]}, + }, + { + name: "content contains none", + where: nil, + whereDocument: map[string]string{"$contains": "bonjour"}, + want: nil, + }, + { + name: "content not_contains all", + where: nil, + whereDocument: map[string]string{"$not_contains": "bonjour"}, + want: []*Document{docs["1"], docs["2"]}, + }, + { + name: "content not_contains one", + where: nil, + whereDocument: map[string]string{"$not_contains": "hello"}, + want: []*Document{docs["2"]}, + }, + { + name: "meta and content match", + where: map[string]string{"language": "de"}, + whereDocument: map[string]string{"$contains": "hallo"}, + want: []*Document{docs["2"]}, + }, + { + name: "meta + contains + not_contains", + where: map[string]string{"language": "de"}, + whereDocument: map[string]string{"$contains": "hallo", "$not_contains": "bonjour"}, + want: []*Document{docs["2"]}, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + got := filterDocs(docs, tc.where, tc.whereDocument) + + if !reflect.DeepEqual(got, tc.want) { + t.Errorf("got %v; want %v", got, tc.want) + } + }) + } +}