Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve query performance #47

Merged
merged 12 commits into from
Mar 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 40 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ Because `chromem-go` is embeddable it enables you to add retrieval augmented gen

It's *not* a library to connect to Chroma and also not a reimplementation of it in Go. It's a database on its own.

The focus is not scale or number of features, but simplicity.

Performance has not been a priority yet. Without optimizations (except some parallelization with goroutines) querying 5,000 documents takes ~500ms on a mid-range laptop CPU (11th Gen Intel i5-1135G7, like in the first generation Framework Laptop 13).
The focus is not scale (millions of documents) or number of features, but simplicity and performance for the most common use cases. On a mid-range 2020 Intel laptop CPU you can query 1,000 documents in 0.5 ms and 100,000 documents in 56 ms, both with just 44 memory allocations. See [Benchmarks](#benchmarks) for details.

> ⚠️ The project is in beta, under heavy construction, and may introduce breaking changes in releases before `v1.0.0`. All changes are documented in the [`CHANGELOG`](./CHANGELOG.md).

Expand All @@ -23,8 +21,9 @@ Performance has not been a priority yet. Without optimizations (except some para
2. [Interface](#interface)
3. [Features](#features)
4. [Usage](#usage)
5. [Motivation](#motivation)
6. [Related projects](#related-projects)
5. [Benchmarks](#benchmarks)
6. [Motivation](#motivation)
7. [Related projects](#related-projects)

## Use cases

Expand Down Expand Up @@ -156,32 +155,54 @@ See the Godoc for details: <https://pkg.go.dev/github.com/philippgille/chromem-g
### Roadmap

- Performance:
- [ ] Add Go benchmark code
- [ ] Improve code based on CPU and memory profiles
- Add SIMD / Assembler to speed up dot product calculation
- Add [roaring bitmaps](https://github.com/RoaringBitmap/roaring) to speed up full text filtering
- Embedding creators:
- [ ] Add an `EmbeddingFunc` that downloads and shells out to [llamafile](https://github.com/Mozilla-Ocho/llamafile)
- Add an `EmbeddingFunc` that downloads and shells out to [llamafile](https://github.com/Mozilla-Ocho/llamafile)
- Similarity search:
- [ ] Approximate nearest neighbor search with index (ANN)
- [ ] Hierarchical Navigable Small World (HNSW)
- [ ] Inverted file flat (IVFFlat)
- Approximate nearest neighbor search with index (ANN)
- Hierarchical Navigable Small World (HNSW)
- Inverted file flat (IVFFlat)
- Filters:
- [ ] Operators (`$and`, `$or` etc.)
- Operators (`$and`, `$or` etc.)
- Storage:
- [ ] JSON as second encoding format
- [ ] Write-ahead log (WAL) as second file format
- [ ] Compression
- [ ] Encryption (at rest)
- [ ] Optional remote storage (S3, PostgreSQL, ...)
- JSON as second encoding format
- Write-ahead log (WAL) as second file format
- Compression
- Encryption (at rest)
- Optional remote storage (S3, PostgreSQL, ...)
- Data types:
- [ ] Images
- [ ] Videos
- Images
- Videos

## Usage

See the Godoc for a reference: <https://pkg.go.dev/github.com/philippgille/chromem-go>

For full, working examples, using the vector database for retrieval augmented generation (RAG) and semantic search and using either OpenAI or locally running the embeddings model and LLM (in Ollama), see the [example code](examples).

## Benchmarks

```console
$ go test -benchmem -run=^$ -bench .
goos: linux
goarch: amd64
pkg: github.com/philippgille/chromem-go
cpu: 11th Gen Intel(R) Core(TM) i5-1135G7 @ 2.40GHz
BenchmarkCollection_Query_NoContent_100-8 10000 110126 ns/op 6492 B/op 44 allocs/op
BenchmarkCollection_Query_NoContent_1000-8 2020 537416 ns/op 35669 B/op 44 allocs/op
BenchmarkCollection_Query_NoContent_5000-8 351 4264192 ns/op 166728 B/op 44 allocs/op
BenchmarkCollection_Query_NoContent_25000-8 75 16411744 ns/op 813928 B/op 44 allocs/op
BenchmarkCollection_Query_NoContent_100000-8 18 64670962 ns/op 3205962 B/op 44 allocs/op
BenchmarkCollection_Query_100-8 10923 109936 ns/op 6480 B/op 44 allocs/op
BenchmarkCollection_Query_1000-8 2184 562778 ns/op 35667 B/op 44 allocs/op
BenchmarkCollection_Query_5000-8 400 2986732 ns/op 166750 B/op 44 allocs/op
BenchmarkCollection_Query_25000-8 88 15433911 ns/op 813896 B/op 44 allocs/op
BenchmarkCollection_Query_100000-8 19 63696478 ns/op 3205982 B/op 44 allocs/op
PASS
ok github.com/philippgille/chromem-go 31.373s
```

## Motivation

In December 2023, when I wanted to play around with retrieval augmented generation (RAG) in a Go program, I looked for a vector database that could be embedded in the Go program, just like you would embed SQLite in order to not require any separate DB setup and maintenance. I was surprised when I didn't find any, given the abundance of embedded key-value stores in the Go ecosystem.
Expand Down
42 changes: 37 additions & 5 deletions collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,17 @@ func (c *Collection) AddDocument(ctx context.Context, doc Document) error {
m[k] = v
}

// Create embedding if they don't exist
// Create embedding if they don't exist, otherwise normalize if necessary
if len(doc.Embedding) == 0 {
embedding, err := c.embed(ctx, doc.Content)
if err != nil {
return fmt.Errorf("couldn't create embedding of document: %w", err)
}
doc.Embedding = embedding
} else {
if !isNormalized(doc.Embedding) {
doc.Embedding = normalizeVector(doc.Embedding)
}
}

c.documentsLock.Lock()
Expand Down Expand Up @@ -247,6 +251,19 @@ func (c *Collection) Count() int {
return len(c.documents)
}

// Result represents a single result from a query.
type Result struct {
ID string
Metadata map[string]string
Embedding []float32
Content string

// The cosine similarity between the query and the document.
// The higher the value, the more similar the document is to the query.
// The value is in the range [-1, 1].
Similarity float32
}

// Performs an exhaustive nearest neighbor search on the collection.
//
// - queryText: The text to search for.
Expand Down Expand Up @@ -288,17 +305,32 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int,
}

// For the remaining documents, calculate cosine similarity.
res, err := calcDocSimilarity(ctx, queryVectors, filteredDocs)
docSim, err := calcDocSimilarity(ctx, queryVectors, filteredDocs)
if err != nil {
return nil, fmt.Errorf("couldn't calculate cosine similarity: %w", err)
}

// Sort by similarity
sort.Slice(res, func(i, j int) bool {
sort.Slice(docSim, func(i, j int) bool {
// The `less` function would usually use `<`, but we want to sort descending.
return res[i].Similarity > res[j].Similarity
return docSim[i].similarity > docSim[j].similarity
})

// Return the top nResults or len(docSim), whichever is smaller
if len(docSim) < nResults {
nResults = len(docSim)
}
res := make([]Result, 0, nResults)
for i := 0; i < nResults; i++ {
res = append(res, Result{
ID: docSim[i].docID,
Metadata: c.documents[docSim[i].docID].Metadata,
Embedding: c.documents[docSim[i].docID].Embedding,
Content: c.documents[docSim[i].docID].Content,
Similarity: docSim[i].similarity,
})
}

// Return the top nResults
return res[:nResults], nil
return res, nil
}
11 changes: 6 additions & 5 deletions collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func TestCollection_Add(t *testing.T) {
ctx := context.Background()
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.1, 0.1, 0.2}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}
Expand Down Expand Up @@ -109,7 +109,7 @@ func TestCollection_Add_Error(t *testing.T) {
ctx := context.Background()
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.1, 0.1, 0.2}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}
Expand Down Expand Up @@ -160,7 +160,7 @@ func TestCollection_AddConcurrently(t *testing.T) {
ctx := context.Background()
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.1, 0.1, 0.2}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}
Expand Down Expand Up @@ -256,7 +256,7 @@ func TestCollection_AddConcurrently_Error(t *testing.T) {
ctx := context.Background()
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.1, 0.1, 0.2}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}
Expand Down Expand Up @@ -313,8 +313,9 @@ func TestCollection_Count(t *testing.T) {
db := NewDB()
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return []float32{-0.1, 0.1, 0.2}, nil
return vectors, nil
}
c, err := db.CreateCollection(name, metadata, embeddingFunc)
if err != nil {
Expand Down
3 changes: 3 additions & 0 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ import (
// EmbeddingFunc is a function that creates embeddings for a given text.
// chromem-go will use OpenAI`s "text-embedding-3-small" model by default,
// but you can provide your own function, using any model you like.
// The function must return a *normalized* vector, i.e. the length of the vector
// must be 1. OpenAI's and Mistral's embedding models do this by default. Some
// others like Nomic's "nomic-embed-text-v1.5" don't.
type EmbeddingFunc func(ctx context.Context, text string) ([]float32, error)

// DB is the chromem-go database. It holds collections, which hold documents.
Expand Down
12 changes: 6 additions & 6 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ func TestDB_CreateCollection(t *testing.T) {
// Values in the collection
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.1, 0.1, 0.2}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}
Expand Down Expand Up @@ -81,7 +81,7 @@ func TestDB_ListCollections(t *testing.T) {
// Values in the collection
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.1, 0.1, 0.2}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}
Expand Down Expand Up @@ -147,7 +147,7 @@ func TestDB_GetCollection(t *testing.T) {
// Values in the collection
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.1, 0.1, 0.2}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}
Expand Down Expand Up @@ -196,7 +196,7 @@ func TestDB_GetOrCreateCollection(t *testing.T) {
// Values in the collection
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.1, 0.1, 0.2}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}
Expand Down Expand Up @@ -299,7 +299,7 @@ func TestDB_DeleteCollection(t *testing.T) {
// Values in the collection
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.1, 0.1, 0.2}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}
Expand Down Expand Up @@ -331,7 +331,7 @@ func TestDB_Reset(t *testing.T) {
// Values in the collection
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.1, 0.1, 0.2}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}
Expand Down
2 changes: 1 addition & 1 deletion document_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ func TestDocument_New(t *testing.T) {
ctx := context.Background()
id := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.1, 0.1, 0.2}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
content := "hello world"
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
Expand Down
12 changes: 8 additions & 4 deletions embed_compat.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,13 @@ const (
// NewEmbeddingFuncMistral returns a function that creates embeddings for a text
// using the Mistral API.
func NewEmbeddingFuncMistral(apiKey string) EmbeddingFunc {
// Mistral embeddings are normalized, see section "Distance Measures" on
// https://docs.mistral.ai/guides/embeddings/.
normalized := true

// The Mistral API docs don't mention the `encoding_format` as optional,
// but it seems to be, just like OpenAI. So we reuse the OpenAI function.
return NewEmbeddingFuncOpenAICompat(baseURLMistral, apiKey, embeddingModelMistral)
return NewEmbeddingFuncOpenAICompat(baseURLMistral, apiKey, embeddingModelMistral, &normalized)
}

const baseURLJina = "https://api.jina.ai/v1"
Expand All @@ -28,7 +32,7 @@ const (
// NewEmbeddingFuncJina returns a function that creates embeddings for a text
// using the Jina API.
func NewEmbeddingFuncJina(apiKey string, model EmbeddingModelJina) EmbeddingFunc {
return NewEmbeddingFuncOpenAICompat(baseURLJina, apiKey, string(model))
return NewEmbeddingFuncOpenAICompat(baseURLJina, apiKey, string(model), nil)
}

const baseURLMixedbread = "https://api.mixedbread.ai"
Expand All @@ -49,7 +53,7 @@ const (
// NewEmbeddingFuncMixedbread returns a function that creates embeddings for a text
// using the mixedbread.ai API.
func NewEmbeddingFuncMixedbread(apiKey string, model EmbeddingModelMixedbread) EmbeddingFunc {
return NewEmbeddingFuncOpenAICompat(baseURLMixedbread, apiKey, string(model))
return NewEmbeddingFuncOpenAICompat(baseURLMixedbread, apiKey, string(model), nil)
}

const baseURLLocalAI = "http://localhost:8080/v1"
Expand All @@ -64,5 +68,5 @@ const baseURLLocalAI = "http://localhost:8080/v1"
// But other embedding models are supported as well. See the LocalAI documentation
// for details.
func NewEmbeddingFuncLocalAI(model string) EmbeddingFunc {
return NewEmbeddingFuncOpenAICompat(baseURLLocalAI, "", model)
return NewEmbeddingFuncOpenAICompat(baseURLLocalAI, "", model, nil)
}
18 changes: 17 additions & 1 deletion embed_ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"io"
"net/http"
"sync"
)

// TODO: Turn into const and use as default, but allow user to pass custom URL
Expand All @@ -28,6 +29,9 @@ func NewEmbeddingFuncOllama(model string) EmbeddingFunc {
// and it might have to be a long timeout, depending on the text length.
client := &http.Client{}

var checkedNormalized bool
checkNormalized := sync.Once{}

return func(ctx context.Context, text string) ([]float32, error) {
// Prepare the request body.
reqBody, err := json.Marshal(map[string]string{
Expand Down Expand Up @@ -74,6 +78,18 @@ func NewEmbeddingFuncOllama(model string) EmbeddingFunc {
return nil, errors.New("no embeddings found in the response")
}

return embeddingResponse.Embedding, nil
v := embeddingResponse.Embedding
checkNormalized.Do(func() {
if isNormalized(v) {
checkedNormalized = true
} else {
checkedNormalized = false
}
})
if !checkedNormalized {
v = normalizeVector(v)
}

return v, nil
}
}
2 changes: 1 addition & 1 deletion embed_ollama_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func TestNewEmbeddingFuncOllama(t *testing.T) {
if err != nil {
t.Fatal("unexpected error:", err)
}
wantRes := []float32{-0.1, 0.1, 0.2}
wantRes := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`

// Mock server
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down
Loading