Skip to content

Commit

Permalink
Added GetResults and StreamResults
Browse files Browse the repository at this point in the history
  • Loading branch information
Francesco Cosentino committed May 17, 2023
1 parent 3687f27 commit 8566fd8
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 39 deletions.
10 changes: 6 additions & 4 deletions examples/test/test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ func main() {
// create a new task manager
tm := worker.NewTaskManager(4, 10, 5, time.Second*30, time.Second*30, 3)
// close the task manager
// defer tm.Close()

// register and execute 10 tasks in a separate goroutine
go func() {
Expand All @@ -34,7 +35,8 @@ func main() {
log.Fatal(error)
}
emptyFile.Close()
time.Sleep(time.Second)
time.Sleep(time.Millisecond * 100)

return fmt.Sprintf("** task number %v with id %s executed", j, id), err
},
Retries: 10,
Expand All @@ -61,7 +63,7 @@ func main() {
log.Fatal(error)
}
emptyFile.Close()
time.Sleep(time.Second)
time.Sleep(time.Millisecond * 100)
return fmt.Sprintf("**** task number %v with id %s executed", j, id), err
},
}
Expand All @@ -76,8 +78,8 @@ func main() {
// tm.Close()

// wait for the tasks to finish and print the results
for result := range tm.GetResults() {
fmt.Println(result)
for id, result := range tm.GetResults() {
fmt.Println(id, result)
}

}
7 changes: 6 additions & 1 deletion middleware/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,13 @@ func (mw *loggerMiddleware) GetActiveTasks() int {
return mw.next.GetActiveTasks()
}

// StreamResults streams the results channel
func (mw *loggerMiddleware) StreamResults() <-chan worker.Result {
return mw.next.StreamResults()
}

// GetResults returns the results channel
func (mw *loggerMiddleware) GetResults() <-chan worker.Result {
func (mw *loggerMiddleware) GetResults() []worker.Result {
return mw.next.GetResults()
}

Expand Down
4 changes: 3 additions & 1 deletion service.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ type Service interface {
CancelTask(id uuid.UUID)
// GetActiveTasks returns the number of active tasks
GetActiveTasks() int
// StreamResults streams the `Result` channel
StreamResults() <-chan Result
// GetResults retruns the `Result` channel
GetResults() <-chan Result
GetResults() []Result
// GetCancelled gets the cancelled tasks channel
GetCancelled() <-chan Task
// GetTask gets a task by its ID
Expand Down
4 changes: 2 additions & 2 deletions tests/worker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func TestTaskManager_Start(t *testing.T) {
}
tm.RegisterTask(context.Background(), task)

res := <-tm.GetResults()
res := <-tm.StreamResults()
if res.Task == nil {
t.Fatalf("Task result was not added to the results channel")
}
Expand All @@ -62,7 +62,7 @@ func TestTaskManager_GetResults(t *testing.T) {
}
tm.RegisterTask(context.Background(), task)

results := <-tm.GetResults()
results := <-tm.StreamResults()
if results.Task == nil {
t.Fatalf("results channel is nil")
}
Expand Down
77 changes: 46 additions & 31 deletions worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ type TaskManager struct {
MaxRetries int // MaxRetries is the maximum number of retries
limiter *rate.Limiter // limiter is a rate limiter that limits the number of tasks that can be executed at once
wg sync.WaitGroup // wg is a wait group that waits for all tasks to finish
running sync.WaitGroup // running is a wait group that waits for all running tasks to finish
mutex sync.RWMutex // mutex protects the task handling
quit chan struct{} // quit is a channel to signal all goroutines to stop
ctx context.Context // ctx is the context for the task manager
Expand Down Expand Up @@ -102,6 +103,7 @@ func NewTaskManager(maxWorkers int, maxTasks int, tasksPerSecond float64, timeou
MaxRetries: maxRetries,
limiter: rate.NewLimiter(rate.Limit(tasksPerSecond), maxTasks),
wg: sync.WaitGroup{},
running: sync.WaitGroup{},
mutex: sync.RWMutex{},
quit: make(chan struct{}),
ctx: ctx,
Expand Down Expand Up @@ -282,37 +284,23 @@ func (tm *TaskManager) RegisterTasks(ctx context.Context, tasks ...Task) {

// Wait waits for all tasks to complete or for the timeout to elapse
func (tm *TaskManager) Wait(timeout time.Duration) {
timer := time.NewTimer(timeout)
defer timer.Stop()

// flag to indicate if any tasks have been executed
executed := false
done := make(chan struct{})
go func() {
tm.wg.Wait() // Wait for all tasks to be started
tm.running.Wait() // Wait for all running tasks to finish
close(done)
}()

for {
select {
case <-tm.quit:
// task manager has been closed, cancel all tasks
tm.CancelAll()
close(tm.Results)
close(tm.ctx.Value(ctxKeyCancelled{}).(chan Task))

case <-timer.C:
// timeout reached, cancel all tasks
tm.CancelAll()
default:
// check if any tasks have been executed
if !executed {
// no tasks have been executed, return immediately
return
}
// wait for all tasks to finish
// tm.scheduler.Wait()
tm.wg.Wait()
// close the results and cancelled channels
close(tm.Results)
close(tm.ctx.Value(ctxKeyCancelled{}).(chan Task))
}
select {
case <-done:
// All tasks have finished
case <-time.After(timeout):
// Timeout reached before all tasks finished
}

// close the results and cancelled channels
close(tm.Results)
close(tm.ctx.Value(ctxKeyCancelled{}).(chan Task))
}

// Close stops the task manager and waits for all tasks to finish
Expand Down Expand Up @@ -378,11 +366,35 @@ func (tm *TaskManager) GetActiveTasks() int {
return int(tm.limiter.Limit()) - tm.limiter.Burst()
}

// GetResults gets the results channel
func (tm *TaskManager) GetResults() <-chan Result {
// StreamResults streams the results channel
func (tm *TaskManager) StreamResults() <-chan Result {
return tm.Results
}

// GetResults gets the results channel
func (tm *TaskManager) GetResults() []Result {
results := make([]Result, 0)

// Create a done channel to signal when all tasks have finished
done := make(chan struct{})

// Start a goroutine to read from the Results channel
go func() {
for result := range tm.Results {
results = append(results, result)
}
close(done)
}()

// Wait for all tasks to finish
tm.Wait(tm.Timeout)

// Wait for the results goroutine to finish
<-done

return results
}

// GetCancelled gets the cancelled tasks channel
func (tm *TaskManager) GetCancelled() <-chan Task {
return tm.ctx.Value(ctxKeyCancelled{}).(chan Task)
Expand Down Expand Up @@ -532,6 +544,9 @@ func (tm *TaskManager) ExecuteTask(id uuid.UUID, timeout time.Duration) (interfa
return nil, ErrTaskCancelled
default:
task.setStarted()
// Increment the running wait group when the task starts
tm.running.Add(1)
defer tm.running.Done()

// execute the task
result, err := task.Fn()
Expand Down

0 comments on commit 8566fd8

Please sign in to comment.