From 8bd86e2be4f60e7577aeb4002ca4b5d7e162527e Mon Sep 17 00:00:00 2001 From: roggervalf Date: Thu, 25 Apr 2024 11:25:06 -0500 Subject: [PATCH] feat(rate-limiter-redis): allow passing max duration as argv in lua script --- lib/RateLimiterRedis.js | 10 +++--- test/RateLimiterRedis.ioredis.test.js | 49 ++++++++++++++++++++++++++- test/RateLimiterRedis.redis.test.js | 48 ++++++++++++++++++++++++++ 3 files changed, 101 insertions(+), 6 deletions(-) diff --git a/lib/RateLimiterRedis.js b/lib/RateLimiterRedis.js index 98d52b1..dc8c60c 100644 --- a/lib/RateLimiterRedis.js +++ b/lib/RateLimiterRedis.js @@ -106,7 +106,7 @@ class RateLimiterRedis extends RateLimiterStoreAbstract { if (secDuration > 0) { if(!this.useRedisPackage && !this.useRedis3AndLowerPackage){ return this.client.rlflxIncr( - [rlKey].concat([String(points), String(secDuration), String(this.points)])); + [rlKey].concat([String(points), String(secDuration), String(this.points), String(this.duration)])); } if (this.useRedis3AndLowerPackage) { return new Promise((resolve, reject) => { @@ -119,15 +119,15 @@ class RateLimiterRedis extends RateLimiterStoreAbstract { }; if (typeof this.client.rlflxIncr === 'function') { - this.client.rlflxIncr(rlKey, points, secDuration, this.points, incrCallback); + this.client.rlflxIncr(rlKey, points, secDuration, this.points, this.duration, incrCallback); } else { - this.client.eval(this._incrTtlLuaScript, 1, rlKey, points, secDuration, this.points, incrCallback); + this.client.eval(this._incrTtlLuaScript, 1, rlKey, points, secDuration, this.points, this.duration, incrCallback); } }); - } else { + } else { return this.client.eval(this._incrTtlLuaScript, { keys: [rlKey], - arguments: [String(points), String(secDuration), String(this.points)], + arguments: [String(points), String(secDuration), String(this.points), String(this.duration)], }); } } else { diff --git a/test/RateLimiterRedis.ioredis.test.js b/test/RateLimiterRedis.ioredis.test.js index 247ae0d..44d0e87 100644 --- a/test/RateLimiterRedis.ioredis.test.js +++ b/test/RateLimiterRedis.ioredis.test.js @@ -75,7 +75,7 @@ describe('RateLimiterRedis with fixed window', function RateLimiterRedisTest() { else \ local maxPoints = tonumber(ARGV[3]) \ if maxPoints > 0 and (consumed-1) % maxPoints == 0 and not ok then \ - local expireTime = ttl + tonumber(ARGV[2]) * 1000 \ + local expireTime = ttl + tonumber(ARGV[4]) * 1000 \ redis.call('pexpire', KEYS[1], expireTime) \ return {consumed, expireTime} \ end \ @@ -104,6 +104,53 @@ describe('RateLimiterRedis with fixed window', function RateLimiterRedisTest() { done(err); }); }); + + describe('when passing custom duration', () => { + it('rejected when consume more than maximum points and multiply delay', (done) => { + const testKey = 'consume2'; + const rateLimiter = new RateLimiterRedis({ + storeClient: redisMockClient, + points: 1, + duration: 5, + customIncrTtlLuaScript: `local ok = redis.call('set', KEYS[1], 0, 'EX', ARGV[2], 'NX') \ + local consumed = redis.call('incrby', KEYS[1], ARGV[1]) \ + local ttl = redis.call('pttl', KEYS[1]) \ + if ttl == -1 then \ + redis.call('expire', KEYS[1], ARGV[2]) \ + ttl = 1000 * ARGV[2] \ + else \ + local maxPoints = tonumber(ARGV[3]) \ + if maxPoints > 0 and (consumed-1) % maxPoints == 0 and not ok then \ + local expireTime = ttl + tonumber(ARGV[4]) * 1000 \ + redis.call('pexpire', KEYS[1], expireTime) \ + return {consumed, expireTime} \ + end \ + end \ + return {consumed, ttl} \ + ` + }); + rateLimiter + .consume(testKey, 1, {customDuration: 1}) + .then(() => { + rateLimiter + .consume(testKey) + .then(() => {}) + .catch((rejRes) => { + expect(rejRes.msBeforeNext >= 1000).to.equal(true); + rateLimiter + .consume(testKey) + .then(() => {}) + .catch((rejRes2) => { + expect(rejRes2.msBeforeNext >= 6000).to.equal(true); + done(); + }); + }); + }) + .catch((err) => { + done(err); + }); + }); + }); }); it('execute evenly over duration', (done) => { diff --git a/test/RateLimiterRedis.redis.test.js b/test/RateLimiterRedis.redis.test.js index c1c5cb6..0f3cff9 100644 --- a/test/RateLimiterRedis.redis.test.js +++ b/test/RateLimiterRedis.redis.test.js @@ -104,6 +104,54 @@ describe('RateLimiterRedis with fixed window', function RateLimiterRedisTest() { done(err); }); }); + + describe('when passing custom duration', () => { + it('rejected when consume more than maximum points and multiply delay', (done) => { + const testKey = 'consume2'; + const rateLimiter = new RateLimiterRedis({ + storeClient: redisMockClient, + points: 1, + duration: 5, + customIncrTtlLuaScript: `local ok = redis.call('set', KEYS[1], 0, 'EX', ARGV[2], 'NX') \ + local consumed = redis.call('incrby', KEYS[1], ARGV[1]) \ + local ttl = redis.call('pttl', KEYS[1]) \ + if ttl == -1 then \ + redis.call('expire', KEYS[1], ARGV[2]) \ + ttl = 1000 * ARGV[2] \ + else \ + local maxPoints = tonumber(ARGV[3]) \ + if maxPoints > 0 and (consumed-1) % maxPoints == 0 and not ok then \ + local expireTime = ttl + tonumber(ARGV[4]) * 1000 \ + redis.call('pexpire', KEYS[1], expireTime) \ + return {consumed, expireTime} \ + end \ + end \ + return {consumed, ttl} \ + `, + useRedisPackage: true, + }); + rateLimiter + .consume(testKey, 1, {customDuration: 1}) + .then(() => { + rateLimiter + .consume(testKey) + .then(() => {}) + .catch((rejRes) => { + expect(rejRes.msBeforeNext >= 1000).to.equal(true); + rateLimiter + .consume(testKey) + .then(() => {}) + .catch((rejRes2) => { + expect(rejRes2.msBeforeNext >= 6000).to.equal(true); + done(); + }); + }); + }) + .catch((err) => { + done(err); + }); + }); + }); }); it('execute evenly over duration', (done) => {