Skip to content

Commit

Permalink
Refactor Query() concurrent error handling
Browse files Browse the repository at this point in the history
To match the refactoring in Collection.AddDocuments
from the previous commit
  • Loading branch information
philippgille committed Mar 3, 2024
1 parent 236b6a9 commit a4fd279
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions query.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package chromem

import (
"context"
"fmt"
"runtime"
"strings"
"sync"
Expand Down Expand Up @@ -109,14 +110,23 @@ func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Docu
concurrency = numDocs
}

ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)

docChan := make(chan *Document, concurrency*2)
var globalErr error
globalErrLock := sync.Mutex{}
ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)
setGlobalErr := func(err error) {
globalErrLock.Lock()
defer globalErrLock.Unlock()
// Another goroutine might have already set the error.
if globalErr == nil {
globalErr = err
// Cancel the operation for all other goroutines.
cancel(globalErr)
}
}

wg := sync.WaitGroup{}
docChan := make(chan *Document, concurrency*2)
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func() {
Expand All @@ -129,14 +139,7 @@ func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Docu

sim, err := cosineSimilarity(queryVectors, doc.Embedding)
if err != nil {
globalErrLock.Lock()
defer globalErrLock.Unlock()
// Another goroutine might have already set the error.
if globalErr == nil {
globalErr = err
// Cancel the operation for all other goroutines.
cancel(globalErr)
}
setGlobalErr(fmt.Errorf("couldn't calculate similarity for document '%s': %w", doc.ID, err))
return
}

Expand Down

0 comments on commit a4fd279

Please sign in to comment.