From 30d764759e245c3335b94dd0094e89b9cbee6247 Mon Sep 17 00:00:00 2001 From: Yahav Itzhak Date: Thu, 12 Oct 2023 14:13:13 +0300 Subject: [PATCH] Add force cancel option to the producer-consumer (#40) --- parallel/runner.go | 70 +++++++++--------- parallel/runner_test.go | 153 ++++++++++++++++++++++++++++++++++++---- 2 files changed, 176 insertions(+), 47 deletions(-) diff --git a/parallel/runner.go b/parallel/runner.go index a8311cc..5920ef6 100644 --- a/parallel/runner.go +++ b/parallel/runner.go @@ -14,10 +14,10 @@ type Runner interface { AddTaskWithError(TaskFunc, OnErrorFunc) (int, error) Run() Done() - Cancel() + Cancel(bool) Errors() map[int]error ActiveThreads() uint32 - OpenThreads() int + OpenThreads() uint32 IsStarted() bool SetMaxParallel(int) GetFinishedNotification() chan bool @@ -39,10 +39,12 @@ type runner struct { tasks chan *task // Tasks counter, used to give each task an identifier (task.num). taskId uint32 - // A channel that is closed when the runner is cancelled. - cancel chan struct{} - // Used to make sure the cancel channel is closed only once. + // True when Cancel was invoked + cancel atomic.Bool + // Used to make sure that cancel is called only once. cancelOnce sync.Once + // Used to make sure that done is called only once. + doneOnce sync.Once // The maximum number of threads running in parallel. maxParallel int // If true, the runner will be cancelled on the first error thrown from a task. @@ -52,15 +54,15 @@ type runner struct { // A WaitGroup that waits for all the threads to close. threadsWaitGroup sync.WaitGroup // Threads counter, used to give each thread an identifier (threadId). - threadCount uint32 + threadCount atomic.Uint32 // The number of open threads. - openThreads int + openThreads atomic.Uint32 // A lock on openThreads. openThreadsLock sync.Mutex // The number of threads currently running tasks. - activeThreads uint32 + activeThreads atomic.Uint32 // The number of tasks in the queue. - totalTasksInQueue uint32 + totalTasksInQueue atomic.Uint32 // Indicate that the runner has finished. finishedNotifier chan bool // Indicates that the finish channel is closed. @@ -91,7 +93,7 @@ func NewRunner(maxParallel int, capacity uint, failFast bool) *runner { finishedNotifier: make(chan bool, 1), maxParallel: consumers, failFast: failFast, - cancel: make(chan struct{}), + cancel: atomic.Bool{}, tasks: make(chan *task, capacity), } r.errors = make(map[int]error) @@ -122,14 +124,12 @@ func (r *runner) addTask(t TaskFunc, errorHandler OnErrorFunc) (int, error) { nextCount := atomic.AddUint32(&r.taskId, 1) task := &task{run: t, num: nextCount - 1, onError: errorHandler} - select { - case <-r.cancel: + if r.cancel.Load() { return -1, errors.New("runner stopped") - default: - atomic.AddUint32(&r.totalTasksInQueue, 1) - r.tasks <- task - return int(task.num), nil } + r.totalTasksInQueue.Add(1) + r.tasks <- task + return int(task.num), nil } // Run r.maxParallel go routines in order to consume all the tasks @@ -154,7 +154,9 @@ func (r *runner) Run() { // Done is used to notify that no more tasks will be produced. func (r *runner) Done() { - close(r.tasks) + r.doneOnce.Do(func() { + close(r.tasks) + }) } // GetFinishedNotification returns the finishedNotifier channel, which notifies when the runner is done. @@ -171,10 +173,14 @@ func (r *runner) IsStarted() bool { // Cancel stops the Runner from getting new tasks and empties the tasks queue. // No new tasks will be executed, and tasks that already started will continue running and won't be interrupted. // If this Runner is already cancelled, then this function will do nothing. -func (r *runner) Cancel() { +// force - If true, pending tasks in the queue will not be handled. +func (r *runner) Cancel(force bool) { + // No more adding tasks + r.cancel.Store(true) + if force { + r.Done() + } r.cancelOnce.Do(func() { - // No more adding tasks - close(r.cancel) // Consume all tasks left for len(r.tasks) > 0 { <-r.tasks @@ -191,12 +197,12 @@ func (r *runner) Errors() map[int]error { } // OpenThreads returns the number of open threads (including idle threads). -func (r *runner) OpenThreads() int { - return r.openThreads +func (r *runner) OpenThreads() uint32 { + return r.openThreads.Load() } func (r *runner) ActiveThreads() uint32 { - return r.activeThreads + return r.activeThreads.Load() } func (r *runner) SetFinishedNotification(toEnable bool) { @@ -222,28 +228,28 @@ func (r *runner) SetMaxParallel(newVal int) { func (r *runner) addThread() { r.threadsWaitGroup.Add(1) - nextThreadId := atomic.AddUint32(&r.threadCount, 1) - 1 + nextThreadId := r.threadCount.Add(1) - 1 go func(threadId int) { defer r.threadsWaitGroup.Done() r.openThreadsLock.Lock() - r.openThreads++ + r.openThreads.Add(1) r.openThreadsLock.Unlock() // Keep on taking tasks from the queue. for t := range r.tasks { // Increase the total of active threads. - atomic.AddUint32(&r.activeThreads, 1) + r.activeThreads.Add(1) atomic.AddUint32(&r.started, 1) // Run the task. e := t.run(threadId) // Decrease the total of active threads. - atomic.AddUint32(&r.activeThreads, ^uint32(0)) + r.activeThreads.Add(^uint32(0)) // Decrease the total of in progress tasks. - atomic.AddUint32(&r.totalTasksInQueue, ^uint32(0)) + r.totalTasksInQueue.Add(^uint32(0)) if r.finishedNotificationEnabled { r.finishedNotifierLock.Lock() // Notify that the runner has finished its job. - if r.activeThreads == 0 && r.totalTasksInQueue == 0 { + if r.activeThreads.Load() == 0 && r.totalTasksInQueue.Load() == 0 { r.notifyFinished() } r.finishedNotifierLock.Unlock() @@ -260,15 +266,15 @@ func (r *runner) addThread() { r.errorsLock.Unlock() if r.failFast { - r.Cancel() + r.Cancel(false) break } } r.openThreadsLock.Lock() // If the total of open threads is larger than the maximum (maxParallel), then this thread should be closed. - if r.openThreads > r.maxParallel { - r.openThreads-- + if int(r.openThreads.Load()) > r.maxParallel { + r.openThreads.Add(^uint32(0)) r.openThreadsLock.Unlock() break } diff --git a/parallel/runner_test.go b/parallel/runner_test.go index a130cdb..bafd3c1 100644 --- a/parallel/runner_test.go +++ b/parallel/runner_test.go @@ -1,14 +1,29 @@ package parallel import ( + "errors" "fmt" - "math" "math/rand" + "sync" "testing" "time" + + "github.com/stretchr/testify/assert" ) -func TestTask(t *testing.T) { +var errTest = errors.New("some error") + +func TestIsStarted(t *testing.T) { + runner := NewBounedRunner(1, false) + runner.AddTask(func(i int) error { + return nil + }) + runner.Done() + runner.Run() + assert.True(t, runner.IsStarted()) +} + +func TestAddTask(t *testing.T) { const count = 70 results := make(chan int, 100) @@ -17,44 +32,152 @@ func TestTask(t *testing.T) { var expectedErrorTotal int for i := 0; i < count; i++ { expectedTotal += i - if float64(i) > math.Floor(float64(count)/2) { + if float64(i) > float64(count)/2 { expectedErrorTotal += i } x := i - runner.AddTask(func(i int) error { + _, err := runner.AddTask(func(int) error { results <- x time.Sleep(time.Millisecond * time.Duration(rand.Intn(50))) - if float64(x) > math.Floor(float64(count)/2) { + if float64(x) > float64(count)/2 { return fmt.Errorf("Second half value %d not counted", x) } return nil }) + assert.NoError(t, err) } runner.Done() runner.Run() - errs := runner.Errors() - close(results) var resultsTotal int for result := range results { resultsTotal += result } - if resultsTotal != expectedTotal { - t.Error("Unexpected results total:", resultsTotal) - } + assert.Equal(t, expectedTotal, resultsTotal) var errorsTotal int - for k, v := range errs { + for k, v := range runner.Errors() { if v != nil { errorsTotal += k } } - if errorsTotal != expectedErrorTotal { - t.Error("Unexpected errs total:", errorsTotal) + assert.Equal(t, expectedErrorTotal, errorsTotal) + assert.NotZero(t, errorsTotal) +} + +func TestAddTaskWithError(t *testing.T) { + // Create new runner + runner := NewRunner(1, 1, false) + + // Add task with error + var receivedError = new(error) + onError := func(err error) { *receivedError = err } + taskFunc := func(int) error { return errTest } + _, err := runner.AddTaskWithError(taskFunc, onError) + assert.NoError(t, err) + + // Wait for task to finish + runner.Done() + runner.Run() + + // Assert error captured + assert.Equal(t, errTest, *receivedError) + assert.Equal(t, errTest, runner.Errors()[0]) +} + +func TestCancel(t *testing.T) { + // Create new runner + runner := NewBounedRunner(1, false) + + // Cancel to prevent receiving another tasks + runner.Cancel(false) + + // Add task and expect error + _, err := runner.AddTask(func(int) error { return nil }) + assert.ErrorContains(t, err, "runner stopped") +} + +func TestForceCancel(t *testing.T) { + // Create new runner + const capacity = 10 + runner := NewRunner(1, capacity, true) + // Run tasks + for i := 0; i < capacity; i++ { + taskId := i + _, err := runner.AddTask(func(int) error { + assert.Less(t, taskId, 9) + time.Sleep(100 * time.Millisecond) + return nil + }) + assert.NoError(t, err) } - if errorsTotal == 0 { - t.Error("Unexpected 0 errs total") + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + runner.Run() + }() + go func() { + time.Sleep(200 * time.Millisecond) + runner.Cancel(true) + }() + wg.Wait() + + assert.InDelta(t, 5, runner.started, 4) +} + +func TestFailFast(t *testing.T) { + // Create new runner with fail-fast + runner := NewBounedRunner(1, true) + + // Add task that returns an error + _, err := runner.AddTask(func(int) error { + return errTest + }) + assert.NoError(t, err) + + // Wait for task to finish + runner.Run() + + // Add another task and expect error + _, err = runner.AddTask(func(int) error { + return nil + }) + assert.ErrorContains(t, err, "runner stopped") +} + +func TestNotifyFinished(t *testing.T) { + // Create new runner + runner := NewBounedRunner(1, false) + runner.SetFinishedNotification(true) + + // Cancel to prevent receiving another tasks + runner.Cancel(false) + <-runner.GetFinishedNotification() +} + +func TestMaxParallel(t *testing.T) { + // Create new runner with capacity of 10 and max parallelism of 3 + const capacity = 10 + const parallelism = 3 + runner := NewRunner(parallelism, capacity, false) + + // Run tasks in parallel + for i := 0; i < capacity; i++ { + _, err := runner.AddTask(func(int) error { + // Assert in range between 1 and 3 + assert.InDelta(t, 2, runner.ActiveThreads(), 1) + assert.InDelta(t, 2, runner.OpenThreads(), 1) + time.Sleep(100 * time.Millisecond) + return nil + }) + assert.NoError(t, err) } + + // Wait for tasks to finish + runner.Done() + runner.Run() + assert.Equal(t, uint32(capacity), runner.started) }