Skip to content

Commit

Permalink
Add force cancel option to the producer-consumer (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
yahavi authored Oct 12, 2023
1 parent 2bf299d commit 30d7647
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 47 deletions.
70 changes: 38 additions & 32 deletions parallel/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ type Runner interface {
AddTaskWithError(TaskFunc, OnErrorFunc) (int, error)
Run()
Done()
Cancel()
Cancel(bool)
Errors() map[int]error
ActiveThreads() uint32
OpenThreads() int
OpenThreads() uint32
IsStarted() bool
SetMaxParallel(int)
GetFinishedNotification() chan bool
Expand All @@ -39,10 +39,12 @@ type runner struct {
tasks chan *task
// Tasks counter, used to give each task an identifier (task.num).
taskId uint32
// A channel that is closed when the runner is cancelled.
cancel chan struct{}
// Used to make sure the cancel channel is closed only once.
// True when Cancel was invoked
cancel atomic.Bool
// Used to make sure that cancel is called only once.
cancelOnce sync.Once
// Used to make sure that done is called only once.
doneOnce sync.Once
// The maximum number of threads running in parallel.
maxParallel int
// If true, the runner will be cancelled on the first error thrown from a task.
Expand All @@ -52,15 +54,15 @@ type runner struct {
// A WaitGroup that waits for all the threads to close.
threadsWaitGroup sync.WaitGroup
// Threads counter, used to give each thread an identifier (threadId).
threadCount uint32
threadCount atomic.Uint32
// The number of open threads.
openThreads int
openThreads atomic.Uint32
// A lock on openThreads.
openThreadsLock sync.Mutex
// The number of threads currently running tasks.
activeThreads uint32
activeThreads atomic.Uint32
// The number of tasks in the queue.
totalTasksInQueue uint32
totalTasksInQueue atomic.Uint32
// Indicate that the runner has finished.
finishedNotifier chan bool
// Indicates that the finish channel is closed.
Expand Down Expand Up @@ -91,7 +93,7 @@ func NewRunner(maxParallel int, capacity uint, failFast bool) *runner {
finishedNotifier: make(chan bool, 1),
maxParallel: consumers,
failFast: failFast,
cancel: make(chan struct{}),
cancel: atomic.Bool{},
tasks: make(chan *task, capacity),
}
r.errors = make(map[int]error)
Expand Down Expand Up @@ -122,14 +124,12 @@ func (r *runner) addTask(t TaskFunc, errorHandler OnErrorFunc) (int, error) {
nextCount := atomic.AddUint32(&r.taskId, 1)
task := &task{run: t, num: nextCount - 1, onError: errorHandler}

select {
case <-r.cancel:
if r.cancel.Load() {
return -1, errors.New("runner stopped")
default:
atomic.AddUint32(&r.totalTasksInQueue, 1)
r.tasks <- task
return int(task.num), nil
}
r.totalTasksInQueue.Add(1)
r.tasks <- task
return int(task.num), nil
}

// Run r.maxParallel go routines in order to consume all the tasks
Expand All @@ -154,7 +154,9 @@ func (r *runner) Run() {

// Done is used to notify that no more tasks will be produced.
func (r *runner) Done() {
close(r.tasks)
r.doneOnce.Do(func() {
close(r.tasks)
})
}

// GetFinishedNotification returns the finishedNotifier channel, which notifies when the runner is done.
Expand All @@ -171,10 +173,14 @@ func (r *runner) IsStarted() bool {
// Cancel stops the Runner from getting new tasks and empties the tasks queue.
// No new tasks will be executed, and tasks that already started will continue running and won't be interrupted.
// If this Runner is already cancelled, then this function will do nothing.
func (r *runner) Cancel() {
// force - If true, pending tasks in the queue will not be handled.
func (r *runner) Cancel(force bool) {
// No more adding tasks
r.cancel.Store(true)
if force {
r.Done()
}
r.cancelOnce.Do(func() {
// No more adding tasks
close(r.cancel)
// Consume all tasks left
for len(r.tasks) > 0 {
<-r.tasks
Expand All @@ -191,12 +197,12 @@ func (r *runner) Errors() map[int]error {
}

// OpenThreads returns the number of open threads (including idle threads).
func (r *runner) OpenThreads() int {
return r.openThreads
func (r *runner) OpenThreads() uint32 {
return r.openThreads.Load()
}

func (r *runner) ActiveThreads() uint32 {
return r.activeThreads
return r.activeThreads.Load()
}

func (r *runner) SetFinishedNotification(toEnable bool) {
Expand All @@ -222,28 +228,28 @@ func (r *runner) SetMaxParallel(newVal int) {

func (r *runner) addThread() {
r.threadsWaitGroup.Add(1)
nextThreadId := atomic.AddUint32(&r.threadCount, 1) - 1
nextThreadId := r.threadCount.Add(1) - 1
go func(threadId int) {
defer r.threadsWaitGroup.Done()
r.openThreadsLock.Lock()
r.openThreads++
r.openThreads.Add(1)
r.openThreadsLock.Unlock()

// Keep on taking tasks from the queue.
for t := range r.tasks {
// Increase the total of active threads.
atomic.AddUint32(&r.activeThreads, 1)
r.activeThreads.Add(1)
atomic.AddUint32(&r.started, 1)
// Run the task.
e := t.run(threadId)
// Decrease the total of active threads.
atomic.AddUint32(&r.activeThreads, ^uint32(0))
r.activeThreads.Add(^uint32(0))
// Decrease the total of in progress tasks.
atomic.AddUint32(&r.totalTasksInQueue, ^uint32(0))
r.totalTasksInQueue.Add(^uint32(0))
if r.finishedNotificationEnabled {
r.finishedNotifierLock.Lock()
// Notify that the runner has finished its job.
if r.activeThreads == 0 && r.totalTasksInQueue == 0 {
if r.activeThreads.Load() == 0 && r.totalTasksInQueue.Load() == 0 {
r.notifyFinished()
}
r.finishedNotifierLock.Unlock()
Expand All @@ -260,15 +266,15 @@ func (r *runner) addThread() {
r.errorsLock.Unlock()

if r.failFast {
r.Cancel()
r.Cancel(false)
break
}
}

r.openThreadsLock.Lock()
// If the total of open threads is larger than the maximum (maxParallel), then this thread should be closed.
if r.openThreads > r.maxParallel {
r.openThreads--
if int(r.openThreads.Load()) > r.maxParallel {
r.openThreads.Add(^uint32(0))
r.openThreadsLock.Unlock()
break
}
Expand Down
153 changes: 138 additions & 15 deletions parallel/runner_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,29 @@
package parallel

import (
"errors"
"fmt"
"math"
"math/rand"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestTask(t *testing.T) {
var errTest = errors.New("some error")

func TestIsStarted(t *testing.T) {
runner := NewBounedRunner(1, false)
runner.AddTask(func(i int) error {
return nil
})
runner.Done()
runner.Run()
assert.True(t, runner.IsStarted())
}

func TestAddTask(t *testing.T) {
const count = 70
results := make(chan int, 100)

Expand All @@ -17,44 +32,152 @@ func TestTask(t *testing.T) {
var expectedErrorTotal int
for i := 0; i < count; i++ {
expectedTotal += i
if float64(i) > math.Floor(float64(count)/2) {
if float64(i) > float64(count)/2 {
expectedErrorTotal += i
}

x := i
runner.AddTask(func(i int) error {
_, err := runner.AddTask(func(int) error {
results <- x
time.Sleep(time.Millisecond * time.Duration(rand.Intn(50)))
if float64(x) > math.Floor(float64(count)/2) {
if float64(x) > float64(count)/2 {
return fmt.Errorf("Second half value %d not counted", x)
}
return nil
})
assert.NoError(t, err)
}
runner.Done()
runner.Run()

errs := runner.Errors()

close(results)
var resultsTotal int
for result := range results {
resultsTotal += result
}
if resultsTotal != expectedTotal {
t.Error("Unexpected results total:", resultsTotal)
}
assert.Equal(t, expectedTotal, resultsTotal)

var errorsTotal int
for k, v := range errs {
for k, v := range runner.Errors() {
if v != nil {
errorsTotal += k
}
}
if errorsTotal != expectedErrorTotal {
t.Error("Unexpected errs total:", errorsTotal)
assert.Equal(t, expectedErrorTotal, errorsTotal)
assert.NotZero(t, errorsTotal)
}

func TestAddTaskWithError(t *testing.T) {
// Create new runner
runner := NewRunner(1, 1, false)

// Add task with error
var receivedError = new(error)
onError := func(err error) { *receivedError = err }
taskFunc := func(int) error { return errTest }
_, err := runner.AddTaskWithError(taskFunc, onError)
assert.NoError(t, err)

// Wait for task to finish
runner.Done()
runner.Run()

// Assert error captured
assert.Equal(t, errTest, *receivedError)
assert.Equal(t, errTest, runner.Errors()[0])
}

func TestCancel(t *testing.T) {
// Create new runner
runner := NewBounedRunner(1, false)

// Cancel to prevent receiving another tasks
runner.Cancel(false)

// Add task and expect error
_, err := runner.AddTask(func(int) error { return nil })
assert.ErrorContains(t, err, "runner stopped")
}

func TestForceCancel(t *testing.T) {
// Create new runner
const capacity = 10
runner := NewRunner(1, capacity, true)
// Run tasks
for i := 0; i < capacity; i++ {
taskId := i
_, err := runner.AddTask(func(int) error {
assert.Less(t, taskId, 9)
time.Sleep(100 * time.Millisecond)
return nil
})
assert.NoError(t, err)
}
if errorsTotal == 0 {
t.Error("Unexpected 0 errs total")
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
runner.Run()
}()
go func() {
time.Sleep(200 * time.Millisecond)
runner.Cancel(true)
}()
wg.Wait()

assert.InDelta(t, 5, runner.started, 4)
}

func TestFailFast(t *testing.T) {
// Create new runner with fail-fast
runner := NewBounedRunner(1, true)

// Add task that returns an error
_, err := runner.AddTask(func(int) error {
return errTest
})
assert.NoError(t, err)

// Wait for task to finish
runner.Run()

// Add another task and expect error
_, err = runner.AddTask(func(int) error {
return nil
})
assert.ErrorContains(t, err, "runner stopped")
}

func TestNotifyFinished(t *testing.T) {
// Create new runner
runner := NewBounedRunner(1, false)
runner.SetFinishedNotification(true)

// Cancel to prevent receiving another tasks
runner.Cancel(false)
<-runner.GetFinishedNotification()
}

func TestMaxParallel(t *testing.T) {
// Create new runner with capacity of 10 and max parallelism of 3
const capacity = 10
const parallelism = 3
runner := NewRunner(parallelism, capacity, false)

// Run tasks in parallel
for i := 0; i < capacity; i++ {
_, err := runner.AddTask(func(int) error {
// Assert in range between 1 and 3
assert.InDelta(t, 2, runner.ActiveThreads(), 1)
assert.InDelta(t, 2, runner.OpenThreads(), 1)
time.Sleep(100 * time.Millisecond)
return nil
})
assert.NoError(t, err)
}

// Wait for tasks to finish
runner.Done()
runner.Run()
assert.Equal(t, uint32(capacity), runner.started)
}

0 comments on commit 30d7647

Please sign in to comment.