Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pkg): scheduler package #3319

Merged
merged 12 commits into from
Dec 20, 2024
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

* [3170](https://github.com/zeta-chain/node/pull/3170) - revamp TSS package in zetaclient
* [3291](https://github.com/zeta-chain/node/pull/3291) - revamp zetaclient initialization (+ graceful shutdown)
* [3319](https://github.com/zeta-chain/node/pull/3319) - implement scheduler for zetaclient

### Fixes

Expand Down
46 changes: 46 additions & 0 deletions pkg/scheduler/opts.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package scheduler

import (
"time"

cometbft "github.com/cometbft/cometbft/types"
)

// Opt Task option
type Opt func(task *Task, taskOpts *taskOpts)

// Name sets task name.
func Name(name string) Opt {
return func(t *Task, _ *taskOpts) { t.name = name }
}

// GroupName sets task group. Otherwise, defaults to DefaultGroup.
func GroupName(group Group) Opt {
return func(t *Task, _ *taskOpts) { t.group = group }
}

// LogFields augments Task's logger with some fields.
func LogFields(fields map[string]any) Opt {
return func(_ *Task, opts *taskOpts) { opts.logFields = fields }
}

// Interval sets initial task interval.
func Interval(interval time.Duration) Opt {
return func(_ *Task, opts *taskOpts) { opts.interval = interval }
}

// Skipper sets task skipper function
func Skipper(skipper func() bool) Opt {
return func(t *Task, _ *taskOpts) { t.skipper = skipper }
}

// IntervalUpdater sets interval updater function.
func IntervalUpdater(intervalUpdater func() time.Duration) Opt {
return func(_ *Task, opts *taskOpts) { opts.intervalUpdater = intervalUpdater }
}

// BlockTicker makes Task to listen for new zeta blocks
// instead of using interval ticker. IntervalUpdater is ignored.
func BlockTicker(blocks <-chan cometbft.EventDataNewBlock) Opt {
return func(_ *Task, opts *taskOpts) { opts.blockChan = blocks }
}
216 changes: 216 additions & 0 deletions pkg/scheduler/scheduler.go
swift1337 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
// Package scheduler provides a background task scheduler that allows for the registration,
// execution, and management of periodic tasks. Tasks can be grouped, named, and configured
// with various options such as custom intervals, log fields, and skip conditions.
//
// The scheduler supports dynamic interval updates and can gracefully stop tasks either
// individually or by group.
package scheduler

import (
"context"
"sync"
"time"

cometbft "github.com/cometbft/cometbft/types"
"github.com/google/uuid"
"github.com/rs/zerolog"

"github.com/zeta-chain/node/pkg/bg"
)

// Scheduler represents background task scheduler.
type Scheduler struct {
tasks map[uuid.UUID]*Task
mu sync.RWMutex
logger zerolog.Logger
}

// Executable arbitrary function that can be executed.
type Executable func(ctx context.Context) error

// Group represents Task group. Tasks can be grouped for easier management.
type Group string

// DefaultGroup is the default task group.
const DefaultGroup = Group("default")

// tickable ticker abstraction to support different implementations
type tickable interface {
Start(ctx context.Context) error
Stop()
}

// Task represents scheduler's task.
type Task struct {
// ref to the Scheduler is required
scheduler *Scheduler

id uuid.UUID
group Group
name string

exec Executable

// ticker abstraction to support different implementations
ticker tickable
skipper func() bool

logger zerolog.Logger
}

type taskOpts struct {
interval time.Duration
intervalUpdater func() time.Duration

blockChan <-chan cometbft.EventDataNewBlock

logFields map[string]any
}

// New Scheduler instance.
func New(logger zerolog.Logger) *Scheduler {
return &Scheduler{
tasks: make(map[uuid.UUID]*Task),
logger: logger.With().Str("module", "scheduler").Logger(),
}
}

// Register registers and starts new Task in the background
func (s *Scheduler) Register(ctx context.Context, exec Executable, opts ...Opt) *Task {
id := uuid.New()
task := &Task{
scheduler: s,
id: id,
group: DefaultGroup,
name: id.String(),
exec: exec,
}

config := &taskOpts{
interval: time.Second,
}

for _, opt := range opts {
opt(task, config)
}

task.logger = newTaskLogger(task, config, s.logger)
task.ticker = newTickable(task, config)

task.logger.Info().Msg("Starting scheduler task")
bg.Work(ctx, task.ticker.Start, bg.WithLogger(task.logger))

s.mu.Lock()
s.tasks[id] = task
s.mu.Unlock()

return task
}

// Stop stops all tasks.
func (s *Scheduler) Stop() {
s.StopGroup("")
}

// StopGroup stops all tasks in the group.
func (s *Scheduler) StopGroup(group Group) {
var selectedTasks []*Task

s.mu.RLock()

// Filter desired tasks
for _, task := range s.tasks {
// "" is for wildcard i.e. all groups
if group == "" || task.group == group {
selectedTasks = append(selectedTasks, task)
}
}

s.mu.RUnlock()

if len(selectedTasks) == 0 {
return
}

// Stop all selected tasks concurrently
var wg sync.WaitGroup
wg.Add(len(selectedTasks))

for _, task := range selectedTasks {
go func(task *Task) {
defer wg.Done()
task.Stop()
}(task)
}

wg.Wait()
}

// Stop stops the task and offloads it from the scheduler.
func (t *Task) Stop() {
t.logger.Info().Msg("Stopping scheduler task")
start := time.Now()

t.ticker.Stop()

t.scheduler.mu.Lock()
delete(t.scheduler.tasks, t.id)
t.scheduler.mu.Unlock()

timeTakenMS := time.Since(start).Milliseconds()
t.logger.Info().Int64("time_taken_ms", timeTakenMS).Msg("Stopped scheduler task")
}

// execute executes Task with additional logging and metrics.
func (t *Task) execute(ctx context.Context) error {
// skip tick
if t.skipper != nil && t.skipper() {
return nil
}

t.logger.Debug().Msg("Invoking task")

err := t.exec(ctx)

// todo metrics (TBD)
swift1337 marked this conversation as resolved.
Show resolved Hide resolved
// - duration (time taken)
// - outcome (skip, err, ok)
// - bump invocation counter
// - "last invoked at" timestamp (?)
// - chain_id
// - metrics cardinality: "task_group (?)" "task_name", "status", "chain_id"

return err
}

func newTaskLogger(task *Task, opts *taskOpts, logger zerolog.Logger) zerolog.Logger {
logOpts := logger.With().
Str("task.name", task.name).
Str("task.group", string(task.group))

if len(opts.logFields) > 0 {
logOpts = logOpts.Fields(opts.logFields)
}

taskType := "interval_ticker"
if opts.blockChan != nil {
taskType = "block_ticker"
}

return logOpts.Str("task.type", taskType).Logger()
}

func newTickable(task *Task, opts *taskOpts) tickable {
// Block-based ticker
if opts.blockChan != nil {
return newBlockTicker(task.execute, opts.blockChan, task.logger)
}

return newIntervalTicker(
task.execute,
opts.interval,
opts.intervalUpdater,
task.name,
task.logger,
)
}
Loading
Loading