Skip to content

Commit

Permalink
Add basic query benchmark
Browse files Browse the repository at this point in the history
- Normalized vectors
- Only query similarity
- No metadata or content filter
  • Loading branch information
philippgille committed Mar 12, 2024
1 parent 46b4509 commit b916a81
Showing 1 changed file with 89 additions and 0 deletions.
89 changes: 89 additions & 0 deletions collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ package chromem

import (
"context"
"errors"
"math/rand"
"slices"
"strconv"
"testing"
)

Expand Down Expand Up @@ -335,3 +338,89 @@ func TestCollection_Count(t *testing.T) {
t.Fatal("expected 2, got", c.Count())
}
}

// Global var for assignment in the benchmark to avoid compiler optimizations.
var globalRes []Result

func BenchmarkCollection_Query_100(b *testing.B) {
benchmarkCollection_Query(b, 100)
}

func BenchmarkCollection_Query_1000(b *testing.B) {
benchmarkCollection_Query(b, 1000)
}

func BenchmarkCollection_Query_5000(b *testing.B) {
benchmarkCollection_Query(b, 5000)
}

func BenchmarkCollection_Query_25000(b *testing.B) {
benchmarkCollection_Query(b, 25000)
}

func BenchmarkCollection_Query_100000(b *testing.B) {
benchmarkCollection_Query(b, 100_000)
}

// n is number of documents in the collection
func benchmarkCollection_Query(b *testing.B, n int) {
ctx := context.Background()

// Seed to make deterministic
r := rand.New(rand.NewSource(42))

d := 1536 // dimensions, same as text-embedding-3-small
// Random query vector
qv := make([]float32, d)
for j := 0; j < d; j++ {
qv[j] = r.Float32()
}
// Most embeddings are normalized, so we normalize this one too
qv = normalizeVector(qv)
embeddingFunc := func(_ context.Context, text string) ([]float32, error) {
if text != "foo" {
return nil, errors.New("embedding func not expected to be called")
}
return qv, nil
}

// Create collection
db := NewDB()
name := "test"
c, err := db.CreateCollection(name, nil, embeddingFunc)
if err != nil {
b.Fatal("expected no error, got", err)
}
if c == nil {
b.Fatal("expected collection, got nil")
}

// Add documents
for i := 0; i < n; i++ {
// Random embedding
v := make([]float32, d)
for j := 0; j < d; j++ {
v[j] = r.Float32()
}
v = normalizeVector(v)

// Add document without metadata or content.
// When providing embeddings, the embedding func is not called.
c.AddDocument(ctx, Document{
ID: strconv.Itoa(i),
Embedding: v,
})
}

b.ResetTimer()

// Query
var res []Result
for i := 0; i < b.N; i++ {
res, err = c.Query(ctx, "foo", 10, nil, nil)
}
if err != nil {
b.Fatal("expected nil, got", err)
}
globalRes = res
}

0 comments on commit b916a81

Please sign in to comment.