diff --git a/changelog.md b/changelog.md index 434d69161e..88bb15b57a 100644 --- a/changelog.md +++ b/changelog.md @@ -17,6 +17,7 @@ * [3170](https://github.com/zeta-chain/node/pull/3170) - revamp TSS package in zetaclient * [3291](https://github.com/zeta-chain/node/pull/3291) - revamp zetaclient initialization (+ graceful shutdown) +* [3319](https://github.com/zeta-chain/node/pull/3319) - implement scheduler for zetaclient ### Fixes diff --git a/pkg/scheduler/metrics.go b/pkg/scheduler/metrics.go new file mode 100644 index 0000000000..96d581ebfb --- /dev/null +++ b/pkg/scheduler/metrics.go @@ -0,0 +1,29 @@ +package scheduler + +import ( + "time" + + "github.com/zeta-chain/node/zetaclient/metrics" +) + +// Note that currently the hard-coded "global" metrics are used. +func recordMetrics(task *Task, startedAt time.Time, err error, skipped bool) { + var status string + switch { + case skipped: + status = "skipped" + case err != nil: + status = "failed" + default: + status = "ok" + } + + var ( + group = string(task.group) + name = task.name + dur = time.Since(startedAt).Seconds() + ) + + metrics.SchedulerTaskInvocationCounter.WithLabelValues(status, group, name).Inc() + metrics.SchedulerTaskExecutionDuration.WithLabelValues(status, group, name).Observe(dur) +} diff --git a/pkg/scheduler/opts.go b/pkg/scheduler/opts.go new file mode 100644 index 0000000000..8e5d54e370 --- /dev/null +++ b/pkg/scheduler/opts.go @@ -0,0 +1,46 @@ +package scheduler + +import ( + "time" + + cometbft "github.com/cometbft/cometbft/types" +) + +// Opt Task option +type Opt func(task *Task, taskOpts *taskOpts) + +// Name sets task name. +func Name(name string) Opt { + return func(t *Task, _ *taskOpts) { t.name = name } +} + +// GroupName sets task group. Otherwise, defaults to DefaultGroup. +func GroupName(group Group) Opt { + return func(t *Task, _ *taskOpts) { t.group = group } +} + +// LogFields augments Task's logger with some fields. +func LogFields(fields map[string]any) Opt { + return func(_ *Task, opts *taskOpts) { opts.logFields = fields } +} + +// Interval sets initial task interval. +func Interval(interval time.Duration) Opt { + return func(_ *Task, opts *taskOpts) { opts.interval = interval } +} + +// Skipper sets task skipper function +func Skipper(skipper func() bool) Opt { + return func(t *Task, _ *taskOpts) { t.skipper = skipper } +} + +// IntervalUpdater sets interval updater function. +func IntervalUpdater(intervalUpdater func() time.Duration) Opt { + return func(_ *Task, opts *taskOpts) { opts.intervalUpdater = intervalUpdater } +} + +// BlockTicker makes Task to listen for new zeta blocks +// instead of using interval ticker. IntervalUpdater is ignored. +func BlockTicker(blocks <-chan cometbft.EventDataNewBlock) Opt { + return func(_ *Task, opts *taskOpts) { opts.blockChan = blocks } +} diff --git a/pkg/scheduler/scheduler.go b/pkg/scheduler/scheduler.go new file mode 100644 index 0000000000..2328cbddd7 --- /dev/null +++ b/pkg/scheduler/scheduler.go @@ -0,0 +1,211 @@ +// Package scheduler provides a background task scheduler that allows for the registration, +// execution, and management of periodic tasks. Tasks can be grouped, named, and configured +// with various options such as custom intervals, log fields, and skip conditions. +// +// The scheduler supports dynamic interval updates and can gracefully stop tasks either +// individually or by group. +package scheduler + +import ( + "context" + "sync" + "time" + + cometbft "github.com/cometbft/cometbft/types" + "github.com/google/uuid" + "github.com/rs/zerolog" + + "github.com/zeta-chain/node/pkg/bg" +) + +// Scheduler represents background task scheduler. +type Scheduler struct { + tasks map[uuid.UUID]*Task + mu sync.RWMutex + logger zerolog.Logger +} + +// Executable arbitrary function that can be executed. +type Executable func(ctx context.Context) error + +// Group represents Task group. Tasks can be grouped for easier management. +type Group string + +// DefaultGroup is the default task group. +const DefaultGroup = Group("default") + +// tickable ticker abstraction to support different implementations +type tickable interface { + Start(ctx context.Context) error + Stop() +} + +// Task represents scheduler's task. +type Task struct { + // ref to the Scheduler is required + scheduler *Scheduler + + id uuid.UUID + group Group + name string + + exec Executable + + // ticker abstraction to support different implementations + ticker tickable + skipper func() bool + + logger zerolog.Logger +} + +type taskOpts struct { + interval time.Duration + intervalUpdater func() time.Duration + + blockChan <-chan cometbft.EventDataNewBlock + + logFields map[string]any +} + +// New Scheduler instance. +func New(logger zerolog.Logger) *Scheduler { + return &Scheduler{ + tasks: make(map[uuid.UUID]*Task), + logger: logger.With().Str("module", "scheduler").Logger(), + } +} + +// Register registers and starts new Task in the background +func (s *Scheduler) Register(ctx context.Context, exec Executable, opts ...Opt) *Task { + id := uuid.New() + task := &Task{ + scheduler: s, + id: id, + group: DefaultGroup, + name: id.String(), + exec: exec, + } + + config := &taskOpts{ + interval: time.Second, + } + + for _, opt := range opts { + opt(task, config) + } + + task.logger = newTaskLogger(task, config, s.logger) + task.ticker = newTickable(task, config) + + task.logger.Info().Msg("Starting scheduler task") + bg.Work(ctx, task.ticker.Start, bg.WithLogger(task.logger)) + + s.mu.Lock() + s.tasks[id] = task + s.mu.Unlock() + + return task +} + +// Stop stops all tasks. +func (s *Scheduler) Stop() { + s.StopGroup("") +} + +// StopGroup stops all tasks in the group. +func (s *Scheduler) StopGroup(group Group) { + var selectedTasks []*Task + + s.mu.RLock() + + // Filter desired tasks + for _, task := range s.tasks { + // "" is for wildcard i.e. all groups + if group == "" || task.group == group { + selectedTasks = append(selectedTasks, task) + } + } + + s.mu.RUnlock() + + if len(selectedTasks) == 0 { + return + } + + // Stop all selected tasks concurrently + var wg sync.WaitGroup + wg.Add(len(selectedTasks)) + + for _, task := range selectedTasks { + go func(task *Task) { + defer wg.Done() + task.Stop() + }(task) + } + + wg.Wait() +} + +// Stop stops the task and offloads it from the scheduler. +func (t *Task) Stop() { + t.logger.Info().Msg("Stopping scheduler task") + start := time.Now() + + t.ticker.Stop() + + t.scheduler.mu.Lock() + delete(t.scheduler.tasks, t.id) + t.scheduler.mu.Unlock() + + timeTakenMS := time.Since(start).Milliseconds() + t.logger.Info().Int64("time_taken_ms", timeTakenMS).Msg("Stopped scheduler task") +} + +// execute executes Task with additional logging and metrics. +func (t *Task) execute(ctx context.Context) error { + startedAt := time.Now().UTC() + + // skip tick + if t.skipper != nil && t.skipper() { + recordMetrics(t, startedAt, nil, true) + return nil + } + + err := t.exec(ctx) + + recordMetrics(t, startedAt, err, false) + + return err +} + +func newTaskLogger(task *Task, opts *taskOpts, logger zerolog.Logger) zerolog.Logger { + logOpts := logger.With(). + Str("task.name", task.name). + Str("task.group", string(task.group)) + + if len(opts.logFields) > 0 { + logOpts = logOpts.Fields(opts.logFields) + } + + taskType := "interval_ticker" + if opts.blockChan != nil { + taskType = "block_ticker" + } + + return logOpts.Str("task.type", taskType).Logger() +} + +func newTickable(task *Task, opts *taskOpts) tickable { + // Block-based ticker + if opts.blockChan != nil { + return newBlockTicker(task.execute, opts.blockChan, task.logger) + } + + return newIntervalTicker( + task.execute, + opts.interval, + opts.intervalUpdater, + task.name, + task.logger, + ) +} diff --git a/pkg/scheduler/scheduler_test.go b/pkg/scheduler/scheduler_test.go new file mode 100644 index 0000000000..a993bc875a --- /dev/null +++ b/pkg/scheduler/scheduler_test.go @@ -0,0 +1,403 @@ +package scheduler + +import ( + "bytes" + "context" + "fmt" + "io" + "sync/atomic" + "testing" + "time" + + cometbft "github.com/cometbft/cometbft/types" + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestScheduler(t *testing.T) { + t.Run("Basic case", func(t *testing.T) { + t.Parallel() + + // ARRANGE + ts := newTestSuite(t) + + var counter int32 + + exec := func(ctx context.Context) error { + atomic.AddInt32(&counter, 1) + return nil + } + + // ACT + // Register task and stop it after x1.5 interval. + ts.scheduler.Register(ts.ctx, exec) + time.Sleep(1500 * time.Millisecond) + ts.scheduler.Stop() + + // ASSERT + // Counter should be 2 because we invoke a task once on a start, + // once after 1 second (default interval), + // and then at T=1.5s we stop the scheduler. + assert.Equal(t, int32(2), counter) + + // Check logs + assert.Contains(t, ts.logBuffer.String(), "Stopped scheduler task") + assert.Contains(t, ts.logBuffer.String(), `"task.group":"default"`) + }) + + t.Run("More opts", func(t *testing.T) { + t.Parallel() + + // ARRANGE + ts := newTestSuite(t) + + var counter int32 + + exec := func(ctx context.Context) error { + atomic.AddInt32(&counter, 1) + return nil + } + + // ACT + // Register task and stop it after x1.5 interval. + ts.scheduler.Register( + ts.ctx, + exec, + Name("counter-inc"), + GroupName("my-custom-group"), + Interval(300*time.Millisecond), + LogFields(map[string]any{ + "blockchain": "doge", + "validators": []string{"alice", "bob"}, + }), + ) + + time.Sleep(time.Second) + ts.scheduler.Stop() + + // ASSERT + // Counter should be 1 + 1000/300 = 4 (first run + interval runs) + assert.Equal(t, int32(4), counter) + + // Also check that log fields are present + assert.Contains(t, ts.logBuffer.String(), `"task.name":"counter-inc","task.group":"my-custom-group"`) + assert.Contains(t, ts.logBuffer.String(), `"blockchain":"doge","validators":["alice","bob"]`) + }) + + t.Run("Task can stop itself", func(t *testing.T) { + t.Parallel() + + // ARRANGE + ts := newTestSuite(t) + + var counter int32 + + exec := func(ctx context.Context) error { + atomic.AddInt32(&counter, 1) + return nil + } + + // ACT + // Register task and stop it after x1.5 interval. + task := ts.scheduler.Register(ts.ctx, exec, Interval(300*time.Millisecond)) + + time.Sleep(time.Second) + task.Stop() + + // ASSERT + // Counter should be 1 + 1000/300 = 4 (first run + interval runs) + assert.Equal(t, int32(4), counter) + }) + + t.Run("Skipper option", func(t *testing.T) { + t.Parallel() + + // ARRANGE + ts := newTestSuite(t) + + var counter int32 + + exec := func(ctx context.Context) error { + atomic.AddInt32(&counter, 1) + return nil + } + + const maxValue = 5 + + // Skipper function that drops the task after reaching a certain counter value. + skipper := func() bool { + allowed := atomic.LoadInt32(&counter) < maxValue + return !allowed + } + + // ACT + // Register task and stop it after x1.5 interval. + task := ts.scheduler.Register(ts.ctx, exec, Interval(50*time.Millisecond), Skipper(skipper)) + + time.Sleep(time.Second) + task.Stop() + + // ASSERT + assert.Equal(t, int32(maxValue), counter) + }) + + t.Run("IntervalUpdater option", func(t *testing.T) { + t.Parallel() + + // ARRANGE + ts := newTestSuite(t) + + var counter int32 + + exec := func(ctx context.Context) error { + atomic.AddInt32(&counter, 1) + return nil + } + + // Interval updater that increases the interval by 50ms on each counter increment. + intervalUpdater := func() time.Duration { + return time.Duration(atomic.LoadInt32(&counter)) * 50 * time.Millisecond + } + + // ACT + // Register task and stop it after x1.5 interval. + task := ts.scheduler.Register(ts.ctx, exec, Interval(time.Millisecond), IntervalUpdater(intervalUpdater)) + + time.Sleep(time.Second) + task.Stop() + + // ASSERT + assert.Equal(t, int32(6), counter) + + assert.Contains(t, ts.logBuffer.String(), `"ticker.old_interval":1,"ticker.new_interval":50`) + assert.Contains(t, ts.logBuffer.String(), `"ticker.old_interval":50,"ticker.new_interval":100`) + assert.Contains(t, ts.logBuffer.String(), `"ticker.old_interval":100,"ticker.new_interval":150`) + assert.Contains(t, ts.logBuffer.String(), `"ticker.old_interval":150,"ticker.new_interval":200`) + assert.Contains(t, ts.logBuffer.String(), `"ticker.old_interval":200,"ticker.new_interval":250`) + assert.Contains(t, ts.logBuffer.String(), `"ticker.old_interval":250,"ticker.new_interval":300`) + }) + + t.Run("Multiple tasks in different groups", func(t *testing.T) { + t.Parallel() + + // ARRANGE + ts := newTestSuite(t) + + // Given multiple tasks + var counterA, counterB, counterC int32 + + // Two tasks for Alice + taskAliceA := func(ctx context.Context) error { + atomic.AddInt32(&counterA, 1) + time.Sleep(60 * time.Millisecond) + return nil + } + + taskAliceB := func(ctx context.Context) error { + atomic.AddInt32(&counterB, 1) + time.Sleep(70 * time.Millisecond) + return nil + } + + // One task for Bob + taskBobC := func(ctx context.Context) error { + atomic.AddInt32(&counterC, 1) + time.Sleep(80 * time.Millisecond) + return nil + } + + // ACT + // Register all tasks with different intervals and groups + ts.scheduler.Register(ts.ctx, taskAliceA, Interval(50*time.Millisecond), GroupName("alice"), Name("a")) + ts.scheduler.Register(ts.ctx, taskAliceB, Interval(100*time.Millisecond), GroupName("alice"), Name("b")) + ts.scheduler.Register(ts.ctx, taskBobC, Interval(200*time.Millisecond), GroupName("bob"), Name("c")) + + // Wait and then stop Alice's tasks + time.Sleep(time.Second) + ts.scheduler.StopGroup("alice") + + // ASSERT #1 + shutdownLogPattern := func(group, name string) string { + const pattern = `"task\.name":"%s","task\.group":"%s",.*"message":"Stopped scheduler task"` + return fmt.Sprintf(pattern, name, group) + } + + // Make sure Alice.A and Alice.B are stopped + assert.Regexp(t, shutdownLogPattern("alice", "a"), ts.logBuffer.String()) + assert.Regexp(t, shutdownLogPattern("alice", "b"), ts.logBuffer.String()) + + // But Bob.C is still running + assert.NotRegexp(t, shutdownLogPattern("bob", "c"), ts.logBuffer.String()) + + // ACT #2 + time.Sleep(200 * time.Millisecond) + ts.scheduler.StopGroup("bob") + + // ASSERT #2 + // Bob.C is not running + assert.Regexp(t, shutdownLogPattern("bob", "c"), ts.logBuffer.String()) + }) + + t.Run("Block tick: tick is faster than the block", func(t *testing.T) { + t.Parallel() + + // ARRANGE + ts := newTestSuite(t) + + // Given a task that increments a counter by block height + var counter int64 + + task := func(ctx context.Context) error { + // Note that ctx contains the block event + blockEvent, ok := BlockFromContext(ctx) + require.True(t, ok) + + atomic.AddInt64(&counter, blockEvent.Block.Height) + time.Sleep(100 * time.Millisecond) + return nil + } + + // Given block ticker + blockChan := ts.mockBlockChan(200*time.Millisecond, 0) + + // ACT + // Register block + ts.scheduler.Register(ts.ctx, task, BlockTicker(blockChan)) + time.Sleep(1200 * time.Millisecond) + ts.scheduler.Stop() + + // ASSERT + assert.Equal(t, int64(21), counter) + assert.Contains(t, ts.logBuffer.String(), "Stopped scheduler task") + assert.Contains(t, ts.logBuffer.String(), `"task.type":"block_ticker"`) + }) + + t.Run("Block tick: tick is slower than the block", func(t *testing.T) { + t.Parallel() + + // ARRANGE + ts := newTestSuite(t) + + // Given a task that increments a counter on start + // and then decrements before finish + var counter int64 + + exec := func(ctx context.Context) error { + _, ok := BlockFromContext(ctx) + require.True(t, ok) + + atomic.AddInt64(&counter, 1) + time.Sleep(256 * time.Millisecond) + atomic.AddInt64(&counter, -1) + return nil + } + + // Given block ticker + blockChan := ts.mockBlockChan(100*time.Millisecond, 0) + + // ACT + // Register block + ts.scheduler.Register(ts.ctx, exec, BlockTicker(blockChan)) + time.Sleep(1200 * time.Millisecond) + ts.scheduler.Stop() + + // ASSERT + // zero indicates that Stop() waits for current iteration to finish (graceful shutdown) + assert.Equal(t, int64(0), counter) + }) + + t.Run("Block tick: chan closes unexpectedly", func(t *testing.T) { + t.Parallel() + + // ARRANGE + ts := newTestSuite(t) + + // Given a task that increments a counter on start + // and then decrements before finish + var counter int64 + + exec := func(ctx context.Context) error { + _, ok := BlockFromContext(ctx) + require.True(t, ok) + + atomic.AddInt64(&counter, 1) + time.Sleep(200 * time.Millisecond) + atomic.AddInt64(&counter, -1) + return nil + } + + // Given block ticker that closes after 3 blocks + blockChan := ts.mockBlockChan(100*time.Millisecond, 3) + + // ACT + // Register block + ts.scheduler.Register(ts.ctx, exec, BlockTicker(blockChan), Name("block-tick")) + + // Wait for a while + time.Sleep(1000 * time.Millisecond) + + // Stop the scheduler. + // Note that actually the ticker is already stopped. + ts.scheduler.Stop() + + // ASSERT + // zero indicates that Stop() waits for current iteration to finish (graceful shutdown) + assert.Equal(t, int64(0), counter) + assert.Contains(t, ts.logBuffer.String(), "Block channel closed") + }) +} + +type testSuite struct { + ctx context.Context + scheduler *Scheduler + + logger zerolog.Logger + logBuffer *bytes.Buffer +} + +func newTestSuite(t *testing.T) *testSuite { + logBuffer := &bytes.Buffer{} + logger := zerolog.New(io.MultiWriter(zerolog.NewTestWriter(t), logBuffer)) + + return &testSuite{ + ctx: context.Background(), + scheduler: New(logger), + logger: logger, + logBuffer: logBuffer, + } +} + +// mockBlockChan mocks websocket blocks. Optionally halts after lastBlock. +func (ts *testSuite) mockBlockChan(interval time.Duration, lastBlock int64) chan cometbft.EventDataNewBlock { + producer := make(chan cometbft.EventDataNewBlock) + + go func() { + var blockNumber int64 + + for { + blockNumber++ + ts.logger.Info().Int64("block_number", blockNumber).Msg("Producing new block") + + header := cometbft.Header{ + ChainID: "zeta", + Height: blockNumber, + Time: time.Now(), + } + + producer <- cometbft.EventDataNewBlock{ + Block: &cometbft.Block{Header: header}, + } + + if blockNumber > 0 && blockNumber == lastBlock { + ts.logger.Info().Int64("block_number", blockNumber).Msg("Halting block producer") + close(producer) + return + } + + time.Sleep(interval) + } + }() + + return producer +} diff --git a/pkg/scheduler/tickers.go b/pkg/scheduler/tickers.go new file mode 100644 index 0000000000..613194c44b --- /dev/null +++ b/pkg/scheduler/tickers.go @@ -0,0 +1,172 @@ +package scheduler + +import ( + "context" + "fmt" + "sync" + "time" + + cometbft "github.com/cometbft/cometbft/types" + "github.com/rs/zerolog" + + "github.com/zeta-chain/node/pkg/ticker" +) + +// intervalTicker wrapper for ticker.Ticker. +type intervalTicker struct { + ticker *ticker.Ticker +} + +func newIntervalTicker( + task Executable, + interval time.Duration, + intervalUpdater func() time.Duration, + taskName string, + logger zerolog.Logger, +) *intervalTicker { + wrapper := func(ctx context.Context, t *ticker.Ticker) error { + if err := task(ctx); err != nil { + logger.Error().Err(err).Msg("task failed") + } + + if intervalUpdater != nil { + // noop if interval is not changed + t.SetInterval(intervalUpdater()) + } + + return nil + } + + tt := ticker.New(interval, wrapper, ticker.WithLogger(logger, taskName)) + + return &intervalTicker{ticker: tt} +} + +func (t *intervalTicker) Start(ctx context.Context) error { + return t.ticker.Start(ctx) +} + +func (t *intervalTicker) Stop() { + t.ticker.StopBlocking() +} + +// blockTicker represents custom ticker implementation that ticks on new Zeta block events. +// Pass blockTicker ONLY by pointer. +type blockTicker struct { + exec Executable + + // block channel that will be used to receive new blocks + blockChan <-chan cometbft.EventDataNewBlock + + // stopChan is used to stop the ticker + stopChan chan struct{} + + // doneChan is used to signal that the ticker has stopped (i.e. "blocking stop") + doneChan chan struct{} + + isRunning bool + mu sync.Mutex + + logger zerolog.Logger +} + +type blockCtxKey struct{} + +func newBlockTicker(task Executable, blockChan <-chan cometbft.EventDataNewBlock, logger zerolog.Logger) *blockTicker { + return &blockTicker{ + exec: task, + blockChan: blockChan, + logger: logger, + } +} + +func withBlockEvent(ctx context.Context, event cometbft.EventDataNewBlock) context.Context { + return context.WithValue(ctx, blockCtxKey{}, event) +} + +// BlockFromContext returns cometbft.EventDataNewBlock from the context or false. +func BlockFromContext(ctx context.Context) (cometbft.EventDataNewBlock, bool) { + blockEvent, ok := ctx.Value(blockCtxKey{}).(cometbft.EventDataNewBlock) + return blockEvent, ok +} + +func (t *blockTicker) Start(ctx context.Context) error { + if err := t.init(); err != nil { + return err + } + + defer t.cleanup() + + // release Stop() blocking + defer func() { close(t.doneChan) }() + + for { + select { + case block, ok := <-t.blockChan: + // channel closed + if !ok { + t.logger.Warn().Msg("Block channel closed") + return nil + } + + ctx := withBlockEvent(ctx, block) + + if err := t.exec(ctx); err != nil { + t.logger.Error().Err(err).Msg("Task error") + } + case <-ctx.Done(): + t.logger.Warn().Err(ctx.Err()).Msg("Content error") + return nil + case <-t.stopChan: + // caller invoked t.stop() + return nil + } + } +} + +func (t *blockTicker) Stop() { + t.mu.Lock() + defer t.mu.Unlock() + + // noop + if !t.isRunning { + return + } + + // notify async loop to stop + close(t.stopChan) + + // wait for the loop to stop + <-t.doneChan + + t.isRunning = false +} + +func (t *blockTicker) init() error { + t.mu.Lock() + defer t.mu.Unlock() + + if t.isRunning { + return fmt.Errorf("ticker already started") + } + + t.stopChan = make(chan struct{}) + t.doneChan = make(chan struct{}) + t.isRunning = true + + return nil +} + +// if ticker was stopped NOT by Stop() method, we want to make a cleanup +func (t *blockTicker) cleanup() { + t.mu.Lock() + defer t.mu.Unlock() + + // noop + if !t.isRunning { + return + } + + t.isRunning = false + close(t.stopChan) +} diff --git a/pkg/ticker/ticker.go b/pkg/ticker/ticker.go index ebac3e1a96..9ec0d4cb06 100644 --- a/pkg/ticker/ticker.go +++ b/pkg/ticker/ticker.go @@ -56,6 +56,7 @@ type Ticker struct { stopped bool ctxCancel context.CancelFunc + runCompleteChan chan struct{} externalStopChan <-chan struct{} logger zerolog.Logger } @@ -94,7 +95,7 @@ func New(interval time.Duration, task Task, opts ...Opt) *Ticker { // Run creates and runs a new Ticker. func Run(ctx context.Context, interval time.Duration, task Task, opts ...Opt) error { - return New(interval, task, opts...).Run(ctx) + return New(interval, task, opts...).Start(ctx) } // Run runs the ticker by blocking current goroutine. It also invokes BEFORE ticker starts. @@ -102,8 +103,17 @@ func Run(ctx context.Context, interval time.Duration, task Task, opts ...Opt) er // - context is done (returns ctx.Err()) // - task returns an error or panics // - shutdown signal is received -func (t *Ticker) Run(ctx context.Context) (err error) { +func (t *Ticker) Start(ctx context.Context) (err error) { + // prevent concurrent runs + t.runnerMu.Lock() + defer t.runnerMu.Unlock() + + ctx = t.setStartState(ctx) + defer func() { + // used in StopBlocking() + close(t.runCompleteChan) + if r := recover(); r != nil { stack := string(debug.Stack()) lines := strings.Split(stack, "\n") @@ -116,21 +126,14 @@ func (t *Ticker) Run(ctx context.Context) (err error) { } }() - // prevent concurrent runs - t.runnerMu.Lock() - defer t.runnerMu.Unlock() - - // setup - ctx, t.ctxCancel = context.WithCancel(ctx) - t.ticker = time.NewTicker(t.interval) - t.stopped = false - // initial run - if err := t.task(ctx, t); err != nil { + if err = t.task(ctx, t); err != nil { t.Stop() return fmt.Errorf("ticker task failed (initial run): %w", err) } + defer t.setStopState() + for { select { case <-ctx.Done(): @@ -172,8 +175,35 @@ func (t *Ticker) SetInterval(interval time.Duration) { t.ticker.Reset(interval) } -// Stop stops the ticker. Safe to call concurrently or multiple times. +// Stop stops the ticker in a NON-blocking way. If the task is running in a separate goroutine, +// this call *might* not wait for it to finish. To wait for task finish, use StopBlocking(). +// It's safe to call Stop() multiple times / concurrently / within the task. func (t *Ticker) Stop() { + t.setStopState() +} + +// StopBlocking stops the ticker in a blocking way i.e. it waits for the task to finish. +// DO NOT call this within the task. +func (t *Ticker) StopBlocking() { + t.setStopState() + <-t.runCompleteChan +} + +func (t *Ticker) setStartState(ctx context.Context) context.Context { + t.stateMu.Lock() + defer t.stateMu.Unlock() + + ctx, t.ctxCancel = context.WithCancel(ctx) + t.ticker = time.NewTicker(t.interval) + t.stopped = false + + // this signals that Run() is about to return + t.runCompleteChan = make(chan struct{}) + + return ctx +} + +func (t *Ticker) setStopState() { t.stateMu.Lock() defer t.stateMu.Unlock() diff --git a/pkg/ticker/ticker_test.go b/pkg/ticker/ticker_test.go index 4e03a28d4f..276eb9457f 100644 --- a/pkg/ticker/ticker_test.go +++ b/pkg/ticker/ticker_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "sync/atomic" "testing" "time" @@ -19,6 +20,8 @@ func TestTicker(t *testing.T) { ) t.Run("Basic case with context", func(t *testing.T) { + t.Parallel() + // ARRANGE // Given a counter var counter int @@ -35,7 +38,7 @@ func TestTicker(t *testing.T) { }) // ACT - err := ticker.Run(ctx) + err := ticker.Start(ctx) // ASSERT assert.ErrorIs(t, err, context.DeadlineExceeded) @@ -45,6 +48,8 @@ func TestTicker(t *testing.T) { }) t.Run("Halts when error occurred", func(t *testing.T) { + t.Parallel() + // ARRANGE // Given a counter var counter int @@ -62,7 +67,7 @@ func TestTicker(t *testing.T) { }) // ACT - err := ticker.Run(ctx) + err := ticker.Start(ctx) // ASSERT assert.ErrorContains(t, err, "oops") @@ -70,6 +75,8 @@ func TestTicker(t *testing.T) { }) t.Run("Dynamic interval update", func(t *testing.T) { + t.Parallel() + // ARRANGE // Given a counter var counter int @@ -93,7 +100,7 @@ func TestTicker(t *testing.T) { }) // ACT - err := ticker.Run(ctx) + err := ticker.Start(ctx) // ASSERT assert.ErrorIs(t, err, context.DeadlineExceeded) @@ -104,6 +111,8 @@ func TestTicker(t *testing.T) { }) t.Run("Stop ticker", func(t *testing.T) { + t.Parallel() + // ARRANGE // Given a counter var counter int @@ -124,7 +133,7 @@ func TestTicker(t *testing.T) { }() // ACT - err := ticker.Run(ctx) + err := ticker.Start(ctx) // ASSERT assert.NoError(t, err) @@ -135,7 +144,96 @@ func TestTicker(t *testing.T) { }) }) + t.Run("Stop ticker in a blocking fashion", func(t *testing.T) { + t.Parallel() + + const ( + tickerInterval = 100 * time.Millisecond + workDuration = 600 * time.Millisecond + stopAfterStart = workDuration + tickerInterval/2 + ) + + newLogger := func(t *testing.T) zerolog.Logger { + return zerolog.New(zerolog.NewTestWriter(t)).With().Timestamp().Logger() + } + + // test task that imitates some work + newTask := func(counter *int32, logger zerolog.Logger) Task { + return func(ctx context.Context, _ *Ticker) error { + logger.Info().Msg("Tick start") + atomic.AddInt32(counter, 1) + + time.Sleep(workDuration) + + logger.Info().Msgf("Tick end") + atomic.AddInt32(counter, -1) + + return nil + } + } + + t.Run("Non-blocking stop fails do finish the work", func(t *testing.T) { + t.Parallel() + + // ARRANGE + // Given some test task that imitates some work + testLogger := newLogger(t) + counter := int32(0) + task := newTask(&counter, testLogger) + + // Given a ticker + ticker := New(tickerInterval, task, WithLogger(testLogger, "test-non-blocking-ticker")) + + // ACT + // Imitate the ticker run in the background + go func() { + err := ticker.Start(context.Background()) + require.NoError(t, err) + }() + + // Then stop the ticker after some delay + time.Sleep(stopAfterStart) + testLogger.Info().Msg("Stopping ticker") + ticker.Stop() + testLogger.Info().Msg("Stopped ticker") + + // ASSERT + // If ticker is stopped BEFORE the work is done i.e. "in the middle of work", + // thus the counter would be `1. You can also check the logs + assert.Equal(t, int32(1), counter) + }) + + t.Run("Blocking stop works as expected", func(t *testing.T) { + t.Parallel() + + // ARRANGE + // Now if we have the SAME test but with blocking stop, it should work + testLogger := newLogger(t) + counter := int32(0) + task := newTask(&counter, testLogger) + + ticker := New(tickerInterval, task, WithLogger(testLogger, "test-non-blocking-ticker")) + + // ACT + go func() { + err := ticker.Start(context.Background()) + require.NoError(t, err) + }() + + time.Sleep(stopAfterStart) + testLogger.Info().Msg("Stopping ticker") + ticker.StopBlocking() + testLogger.Info().Msg("Stopped ticker") + + // ASSERT + // If ticker is stopped AFTER the work is done + assert.Equal(t, int32(0), counter) + }) + }) + t.Run("Panic", func(t *testing.T) { + t.Parallel() + // ARRANGE // Given a context ctx := context.Background() @@ -146,15 +244,17 @@ func TestTicker(t *testing.T) { }) // ACT - err := ticker.Run(ctx) + err := ticker.Start(ctx) // ASSERT assert.ErrorContains(t, err, "panic during ticker run: oops") // assert that we get error with the correct line number - assert.ErrorContains(t, err, "ticker_test.go:145") + assert.ErrorContains(t, err, "ticker_test.go:243") }) t.Run("Nil panic", func(t *testing.T) { + t.Parallel() + // ARRANGE // Given a context ctx := context.Background() @@ -167,7 +267,7 @@ func TestTicker(t *testing.T) { }) // ACT - err := ticker.Run(ctx) + err := ticker.Start(ctx) // ASSERT assert.ErrorContains( @@ -176,10 +276,12 @@ func TestTicker(t *testing.T) { "panic during ticker run: runtime error: invalid memory address or nil pointer dereference", ) // assert that we get error with the correct line number - assert.ErrorContains(t, err, "ticker_test.go:165") + assert.ErrorContains(t, err, "ticker_test.go:265") }) t.Run("Run as a single call", func(t *testing.T) { + t.Parallel() + // ARRANGE // Given a counter var counter int @@ -202,6 +304,8 @@ func TestTicker(t *testing.T) { }) t.Run("With stop channel", func(t *testing.T) { + t.Parallel() + // ARRANGE var ( tickerInterval = 100 * time.Millisecond @@ -232,6 +336,8 @@ func TestTicker(t *testing.T) { }) t.Run("With logger", func(t *testing.T) { + t.Parallel() + // ARRANGE out := &bytes.Buffer{} logger := zerolog.New(out) diff --git a/zetaclient/metrics/metrics.go b/zetaclient/metrics/metrics.go index 28fe897504..36dc5ad813 100644 --- a/zetaclient/metrics/metrics.go +++ b/zetaclient/metrics/metrics.go @@ -170,6 +170,27 @@ var ( Name: "num_connected_peers", Help: "The number of connected peers (authenticated keygen peers)", }) + + // SchedulerTaskInvocationCounter tracks invocations categorized by status, group, and name + SchedulerTaskInvocationCounter = promauto.NewCounterVec( + prometheus.CounterOpts{ + Namespace: ZetaClientNamespace, + Name: "scheduler_task_invocations_total", + Help: "Total number of task invocations", + }, + []string{"status", "task_group", "task_name"}, + ) + + // SchedulerTaskExecutionDuration measures the execution duration of tasks + SchedulerTaskExecutionDuration = promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: ZetaClientNamespace, + Name: "scheduler_task_duration_seconds", + Help: "Histogram of task execution duration in seconds", + Buckets: []float64{0.05, 0.1, 0.2, 0.3, 0.5, 1, 1.5, 2, 3, 5, 7.5, 10, 15}, // 50ms to 15s + }, + []string{"status", "task_group", "task_name"}, + ) ) // NewMetrics creates a new Metrics instance