diff --git a/pkg/schedule/schedulers/scheduler_controller.go b/pkg/schedule/schedulers/scheduler_controller.go index 4d72699b0fe..58cb802d629 100644 --- a/pkg/schedule/schedulers/scheduler_controller.go +++ b/pkg/schedule/schedulers/scheduler_controller.go @@ -30,6 +30,7 @@ import ( "github.com/tikv/pd/pkg/schedule/plan" "github.com/tikv/pd/pkg/storage/endpoint" "github.com/tikv/pd/pkg/utils/logutil" + "github.com/tikv/pd/pkg/utils/syncutil" "go.uber.org/zap" ) @@ -40,7 +41,7 @@ var denySchedulersByLabelerCounter = labeler.LabelerEventCounter.WithLabelValues // Controller is used to manage all schedulers. type Controller struct { sync.RWMutex - wg sync.WaitGroup + wg *syncutil.FlexibleWaitGroup ctx context.Context cluster sche.SchedulerCluster storage endpoint.ConfigStorage @@ -57,6 +58,7 @@ type Controller struct { func NewController(ctx context.Context, cluster sche.SchedulerCluster, storage endpoint.ConfigStorage, opController *operator.Controller) *Controller { return &Controller{ ctx: ctx, + wg: syncutil.NewFlexibleWaitGroup(), cluster: cluster, storage: storage, schedulers: make(map[string]*ScheduleController), diff --git a/pkg/utils/syncutil/flexible_wait_group.go b/pkg/utils/syncutil/flexible_wait_group.go new file mode 100644 index 00000000000..ae72e7720b3 --- /dev/null +++ b/pkg/utils/syncutil/flexible_wait_group.go @@ -0,0 +1,70 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package syncutil + +import ( + "sync" +) + +// FlexibleWaitGroup is a flexible version of sync.WaitGroup. +// It supports adding to the counter after Wait() has been called, +// which is not allowed in sync.WaitGroup. +type FlexibleWaitGroup struct { + sync.Mutex + count int + cond *sync.Cond +} + +// NewFlexibleWaitGroup creates and returns a new FlexibleWaitGroup. +func NewFlexibleWaitGroup() *FlexibleWaitGroup { + dwg := &FlexibleWaitGroup{} + dwg.cond = sync.NewCond(&dwg.Mutex) + return dwg +} + +// Add adds delta (which may be negative) to the FlexibleWaitGroup counter. +// If the counter becomes zero or negative, all goroutines blocked on Wait are released. +func (fwg *FlexibleWaitGroup) Add(delta int) { + fwg.Lock() + defer fwg.Unlock() + + fwg.count += delta + if fwg.count <= 0 { + fwg.cond.Broadcast() + fwg.count = 0 + } +} + +// Done decrements the FlexibleWaitGroup counter by one. +func (fwg *FlexibleWaitGroup) Done() { + fwg.Add(-1) +} + +// Wait blocks until the FlexibleWaitGroup counter is zero or negative. +func (fwg *FlexibleWaitGroup) Wait() { + fwg.Lock() + for fwg.count > 0 { + fwg.cond.Wait() + } + fwg.Unlock() +} + +// getCount returns the current count of the FlexibleWaitGroup. +// It is only used for testing. +func (fwg *FlexibleWaitGroup) getCount() int { + fwg.Lock() + defer fwg.Unlock() + return fwg.count +} diff --git a/pkg/utils/syncutil/flexible_wait_group_test.go b/pkg/utils/syncutil/flexible_wait_group_test.go new file mode 100644 index 00000000000..9d74f9a0695 --- /dev/null +++ b/pkg/utils/syncutil/flexible_wait_group_test.go @@ -0,0 +1,132 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package syncutil + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestFlexibleWaitGroup(t *testing.T) { + re := require.New(t) + fwg := NewFlexibleWaitGroup() + now := time.Now() + for i := 20; i >= 0; i-- { + fwg.Add(1) + go func(i int) { + defer fwg.Done() + time.Sleep(time.Millisecond * time.Duration(i*50)) + }(i) + } + fwg.Wait() + re.GreaterOrEqual(time.Since(now).Milliseconds(), int64(1000)) +} + +// TestAddAfterWait tests the case where Add is called after Wait has started and before Wait has finished. +func TestAddAfterWait(t *testing.T) { + fwg := NewFlexibleWaitGroup() + startWait := make(chan struct{}) + addTwice := make(chan struct{}) + done := make(chan struct{}) + + // First goroutine: Adds a task, then waits for the second task to be added before finishing. + go func() { + defer fwg.Done() + fwg.Add(1) + <-addTwice + }() + + // Second goroutine: adds a second task after ensure the third goroutine has started to wait + // and triggers the first goroutine to finish. + go func() { + defer fwg.Done() + <-startWait + fwg.Add(1) + addTwice <- struct{}{} + }() + + // Third goroutine: waits for all tasks to be added, then finishes. + go func() { + startWait <- struct{}{} + fwg.Wait() + done <- struct{}{} + }() + <-done +} + +// TestNegativeDelta tests the case where Add is called with a negative delta. +func TestNegativeDelta(t *testing.T) { + require := require.New(t) + fwg := NewFlexibleWaitGroup() + fwg.Add(5) + go func() { + fwg.Add(-3) + fwg.Done() + fwg.Done() + }() + go func() { + fwg.Add(-2) + fwg.Done() + }() + fwg.Wait() + require.Equal(0, fwg.getCount()) +} + +// TestMultipleWait tests the case where Wait is called multiple times concurrently. +func TestMultipleWait(t *testing.T) { + require := require.New(t) + fwg := NewFlexibleWaitGroup() + fwg.Add(3) + done := make(chan struct{}) + go func() { + fwg.Wait() + done <- struct{}{} + }() + go func() { + fwg.Wait() + done <- struct{}{} + }() + go func() { + fwg.Done() + time.Sleep(100 * time.Millisecond) // Ensure that Done is called after the Waits + fwg.Done() + fwg.Done() + }() + <-done + <-done + require.Equal(0, fwg.getCount()) +} + +// TestAddAfterWaitFinished tests the case where Add is called after Wait has finished. +func TestAddAfterWaitFinished(t *testing.T) { + require := require.New(t) + fwg := NewFlexibleWaitGroup() + done := make(chan struct{}) + go func() { + fwg.Add(1) + fwg.Done() + }() + go func() { + fwg.Wait() + done <- struct{}{} + }() + <-done + fwg.Add(1) + require.Equal(1, fwg.getCount()) + fwg.Done() + require.Equal(0, fwg.getCount()) +}