From d2e12d88c00124f53bad2dd98e626d611a529e85 Mon Sep 17 00:00:00 2001 From: Francesco Cosentino Date: Fri, 19 May 2023 13:36:29 +0300 Subject: [PATCH] Refactored Workers and Execution flows, improved Task --- .gitignore | 2 +- examples/manual/main.go | 2 +- examples/middleware/main.go | 39 ++- examples/multi/multi.go | 129 +++++++++ examples/{test => multi}/res/.gitkeep | 0 examples/test/test.go | 87 ------ middleware/logger.go | 6 +- service.go | 4 +- task.go | 97 ++++--- tests/worker_test.go | 10 +- worker.go | 369 ++++++++++++-------------- 11 files changed, 393 insertions(+), 352 deletions(-) create mode 100644 examples/multi/multi.go rename examples/{test => multi}/res/.gitkeep (100%) delete mode 100644 examples/test/test.go diff --git a/.gitignore b/.gitignore index 18d96b7..97d09fa 100644 --- a/.gitignore +++ b/.gitignore @@ -118,5 +118,5 @@ $RECYCLE.BIN/ # Local example *.bak *.now -examples/test/res/*.txt +examples/multi/res/*.txt .dccache diff --git a/examples/manual/main.go b/examples/manual/main.go index adfbcde..2da11c1 100644 --- a/examples/manual/main.go +++ b/examples/manual/main.go @@ -27,7 +27,7 @@ func main() { task := worker.Task{ ID: uuid.New(), Priority: 1, - Fn: func() (val interface{}, err error) { return "Hello, World from Task!", err }, + Execute: func() (val interface{}, err error) { return "Hello, World from Task!", err }, } res, err := srv.ExecuteTask(task.ID, time.Second*5) diff --git a/examples/middleware/main.go b/examples/middleware/main.go index 5c42a5c..e24a02c 100644 --- a/examples/middleware/main.go +++ b/examples/middleware/main.go @@ -11,9 +11,9 @@ import ( ) func main() { - tm := worker.NewTaskManager(context.TODO(), 4, 10, 5, time.Second*3, time.Second*30, 3) + tm := worker.NewTaskManager(context.TODO(), 4, 10, 5, time.Second*3, time.Second*3, 3) - // defer tm.Close() + // defer tm.Stop() var srv worker.Service = tm // apply middleware in the same order as you want to execute them @@ -24,12 +24,12 @@ func main() { }, ) - // defer srv.Close() + // defer srv.Stop() task := worker.Task{ ID: uuid.New(), Priority: 1, - Fn: func() (val interface{}, err error) { + Execute: func() (val interface{}, err error) { return func(a int, b int) (val interface{}, err error) { return a + b, err }(2, 5) @@ -40,22 +40,23 @@ func main() { task1 := worker.Task{ ID: uuid.New(), Priority: 10, - // Fn: func() (val interface{}, err error) { return "Hello, World from Task 1!", err }, + // Execute: func() (val interface{}, err error) { return "Hello, World from Task 1!", err }, } task2 := worker.Task{ ID: uuid.New(), Priority: 5, - Fn: func() (val interface{}, err error) { + Execute: func() (val interface{}, err error) { time.Sleep(time.Second * 2) return "Hello, World from Task 2!", err }, + Ctx: context.TODO(), } task3 := worker.Task{ ID: uuid.New(), Priority: 90, - Fn: func() (val interface{}, err error) { + Execute: func() (val interface{}, err error) { // Simulate a long running task // time.Sleep(3 * time.Second) return "Hello, World from Task 3!", err @@ -65,7 +66,7 @@ func main() { task4 := worker.Task{ ID: uuid.New(), Priority: 150, - Fn: func() (val interface{}, err error) { + Execute: func() (val interface{}, err error) { // Simulate a long running task time.Sleep(1 * time.Second) return "Hello, World from Task 4!", err @@ -79,12 +80,26 @@ func main() { srv.RegisterTask(context.TODO(), task4) // Print results - for result := range srv.GetResults() { - fmt.Println(result) - } + // for result := range srv.GetResults() { + // fmt.Println(result) + // } + // tm.Wait(tm.Timeout) tasks := srv.GetTasks() for _, task := range tasks { - fmt.Println(task) + fmt.Print(task.ID, " ", task.Priority, " ", task.Status, " ", task.Error, " ", "\n") + } + + fmt.Println("printing cancelled tasks") + + // get the cancelled tasks + cancelledTasks := tm.GetCancelledTasks() + + select { + case task := <-cancelledTasks: + fmt.Printf("Task %s was cancelled\n", task.ID.String()) + default: + fmt.Println("No tasks have been cancelled yet") } + } diff --git a/examples/multi/multi.go b/examples/multi/multi.go new file mode 100644 index 0000000..3991226 --- /dev/null +++ b/examples/multi/multi.go @@ -0,0 +1,129 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + "path" + "time" + + "github.com/google/uuid" + worker "github.com/hyp3rd/go-worker" +) + +func main() { + // create a new task manager + tm := worker.NewTaskManager(context.TODO(), 1, 50, 10, time.Second*60, time.Second*2, 5) + + // register and execute 10 tasks in a separate goroutine + go func() { + for i := 0; i < 10; i++ { + j := i + // create a new task + id := uuid.New() + task := worker.Task{ + ID: id, + Name: "Some task", + Description: "Here goes the description of the task", + Priority: 10, + Execute: func() (val interface{}, err error) { + emptyFile, error := os.Create(path.Join("examples", "multi", "res", fmt.Sprintf("1st__EmptyFile___%v.txt", j))) + if error != nil { + log.Fatal(error) + } + emptyFile.Close() + return fmt.Sprintf("** task number %v with id %s executed", j, id), err + }, + Ctx: context.TODO(), + Retries: 5, + RetryDelay: 1, + } + + // register the task + tm.RegisterTask(context.TODO(), task) + } + }() + + // register and execute 10 tasks in a separate goroutine + go func() { + for i := 0; i < 10; i++ { + j := i + // create a new task + id := uuid.New() + task := worker.Task{ + ID: id, + Execute: func() (val interface{}, err error) { + emptyFile, error := os.Create(path.Join("examples", "multi", "res", fmt.Sprintf("2nd__EmptyFile___%v.txt", j))) + + if error != nil { + log.Fatal(error) + } + emptyFile.Close() + // time.Sleep(time.Millisecond * 100) + return fmt.Sprintf("**** task number %v with id %s executed", j, id), err + }, + // Ctx: context.TODO(), + Retries: 5, + RetryDelay: 1, + } + + // register the task + tm.RegisterTask(context.TODO(), task) + } + }() + + for i := 0; i < 10; i++ { + j := i + // create a new task + id := uuid.New() + task := worker.Task{ + ID: id, + Execute: func() (val interface{}, err error) { + emptyFile, error := os.Create(path.Join("examples", "multi", "res", fmt.Sprintf("3nd__EmptyFile___%v.txt", j))) + + if error != nil { + log.Fatal(error) + } + emptyFile.Close() + // time.Sleep(time.Millisecond * 100) + return fmt.Sprintf("**** task number %v with id %s executed", j, id), err + }, + Ctx: context.TODO(), + Retries: 5, + RetryDelay: 1, + } + + // register the task + tm.RegisterTask(context.TODO(), task) + } + + for i := 0; i < 10; i++ { + j := i + // create a new task + id := uuid.New() + task := worker.Task{ + ID: id, + Execute: func() (val interface{}, err error) { + emptyFile, err := os.Create(path.Join("examples", "wrong-path", "res", fmt.Sprintf("4nd__EmptyFile___%v.txt", j))) + + if err != nil { + log.Println(err) + } + emptyFile.Close() + // time.Sleep(time.Millisecond * 100) + return fmt.Sprintf("**** wrong task number %v with id %s executed", j, id), err + }, + Ctx: context.TODO(), + Retries: 3, + } + + // register the task + tm.RegisterTask(context.TODO(), task) + } + + // wait for the tasks to finish and print the results + for id, result := range tm.GetResults() { + fmt.Println(id, result) + } +} diff --git a/examples/test/res/.gitkeep b/examples/multi/res/.gitkeep similarity index 100% rename from examples/test/res/.gitkeep rename to examples/multi/res/.gitkeep diff --git a/examples/test/test.go b/examples/test/test.go deleted file mode 100644 index e0b7343..0000000 --- a/examples/test/test.go +++ /dev/null @@ -1,87 +0,0 @@ -package main - -import ( - "context" - "fmt" - "log" - "os" - "path" - "time" - - "github.com/google/uuid" - worker "github.com/hyp3rd/go-worker" -) - -func main() { - // create a new task manager - tm := worker.NewTaskManager(context.TODO(), 4, 20, 10, time.Second*60, time.Second*1, 5) - - // register and execute 10 tasks in a separate goroutine - go func() { - for i := 0; i < 10; i++ { - j := i - // create a new task - id := uuid.New() - task := worker.Task{ - ID: id, - Name: "Some task", - Description: "Here goes the description of the task", - Priority: 10, - Fn: func() (val interface{}, err error) { - emptyFile, error := os.Create(path.Join("examples", "test", "res", fmt.Sprintf("1st__EmptyFile___%v.txt", j))) - if error != nil { - log.Fatal(error) - } - emptyFile.Close() - // time.Sleep(time.Millisecond * 100) - - return fmt.Sprintf("** task number %v with id %s executed", j, id), err - }, - Retries: 5, - RetryDelay: 1, - } - - // register the task - tm.RegisterTask(context.TODO(), task) - } - }() - - // register and execute 10 tasks in a separate goroutine - go func() { - for i := 0; i < 10; i++ { - j := i - // create a new task - id := uuid.New() - task := worker.Task{ - ID: id, - Fn: func() (val interface{}, err error) { - emptyFile, error := os.Create(path.Join("examples", "test", "res", fmt.Sprintf("2nd__EmptyFile___%v.txt", j))) - - if error != nil { - log.Fatal(error) - } - emptyFile.Close() - // time.Sleep(time.Millisecond * 100) - return fmt.Sprintf("**** task number %v with id %s executed", j, id), err - }, - Retries: 5, - RetryDelay: 1, - } - - // register the task - tm.RegisterTask(context.TODO(), task) - } - }() - - // tm.CancelAll() - - // wait for the tasks to finish and print the results - for id, result := range tm.GetResults() { - fmt.Println(id, result) - } - - for cancelled := range tm.GetCancelled() { - fmt.Println(cancelled) - } - -} diff --git a/middleware/logger.go b/middleware/logger.go index e1f8724..c8592a1 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -132,9 +132,9 @@ func (mw *loggerMiddleware) GetResults() []worker.Result { return mw.next.GetResults() } -// GetCancelled streams the cancelled tasks channel -func (mw *loggerMiddleware) GetCancelled() <-chan worker.Task { - return mw.next.GetCancelled() +// GetCancelledTasks streams the cancelled tasks channel +func (mw *loggerMiddleware) GetCancelledTasks() <-chan worker.Task { + return mw.next.GetCancelledTasks() } // GetTask gets a task by its ID diff --git a/service.go b/service.go index fad1d77..ab1d25f 100644 --- a/service.go +++ b/service.go @@ -29,8 +29,8 @@ type Service interface { StreamResults() <-chan Result // GetResults retruns the `Result` channel GetResults() []Result - // GetCancelled gets the cancelled tasks channel - GetCancelled() <-chan Task + // GetCancelledTasks gets the cancelled tasks channel + GetCancelledTasks() <-chan Task // GetTask gets a task by its ID GetTask(id uuid.UUID) (task *Task, err error) // GetTasks gets all tasks diff --git a/task.go b/task.go index 347108a..c7de388 100644 --- a/task.go +++ b/task.go @@ -3,6 +3,7 @@ package worker import ( "context" "errors" + "fmt" "sync/atomic" "time" @@ -15,6 +16,8 @@ var ( ErrInvalidTaskID = errors.New("invalid task id") // ErrInvalidTaskFunc is returned when a task has an invalid function ErrInvalidTaskFunc = errors.New("invalid task function") + // ErrInvalidTaskContext is returned when a task has an invalid context + ErrInvalidTaskContext = errors.New("invalid task context") // ErrTaskNotFound is returned when a task is not found ErrTaskNotFound = errors.New("task not found") // ErrTaskTimeout is returned when a task times out @@ -40,6 +43,10 @@ type ( // - 2: `RateLimited` // - 3: `Cancelled` // - 4: `Failed` +// - 5: `Queued` +// - 6: `Running` +// - 7: `Invalid` +// - 8: `Completed` const ( // ContextDeadlineReached means the context is past its deadline. ContextDeadlineReached = TaskStatus(1) @@ -53,6 +60,10 @@ const ( Queued = TaskStatus(5) // Running means the `Task` is running. Running = TaskStatus(6) + // Invalid means the `Task` is invalid. + Invalid = TaskStatus(7) + // Completed means the `Task` is completed. + Completed = TaskStatus(8) ) // String returns the string representation of the task status. @@ -70,6 +81,8 @@ func (ts TaskStatus) String() string { return "Queued" case Running: return "Running" + case Invalid: + return "Invalid" default: return "Unknown" } @@ -81,7 +94,7 @@ type Task struct { Name string `json:"name"` // Name is the name of the task Description string `json:"description"` // Description is the description of the task Priority int `json:"priority"` // Priority is the priority of the task - Fn TaskFunc `json:"-"` // Fn is the function that will be executed by the task + Execute TaskFunc `json:"-"` // Execute is the function that will be executed by the task Ctx context.Context `json:"context"` // Ctx is the context of the task CancelFunc context.CancelFunc `json:"-"` // CancelFunc is the cancel function of the task Status TaskStatus `json:"task_status"` // TaskStatus is stores the status of the task @@ -94,6 +107,26 @@ type Task struct { index int `json:"-"` // index is the index of the task in the task manager } +// NewTask creates a new task with the provided function and context +func NewTask(fn TaskFunc, ctx context.Context) (*Task, error) { + task := &Task{ + ID: uuid.New(), + Execute: fn, + Ctx: ctx, + Retries: 0, + RetryDelay: 0, + } + + if err := task.IsValid(); err != nil { + // prevent the task from being rescheduled + task.Status = Invalid + task.setCancelled() + return nil, err + } + + return task, nil +} + // IsValid returns an error if the task is invalid func (task *Task) IsValid() (err error) { if task.ID == uuid.Nil { @@ -101,7 +134,12 @@ func (task *Task) IsValid() (err error) { task.Error.Store(err.Error()) return } - if task.Fn == nil { + if task.Ctx == nil { + err = ErrInvalidTaskContext + task.Error.Store(err.Error()) + return + } + if task.Execute == nil { err = ErrInvalidTaskFunc task.Error.Store(err.Error()) return @@ -112,62 +150,39 @@ func (task *Task) IsValid() (err error) { // setStarted handles the start of a task by setting the start time func (task *Task) setStarted() { task.Started.Store(time.Now().UnixNano()) + task.Status = Running } // setCompleted handles the finish of a task by setting the finish time func (task *Task) setCompleted() { task.Completed.Store(time.Now().UnixNano()) + task.Status = Completed } // setCancelled handles the cancellation of a task by setting the cancellation time func (task *Task) setCancelled() { task.Cancelled.Store(time.Now().UnixNano()) + task.Status = Cancelled } -// setRetryDelay sets the retry delay for the task -func (task *Task) setRetryDelay(delay time.Duration) { - task.RetryDelay = delay +// setQueued handles the queuing of a task by setting the status to queued +func (task *Task) setQueued() { + task.Status = Queued } // WaitCancelled waits for the task to be cancelled func (task *Task) WaitCancelled() { - for task.Cancelled.Load() == 0 { - // create a timer with a short duration to check if the task has been cancelled - timer := time.NewTimer(time.Millisecond * 100) - defer timer.Stop() - defer timer.Reset(0) - - // wait for either the timer to fire or the task to be cancelled - select { - case <-timer.C: - // timer expired, check if the task has been cancelled - continue - case <-task.CancelledChan(): - // task has been cancelled, return - return - } + select { + case <-task.Ctx.Done(): + return + case <-time.After(100 * time.Millisecond): + task.WaitCancelled() } } -// CancelledChan returns a channel that will be closed when the task is cancelled +// CancelledChan returns a channel which gets closed when the task is cancelled. func (task *Task) CancelledChan() <-chan struct{} { - if task.Cancelled.Load() > 0 { - ch := make(chan struct{}) - close(ch) - return ch - } - - cancelledChan := make(chan struct{}) - go func() { - defer close(cancelledChan) - for { - if task.Cancelled.Load() > 0 { - return - } - time.Sleep(time.Millisecond * 10) - } - }() - return cancelledChan + return task.Ctx.Done() } // ShouldSchedule returns an error if the task should not be scheduled @@ -175,17 +190,17 @@ func (task *Task) ShouldSchedule() error { // check if the task has been cancelled if task.Cancelled.Load() > 0 && task.Status != Cancelled { - return ErrTaskCancelled + return fmt.Errorf("%w: Task ID %s is already cancelled", ErrTaskCancelled, task.ID) } // check if the task has started if task.Started.Load() > 0 { - return ErrTaskAlreadyStarted + return fmt.Errorf("%w: Task ID %s has already started", ErrTaskAlreadyStarted, task.ID) } // check if the task has completed if task.Completed.Load() > 0 { - return ErrTaskCompleted + return fmt.Errorf("%w: Task ID %s has already completed", ErrTaskCompleted, task.ID) } return nil diff --git a/tests/worker_test.go b/tests/worker_test.go index 6871954..15a33f3 100644 --- a/tests/worker_test.go +++ b/tests/worker_test.go @@ -20,7 +20,7 @@ func TestTaskManager_RegisterTask(t *testing.T) { tm := worker.NewTaskManager(context.TODO(), 4, 10, 5, time.Second*30, time.Second*30, 3) task := worker.Task{ ID: uuid.New(), - Fn: func() (val interface{}, err error) { return nil, err }, + Execute: func() (val interface{}, err error) { return nil, err }, Priority: 10, } @@ -42,7 +42,7 @@ func TestTaskManager_Start(t *testing.T) { tm := worker.NewTaskManager(context.TODO(), 4, 10, 5, time.Second*30, time.Second*30, 3) task := worker.Task{ ID: uuid.New(), - Fn: func() (val interface{}, err error) { return "task", err }, + Execute: func() (val interface{}, err error) { return "task", err }, Priority: 10, } tm.RegisterTask(context.TODO(), task) @@ -57,7 +57,7 @@ func TestTaskManager_GetResults(t *testing.T) { tm := worker.NewTaskManager(context.TODO(), 4, 10, 5, time.Second*30, time.Second*30, 3) task := worker.Task{ ID: uuid.New(), - Fn: func() (val interface{}, err error) { return "task", err }, + Execute: func() (val interface{}, err error) { return "task", err }, Priority: 10, } tm.RegisterTask(context.TODO(), task) @@ -72,7 +72,7 @@ func TestTaskManager_GetTask(t *testing.T) { tm := worker.NewTaskManager(context.TODO(), 4, 10, 5, time.Second*30, time.Second*30, 3) task := worker.Task{ ID: uuid.New(), - Fn: func() (val interface{}, err error) { return "task", err }, + Execute: func() (val interface{}, err error) { return "task", err }, Priority: 10, } tm.RegisterTask(context.TODO(), task) @@ -90,7 +90,7 @@ func TestTaskManager_ExecuteTask(t *testing.T) { tm := worker.NewTaskManager(context.TODO(), 4, 10, 5, time.Second*30, time.Second*30, 3) task := worker.Task{ ID: uuid.New(), - Fn: func() (val interface{}, err error) { return "task", err }, + Execute: func() (val interface{}, err error) { return "task", err }, Priority: 10, } tm.RegisterTask(context.TODO(), task) diff --git a/worker.go b/worker.go index 1c95b63..1292645 100644 --- a/worker.go +++ b/worker.go @@ -96,7 +96,7 @@ func NewTaskManager(ctx context.Context, maxWorkers int, maxTasks int, tasksPerS tm := &TaskManager{ Registry: sync.Map{}, Results: make(chan Result, maxTasks), - Tasks: make(chan Task), + Tasks: make(chan Task, maxTasks), Cancelled: cancelled, Timeout: timeout, MaxWorkers: maxWorkers, @@ -121,6 +121,13 @@ func NewTaskManager(ctx context.Context, maxWorkers int, maxTasks int, tasksPerS return tm } +// IsEmpty checks if the task scheduler queue is empty +func (tm *TaskManager) IsEmpty() bool { + tm.mutex.Lock() + defer tm.mutex.Unlock() + return tm.scheduler.Len() == 0 +} + // StartWorkers starts the task manager and its goroutines func (tm *TaskManager) StartWorkers() { // start the workers @@ -130,10 +137,13 @@ func (tm *TaskManager) StartWorkers() { defer tm.wg.Done() for { select { + case task, ok := <-tm.Tasks: + if (!ok || task.ID == uuid.Nil) && tm.IsEmpty() { + return + } + tm.ExecuteTask(task.ID, tm.Timeout) case <-tm.quit: return - case task := <-tm.Tasks: - tm.ExecuteTask(task.ID, tm.Timeout) } } }() @@ -182,51 +192,32 @@ func (tm *TaskManager) StartWorkers() { // RegisterTask registers a new task to the task manager func (tm *TaskManager) RegisterTask(ctx context.Context, task Task) error { // set the maximum retries and retry delay for the task - task.Retries = tm.MaxRetries - task.RetryDelay = tm.RetryDelay - - defer func() { - // recover from any panics and send the task to the `Cancelled` channel - if r := recover(); r != nil { - cancelled, ok := tm.ctx.Value(ctxKeyCancelled{}).(chan Task) - if ok { - select { - case cancelled <- task: - // task sent successfully - return - default: - // channel buffer is full, task not sent - } - } - } - }() + if task.Retries == 0 || task.Retries > tm.MaxRetries { + task.Retries = tm.MaxRetries + } + if task.RetryDelay == 0 { + task.RetryDelay = tm.RetryDelay + } + + if task.Ctx == nil { + task.Ctx = ctx + } + + // create a new context for this task + task.Ctx, task.CancelFunc = context.WithCancel(context.WithValue(task.Ctx, ctxKeyTaskID{}, task.ID)) // if the task is invalid, send it to the results channel - err := task.IsValid() + err := tm.validateTask(&task) if err != nil { - cancelled, ok := tm.ctx.Value(ctxKeyCancelled{}).(chan Task) - if !ok { - return err - } - select { - case cancelled <- task: - // task sent successfully - default: - // channel buffer is full, task not sent - } return err } - // add a wait group for the task - tm.wg.Add(1) - - // create a new context for this task - task.Ctx, task.CancelFunc = context.WithCancel(context.WithValue(ctx, ctxKeyTaskID{}, task.ID)) - - if ctx.Err() != nil { + // Check if context is done before waiting on the limiter + if err := ctx.Err(); err != nil { tm.cancelTask(&task, Cancelled, false) - return ctx.Err() + return err } + if err := tm.limiter.Wait(ctx); err != nil { // the task has been cancelled at this time tm.cancelTask(&task, RateLimited, err != context.Canceled) @@ -237,6 +228,8 @@ func (tm *TaskManager) RegisterTask(ctx context.Context, task Task) error { tm.Registry.Store(task.ID, &task) tm.mutex.Lock() + // add a wait group for the task + tm.wg.Add(1) // add the task to the scheduler heap.Push(tm.scheduler, task) // send the task to the NewTasks channel @@ -280,14 +273,15 @@ func (tm *TaskManager) Wait(timeout time.Duration) { func (tm *TaskManager) Stop() { // Signal context cancellation <-tm.cancel - // // close the tasks channel - // defer close(tm.Tasks) close(tm.quit) // wait for all tasks to finish before closing the task manager tm.wg.Wait() // close the results and cancelled channels close(tm.Results) close(tm.Cancelled) + + // Close the tasks channel + close(tm.Tasks) } // ExecuteTask executes a task given its ID and returns the result @@ -303,37 +297,17 @@ func (tm *TaskManager) Stop() { // - If the task fails with all retries exhausted, it cancels the task and returns an error. func (tm *TaskManager) ExecuteTask(id uuid.UUID, timeout time.Duration) (interface{}, error) { // defer tm.wg.Done() + // get the task task, err := tm.GetTask(id) if err != nil { return nil, err } - // Lock the mutex to access the task data - tm.mutex.RLock() - defer tm.mutex.RUnlock() - - // check if the context has been cancelled before checking if the task is already running - select { - case <-task.Ctx.Done(): - tm.cancelTask(task, Cancelled, false) - return nil, ErrTaskAlreadyStarted - default: - } - - // if the task is invalid, send it to the cancelled channel - err = task.IsValid() + // validate the task + err = tm.validateTask(task) if err != nil { - cancelled, ok := tm.ctx.Value(ctxKeyCancelled{}).(chan Task) - if !ok { - return nil, ErrTaskCancelled - } - select { - case cancelled <- *task: - // task sent successfully - default: - // channel buffer is full, task not sent - } + tm.Cancelled <- *task return nil, err } @@ -344,90 +318,59 @@ func (tm *TaskManager) ExecuteTask(id uuid.UUID, timeout time.Duration) (interfa } // create a new context for this task - ctx, cancel := context.WithTimeout(context.WithValue(task.Ctx, ctxKeyTaskID{}, task.ID), tm.Timeout) + var cancel = func() {} + task.Ctx, cancel = context.WithTimeout(task.Ctx, tm.Timeout) defer cancel() - // wait for the result to be available and return it - for { + // reserve a token from the limiter + r := tm.limiter.Reserve() + + // if reservation is not okay, wait and retry + if !r.OK() { + // not allowed to execute the task yet select { - case res := <-tm.Results: - if res.Task.ID == task.ID { - return res.Result, nil - } - case cancelledTask := <-tm.GetCancelled(): - if cancelledTask.ID == task.ID { - // the task was cancelled before it could complete - return nil, ErrTaskCancelled - } - case <-ctx.Done(): - // the task has timed out - tm.CancelTask(task.ID) - return nil, ErrTaskTimeout - case <-time.After(timeout): - return nil, ErrTaskTimeout - default: - // execute the task - // reserve a token from the limiter - r := tm.limiter.Reserve() + case <-time.After(r.Delay()): + case <-tm.quit: + tm.cancelTask(task, Cancelled, false) + return nil, ErrTaskCancelled + case <-task.Ctx.Done(): + tm.cancelTask(task, Cancelled, false) + return nil, ErrTaskCancelled + } + } - if !r.OK() { - // not allowed to execute the task yet - waitTime := r.Delay() - select { - case <-tm.quit: - // the task manager has been closed - tm.cancelTask(task, Cancelled, false) - return nil, ErrTaskCancelled - case <-task.Ctx.Done(): - // the task has been cancelled - tm.cancelTask(task, Cancelled, false) - return nil, ErrTaskCancelled - case <-time.After(waitTime): - // continue with the task execution - continue - } - } + // if reservation is okay, execute the task + task.setStarted() + result, err := task.Execute() - select { - case <-tm.quit: - // the task manager has been closed - tm.cancelTask(task, Cancelled, false) - return nil, ErrTaskCancelled - case <-task.Ctx.Done(): - // the task has been cancelled - tm.cancelTask(task, Cancelled, false) - return nil, ErrTaskCancelled - default: - task.setStarted() - - // execute the task - result, err := task.Fn() - if err != nil { - // task failed, retry up to max retries with delay between retries - if task.Retries < tm.MaxRetries { - task.Retries++ - tm.retryTask(task) - return nil, err - } + // if task execution fails, cancel task + if err != nil { + tm.cancelTask(task, Failed, false) + return nil, err + } - // task failed, no more retries - tm.cancelTask(task, Failed, false) - // return nil, err - } else { - // task completed successfully - task.setCompleted() - } + // if task execution is successful, set task as completed and send result + task.setCompleted() + tm.Results <- Result{ + Task: task, + Result: result, + Error: err, + } - // send the result to the results channel - tm.Results <- Result{ - Task: task, - Result: result, - Error: err, - } + return result, err +} - return result, err - } - } +// retryTask retries a task up to its maximum number of retries with a delay between retries +func (tm *TaskManager) retryTask(task *Task) { + select { + case <-time.After(task.RetryDelay): + tm.mutex.Lock() + tm.RegisterTask(task.Ctx, *task) + tm.mutex.Unlock() + case <-task.Ctx.Done(): + return + case <-tm.quit: + return } } @@ -439,9 +382,8 @@ func (tm *TaskManager) CancelAll() { return false } - if task.Started.Load() == 0 { + if task.Started.Load() <= 0 { // task has not been started yet, remove it from the scheduler - fmt.Println("task not started yet") tm.mutex.Lock() heap.Remove(tm.scheduler, task.index) tm.mutex.Unlock() @@ -470,6 +412,72 @@ func (tm *TaskManager) CancelTask(id uuid.UUID) { tm.cancelTask(task, Cancelled, true) } +// cancelTask cancels a task +func (tm *TaskManager) cancelTask(task *Task, status TaskStatus, notifyWG bool) { + tm.mutex.Lock() + defer tm.mutex.Unlock() + + task.CancelFunc() + // set the cancelled time + task.setCancelled() + // set the task status + task.Status = status + + // wait for the task to be cancelled + task.WaitCancelled() + + if notifyWG || status == Cancelled { + tm.wg.Done() + } + + // update the task in the registry + tm.Registry.Store(task.ID, task) + + // send the task to the cancelled channel + cancelled, ok := tm.ctx.Value(ctxKeyCancelled{}).(chan Task) + + if ok { + select { + case cancelled <- *task: + case <-tm.quit: + return + default: + } + } + + // if the task has retries remaining and status is Failed, add it back to the queue with a delay + if task.Retries > 0 && status != Cancelled { + task.setQueued() + task.Retries-- + tm.retryTask(task) + } else { + task.Status = Failed + } +} + +// GetCancelledTasks gets the cancelled tasks channel +// Example usage: +// +// get the cancelled tasks +// cancelledTasks := tm.GetCancelledTasks() + +// select { +// case task := <-cancelledTasks: +// +// fmt.Printf("Task %s was cancelled\n", task.ID.String()) +// +// default: +// +// fmt.Println("No tasks have been cancelled yet") +// } +func (tm *TaskManager) GetCancelledTasks() <-chan Task { + results, ok := tm.ctx.Value(ctxKeyCancelled{}).(chan Task) + if !ok { + return nil + } + return results +} + // GetActiveTasks returns the number of active tasks func (tm *TaskManager) GetActiveTasks() int { return int(tm.limiter.Limit()) - tm.limiter.Burst() @@ -503,11 +511,6 @@ func (tm *TaskManager) GetResults() []Result { return results } -// GetCancelled gets the cancelled tasks channel -func (tm *TaskManager) GetCancelled() <-chan Task { - return tm.ctx.Value(ctxKeyCancelled{}).(chan Task) -} - // GetTask gets a task by its ID func (tm *TaskManager) GetTask(id uuid.UUID) (task *Task, err error) { // tm.mutex.RLock() @@ -543,57 +546,23 @@ func (tm *TaskManager) GetTasks() []Task { return tasks } -// retryTask retries a task up to its maximum number of retries with a delay between retries -func (tm *TaskManager) retryTask(task *Task) { - task.setRetryDelay(tm.RetryDelay) - - // wait for the retry delay to pass - timer := time.NewTimer(tm.RetryDelay) - defer timer.Stop() - <-timer.C - - // re-enqueue the task - tm.RegisterTask(task.Ctx, *task) -} - -// cancelTask cancels a task -func (tm *TaskManager) cancelTask(task *Task, status TaskStatus, notifyWG bool) { - tm.mutex.Lock() - defer tm.mutex.Unlock() - - task.CancelFunc() - // set the cancelled time - task.setCancelled() - // set the task status - task.Status = status - - // wait for the task to be cancelled - task.WaitCancelled() - - if notifyWG || status == Cancelled { - tm.wg.Done() - } - // update the task in the registry - tm.Registry.Store(task.ID, task) - - // send the task to the cancelled channel - cancelled, ok := tm.ctx.Value(ctxKeyCancelled{}).(chan Task) - - if !ok { - return - } - select { - case cancelled <- *task: - // task sent successfully - default: - // channel buffer is full, task not sent +// validateTask validates a task and sends it to the cancelled channel if it is invalid +func (tm *TaskManager) validateTask(task *Task) error { + // if the task is invalid, send it to the cancelled channel + err := task.IsValid() + if err != nil { + cancelled, ok := tm.ctx.Value(ctxKeyCancelled{}).(chan Task) + if !ok { + return ErrTaskCancelled + } + select { + case cancelled <- *task: + // task sent successfully + default: + // channel buffer is full, task not sent + } + return err } - // if the task has retries remaining, add it back to the queue with a delay - if task.Retries > 0 && status != Cancelled { - task.Retries-- - time.AfterFunc(task.RetryDelay, func() { - tm.RegisterTask(tm.ctx, *task) - }) - } + return nil }