diff --git a/query.go b/query.go index ab6b832..4d15f60 100644 --- a/query.go +++ b/query.go @@ -2,6 +2,7 @@ package chromem import ( "context" + "fmt" "runtime" "strings" "sync" @@ -109,14 +110,23 @@ func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Docu concurrency = numDocs } - ctx, cancel := context.WithCancelCause(ctx) - defer cancel(nil) - - docChan := make(chan *Document, concurrency*2) var globalErr error globalErrLock := sync.Mutex{} + ctx, cancel := context.WithCancelCause(ctx) + defer cancel(nil) + setGlobalErr := func(err error) { + globalErrLock.Lock() + defer globalErrLock.Unlock() + // Another goroutine might have already set the error. + if globalErr == nil { + globalErr = err + // Cancel the operation for all other goroutines. + cancel(globalErr) + } + } wg := sync.WaitGroup{} + docChan := make(chan *Document, concurrency*2) for i := 0; i < concurrency; i++ { wg.Add(1) go func() { @@ -129,14 +139,7 @@ func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Docu sim, err := cosineSimilarity(queryVectors, doc.Embedding) if err != nil { - globalErrLock.Lock() - defer globalErrLock.Unlock() - // Another goroutine might have already set the error. - if globalErr == nil { - globalErr = err - // Cancel the operation for all other goroutines. - cancel(globalErr) - } + setGlobalErr(fmt.Errorf("couldn't calculate similarity for document '%s': %w", doc.ID, err)) return }