From 88cda2aeaea01ab3d482c3dcdea422f502f3e30b Mon Sep 17 00:00:00 2001 From: Lucas Godoy Date: Thu, 1 Jul 2021 09:42:56 -0300 Subject: [PATCH] Replace WaitGroup using semphore instead --- wpool/exec.go | 56 +++++++++++++++++++++++++++++++++++++--------- wpool/exec_test.go | 20 +++++++++-------- 2 files changed, 56 insertions(+), 20 deletions(-) diff --git a/wpool/exec.go b/wpool/exec.go index f9057a6..2efdfd2 100644 --- a/wpool/exec.go +++ b/wpool/exec.go @@ -3,11 +3,10 @@ package wpool import ( "context" "fmt" - "sync" ) -func worker(ctx context.Context, wg *sync.WaitGroup, jobs <-chan Job, results chan<- Result) { - defer wg.Done() +func worker(ctx context.Context, jobs <-chan Job, results chan<- Result, sem semaphore) { + defer sem.release() for { select { case job, ok := <-jobs: @@ -26,11 +25,47 @@ func worker(ctx context.Context, wg *sync.WaitGroup, jobs <-chan Job, results ch } } +type semaphore interface { + acquire() + release() + wait() + close() +} + +type slot struct{} +type slots chan slot + +type execution struct { + slots slots +} + +func newExecutionSlots(capacity int) execution { + slots := make(chan slot, capacity) + return execution{slots: slots} +} + +func (e execution) acquire() { + e.slots <- slot{} +} + +func (e execution) release() { + <-e.slots +} + +func (e execution) wait() { + for i := 0; i < cap(e.slots); i++ { + e.slots <- slot{} + } +} + +func (e execution) close() { + close(e.slots) +} + type WorkerPool struct { workersCount int jobs chan Job results chan Result - Done chan struct{} } func New(wcount int) WorkerPool { @@ -38,23 +73,22 @@ func New(wcount int) WorkerPool { workersCount: wcount, jobs: make(chan Job, wcount), results: make(chan Result, wcount), - Done: make(chan struct{}), } } func (wp WorkerPool) Run(ctx context.Context) { - var wg sync.WaitGroup + eSlots := newExecutionSlots(wp.workersCount) + defer eSlots.close() for i := 0; i < wp.workersCount; i++ { - wg.Add(1) + eSlots.acquire() // fan out worker goroutines //reading from jobs channel and //pushing calcs into results channel - go worker(ctx, &wg, wp.jobs, wp.results) + go worker(ctx, wp.jobs, wp.results, eSlots) } - wg.Wait() - close(wp.Done) + eSlots.wait() close(wp.results) } @@ -63,7 +97,7 @@ func (wp WorkerPool) Results() <-chan Result { } func (wp WorkerPool) GenerateFrom(jobsBulk []Job) { - for i, _ := range jobsBulk { + for i := range jobsBulk { wp.jobs <- jobsBulk[i] } close(wp.jobs) diff --git a/wpool/exec_test.go b/wpool/exec_test.go index 617c58b..b475cc9 100644 --- a/wpool/exec_test.go +++ b/wpool/exec_test.go @@ -27,7 +27,7 @@ func TestWorkerPool(t *testing.T) { select { case r, ok := <-wp.Results(): if !ok { - continue + return } i, err := strconv.ParseInt(string(r.Descriptor.ID), 10, 64) @@ -39,8 +39,6 @@ func TestWorkerPool(t *testing.T) { if val != int(i)*2 { t.Fatalf("wrong value %v; expected %v", val, int(i)*2) } - case <-wp.Done: - return default: } } @@ -56,12 +54,14 @@ func TestWorkerPool_TimeOut(t *testing.T) { for { select { - case r := <-wp.Results(): + case r, ok := <-wp.Results(): + if !ok { + return + } + if r.Err != nil && r.Err != context.DeadlineExceeded { t.Fatalf("expected error: %v; got: %v", context.DeadlineExceeded, r.Err) } - case <-wp.Done: - return default: } } @@ -77,12 +77,14 @@ func TestWorkerPool_Cancel(t *testing.T) { for { select { - case r := <-wp.Results(): + case r, ok := <-wp.Results(): + if !ok { + return + } + if r.Err != nil && r.Err != context.Canceled { t.Fatalf("expected error: %v; got: %v", context.Canceled, r.Err) } - case <-wp.Done: - return default: } }