diff --git a/xsync/xsync.go b/xsync/xsync.go index 353384c..d870dc9 100644 --- a/xsync/xsync.go +++ b/xsync/xsync.go @@ -67,10 +67,12 @@ func (c *ContextCond) Wait(ctx context.Context) error { // Group manages a group of goroutines. type Group struct { - baseCtx context.Context - ctx context.Context - cancel context.CancelFunc - wg sync.WaitGroup + ctx context.Context + cancel context.CancelFunc + // held in R when spawning to check if ctx is already cancelled and in W when cancelling ctx to + // make sure we never cause wg to go 0->1 while inside Wait() + m sync.RWMutex + wg sync.WaitGroup } // NewGroup returns a Group ready for use. The context passed to any of the f functions will be a @@ -78,21 +80,32 @@ type Group struct { func NewGroup(ctx context.Context) *Group { bgCtx, cancel := context.WithCancel(ctx) return &Group{ - baseCtx: ctx, - ctx: bgCtx, - cancel: cancel, + ctx: bgCtx, + cancel: cancel, } } -// Once calls f once from another goroutine. -func (g *Group) Once(f func(ctx context.Context)) { +// helper even though it's exactly g.Do so that the goroutine stack for a spawned function doesn't +// confusingly show all of them as created by Do. +func (g *Group) spawn(f func()) { + g.m.RLock() + if g.ctx.Err() != nil { + return + } g.wg.Add(1) + g.m.RUnlock() + go func() { - f(g.ctx) + f() g.wg.Done() }() } +// Do calls f once from another goroutine. +func (g *Group) Do(f func(ctx context.Context)) { + g.spawn(func() { f(g.ctx) }) +} + // returns a random duration in [d - jitter, d + jitter] func jitterDuration(d time.Duration, jitter time.Duration) time.Duration { return d + time.Duration(float64(jitter)*((rand.Float64()*2)-1)) @@ -104,9 +117,7 @@ func (g *Group) Periodic( jitter time.Duration, f func(ctx context.Context), ) { - g.wg.Add(1) - go func() { - defer g.wg.Done() + g.spawn(func() { t := time.NewTimer(jitterDuration(interval, jitter)) defer t.Stop() for { @@ -121,16 +132,15 @@ func (g *Group) Periodic( t.Reset(jitterDuration(interval, jitter)) f(g.ctx) } - }() + }) } // Trigger spawns a goroutine which calls f whenever the returned function is called. If f is // already running when triggered, f will run again immediately when it finishes. func (g *Group) Trigger(f func(ctx context.Context)) func() { c := make(chan struct{}, 1) - g.wg.Add(1) - go func() { - defer g.wg.Done() + + g.spawn(func() { for { if g.ctx.Err() != nil { return @@ -142,7 +152,7 @@ func (g *Group) Trigger(f func(ctx context.Context)) func() { } f(g.ctx) } - }() + }) return func() { select { @@ -161,9 +171,7 @@ func (g *Group) PeriodicOrTrigger( f func(ctx context.Context), ) func() { c := make(chan struct{}, 1) - g.wg.Add(1) - go func() { - defer g.wg.Done() + g.spawn(func() { t := time.NewTimer(jitterDuration(interval, jitter)) defer t.Stop() for { @@ -183,7 +191,7 @@ func (g *Group) PeriodicOrTrigger( } f(g.ctx) } - }() + }) return func() { select { @@ -193,19 +201,19 @@ func (g *Group) PeriodicOrTrigger( } } -// Stop cancels the context passed to spawned goroutines. +// Stop cancels the context passed to spawned goroutines. After the group is stopped, no more +// goroutines will be spawned. func (g *Group) Stop() { + g.m.Lock() g.cancel() + g.m.Unlock() } -// Wait cancels the context passed to any of the spawned goroutines and waits for all spawned -// goroutines to exit. -// -// It is not safe to call Wait concurrently with any other method on g. -func (g *Group) Wait() { - g.cancel() +// StopAndWait cancels the context passed to any of the spawned goroutines and waits for all spawned +// goroutines to exit. After the group is stopped, no more goroutines will be spawned. +func (g *Group) StopAndWait() { + g.Stop() g.wg.Wait() - g.ctx, g.cancel = context.WithCancel(g.baseCtx) } // Lazy makes a lazily-initialized value. On first access, it uses f to create the value. Later diff --git a/xsync/xsync_test.go b/xsync/xsync_test.go index 4283654..e3cbdb2 100644 --- a/xsync/xsync_test.go +++ b/xsync/xsync_test.go @@ -1,6 +1,13 @@ package xsync -import "fmt" +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/bradenaw/juniper/xtime" +) func ExampleLazy() { var ( @@ -18,3 +25,75 @@ func ExampleLazy() { // foo // foo } + +func TestGroup(t *testing.T) { + g := NewGroup(context.Background()) + + dos := make(chan struct{}, 100) + g.Do(func(ctx context.Context) { + for { + err := xtime.SleepContext(ctx, 50*time.Millisecond) + if err != nil { + return + } + + select { + case dos <- struct{}{}: + default: + } + } + }) + + periodics := make(chan struct{}, 100) + g.Periodic(35*time.Millisecond, 0 /*jitter*/, func(ctx context.Context) { + select { + case periodics <- struct{}{}: + default: + } + }) + + periodicOrTriggers := make(chan struct{}, 100) + periodicOrTrigger := g.PeriodicOrTrigger(75*time.Millisecond, 0 /*jitter*/, func(ctx context.Context) { + select { + case periodicOrTriggers <- struct{}{}: + default: + } + }) + + triggers := make(chan struct{}, 100) + trigger := g.Trigger(func(ctx context.Context) { + select { + case triggers <- struct{}{}: + default: + } + }) + + trigger() + periodicOrTrigger() + time.Sleep(200 * time.Millisecond) + trigger() + + <-dos + <-dos + <-dos + <-dos + <-periodics + <-periodics + <-periodics + <-periodics + <-periodics + <-periodicOrTriggers + <-periodicOrTriggers + <-periodicOrTriggers + <-triggers + <-triggers + + g.StopAndWait() + + g.Do(func(ctx context.Context) { + panic("this will never spawn because StopAndWait was already called") + }) + + // Jank, but just in case we'd be safe from the above panic just because the test is over. + time.Sleep(200 * time.Millisecond) +}