From 89fad0f908f88fe2f9c25951fda43ac21429f66f Mon Sep 17 00:00:00 2001 From: gson Liang Date: Thu, 27 Mar 2025 00:05:06 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20handle=20concurrency=20with=20sync.?= =?UTF-8?q?Map?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 27 +++++++++++++++++++ internal/service/rate_limiter/rate-limiter.go | 15 +++++------ 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 1bb5f2b..cfce655 100644 --- a/README.md +++ b/README.md @@ -90,4 +90,31 @@ func (r *RateLimiter) RateLimiterMiddleware(next http.Handler, limit rate.Limit, next.ServeHTTP(w, req) }) } +``` + +## handle concurrency problem with sync.Map + +```golang +var ipLimiterMap sync.Map + +// RateLimiterMiddleware - 建立 ratelimiter middleware +func (r *RateLimiter) RateLimiterMiddleware(next http.Handler, limit rate.Limit, burst int) http.Handler { + + // var mu sync.Mutex + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + // Fetch IP + ip := r.getIP(req) + // Create limiter if not present for IP + limiterAny, _ := ipLimiterMap.LoadOrStore(ip, rate.NewLimiter(limit, burst)) + limiter := limiterAny.(*rate.Limiter) + // return error if the limit has been reached + if !limiter.Allow() { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusTooManyRequests) + json.NewEncoder(w).Encode(map[string]string{"error": "Too many requests"}) + return + } + next.ServeHTTP(w, req) + }) +} ``` \ No newline at end of file diff --git a/internal/service/rate_limiter/rate-limiter.go b/internal/service/rate_limiter/rate-limiter.go index e562916..69add5e 100644 --- a/internal/service/rate_limiter/rate-limiter.go +++ b/internal/service/rate_limiter/rate-limiter.go @@ -22,21 +22,18 @@ func (r *RateLimiter) getIP(req *http.Request) string { return host } +var ipLimiterMap sync.Map + // RateLimiterMiddleware - 建立 ratelimiter middleware func (r *RateLimiter) RateLimiterMiddleware(next http.Handler, limit rate.Limit, burst int) http.Handler { - ipLimiterMap := make(map[string]*rate.Limiter) - var mu sync.Mutex + + // var mu sync.Mutex return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { // Fetch IP ip := r.getIP(req) // Create limiter if not present for IP - mu.Lock() - limiter, exists := ipLimiterMap[ip] - if !exists { - limiter = rate.NewLimiter(limit, burst) - ipLimiterMap[ip] = limiter - } - mu.Unlock() + limiterAny, _ := ipLimiterMap.LoadOrStore(ip, rate.NewLimiter(limit, burst)) + limiter := limiterAny.(*rate.Limiter) // return error if the limit has been reached if !limiter.Allow() { w.Header().Set("Content-Type", "application/json")