diff --git a/src/RedisRateLimiting/FixedWindow/RedisFixedWindowManager.cs b/src/RedisRateLimiting/FixedWindow/RedisFixedWindowManager.cs index 651720a..9dff775 100644 --- a/src/RedisRateLimiting/FixedWindow/RedisFixedWindowManager.cs +++ b/src/RedisRateLimiting/FixedWindow/RedisFixedWindowManager.cs @@ -13,6 +13,8 @@ internal class RedisFixedWindowManager private static readonly LuaScript Script = LuaScript.Prepare( @"local expires_at = tonumber(redis.call(""get"", @expires_at_key)) + local limit = tonumber(@permit_limit) + local inc = tonumber(@increment_amount) if not expires_at or expires_at < tonumber(@current_time) then -- this is either a brand new window, @@ -29,11 +31,25 @@ internal class RedisFixedWindowManager expires_at = @next_expires_at end - -- now that the window either already exists or it was freshly initialized, + -- now that the window either already exists or it was freshly initialized -- increment the counter(`incrby` returns a number) - local current = redis.call(""incrby"", @rate_limit_key, @increment_amount) - return { current, expires_at }"); + local current = redis.call(""get"", @rate_limit_key) + + if not current then + current = 0 + else + current = tonumber(current) + end + + local allowed = current + inc <= limit + + if allowed then + current = redis.call(""incrby"", @rate_limit_key, inc) + end + + return { current, expires_at, allowed } + "); public RedisFixedWindowManager( string partitionKey, @@ -46,7 +62,7 @@ public RedisFixedWindowManager( RateLimitExpireKey = new RedisKey($"rl:fw:{{{partitionKey}}}:exp"); } - internal async Task TryAcquireLeaseAsync() + internal async Task TryAcquireLeaseAsync(int permitCount) { var now = DateTimeOffset.UtcNow; var nowUnixTimeSeconds = now.ToUnixTimeSeconds(); @@ -61,7 +77,8 @@ internal async Task TryAcquireLeaseAsync() expires_at_key = RateLimitExpireKey, next_expires_at = (RedisValue)now.Add(_options.Window).ToUnixTimeSeconds(), current_time = (RedisValue)nowUnixTimeSeconds, - increment_amount = (RedisValue)1D, + permit_limit = (RedisValue)_options.PermitLimit, + increment_amount = (RedisValue)permitCount, }); var result = new RedisFixedWindowResponse(); @@ -70,6 +87,7 @@ internal async Task TryAcquireLeaseAsync() { result.Count = (long)response[0]; result.ExpiresAt = (long)response[1]; + result.Allowed = (bool)response[2]; result.RetryAfter = TimeSpan.FromSeconds(result.ExpiresAt - nowUnixTimeSeconds); } @@ -112,5 +130,6 @@ internal class RedisFixedWindowResponse internal long ExpiresAt { get; set; } internal TimeSpan RetryAfter { get; set; } internal long Count { get; set; } + internal bool Allowed { get; set; } } } diff --git a/src/RedisRateLimiting/FixedWindow/RedisFixedWindowRateLimiter.cs b/src/RedisRateLimiting/FixedWindow/RedisFixedWindowRateLimiter.cs index efc9841..616adc2 100644 --- a/src/RedisRateLimiting/FixedWindow/RedisFixedWindowRateLimiter.cs +++ b/src/RedisRateLimiting/FixedWindow/RedisFixedWindowRateLimiter.cs @@ -57,7 +57,7 @@ protected override ValueTask AcquireAsyncCore(int permitCount, C throw new ArgumentOutOfRangeException(nameof(permitCount), permitCount, string.Format("{0} permit(s) exceeds the permit limit of {1}.", permitCount, _options.PermitLimit)); } - return AcquireAsyncCoreInternal(); + return AcquireAsyncCoreInternal(permitCount); } protected override RateLimitLease AttemptAcquireCore(int permitCount) @@ -66,7 +66,7 @@ protected override RateLimitLease AttemptAcquireCore(int permitCount) return FailedLease; } - private async ValueTask AcquireAsyncCoreInternal() + private async ValueTask AcquireAsyncCoreInternal(int permitCount) { var leaseContext = new FixedWindowLeaseContext { @@ -74,18 +74,13 @@ private async ValueTask AcquireAsyncCoreInternal() Window = _options.Window, }; - var response = await _redisManager.TryAcquireLeaseAsync(); + var response = await _redisManager.TryAcquireLeaseAsync(permitCount); leaseContext.Count = response.Count; leaseContext.RetryAfter = response.RetryAfter; leaseContext.ExpiresAt = response.ExpiresAt; - if (leaseContext.Count > _options.PermitLimit) - { - return new FixedWindowLease(isAcquired: false, leaseContext); - } - - return new FixedWindowLease(isAcquired: true, leaseContext); + return new FixedWindowLease(isAcquired: response.Allowed, leaseContext); } private sealed class FixedWindowLeaseContext diff --git a/test/RedisRateLimiting.Tests/UnitTests/FixedWindowUnitTests.cs b/test/RedisRateLimiting.Tests/UnitTests/FixedWindowUnitTests.cs index 12edf89..dd33180 100644 --- a/test/RedisRateLimiting.Tests/UnitTests/FixedWindowUnitTests.cs +++ b/test/RedisRateLimiting.Tests/UnitTests/FixedWindowUnitTests.cs @@ -79,5 +79,27 @@ public async Task CanAcquireAsyncResource() using var lease2 = await limiter.AcquireAsync(); Assert.False(lease2.IsAcquired); } + + [Fact] + public async Task CanAcquireMultiplePermits() + { + using var limiter = new RedisFixedWindowRateLimiter( + partitionKey: Guid.NewGuid().ToString(), + new RedisFixedWindowRateLimiterOptions + { + PermitLimit = 5, + Window = TimeSpan.FromMinutes(1), + ConnectionMultiplexerFactory = Fixture.ConnectionMultiplexerFactory, + }); + + using var lease = await limiter.AcquireAsync(permitCount: 3); + Assert.True(lease.IsAcquired); + + using var lease2 = await limiter.AcquireAsync(permitCount: 3); + Assert.False(lease2.IsAcquired); + + using var lease3 = await limiter.AcquireAsync(permitCount: 2); + Assert.True(lease3.IsAcquired); + } } }