diff --git a/collection.go b/collection.go index 4262a92..114026d 100644 --- a/collection.go +++ b/collection.go @@ -3,6 +3,7 @@ package chromem import ( "context" "errors" + "fmt" "sync" ) @@ -41,6 +42,20 @@ func newCollection(name string, metadata map[string]string, embed EmbeddingFunc) // // A row-based API will be added when Chroma adds it (they already plan to). func (c *Collection) Add(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, documents []string) error { + return c.add(ctx, ids, documents, embeddings, metadatas, 1) +} + +// AddConcurrently is like Add, but adds embeddings concurrently. +// This is mostly useful when you don't pass any embeddings so they have to be created. +// Upon error, concurrently running operations are canceled and the error is returned. +func (c *Collection) AddConcurrently(ctx context.Context, ids []string, embeddings [][]float32, metadatas []map[string]string, documents []string, concurrency int) error { + if concurrency < 1 { + return errors.New("concurrency must be at least 1") + } + return c.add(ctx, ids, documents, embeddings, metadatas, concurrency) +} + +func (c *Collection) add(ctx context.Context, ids []string, documents []string, embeddings [][]float32, metadatas []map[string]string, concurrency int) error { if len(ids) == 0 || len(documents) == 0 { return errors.New("ids and documents must not be empty") } @@ -54,23 +69,68 @@ func (c *Collection) Add(ctx context.Context, ids []string, embeddings [][]float return errors.New("ids, metadatas and documents must have the same length") } - var embedding []float32 - var metadata map[string]string - var err error + ctx, cancel := context.WithCancelCause(ctx) + + var wg sync.WaitGroup + var globalErr error + var globalErrLock sync.RWMutex + semaphore := make(chan struct{}, concurrency) for i, document := range documents { + var embedding []float32 + var metadata map[string]string if len(embeddings) != 0 { embedding = embeddings[i] } if len(metadatas) != 0 { metadata = metadatas[i] } - c.documentsLock.Lock() - // We don't defer the unlock because we want to unlock much earlier - c.documents[ids[i]], err = newDocument(ctx, ids[i], embedding, metadata, document, c.embed) - c.documentsLock.Unlock() - if err != nil { - return err - } + + wg.Add(1) + go func(id string, embedding []float32, metadata map[string]string, document string) { + defer wg.Done() + + // Don't even start if we already have an error + globalErrLock.RLock() + // We don't defer the unlock because we want to unlock much earlier. + if globalErr != nil { + globalErrLock.RUnlock() + return + } + globalErrLock.RUnlock() + + // Wait here while $concurrency other goroutines are creating documents. + semaphore <- struct{}{} + defer func() { <-semaphore }() + + err := c.addRow(ctx, id, document, embedding, metadata) + if err != nil { + globalErrLock.Lock() + defer globalErrLock.Unlock() + // Another goroutine might have already set the error. + if globalErr == nil { + globalErr = err + // Cancel the operation for all other goroutines. + cancel(globalErr) + } + return + } + }(ids[i], embedding, metadata, document) } + + wg.Wait() + + return globalErr +} + +func (c *Collection) addRow(ctx context.Context, id string, document string, embedding []float32, metadata map[string]string) error { + doc, err := newDocument(ctx, id, embedding, metadata, document, c.embed) + if err != nil { + return fmt.Errorf("couldn't create document '%s': %w", id, err) + } + + c.documentsLock.Lock() + defer c.documentsLock.Unlock() + c.documents[id] = doc + return nil }