Skip to content

Commit

Permalink
Revert "Add "normalized" parameter to skip check if normalization is …
Browse files Browse the repository at this point in the history
…known"

This reverts commit ff28a38.
  • Loading branch information
philippgille committed Mar 16, 2024
1 parent ff28a38 commit 05c4f76
Show file tree
Hide file tree
Showing 9 changed files with 34 additions and 62 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func main() {
db := chromem.NewDB()

// Create collection. GetCollection, GetOrCreateCollection, DeleteCollection also available!
collection, _ := db.CreateCollection("all-my-documents", nil, nil, nil)
collection, _ := db.CreateCollection("all-my-documents", nil, nil)

// Add docs to the collection. Update and delete will be added in the future.
// Can be multi-threaded with AddConcurrently()!
Expand Down
12 changes: 5 additions & 7 deletions collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@ type Collection struct {
documents map[string]*Document
documentsLock sync.RWMutex
embed EmbeddingFunc
normalized *bool
}

// We don't export this yet to keep the API surface to the bare minimum.
// Users create collections via [Client.CreateCollection].
func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, normalized *bool, dir string) (*Collection, error) {
func newCollection(name string, metadata map[string]string, embed EmbeddingFunc, dir string) (*Collection, error) {
// We copy the metadata to avoid data races in case the caller modifies the
// map after creating the collection while we range over it.
m := make(map[string]string, len(metadata))
Expand All @@ -38,10 +37,9 @@ func newCollection(name string, metadata map[string]string, embed EmbeddingFunc,
c := &Collection{
Name: name,

metadata: m,
documents: make(map[string]*Document),
embed: embed,
normalized: normalized,
metadata: m,
documents: make(map[string]*Document),
embed: embed,
}

// Persistence
Expand Down Expand Up @@ -303,7 +301,7 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int,
}

// For the remaining documents, calculate cosine similarity.
docSim, err := calcDocSimilarity(ctx, queryVectors, filteredDocs, c.normalized)
docSim, err := calcDocSimilarity(ctx, queryVectors, filteredDocs)
if err != nil {
return nil, fmt.Errorf("couldn't calculate cosine similarity: %w", err)
}
Expand Down
12 changes: 6 additions & 6 deletions collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func TestCollection_Add(t *testing.T) {

// Create collection
db := NewDB()
c, err := db.CreateCollection(name, metadata, embeddingFunc, nil)
c, err := db.CreateCollection(name, metadata, embeddingFunc)
if err != nil {
t.Fatal("expected no error, got", err)
}
Expand Down Expand Up @@ -116,7 +116,7 @@ func TestCollection_Add_Error(t *testing.T) {

// Create collection
db := NewDB()
c, err := db.CreateCollection(name, metadata, embeddingFunc, nil)
c, err := db.CreateCollection(name, metadata, embeddingFunc)
if err != nil {
t.Fatal("expected no error, got", err)
}
Expand Down Expand Up @@ -167,7 +167,7 @@ func TestCollection_AddConcurrently(t *testing.T) {

// Create collection
db := NewDB()
c, err := db.CreateCollection(name, metadata, embeddingFunc, nil)
c, err := db.CreateCollection(name, metadata, embeddingFunc)
if err != nil {
t.Fatal("expected no error, got", err)
}
Expand Down Expand Up @@ -263,7 +263,7 @@ func TestCollection_AddConcurrently_Error(t *testing.T) {

// Create collection
db := NewDB()
c, err := db.CreateCollection(name, metadata, embeddingFunc, nil)
c, err := db.CreateCollection(name, metadata, embeddingFunc)
if err != nil {
t.Fatal("expected no error, got", err)
}
Expand Down Expand Up @@ -316,7 +316,7 @@ func TestCollection_Count(t *testing.T) {
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return []float32{-0.1, 0.1, 0.2}, nil
}
c, err := db.CreateCollection(name, metadata, embeddingFunc, nil)
c, err := db.CreateCollection(name, metadata, embeddingFunc)
if err != nil {
t.Fatal("expected no error, got", err)
}
Expand Down Expand Up @@ -407,7 +407,7 @@ func benchmarkCollection_Query(b *testing.B, n int, withContent bool) {
// Create collection
db := NewDB()
name := "test"
c, err := db.CreateCollection(name, nil, embeddingFunc, &trueVal)
c, err := db.CreateCollection(name, nil, embeddingFunc)
if err != nil {
b.Fatal("expected no error, got", err)
}
Expand Down
14 changes: 4 additions & 10 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,14 @@ func NewPersistentDB(path string) (*DB, error) {
// - metadata: Optional metadata to associate with the collection.
// - embeddingFunc: Optional function to use to embed documents.
// Uses the default embedding function if not provided.
// - normalized: Optional flag to indicate if the embeddings of the collection
// are normalized (when you add embeddings yourself, or the embeddings created
// by the embeddingFunc). If nil it will be autodetected.
func (db *DB) CreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc, normalized *bool) (*Collection, error) {
func (db *DB) CreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc) (*Collection, error) {
if name == "" {
return nil, errors.New("collection name is empty")
}
if embeddingFunc == nil {
embeddingFunc = NewEmbeddingFuncDefault()
}
collection, err := newCollection(name, metadata, embeddingFunc, normalized, db.persistDirectory)
collection, err := newCollection(name, metadata, embeddingFunc, db.persistDirectory)
if err != nil {
return nil, fmt.Errorf("couldn't create collection: %w", err)
}
Expand Down Expand Up @@ -216,15 +213,12 @@ func (db *DB) GetCollection(name string, embeddingFunc EmbeddingFunc) *Collectio
// - metadata: Optional metadata to associate with the collection.
// - embeddingFunc: Optional function to use to embed documents.
// Uses the default embedding function if not provided.
// - normalized: Optional flag to indicate if the embeddings of the collection
// are normalized (when you add embeddings yourself, or the embeddings created
// by the embeddingFunc). If nil it will be autodetected.
func (db *DB) GetOrCreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc, normalized *bool) (*Collection, error) {
func (db *DB) GetOrCreateCollection(name string, metadata map[string]string, embeddingFunc EmbeddingFunc) (*Collection, error) {
// No need to lock here, because the methods we call do that.
collection := db.GetCollection(name, embeddingFunc)
if collection == nil {
var err error
collection, err = db.CreateCollection(name, metadata, embeddingFunc, normalized)
collection, err = db.CreateCollection(name, metadata, embeddingFunc)
if err != nil {
return nil, fmt.Errorf("couldn't create collection: %w", err)
}
Expand Down
18 changes: 9 additions & 9 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func TestDB_CreateCollection(t *testing.T) {
db := NewDB()

t.Run("OK", func(t *testing.T) {
c, err := db.CreateCollection(name, metadata, embeddingFunc, nil)
c, err := db.CreateCollection(name, metadata, embeddingFunc)
if err != nil {
t.Fatal("expected no error, got", err)
}
Expand Down Expand Up @@ -70,7 +70,7 @@ func TestDB_CreateCollection(t *testing.T) {
})

t.Run("NOK - Empty name", func(t *testing.T) {
_, err := db.CreateCollection("", metadata, embeddingFunc, nil)
_, err := db.CreateCollection("", metadata, embeddingFunc)
if err == nil {
t.Fatal("expected error, got nil")
}
Expand All @@ -89,7 +89,7 @@ func TestDB_ListCollections(t *testing.T) {
// Create initial collection
db := NewDB()
// We ignore the return value. CreateCollection is tested elsewhere.
_, err := db.CreateCollection(name, metadata, embeddingFunc, nil)
_, err := db.CreateCollection(name, metadata, embeddingFunc)
if err != nil {
t.Fatal("expected no error, got", err)
}
Expand Down Expand Up @@ -155,7 +155,7 @@ func TestDB_GetCollection(t *testing.T) {
// Create initial collection
db := NewDB()
// We ignore the return value. CreateCollection is tested elsewhere.
_, err := db.CreateCollection(name, metadata, embeddingFunc, nil)
_, err := db.CreateCollection(name, metadata, embeddingFunc)
if err != nil {
t.Fatal("expected no error, got", err)
}
Expand Down Expand Up @@ -207,15 +207,15 @@ func TestDB_GetOrCreateCollection(t *testing.T) {
// Create collection so that the GetOrCreateCollection() call below only
// gets it.
// We ignore the return value. CreateCollection is tested elsewhere.
_, err := db.CreateCollection(name, metadata, embeddingFunc, nil)
_, err := db.CreateCollection(name, metadata, embeddingFunc)
if err != nil {
t.Fatal("expected no error, got", err)
}

// Call GetOrCreateCollection() with the same name to only get it. We pass
// nil for the metadata and embeddingFunc so we can check that the returned
// collection is the original one, and not a new one.
c, err := db.GetOrCreateCollection(name, nil, embeddingFunc, nil)
c, err := db.GetOrCreateCollection(name, nil, nil)
if err != nil {
t.Fatal("expected no error, got", err)
}
Expand Down Expand Up @@ -257,7 +257,7 @@ func TestDB_GetOrCreateCollection(t *testing.T) {
db := NewDB()

// Call GetOrCreateCollection()
c, err := db.GetOrCreateCollection(name, metadata, embeddingFunc, nil)
c, err := db.GetOrCreateCollection(name, metadata, embeddingFunc)
if err != nil {
t.Fatal("expected no error, got", err)
}
Expand Down Expand Up @@ -307,7 +307,7 @@ func TestDB_DeleteCollection(t *testing.T) {
// Create initial collection
db := NewDB()
// We ignore the return value. CreateCollection is tested elsewhere.
_, err := db.CreateCollection(name, metadata, embeddingFunc, nil)
_, err := db.CreateCollection(name, metadata, embeddingFunc)
if err != nil {
t.Fatal("expected no error, got", err)
}
Expand Down Expand Up @@ -339,7 +339,7 @@ func TestDB_Reset(t *testing.T) {
// Create initial collection
db := NewDB()
// We ignore the return value. CreateCollection is tested elsewhere.
_, err := db.CreateCollection(name, metadata, embeddingFunc, nil)
_, err := db.CreateCollection(name, metadata, embeddingFunc)
if err != nil {
t.Fatal("expected no error, got", err)
}
Expand Down
5 changes: 1 addition & 4 deletions examples/rag-wikipedia-ollama/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@ const (
embeddingModel = "nomic-embed-text"
)

// The nomic-embed-text-v1.5 model doesn't return normalized embeddings
var normalized = false

func main() {
ctx := context.Background()

Expand Down Expand Up @@ -52,7 +49,7 @@ func main() {
// variable to be set.
// For this example we choose to use a locally running embedding model though.
// It requires Ollama to serve its API at "http://localhost:11434/api".
collection, err := db.GetOrCreateCollection("Wikipedia", nil, chromem.NewEmbeddingFuncOllama(embeddingModel), &normalized)
collection, err := db.GetOrCreateCollection("Wikipedia", nil, chromem.NewEmbeddingFuncOllama(embeddingModel))
if err != nil {
panic(err)
}
Expand Down
5 changes: 1 addition & 4 deletions examples/semantic-search-arxiv-openai/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ import (

const searchTerm = "semantic search with vector databases"

// OpenAI embeddings are already normalized.
var normalized = true

func main() {
ctx := context.Background()

Expand All @@ -33,7 +30,7 @@ func main() {
// We pass nil as embedding function to use the default (OpenAI text-embedding-3-small),
// which is very good and cheap. It requires the OPENAI_API_KEY environment
// variable to be set.
collection, err := db.GetOrCreateCollection("arXiv cs.CL 2023", nil, nil, &normalized)
collection, err := db.GetOrCreateCollection("arXiv cs.CL 2023", nil, nil)
if err != nil {
panic(err)
}
Expand Down
4 changes: 2 additions & 2 deletions query.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func documentMatchesFilters(document *Document, where, whereDocument map[string]
return true
}

func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Document, isNormalized *bool) ([]docSim, error) {
func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Document) ([]docSim, error) {
similarities := make([]docSim, 0, len(docs))
similaritiesLock := sync.Mutex{}

Expand Down Expand Up @@ -145,7 +145,7 @@ func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Docu
return
}

sim, err := cosineSimilarity(queryVectors, doc.Embedding, isNormalized)
sim, err := cosineSimilarity(queryVectors, doc.Embedding)
if err != nil {
setSharedErr(fmt.Errorf("couldn't calculate similarity for document '%s': %w", doc.ID, err))
return
Expand Down
24 changes: 5 additions & 19 deletions vector.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,19 @@ import (

const isNormalizedPrecisionTolerance = 1e-6

var (
falseVal = false
trueVal = true
)

// cosineSimilarity calculates the cosine similarity between two vectors.
// Pass isNormalized=true if the vectors are already normalized, false
// to normalize them, and nil to autodetect.
// Vectors are normalized first.
// The resulting value represents the similarity, so a higher value means the
// vectors are more similar.
func cosineSimilarity(a, b []float32, isNormalized *bool) (float32, error) {
func cosineSimilarity(a, b []float32) (float32, error) {
// The vectors must have the same length
if len(a) != len(b) {
return 0, errors.New("vectors must have the same length")
}

if isNormalized == nil {
if !checkNormalized(a) || !checkNormalized(b) {
isNormalized = &falseVal
} else {
isNormalized = &trueVal
}
}
if !*isNormalized {
if !isNormalized(a) || !isNormalized(b) {
a, b = normalizeVector(a), normalizeVector(b)
}

var dotProduct float32
for i := range a {
dotProduct += a[i] * b[i]
Expand All @@ -58,8 +44,8 @@ func normalizeVector(v []float32) []float32 {
return res
}

// checkNormalized checks if the vector is normalized.
func checkNormalized(v []float32) bool {
// isNormalized checks if the vector is normalized.
func isNormalized(v []float32) bool {
var sqSum float64
for _, val := range v {
sqSum += float64(val) * float64(val)
Expand Down

0 comments on commit 05c4f76

Please sign in to comment.