diff --git a/examples/test/test.go b/examples/test/test.go index 78e0be8..7c6bff8 100644 --- a/examples/test/test.go +++ b/examples/test/test.go @@ -16,6 +16,7 @@ func main() { // create a new task manager tm := worker.NewTaskManager(4, 10, 5, time.Second*30, time.Second*30, 3) // close the task manager + // defer tm.Close() // register and execute 10 tasks in a separate goroutine go func() { @@ -34,7 +35,8 @@ func main() { log.Fatal(error) } emptyFile.Close() - time.Sleep(time.Second) + time.Sleep(time.Millisecond * 100) + return fmt.Sprintf("** task number %v with id %s executed", j, id), err }, Retries: 10, @@ -61,7 +63,7 @@ func main() { log.Fatal(error) } emptyFile.Close() - time.Sleep(time.Second) + time.Sleep(time.Millisecond * 100) return fmt.Sprintf("**** task number %v with id %s executed", j, id), err }, } @@ -76,8 +78,8 @@ func main() { // tm.Close() // wait for the tasks to finish and print the results - for result := range tm.GetResults() { - fmt.Println(result) + for id, result := range tm.GetResults() { + fmt.Println(id, result) } } diff --git a/middleware/logger.go b/middleware/logger.go index ac0a0dd..33f055e 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -131,8 +131,13 @@ func (mw *loggerMiddleware) GetActiveTasks() int { return mw.next.GetActiveTasks() } +// StreamResults streams the results channel +func (mw *loggerMiddleware) StreamResults() <-chan worker.Result { + return mw.next.StreamResults() +} + // GetResults returns the results channel -func (mw *loggerMiddleware) GetResults() <-chan worker.Result { +func (mw *loggerMiddleware) GetResults() []worker.Result { return mw.next.GetResults() } diff --git a/service.go b/service.go index 29ec7a7..a7f4639 100644 --- a/service.go +++ b/service.go @@ -27,8 +27,10 @@ type Service interface { CancelTask(id uuid.UUID) // GetActiveTasks returns the number of active tasks GetActiveTasks() int + // StreamResults streams the `Result` channel + StreamResults() <-chan Result // GetResults retruns the `Result` channel - GetResults() <-chan Result + GetResults() []Result // GetCancelled gets the cancelled tasks channel GetCancelled() <-chan Task // GetTask gets a task by its ID diff --git a/tests/worker_test.go b/tests/worker_test.go index 6aca81f..ee52a19 100644 --- a/tests/worker_test.go +++ b/tests/worker_test.go @@ -47,7 +47,7 @@ func TestTaskManager_Start(t *testing.T) { } tm.RegisterTask(context.Background(), task) - res := <-tm.GetResults() + res := <-tm.StreamResults() if res.Task == nil { t.Fatalf("Task result was not added to the results channel") } @@ -62,7 +62,7 @@ func TestTaskManager_GetResults(t *testing.T) { } tm.RegisterTask(context.Background(), task) - results := <-tm.GetResults() + results := <-tm.StreamResults() if results.Task == nil { t.Fatalf("results channel is nil") } diff --git a/worker.go b/worker.go index 66823a1..098dfb9 100644 --- a/worker.go +++ b/worker.go @@ -44,6 +44,7 @@ type TaskManager struct { MaxRetries int // MaxRetries is the maximum number of retries limiter *rate.Limiter // limiter is a rate limiter that limits the number of tasks that can be executed at once wg sync.WaitGroup // wg is a wait group that waits for all tasks to finish + running sync.WaitGroup // running is a wait group that waits for all running tasks to finish mutex sync.RWMutex // mutex protects the task handling quit chan struct{} // quit is a channel to signal all goroutines to stop ctx context.Context // ctx is the context for the task manager @@ -102,6 +103,7 @@ func NewTaskManager(maxWorkers int, maxTasks int, tasksPerSecond float64, timeou MaxRetries: maxRetries, limiter: rate.NewLimiter(rate.Limit(tasksPerSecond), maxTasks), wg: sync.WaitGroup{}, + running: sync.WaitGroup{}, mutex: sync.RWMutex{}, quit: make(chan struct{}), ctx: ctx, @@ -282,37 +284,23 @@ func (tm *TaskManager) RegisterTasks(ctx context.Context, tasks ...Task) { // Wait waits for all tasks to complete or for the timeout to elapse func (tm *TaskManager) Wait(timeout time.Duration) { - timer := time.NewTimer(timeout) - defer timer.Stop() - - // flag to indicate if any tasks have been executed - executed := false + done := make(chan struct{}) + go func() { + tm.wg.Wait() // Wait for all tasks to be started + tm.running.Wait() // Wait for all running tasks to finish + close(done) + }() - for { - select { - case <-tm.quit: - // task manager has been closed, cancel all tasks - tm.CancelAll() - close(tm.Results) - close(tm.ctx.Value(ctxKeyCancelled{}).(chan Task)) - - case <-timer.C: - // timeout reached, cancel all tasks - tm.CancelAll() - default: - // check if any tasks have been executed - if !executed { - // no tasks have been executed, return immediately - return - } - // wait for all tasks to finish - // tm.scheduler.Wait() - tm.wg.Wait() - // close the results and cancelled channels - close(tm.Results) - close(tm.ctx.Value(ctxKeyCancelled{}).(chan Task)) - } + select { + case <-done: + // All tasks have finished + case <-time.After(timeout): + // Timeout reached before all tasks finished } + + // close the results and cancelled channels + close(tm.Results) + close(tm.ctx.Value(ctxKeyCancelled{}).(chan Task)) } // Close stops the task manager and waits for all tasks to finish @@ -378,11 +366,35 @@ func (tm *TaskManager) GetActiveTasks() int { return int(tm.limiter.Limit()) - tm.limiter.Burst() } -// GetResults gets the results channel -func (tm *TaskManager) GetResults() <-chan Result { +// StreamResults streams the results channel +func (tm *TaskManager) StreamResults() <-chan Result { return tm.Results } +// GetResults gets the results channel +func (tm *TaskManager) GetResults() []Result { + results := make([]Result, 0) + + // Create a done channel to signal when all tasks have finished + done := make(chan struct{}) + + // Start a goroutine to read from the Results channel + go func() { + for result := range tm.Results { + results = append(results, result) + } + close(done) + }() + + // Wait for all tasks to finish + tm.Wait(tm.Timeout) + + // Wait for the results goroutine to finish + <-done + + return results +} + // GetCancelled gets the cancelled tasks channel func (tm *TaskManager) GetCancelled() <-chan Task { return tm.ctx.Value(ctxKeyCancelled{}).(chan Task) @@ -532,6 +544,9 @@ func (tm *TaskManager) ExecuteTask(id uuid.UUID, timeout time.Duration) (interfa return nil, ErrTaskCancelled default: task.setStarted() + // Increment the running wait group when the task starts + tm.running.Add(1) + defer tm.running.Done() // execute the task result, err := task.Fn()