diff --git a/pkg/utils/syncutil/flexible_wait_group.go b/pkg/utils/syncutil/flexible_wait_group.go index 3ee602d7624..46ac0e83c74 100644 --- a/pkg/utils/syncutil/flexible_wait_group.go +++ b/pkg/utils/syncutil/flexible_wait_group.go @@ -18,23 +18,24 @@ import ( "sync" ) -// FlexibleWaitGroup is a flexible wait group. -// Note: we can't use sync.WaitGroup because it doesn't support to call `Add` after `Wait` finished. +// FlexibleWaitGroup is a flexible version of sync.WaitGroup. +// It supports adding to the counter after Wait() has been called, +// which is not allowed in sync.WaitGroup. type FlexibleWaitGroup struct { sync.Mutex count int cond *sync.Cond } -// NewFlexibleWaitGroup creates a FlexibleWaitGroup. +// NewFlexibleWaitGroup creates and returns a new FlexibleWaitGroup. func NewFlexibleWaitGroup() *FlexibleWaitGroup { dwg := &FlexibleWaitGroup{} dwg.cond = sync.NewCond(&dwg.Mutex) return dwg } -// Add adds delta, which may be negative, to the FlexibleWaitGroup counter. -// If the counter becomes zero, all goroutines blocked on Wait are released. +// Add adds delta (which may be negative) to the FlexibleWaitGroup counter. +// If the counter becomes zero or negative, all goroutines blocked on Wait are released. func (fwg *FlexibleWaitGroup) Add(delta int) { fwg.Lock() defer fwg.Unlock() @@ -42,15 +43,16 @@ func (fwg *FlexibleWaitGroup) Add(delta int) { fwg.count += delta if fwg.count <= 0 { fwg.cond.Broadcast() + fwg.count = 0 } } -// Done decrements the FlexibleWaitGroup counter. +// Done decrements the FlexibleWaitGroup counter by one. func (fwg *FlexibleWaitGroup) Done() { fwg.Add(-1) } -// Wait blocks until the FlexibleWaitGroup counter is zero. +// Wait blocks until the FlexibleWaitGroup counter is zero or negative. func (fwg *FlexibleWaitGroup) Wait() { fwg.Lock() for fwg.count > 0 { diff --git a/pkg/utils/syncutil/flexible_wait_group_test.go b/pkg/utils/syncutil/flexible_wait_group_test.go index e98074cf55d..2d5a1974f5b 100644 --- a/pkg/utils/syncutil/flexible_wait_group_test.go +++ b/pkg/utils/syncutil/flexible_wait_group_test.go @@ -36,6 +36,7 @@ func TestFlexibleWaitGroup(t *testing.T) { re.GreaterOrEqual(time.Since(now).Milliseconds(), int64(1000)) } +// TestAddAfterWait tests the case where Add is called after Wait has started and before Wait has finished. func TestAddAfterWait(t *testing.T) { fwg := NewFlexibleWaitGroup() startWait := make(chan struct{}) @@ -66,3 +67,66 @@ func TestAddAfterWait(t *testing.T) { }() <-done } + +// TestNegativeDelta tests the case where Add is called with a negative delta. +func TestNegativeDelta(t *testing.T) { + require := require.New(t) + fwg := NewFlexibleWaitGroup() + fwg.Add(5) + go func() { + fwg.Add(-3) + fwg.Done() + fwg.Done() + }() + go func() { + fwg.Add(-2) + fwg.Done() + }() + fwg.Wait() + require.Equal(0, fwg.count) +} + +// TestMultipleWait tests the case where Wait is called multiple times concurrently. +func TestMultipleWait(t *testing.T) { + require := require.New(t) + fwg := NewFlexibleWaitGroup() + fwg.Add(3) + done := make(chan struct{}) + go func() { + fwg.Wait() + done <- struct{}{} + }() + go func() { + fwg.Wait() + done <- struct{}{} + }() + go func() { + fwg.Done() + time.Sleep(100 * time.Millisecond) // Ensure that Done is called after the Waits + fwg.Done() + fwg.Done() + }() + <-done + <-done + require.Equal(0, fwg.count) +} + +// TestAddAfterWaitFinished tests the case where Add is called after Wait has finished. +func TestAddAfterWaitFinished(t *testing.T) { + require := require.New(t) + fwg := NewFlexibleWaitGroup() + done := make(chan struct{}) + go func() { + fwg.Add(1) + fwg.Done() + }() + go func() { + fwg.Wait() + done <- struct{}{} + }() + <-done + fwg.Add(1) + require.Equal(1, fwg.count) + fwg.Done() + require.Equal(0, fwg.count) +}