Skip to content

Commit

Permalink
pkg/ratelimit: introduce an executor that can run with a limiter
Browse files Browse the repository at this point in the history
Signed-off-by: nolouch <[email protected]>
  • Loading branch information
nolouch committed Apr 3, 2024
1 parent fff288d commit f71c46a
Show file tree
Hide file tree
Showing 5 changed files with 407 additions and 15 deletions.
97 changes: 87 additions & 10 deletions pkg/ratelimit/concurrency_limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,33 @@

package ratelimit

import "github.com/tikv/pd/pkg/utils/syncutil"
import (
"context"

type concurrencyLimiter struct {
"github.com/tikv/pd/pkg/utils/syncutil"
)

// ConcurrencyLimiter is a limiter that limits the number of concurrent tasks.
type ConcurrencyLimiter struct {
mu syncutil.RWMutex
current uint64
waiting uint64
limit uint64

// statistic
maxLimit uint64
queue chan *TaskToken
}

func newConcurrencyLimiter(limit uint64) *concurrencyLimiter {
return &concurrencyLimiter{limit: limit}
// NewConcurrencyLimiter creates a new ConcurrencyLimiter.
func NewConcurrencyLimiter(limit uint64) *ConcurrencyLimiter {
return &ConcurrencyLimiter{limit: limit, queue: make(chan *TaskToken, limit)}
}

const unlimit = uint64(0)

func (l *concurrencyLimiter) allow() bool {
// old interface. only used in the ratelimiter package.
func (l *ConcurrencyLimiter) allow() bool {
l.mu.Lock()
defer l.mu.Unlock()

Expand All @@ -45,7 +54,8 @@ func (l *concurrencyLimiter) allow() bool {
return false
}

func (l *concurrencyLimiter) release() {
// old interface. only used in the ratelimiter package.
func (l *ConcurrencyLimiter) release() {
l.mu.Lock()
defer l.mu.Unlock()

Expand All @@ -54,28 +64,32 @@ func (l *concurrencyLimiter) release() {
}
}

func (l *concurrencyLimiter) getLimit() uint64 {
// old interface. only used in the ratelimiter package.
func (l *ConcurrencyLimiter) getLimit() uint64 {
l.mu.RLock()
defer l.mu.RUnlock()

return l.limit
}

func (l *concurrencyLimiter) setLimit(limit uint64) {
// old interface. only used in the ratelimiter package.
func (l *ConcurrencyLimiter) setLimit(limit uint64) {
l.mu.Lock()
defer l.mu.Unlock()

l.limit = limit
}

func (l *concurrencyLimiter) getCurrent() uint64 {
// old interface. only used in the ratelimiter package.
func (l *ConcurrencyLimiter) getCurrent() uint64 {
l.mu.RLock()
defer l.mu.RUnlock()

return l.current
}

func (l *concurrencyLimiter) getMaxConcurrency() uint64 {
// old interface. only used in the ratelimiter package.
func (l *ConcurrencyLimiter) getMaxConcurrency() uint64 {
l.mu.Lock()
defer func() {
l.maxLimit = l.current
Expand All @@ -84,3 +98,66 @@ func (l *concurrencyLimiter) getMaxConcurrency() uint64 {

return l.maxLimit
}

// GetRunningTasksNum returns the number of running tasks.
func (l *ConcurrencyLimiter) GetRunningTasksNum() uint64 {
return l.getCurrent()
}

// GetWaitingTasksNum returns the number of waiting tasks.
func (l *ConcurrencyLimiter) GetWaitingTasksNum() uint64 {
l.mu.Lock()
defer l.mu.Unlock()
return l.waiting
}

// Acquire acquires a token from the limiter. which will block until a token is available or ctx is done, like Timeout.
func (l *ConcurrencyLimiter) Acquire(ctx context.Context) (*TaskToken, error) {
l.mu.Lock()
if l.current >= l.limit {
l.waiting++
l.mu.Unlock()
// block the waiting task on the caller goroutine
select {
case <-ctx.Done():
l.mu.Lock()
l.waiting--
l.mu.Unlock()
return nil, ctx.Err()
case token := <-l.queue:
l.mu.Lock()
token.released = false
l.current++
l.waiting--
l.mu.Unlock()
return token, nil
}
}
l.current++
token := &TaskToken{limiter: l}
l.mu.Unlock()
return token, nil
}

// TaskToken is a token that must be released after the task is done.
type TaskToken struct {
released bool
limiter *ConcurrencyLimiter
}

// Release releases the token.
func (tt *TaskToken) Release() {
tt.limiter.mu.Lock()
defer tt.limiter.mu.Unlock()
if tt.released {
return
}
if tt.limiter.current == 0 {
panic("release token more than acquire")
}
tt.released = true
tt.limiter.current--
if len(tt.limiter.queue) < int(tt.limiter.limit) {
tt.limiter.queue <- tt
}
}
76 changes: 75 additions & 1 deletion pkg/ratelimit/concurrency_limiter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,20 @@
package ratelimit

import (
"context"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func TestConcurrencyLimiter(t *testing.T) {
t.Parallel()
re := require.New(t)
cl := newConcurrencyLimiter(10)
cl := NewConcurrencyLimiter(10)
for i := 0; i < 10; i++ {
re.True(cl.allow())
}
Expand Down Expand Up @@ -52,3 +57,72 @@ func TestConcurrencyLimiter(t *testing.T) {
re.Equal(uint64(5), cl.getMaxConcurrency())
re.Equal(uint64(0), cl.getMaxConcurrency())
}

func TestConcurrencyLimiter2(t *testing.T) {
limit := uint64(2)
limiter := NewConcurrencyLimiter(limit)

require.Equal(t, uint64(0), limiter.GetRunningTasksNum(), "Expected running tasks to be 0")
require.Equal(t, uint64(0), limiter.GetWaitingTasksNum(), "Expected waiting tasks to be 0")

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// Acquire two tokens
token1, err := limiter.Acquire(ctx)
require.NoError(t, err, "Failed to acquire token")

token2, err := limiter.Acquire(ctx)
require.NoError(t, err, "Failed to acquire token")

require.Equal(t, limit, limiter.GetRunningTasksNum(), "Expected running tasks to be 2")

// Try to acquire third token, it should not be able to acquire immediately due to limit
go func() {
_, err := limiter.Acquire(ctx)
require.NoError(t, err, "Failed to acquire token")
}()

time.Sleep(100 * time.Millisecond) // Give some time for the goroutine to run
require.Equal(t, uint64(1), limiter.GetWaitingTasksNum(), "Expected waiting tasks to be 1")

// Release a token
token1.Release()
time.Sleep(100 * time.Millisecond) // Give some time for the goroutine to run
require.Equal(t, uint64(2), limiter.GetRunningTasksNum(), "Expected running tasks to be 2")
require.Equal(t, uint64(0), limiter.GetWaitingTasksNum(), "Expected waiting tasks to be 0")

// Release the second token
token2.Release()
time.Sleep(100 * time.Millisecond) // Give some time for the goroutine to run
require.Equal(t, uint64(1), limiter.GetRunningTasksNum(), "Expected running tasks to be 1")
}

func TestConcurrencyLimiterAcquire(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()

limiter := NewConcurrencyLimiter(20)
sum := int64(0)
start := time.Now()
wg := &sync.WaitGroup{}
wg.Add(100)
for i := 0; i < 100; i++ {
go func(i int) {
defer wg.Done()
token, err := limiter.Acquire(ctx)
if err != nil {
fmt.Printf("Task %d failed to acquire: %v\n", i, err)
return
}
defer token.Release()
// simulate takes some time
time.Sleep(10 * time.Millisecond)
atomic.AddInt64(&sum, 1)
}(i)
}
wg.Wait()
// We should have 20 tasks running concurrently, so it should take at least 50ms to complete
require.Greater(t, time.Since(start).Milliseconds(), int64(50))
require.Equal(t, int64(100), sum)
}
8 changes: 4 additions & 4 deletions pkg/ratelimit/limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,18 @@ type DimensionConfig struct {

type limiter struct {
mu syncutil.RWMutex
concurrency *concurrencyLimiter
concurrency *ConcurrencyLimiter
rate *RateLimiter
}

func newLimiter() *limiter {
lim := &limiter{
concurrency: newConcurrencyLimiter(0),
concurrency: NewConcurrencyLimiter(0),
}
return lim
}

func (l *limiter) getConcurrencyLimiter() *concurrencyLimiter {
func (l *limiter) getConcurrencyLimiter() *ConcurrencyLimiter {
l.mu.RLock()
defer l.mu.RUnlock()
return l.concurrency
Expand Down Expand Up @@ -101,7 +101,7 @@ func (l *limiter) updateConcurrencyConfig(limit uint64) UpdateStatus {
}
l.concurrency.setLimit(limit)
} else {
l.concurrency = newConcurrencyLimiter(limit)
l.concurrency = NewConcurrencyLimiter(limit)
}
return ConcurrencyChanged
}
Expand Down
Loading

0 comments on commit f71c46a

Please sign in to comment.