diff --git a/pkg/utils/syncutil/flexible_wait_group.go b/pkg/utils/syncutil/flexible_wait_group.go index 46ac0e83c74..ae72e7720b3 100644 --- a/pkg/utils/syncutil/flexible_wait_group.go +++ b/pkg/utils/syncutil/flexible_wait_group.go @@ -60,3 +60,11 @@ func (fwg *FlexibleWaitGroup) Wait() { } fwg.Unlock() } + +// getCount returns the current count of the FlexibleWaitGroup. +// It is only used for testing. +func (fwg *FlexibleWaitGroup) getCount() int { + fwg.Lock() + defer fwg.Unlock() + return fwg.count +} diff --git a/pkg/utils/syncutil/flexible_wait_group_test.go b/pkg/utils/syncutil/flexible_wait_group_test.go index 2d5a1974f5b..9d74f9a0695 100644 --- a/pkg/utils/syncutil/flexible_wait_group_test.go +++ b/pkg/utils/syncutil/flexible_wait_group_test.go @@ -1,4 +1,4 @@ -// Copyright 2022 TiKV Project Authors. +// Copyright 2023 TiKV Project Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ import ( func TestFlexibleWaitGroup(t *testing.T) { re := require.New(t) fwg := NewFlexibleWaitGroup() + now := time.Now() for i := 20; i >= 0; i-- { fwg.Add(1) go func(i int) { @@ -31,7 +32,6 @@ func TestFlexibleWaitGroup(t *testing.T) { time.Sleep(time.Millisecond * time.Duration(i*50)) }(i) } - now := time.Now() fwg.Wait() re.GreaterOrEqual(time.Since(now).Milliseconds(), int64(1000)) } @@ -83,7 +83,7 @@ func TestNegativeDelta(t *testing.T) { fwg.Done() }() fwg.Wait() - require.Equal(0, fwg.count) + require.Equal(0, fwg.getCount()) } // TestMultipleWait tests the case where Wait is called multiple times concurrently. @@ -108,7 +108,7 @@ func TestMultipleWait(t *testing.T) { }() <-done <-done - require.Equal(0, fwg.count) + require.Equal(0, fwg.getCount()) } // TestAddAfterWaitFinished tests the case where Add is called after Wait has finished. @@ -126,7 +126,7 @@ func TestAddAfterWaitFinished(t *testing.T) { }() <-done fwg.Add(1) - require.Equal(1, fwg.count) + require.Equal(1, fwg.getCount()) fwg.Done() - require.Equal(0, fwg.count) + require.Equal(0, fwg.getCount()) }