Skip to content

Commit

Permalink
Add method to add embeddings concurrently
Browse files Browse the repository at this point in the history
  • Loading branch information
philippgille committed Jan 1, 2024
1 parent cefec66 commit 50fe3b7
Showing 1 changed file with 70 additions and 10 deletions.
80 changes: 70 additions & 10 deletions collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package chromem
import (
"context"
"errors"
"fmt"
"sync"
)

Expand Down Expand Up @@ -41,6 +42,20 @@ func newCollection(name string, metadata map[string]string, embed EmbeddingFunc)
//
// A row-based API will be added when Chroma adds it (they already plan to).
func (c *Collection) Add(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, documents []string) error {
return c.add(ctx, ids, documents, embeddings, metadatas, 1)
}

// AddConcurrently is like Add, but adds embeddings concurrently.
// This is mostly useful when you don't pass any embeddings so they have to be created.
// Upon error, concurrently running operations are canceled and the error is returned.
func (c *Collection) AddConcurrently(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, documents []string, concurrency int) error {
if concurrency < 1 {
return errors.New("concurrency must be at least 1")
}
return c.add(ctx, ids, documents, embeddings, metadatas, concurrency)
}

func (c *Collection) add(ctx context.Context, ids []string, documents []string, embeddings [][]float32, metadatas []map[string]string, concurrency int) error {
if len(ids) == 0 || len(documents) == 0 {
return errors.New("ids and documents must not be empty")
}
Expand All @@ -54,23 +69,68 @@ func (c *Collection) Add(ctx context.Context, ids []string, embeddings [][]float
return errors.New("ids, metadatas and documents must have the same length")
}

var embedding []float32
var metadata map[string]string
var err error
ctx, cancel := context.WithCancelCause(ctx)

var wg sync.WaitGroup
var globalErr error
var globalErrLock sync.RWMutex
semaphore := make(chan struct{}, concurrency)
for i, document := range documents {
var embedding []float32
var metadata map[string]string
if len(embeddings) != 0 {
embedding = embeddings[i]
}
if len(metadatas) != 0 {
metadata = metadatas[i]
}
c.documentsLock.Lock()
// We don't defer the unlock because we want to unlock much earlier
c.documents[ids[i]], err = newDocument(ctx, ids[i], embedding, metadata, document, c.embed)
c.documentsLock.Unlock()
if err != nil {
return err
}

wg.Add(1)
go func(id string, embedding []float32, metadata map[string]string, document string) {
defer wg.Done()

// Don't even start if we already have an error
globalErrLock.RLock()
// We don't defer the unlock because we want to unlock much earlier.
if globalErr != nil {
globalErrLock.RUnlock()
return
}
globalErrLock.RUnlock()

// Wait here while $concurrency other goroutines are creating documents.
semaphore <- struct{}{}
defer func() { <-semaphore }()

err := c.addRow(ctx, id, document, embedding, metadata)
if err != nil {
globalErrLock.Lock()
defer globalErrLock.Unlock()
// Another goroutine might have already set the error.
if globalErr == nil {
globalErr = err
// Cancel the operation for all other goroutines.
cancel(globalErr)
}
return
}
}(ids[i], embedding, metadata, document)
}

wg.Wait()

return globalErr
}

func (c *Collection) addRow(ctx context.Context, id string, document string, embedding []float32, metadata map[string]string) error {
doc, err := newDocument(ctx, id, embedding, metadata, document, c.embed)
if err != nil {
return fmt.Errorf("couldn't create document '%s': %w", id, err)
}

c.documentsLock.Lock()
defer c.documentsLock.Unlock()
c.documents[id] = doc

return nil
}

0 comments on commit 50fe3b7

Please sign in to comment.