diff --git a/pkg/ratelimit/concurrency_limiter.go b/pkg/ratelimit/concurrency_limiter.go index b1eef3c8101..a7b00bc6242 100644 --- a/pkg/ratelimit/concurrency_limiter.go +++ b/pkg/ratelimit/concurrency_limiter.go @@ -14,24 +14,33 @@ package ratelimit -import "github.com/tikv/pd/pkg/utils/syncutil" +import ( + "context" -type concurrencyLimiter struct { + "github.com/tikv/pd/pkg/utils/syncutil" +) + +// ConcurrencyLimiter is a limiter that limits the number of concurrent tasks. +type ConcurrencyLimiter struct { mu syncutil.RWMutex current uint64 + waiting uint64 limit uint64 // statistic maxLimit uint64 + queue chan *TaskToken } -func newConcurrencyLimiter(limit uint64) *concurrencyLimiter { - return &concurrencyLimiter{limit: limit} +// NewConcurrencyLimiter creates a new ConcurrencyLimiter. +func NewConcurrencyLimiter(limit uint64) *ConcurrencyLimiter { + return &ConcurrencyLimiter{limit: limit, queue: make(chan *TaskToken, limit)} } const unlimit = uint64(0) -func (l *concurrencyLimiter) allow() bool { +// old interface. only used in the ratelimiter package. +func (l *ConcurrencyLimiter) allow() bool { l.mu.Lock() defer l.mu.Unlock() @@ -45,7 +54,8 @@ func (l *concurrencyLimiter) allow() bool { return false } -func (l *concurrencyLimiter) release() { +// old interface. only used in the ratelimiter package. +func (l *ConcurrencyLimiter) release() { l.mu.Lock() defer l.mu.Unlock() @@ -54,28 +64,32 @@ func (l *concurrencyLimiter) release() { } } -func (l *concurrencyLimiter) getLimit() uint64 { +// old interface. only used in the ratelimiter package. +func (l *ConcurrencyLimiter) getLimit() uint64 { l.mu.RLock() defer l.mu.RUnlock() return l.limit } -func (l *concurrencyLimiter) setLimit(limit uint64) { +// old interface. only used in the ratelimiter package. +func (l *ConcurrencyLimiter) setLimit(limit uint64) { l.mu.Lock() defer l.mu.Unlock() l.limit = limit } -func (l *concurrencyLimiter) getCurrent() uint64 { +// old interface. only used in the ratelimiter package. +func (l *ConcurrencyLimiter) getCurrent() uint64 { l.mu.RLock() defer l.mu.RUnlock() return l.current } -func (l *concurrencyLimiter) getMaxConcurrency() uint64 { +// old interface. only used in the ratelimiter package. +func (l *ConcurrencyLimiter) getMaxConcurrency() uint64 { l.mu.Lock() defer func() { l.maxLimit = l.current @@ -84,3 +98,66 @@ func (l *concurrencyLimiter) getMaxConcurrency() uint64 { return l.maxLimit } + +// GetRunningTasksNum returns the number of running tasks. +func (l *ConcurrencyLimiter) GetRunningTasksNum() uint64 { + return l.getCurrent() +} + +// GetWaitingTasksNum returns the number of waiting tasks. +func (l *ConcurrencyLimiter) GetWaitingTasksNum() uint64 { + l.mu.Lock() + defer l.mu.Unlock() + return l.waiting +} + +// Acquire acquires a token from the limiter. which will block until a token is available or ctx is done, like Timeout. +func (l *ConcurrencyLimiter) Acquire(ctx context.Context) (*TaskToken, error) { + l.mu.Lock() + if l.current >= l.limit { + l.waiting++ + l.mu.Unlock() + // block the waiting task on the caller goroutine + select { + case <-ctx.Done(): + l.mu.Lock() + l.waiting-- + l.mu.Unlock() + return nil, ctx.Err() + case token := <-l.queue: + l.mu.Lock() + token.released = false + l.current++ + l.waiting-- + l.mu.Unlock() + return token, nil + } + } + l.current++ + token := &TaskToken{limiter: l} + l.mu.Unlock() + return token, nil +} + +// TaskToken is a token that must be released after the task is done. +type TaskToken struct { + released bool + limiter *ConcurrencyLimiter +} + +// Release releases the token. +func (tt *TaskToken) Release() { + tt.limiter.mu.Lock() + defer tt.limiter.mu.Unlock() + if tt.released { + return + } + if tt.limiter.current == 0 { + panic("release token more than acquire") + } + tt.released = true + tt.limiter.current-- + if len(tt.limiter.queue) < int(tt.limiter.limit) { + tt.limiter.queue <- tt + } +} diff --git a/pkg/ratelimit/concurrency_limiter_test.go b/pkg/ratelimit/concurrency_limiter_test.go index 5fe03740394..216da1ac8a0 100644 --- a/pkg/ratelimit/concurrency_limiter_test.go +++ b/pkg/ratelimit/concurrency_limiter_test.go @@ -15,7 +15,12 @@ package ratelimit import ( + "context" + "fmt" + "sync" + "sync/atomic" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -23,7 +28,7 @@ import ( func TestConcurrencyLimiter(t *testing.T) { t.Parallel() re := require.New(t) - cl := newConcurrencyLimiter(10) + cl := NewConcurrencyLimiter(10) for i := 0; i < 10; i++ { re.True(cl.allow()) } @@ -52,3 +57,72 @@ func TestConcurrencyLimiter(t *testing.T) { re.Equal(uint64(5), cl.getMaxConcurrency()) re.Equal(uint64(0), cl.getMaxConcurrency()) } + +func TestConcurrencyLimiter2(t *testing.T) { + limit := uint64(2) + limiter := NewConcurrencyLimiter(limit) + + require.Equal(t, uint64(0), limiter.GetRunningTasksNum(), "Expected running tasks to be 0") + require.Equal(t, uint64(0), limiter.GetWaitingTasksNum(), "Expected waiting tasks to be 0") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Acquire two tokens + token1, err := limiter.Acquire(ctx) + require.NoError(t, err, "Failed to acquire token") + + token2, err := limiter.Acquire(ctx) + require.NoError(t, err, "Failed to acquire token") + + require.Equal(t, limit, limiter.GetRunningTasksNum(), "Expected running tasks to be 2") + + // Try to acquire third token, it should not be able to acquire immediately due to limit + go func() { + _, err := limiter.Acquire(ctx) + require.NoError(t, err, "Failed to acquire token") + }() + + time.Sleep(100 * time.Millisecond) // Give some time for the goroutine to run + require.Equal(t, uint64(1), limiter.GetWaitingTasksNum(), "Expected waiting tasks to be 1") + + // Release a token + token1.Release() + time.Sleep(100 * time.Millisecond) // Give some time for the goroutine to run + require.Equal(t, uint64(2), limiter.GetRunningTasksNum(), "Expected running tasks to be 2") + require.Equal(t, uint64(0), limiter.GetWaitingTasksNum(), "Expected waiting tasks to be 0") + + // Release the second token + token2.Release() + time.Sleep(100 * time.Millisecond) // Give some time for the goroutine to run + require.Equal(t, uint64(1), limiter.GetRunningTasksNum(), "Expected running tasks to be 1") +} + +func TestConcurrencyLimiterAcquire(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + limiter := NewConcurrencyLimiter(20) + sum := int64(0) + start := time.Now() + wg := &sync.WaitGroup{} + wg.Add(100) + for i := 0; i < 100; i++ { + go func(i int) { + defer wg.Done() + token, err := limiter.Acquire(ctx) + if err != nil { + fmt.Printf("Task %d failed to acquire: %v\n", i, err) + return + } + defer token.Release() + // simulate takes some time + time.Sleep(10 * time.Millisecond) + atomic.AddInt64(&sum, 1) + }(i) + } + wg.Wait() + // We should have 20 tasks running concurrently, so it should take at least 50ms to complete + require.Greater(t, time.Since(start).Milliseconds(), int64(50)) + require.Equal(t, int64(100), sum) +} diff --git a/pkg/ratelimit/limiter.go b/pkg/ratelimit/limiter.go index dc744d9ac1b..7b3eba10325 100644 --- a/pkg/ratelimit/limiter.go +++ b/pkg/ratelimit/limiter.go @@ -36,18 +36,18 @@ type DimensionConfig struct { type limiter struct { mu syncutil.RWMutex - concurrency *concurrencyLimiter + concurrency *ConcurrencyLimiter rate *RateLimiter } func newLimiter() *limiter { lim := &limiter{ - concurrency: newConcurrencyLimiter(0), + concurrency: NewConcurrencyLimiter(0), } return lim } -func (l *limiter) getConcurrencyLimiter() *concurrencyLimiter { +func (l *limiter) getConcurrencyLimiter() *ConcurrencyLimiter { l.mu.RLock() defer l.mu.RUnlock() return l.concurrency @@ -101,7 +101,7 @@ func (l *limiter) updateConcurrencyConfig(limit uint64) UpdateStatus { } l.concurrency.setLimit(limit) } else { - l.concurrency = newConcurrencyLimiter(limit) + l.concurrency = NewConcurrencyLimiter(limit) } return ConcurrencyChanged } diff --git a/pkg/ratelimit/runner.go b/pkg/ratelimit/runner.go new file mode 100644 index 00000000000..a31fc5c0e69 --- /dev/null +++ b/pkg/ratelimit/runner.go @@ -0,0 +1,167 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ratelimit + +import ( + "context" + "errors" + "sync" + "time" + + "github.com/pingcap/log" + "go.uber.org/zap" +) + +const initialCapacity = 100 + +// Runner is the interface for running tasks. +type Runner interface { + RunTask(ctx context.Context, opt TaskOpts, f func(context.Context)) error +} + +// Task is a task to be run. +type Task struct { + Ctx context.Context + Opts TaskOpts + f func(context.Context) + submittedAt time.Time +} + +// ErrMaxWaitingTasksExceeded is returned when the number of waiting tasks exceeds the maximum. +var ErrMaxWaitingTasksExceeded = errors.New("max waiting tasks exceeded") + +// AsyncRunner is a simple task runner that limits the number of concurrent tasks. +type AsyncRunner struct { + name string + maxPendingDuration time.Duration + taskChan chan *Task + pendingTasks []*Task + pendingMu sync.Mutex + stopChan chan struct{} + wg sync.WaitGroup +} + +// NewAsyncRunner creates a new AsyncRunner. +func NewAsyncRunner(name string, maxPendingDuration time.Duration) *AsyncRunner { + s := &AsyncRunner{ + name: name, + maxPendingDuration: maxPendingDuration, + taskChan: make(chan *Task), + pendingTasks: make([]*Task, 0, initialCapacity), + stopChan: make(chan struct{}), + } + s.Start() + return s +} + +// TaskOpts is the options for RunTask. +type TaskOpts struct { + // TaskName is a human-readable name for the operation. TODO: metrics by name. + TaskName string + Limit *ConcurrencyLimiter +} + +// Start starts the runner. +func (s *AsyncRunner) Start() { + s.wg.Add(1) + go func() { + defer s.wg.Done() + for { + select { + case task := <-s.taskChan: + if task.Opts.Limit != nil { + token, err := task.Opts.Limit.Acquire(context.Background()) + if err != nil { + log.Error("failed to acquire semaphore", zap.String("task-name", task.Opts.TaskName), zap.Error(err)) + continue + } + go s.run(task.Ctx, task.f, token) + } else { + go s.run(task.Ctx, task.f, nil) + } + case <-s.stopChan: + return + } + } + }() +} + +func (s *AsyncRunner) run(ctx context.Context, task func(context.Context), token *TaskToken) { + task(ctx) + if token != nil { + token.Release() + s.processPendingTasks() + } +} + +func (s *AsyncRunner) processPendingTasks() { + s.pendingMu.Lock() + defer s.pendingMu.Unlock() + for len(s.pendingTasks) > 0 { + task := s.pendingTasks[0] + select { + case s.taskChan <- task: + s.pendingTasks = s.pendingTasks[1:] + return + default: + return + } + } +} + +// Stop stops the runner. +func (s *AsyncRunner) Stop() { + close(s.stopChan) + s.wg.Wait() +} + +// RunTask runs the task asynchronously. +func (s *AsyncRunner) RunTask(ctx context.Context, opt TaskOpts, f func(context.Context)) error { + task := &Task{ + Ctx: ctx, + Opts: opt, + f: f, + } + s.processPendingTasks() + select { + case s.taskChan <- task: + default: + s.pendingMu.Lock() + defer s.pendingMu.Unlock() + if len(s.pendingTasks) > 0 { + maxWait := time.Since(s.pendingTasks[0].submittedAt) + if maxWait > s.maxPendingDuration { + return errors.New("max pending duration exceeded") + } + } + task.submittedAt = time.Now() + s.pendingTasks = append(s.pendingTasks, task) + } + return nil +} + +// SyncRunner is a simple task runner that limits the number of concurrent tasks. +type SyncRunner struct{} + +// NewSyncRunner creates a new SyncRunner. +func NewSyncRunner() *SyncRunner { + return &SyncRunner{} +} + +// RunTask runs the task synchronously. +func (s *SyncRunner) RunTask(ctx context.Context, opt TaskOpts, f func(context.Context)) error { + f(ctx) + return nil +} diff --git a/pkg/ratelimit/runner_test.go b/pkg/ratelimit/runner_test.go new file mode 100644 index 00000000000..051117f1764 --- /dev/null +++ b/pkg/ratelimit/runner_test.go @@ -0,0 +1,74 @@ +// Copyright 2024 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ratelimit + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestAsyncRunner(t *testing.T) { + t.Run("RunTask", func(t *testing.T) { + limiter := NewConcurrencyLimiter(1) + runner := NewAsyncRunner("test", time.Second) + defer runner.Stop() + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + time.Sleep(50 * time.Millisecond) + wg.Add(1) + err := runner.RunTask(context.Background(), TaskOpts{ + TaskName: "test1", + Limit: limiter, + }, func(ctx context.Context) { + defer wg.Done() + time.Sleep(100 * time.Millisecond) + }) + require.NoError(t, err) + } + wg.Wait() + }) + + t.Run("MaxPendingDuration", func(t *testing.T) { + limiter := NewConcurrencyLimiter(1) + runner := NewAsyncRunner("test", 2*time.Millisecond) + defer runner.Stop() + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + err := runner.RunTask(context.Background(), TaskOpts{ + TaskName: "test2", + Limit: limiter, + }, func(ctx context.Context) { + defer wg.Done() + time.Sleep(100 * time.Millisecond) + }) + if err != nil { + wg.Done() + // task 0 running + // task 1 after recv by runner, blocked by task 1, wait on Acquire. + // task 2 enqueue pendingTasks + // task 3 enqueue pendingTasks + // task 4 enqueue pendingTasks, check pendingTasks[0] timeout, report error + } + time.Sleep(1 * time.Millisecond) + } + wg.Wait() + }) +}