diff --git a/fn/context_guard.go b/fn/context_guard.go index 23914702e7..ef07c134bf 100644 --- a/fn/context_guard.go +++ b/fn/context_guard.go @@ -3,6 +3,7 @@ package fn import ( "context" "sync" + "sync/atomic" "time" ) @@ -11,103 +12,234 @@ var ( DefaultTimeout = 30 * time.Second ) -// ContextGuard is an embeddable struct that provides a wait group and main quit -// channel that can be used to create guarded contexts. +// ContextGuard is a struct that provides a wait group and main quit channel +// that can be used to create guarded contexts. type ContextGuard struct { - DefaultTimeout time.Duration - Wg sync.WaitGroup - Quit chan struct{} + mu sync.Mutex + wg sync.WaitGroup + + quit chan struct{} + stopped sync.Once + + // id is used to generate unique ids for each context that should be + // cancelled when the main quit signal is triggered. + id atomic.Uint32 + + // cancelFns is a map of cancel functions that can be used to cancel + // any context that should be cancelled when the main quit signal is + // triggered. The key is the id of the context. The mutex must be held + // when accessing this map. + cancelFns map[uint32]context.CancelFunc } +// NewContextGuard constructs and returns a new instance of ContextGuard. func NewContextGuard() *ContextGuard { return &ContextGuard{ - DefaultTimeout: DefaultTimeout, - Quit: make(chan struct{}), + quit: make(chan struct{}), + cancelFns: make(map[uint32]context.CancelFunc), } } -// WithCtxQuit is used to create a cancellable context that will be cancelled -// if the main quit signal is triggered or after the default timeout occurred. -func (g *ContextGuard) WithCtxQuit() (context.Context, func()) { - return g.WithCtxQuitCustomTimeout(g.DefaultTimeout) -} +// Quit is used to signal the main quit channel, which will cancel all +// non-blocking contexts derived from the ContextGuard. +func (g *ContextGuard) Quit() { + g.stopped.Do(func() { + g.mu.Lock() + defer g.mu.Unlock() -// WithCtxQuitCustomTimeout is used to create a cancellable context that will be -// cancelled if the main quit signal is triggered or after the given timeout -// occurred. -func (g *ContextGuard) WithCtxQuitCustomTimeout( - timeout time.Duration) (context.Context, func()) { + for _, cancel := range g.cancelFns { + cancel() + } - timeoutTimer := time.NewTimer(timeout) - ctx, cancel := context.WithCancel(context.Background()) + close(g.quit) + }) +} - g.Wg.Add(1) - go func() { - defer timeoutTimer.Stop() - defer cancel() - defer g.Wg.Done() +// Done returns a channel that will be closed when the main quit signal is +// triggered. +func (g *ContextGuard) Done() <-chan struct{} { + return g.quit +} - select { - case <-g.Quit: +// WgAdd is used to add delta to the internal wait group of the ContextGuard. +func (g *ContextGuard) WgAdd(delta int) { + g.wg.Add(delta) +} - case <-timeoutTimer.C: +// WgDone is used to decrement the internal wait group of the ContextGuard. +func (g *ContextGuard) WgDone() { + g.wg.Done() +} - case <-ctx.Done(): - } - }() +// WgWait is used to block until the internal wait group of the ContextGuard is +// empty. +func (g *ContextGuard) WgWait() { + g.wg.Wait() +} - return ctx, cancel +// ctxGuardOptions is used to configure the behaviour of the context derived +// via the WithCtx method of the ContextGuard. +type ctxGuardOptions struct { + blocking bool + withTimeout bool + timeout time.Duration } -// CtxBlocking is used to create a cancellable context that will NOT be +// ContextGuardOption defines the signature of a functional option that can be +// used to configure the behaviour of the context derived via the WithCtx method +// of the ContextGuard. +type ContextGuardOption func(*ctxGuardOptions) + +// BlockingCGOpt is used to create a cancellable context that will NOT be // cancelled if the main quit signal is triggered, to block shutdown of -// important tasks. The context will be cancelled if the timeout is reached. -func (g *ContextGuard) CtxBlocking() (context.Context, func()) { - return g.CtxBlockingCustomTimeout(g.DefaultTimeout) +// important tasks. +func BlockingCGOpt() ContextGuardOption { + return func(o *ctxGuardOptions) { + o.blocking = true + } +} + +// CustomTimeoutCGOpt is used to create a cancellable context with a custom +// timeout. Such a context will be cancelled if either the parent context is +// cancelled, the timeout is reached or, if the Blocking option is not provided, +// the main quit signal is triggered. +func CustomTimeoutCGOpt(timeout time.Duration) ContextGuardOption { + return func(o *ctxGuardOptions) { + o.withTimeout = true + o.timeout = timeout + } +} + +// TimeoutCGOpt is used to create a cancellable context with a default timeout. +// Such a context will be cancelled if either the parent context is cancelled, +// the timeout is reached or, if the Blocking option is not provided, the main +// quit signal is triggered. +func TimeoutCGOpt() ContextGuardOption { + return func(o *ctxGuardOptions) { + o.withTimeout = true + o.timeout = DefaultTimeout + } +} + +// WithCtx is used to derive a cancellable context from the parent. Various +// options can be provided to configure the behaviour of the derived context. +func (g *ContextGuard) WithCtx(ctx context.Context, + options ...ContextGuardOption) (context.Context, context.CancelFunc) { + + g.mu.Lock() + defer g.mu.Unlock() + + // Exit early if the parent context has already been cancelled. + select { + case <-ctx.Done(): + return ctx, func() {} + default: + } + + var opts ctxGuardOptions + for _, o := range options { + o(&opts) + } + + var cancel context.CancelFunc + if opts.withTimeout { + ctx, cancel = context.WithTimeout(ctx, opts.timeout) + } else { + ctx, cancel = context.WithCancel(ctx) + } + + if opts.blocking { + g.ctxBlocking(ctx, cancel) + + return ctx, cancel + } + + // If the call is non-blocking, then we can exit early if the main quit + // signal has been triggered. + select { + case <-g.quit: + cancel() + + return ctx, cancel + default: + } + + cancel = g.ctxQuit(ctx, cancel) + + return ctx, cancel } -// CtxBlockingCustomTimeout is used to create a cancellable context with a -// custom timeout that will NOT be cancelled if the main quit signal is -// triggered, to block shutdown of important tasks. The context will be -// cancelled if the timeout is reached. -func (g *ContextGuard) CtxBlockingCustomTimeout( - timeout time.Duration) (context.Context, func()) { +// ctxQuit spins off a goroutine that will block until the passed context +// is cancelled or until the quit channel has been signaled after which it will +// call the passed cancel function and decrement the wait group. +// +// NOTE: the caller must hold the ContextGuard's mutex before calling this +// function. +func (g *ContextGuard) ctxQuit(ctx context.Context, + cancel context.CancelFunc) context.CancelFunc { - timeoutTimer := time.NewTimer(timeout) - ctx, cancel := context.WithCancel(context.Background()) + cancel = g.addCancelFn(cancel) - g.Wg.Add(1) + g.wg.Add(1) go func() { - defer timeoutTimer.Stop() defer cancel() - defer g.Wg.Done() + defer g.wg.Done() select { - case <-timeoutTimer.C: + case <-g.quit: case <-ctx.Done(): } }() - return ctx, cancel + return cancel } -// WithCtxQuitNoTimeout is used to create a cancellable context that will be -// cancelled if the main quit signal is triggered. -func (g *ContextGuard) WithCtxQuitNoTimeout() (context.Context, func()) { - ctx, cancel := context.WithCancel(context.Background()) +// ctxBlocking spins off a goroutine that will block until the passed context +// is cancelled after which it will call the passed cancel function and +// decrement the wait group. +func (g *ContextGuard) ctxBlocking(ctx context.Context, + cancel context.CancelFunc) { - g.Wg.Add(1) + g.wg.Add(1) go func() { defer cancel() - defer g.Wg.Done() + defer g.wg.Done() select { - case <-g.Quit: - case <-ctx.Done(): } }() +} - return ctx, cancel +// addCancelFn adds a context cancel function to the manager and returns a +// call-back which can safely be used to cancel the context. +// +// NOTE: the caller must hold the ContextGuard's mutex before calling this +// function. +func (g *ContextGuard) addCancelFn( + cancel context.CancelFunc) context.CancelFunc { + + id := g.id.Add(1) + g.cancelFns[id] = cancel + + return g.cancelCtxFn(id) +} + +// cancelCtxFn returns a call-back that can be used to cancels the context +// associated with the passed id. +func (g *ContextGuard) cancelCtxFn(id uint32) context.CancelFunc { + return func() { + g.mu.Lock() + defer g.mu.Unlock() + + fn, ok := g.cancelFns[id] + if !ok { + return + } + + fn() + + delete(g.cancelFns, id) + } } diff --git a/fn/context_guard_test.go b/fn/context_guard_test.go new file mode 100644 index 0000000000..2d98c43dbd --- /dev/null +++ b/fn/context_guard_test.go @@ -0,0 +1,434 @@ +package fn + +import ( + "context" + "testing" + "time" +) + +// TestContextGuard tests the behaviour of the ContextGuard. +func TestContextGuard(t *testing.T) { + t.Parallel() + + // Test that the derived context is cancelled when the passed context is + // cancelled. + t.Run("Parent context is cancelled", func(t *testing.T) { + t.Parallel() + var ( + ctx, cancel = context.WithCancel(context.Background()) + g = NewContextGuard() + ) + + ctxc, _ := g.WithCtx(ctx) + + // Cancel the parent context. + cancel() + // Assert that the derived context is cancelled. + select { + case <-ctxc.Done(): + default: + t.Errorf("The derived context should be cancelled at " + + "this point") + } + }) + + // Test that the derived context is cancelled when the returned cancel + // function is called. + t.Run("Derived context is cancelled", func(t *testing.T) { + t.Parallel() + var ( + ctx = context.Background() + g = NewContextGuard() + ) + + ctxc, cancel := g.WithCtx(ctx) + + // Cancel the context. + cancel() + + // Assert that the derived context is cancelled. + select { + case <-ctxc.Done(): + default: + t.Errorf("The derived context should be cancelled at " + + "this point") + } + }) + + // Test that the derived context is cancelled when the quit channel is + // closed. + t.Run("Quit channel is closed", func(t *testing.T) { + t.Parallel() + + var ( + ctx = context.Background() + g = NewContextGuard() + ) + ctxc, _ := g.WithCtx(ctx) + + // Close the quit channel. + g.Quit() + + // Assert that the derived context is cancelled. + select { + case <-ctxc.Done(): + default: + t.Errorf("The derived context should be cancelled at " + + "this point") + } + }) + + t.Run("Parent context is already closed", func(t *testing.T) { + t.Parallel() + + var ( + ctx, cancel = context.WithCancel(context.Background()) + g = NewContextGuard() + ) + cancel() + + ctxc, _ := g.WithCtx(ctx) + // Assert that the derived context is cancelled already + // cancelled. + select { + case <-ctxc.Done(): + default: + t.Errorf("The derived context should be cancelled at " + + "this point") + } + }) + + t.Run("Quit channel is already closed", func(t *testing.T) { + t.Parallel() + + var ( + ctx = context.Background() + g = NewContextGuard() + ) + + g.Quit() + + ctxc, _ := g.WithCtx(ctx) + + // Assert that the derived context is cancelled already + // cancelled. + select { + case <-ctxc.Done(): + default: + t.Errorf("The derived context should be cancelled at " + + "this point") + } + }) + + t.Run("Child context should be cancelled synchronously with "+ + "parent", func(t *testing.T) { + + t.Parallel() + + var ( + ctx, cancel = context.WithCancel(context.Background()) + g = NewContextGuard() + task = make(chan struct{}) + done = make(chan struct{}) + ) + // Derive a child context. + ctxc, _ := g.WithCtx(ctx) + + // Spin off a routine that exists cleaning if the child context + // is cancelled but fails if the task is performed. + go func() { + defer close(done) + select { + case <-ctxc.Done(): + case <-task: + t.Fatalf("should not get here") + } + }() + + // Give the goroutine above a chance to spin up so that it's + // waiting on the select. + time.Sleep(time.Millisecond * 200) + + // First cancel the parent context. Then immediately execute the + // task. + cancel() + close(task) + + // Wait for the goroutine to exit. + select { + case <-done: + case <-time.After(time.Second): + t.Fatalf("timeout") + } + }) + + t.Run("Child context should be cancelled synchronously with the "+ + "close of the quit channel", func(t *testing.T) { + + t.Parallel() + + var ( + ctx = context.Background() + g = NewContextGuard() + task = make(chan struct{}) + done = make(chan struct{}) + ) + + // Derive a child context. + ctxc, _ := g.WithCtx(ctx) + + // Spin off a routine that exists cleaning if the child context + // is cancelled but fails if the task is performed. + go func() { + defer close(done) + select { + case <-ctxc.Done(): + case <-task: + t.Fatalf("should not get here") + } + }() + + // Give the goroutine above a chance to spin up so that it's + // waiting on the select. + time.Sleep(time.Millisecond * 200) + + // First cancel the parent context. Then immediately execute the + // task. + g.Quit() + + // Execute the task. + close(task) + + // Wait for the goroutine to exit. + select { + case <-done: + case <-time.After(time.Second): + t.Fatalf("timeout") + } + }) + + // Test that if we add the BlockingCGOpt option, then the context will + // not be cancelled when the quit channel is closed but will be when the + // cancel function is called. + t.Run("Blocking context no timeout", func(t *testing.T) { + t.Parallel() + + var ( + ctx = context.Background() + g = NewContextGuard() + task = make(chan struct{}) + done = make(chan struct{}) + ) + + // Derive a blocking child context. + ctxc, cancel := g.WithCtx(ctx, BlockingCGOpt()) + + // Spin of a routine that will exit cleanly if the context is + // cancelled but will fail if the task is performed. + go func() { + defer func() { + done <- struct{}{} + }() + + select { + case <-ctxc.Done(): + case <-task: + t.Fatalf("Expected context to be cancelled") + } + }() + + // Give the goroutine above a chance to spin up so that it's + // waiting on the select. + time.Sleep(time.Millisecond * 200) + + // Cancel the context. + cancel() + + // Attempt to perform the task. + select { + case task <- struct{}{}: + t.Fatalf("Expected task to not be performed") + default: + } + + // Assert that the task goroutine has now completed. + select { + case <-done: + case <-time.After(time.Second): + t.Fatalf("timeout") + } + + // Derive a new blocking child context. + ctxc, cancel = g.WithCtx(ctx, BlockingCGOpt()) + + // Repeat the task but this time, we will call Quit first, but + // since this is a blocking context, the context should not be + // cancelled and the task _should_ be performed. + go func() { + defer func() { + done <- struct{}{} + }() + + select { + case <-ctxc.Done(): + t.Fatalf("Expected task to be performed") + case <-task: + } + }() + + // Give the goroutine above a chance to spin up so that it's + // waiting on the select. + time.Sleep(time.Millisecond * 200) + + // Close the quit channel. This should NOT cause the context + // to be cancelled. + g.Quit() + + // Now, perform the task. + select { + case task <- struct{}{}: + case <-time.After(time.Second): + t.Fatalf("timeout") + } + + // Assert that the task goroutine has now completed. + select { + case <-done: + case <-time.After(time.Second): + t.Fatalf("timeout") + } + }) + + // Test that if we add the CustomTimeoutCGOpt option, then the context + // will be not be cancelled when the quit channel is closed but will be + // if either the context is cancelled or the timeout is reached. + t.Run("Blocking context with timeout", func(t *testing.T) { + t.Parallel() + + var ( + ctx = context.Background() + g = NewContextGuard() + task = make(chan struct{}) + done = make(chan struct{}) + timeout = time.Millisecond * 500 + ) + + // Derive a blocking child context. + ctxc, cancel := g.WithCtx( + ctx, BlockingCGOpt(), CustomTimeoutCGOpt(timeout), + ) + + // Spin of a routine that will exit cleanly if the context is + // cancelled but will fail if the task is performed. + go func() { + defer func() { + done <- struct{}{} + }() + + select { + case <-ctxc.Done(): + case <-task: + t.Fatalf("Expected context to be cancelled") + } + }() + + // Give the goroutine above a chance to spin up so that it's + // waiting on the select. + time.Sleep(time.Millisecond * 200) + + // Cancel the context. + cancel() + + // Attempt to perform the task. + select { + case task <- struct{}{}: + t.Fatalf("Expected task to not be performed") + default: + } + + // Assert that the task goroutine has now completed. + select { + case <-done: + case <-time.After(time.Second): + t.Fatalf("timeout") + } + + // Derive a new blocking child context with a timeout. + ctxc, cancel = g.WithCtx( + ctx, BlockingCGOpt(), CustomTimeoutCGOpt(timeout), + ) + + // Repeat the task but this time, but this time, we will assert + // that the context is cancelled if the timeout is reached. + // We will again fail if the task is performed. + go func() { + defer func() { + done <- struct{}{} + }() + + select { + case <-ctxc.Done(): + case <-task: + t.Fatalf("Expected context to be cancelled") + } + }() + + // Wait for the timeout to be reached. + time.Sleep(timeout + time.Millisecond*100) + + // Attempt to perform the task. + select { + case task <- struct{}{}: + t.Fatalf("Expected task to not be performed") + default: + } + + // Assert that the task goroutine has now completed. + select { + case <-done: + case <-time.After(time.Second): + t.Fatalf("timeout") + } + + // Finally, repeat the task but this time show that calling + // Quit does not cancel the context and that the task still gets + // performed if it takes place before the context is timed out. + ctxc, cancel = g.WithCtx( + ctx, BlockingCGOpt(), CustomTimeoutCGOpt(timeout), + ) + + go func() { + defer func() { + done <- struct{}{} + }() + + select { + case <-ctxc.Done(): + t.Fatalf("Expected the task to be performed") + case <-task: + } + }() + + // Give the goroutine above a chance to spin up so that it's + // waiting on the select. + time.Sleep(time.Millisecond * 200) + + // Close the quit channel. This should NOT cause the context + // to be cancelled. + g.Quit() + + // Now, perform the task. + select { + case task <- struct{}{}: + case <-time.After(time.Second): + t.Fatalf("timeout") + } + + // Assert that the task goroutine has now completed. + select { + case <-done: + case <-time.After(time.Second): + t.Fatalf("timeout") + } + }) +}