Skip to content

Commit

Permalink
Merge pull request #113 from upstash/DX-989
Browse files Browse the repository at this point in the history
Add Reset Field to `getRemaining`
  • Loading branch information
CahidArda authored Jul 8, 2024
2 parents c264a32 + da08fec commit a239d58
Show file tree
Hide file tree
Showing 10 changed files with 78 additions and 23 deletions.
2 changes: 1 addition & 1 deletion examples/enable-protection/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"lint": "next lint"
},
"dependencies": {
"@upstash/ratelimit": "^1.2.0-canary",
"@upstash/ratelimit": "^1.2.1",
"@vercel/functions": "^1.0.2",
"next": "14.2.3",
"react": "^18",
Expand Down
3 changes: 2 additions & 1 deletion examples/with-vercel-kv/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@
"react-dom": "18.2.0",
"tailwindcss": "3.3.2",
"typescript": "5.0.4"
}
},
"packageManager": "[email protected]"
}
4 changes: 2 additions & 2 deletions src/deny-list/scripts.test.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { Redis } from "@upstash/redis";
import { beforeAll, beforeEach, describe, expect, test } from "bun:test";
import { beforeEach, describe, expect, test } from "bun:test";
import { DenyListExtension, IpDenyListStatusKey, IsDenied } from "../types";
import { checkDenyListScript } from "./scripts";
import { disableIpDenyList, updateIpDenyList } from "./deny-list-update";
import { disableIpDenyList, updateIpDenyList } from "./ip-deny-list";

describe("should manage state correctly", async () => {
const redis = Redis.fromEnv();
Expand Down
21 changes: 17 additions & 4 deletions src/getRemainingTokens.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,32 @@ function run<TContext extends Context>(builder: Ratelimit<TContext>) {
// Stop at any random request call within the limit
const stopAt = Math.floor(Math.random() * (limit - 1) + 1);
for (let i = 1; i <= limit; i++) {
await builder.limit(id);

const [limitResult, remainigResult] = await Promise.all([
builder.limit(id),
builder.getRemaining(id)
])

expect(limitResult.remaining).toBe(remainigResult.remaining)

Check failure on line 26 in src/getRemainingTokens.test.ts

View workflow job for this annotation

GitHub Actions / Tests

error: expect(received).toBe(expected)

Expected: 6 Received: 7 at /home/runner/work/ratelimit-js/ratelimit-js/src/getRemainingTokens.test.ts:26:39
expect(limitResult.reset).toBe(remainigResult.reset)
if (i == stopAt) {
break
}
}

const remaining = await builder.getRemaining(id);
const {remaining} = await builder.getRemaining(id);
expect(remaining).toBe(limit - stopAt);

Check failure on line 34 in src/getRemainingTokens.test.ts

View workflow job for this annotation

GitHub Actions / Tests

error: expect(received).toBe(expected)

Expected: 1 Received: 2 at /home/runner/work/ratelimit-js/ratelimit-js/src/getRemainingTokens.test.ts:34:25

Check failure on line 34 in src/getRemainingTokens.test.ts

View workflow job for this annotation

GitHub Actions / Tests

error: expect(received).toBe(expected)

Expected: 1 Received: 2 at /home/runner/work/ratelimit-js/ratelimit-js/src/getRemainingTokens.test.ts:34:25
}, 10000);
}, {
timeout: 10000,
retry: 3
});
});
}

function newRegion(limiter: Algorithm<RegionContext>): Ratelimit<RegionContext> {
return new RegionRatelimit({
prefix: crypto.randomUUID(),
redis: Redis.fromEnv(),
redis: Redis.fromEnv({enableAutoPipelining: true}),
limiter,
});
}
Expand All @@ -52,14 +62,17 @@ function newMultiRegion(limiter: Algorithm<MultiRegionContext>): Ratelimit<Multi
new Redis({
url: ensureEnv("EU2_UPSTASH_REDIS_REST_URL"),
token: ensureEnv("EU2_UPSTASH_REDIS_REST_TOKEN"),
enableAutoPipelining: true
}),
new Redis({
url: ensureEnv("APN_UPSTASH_REDIS_REST_URL"),
token: ensureEnv("APN_UPSTASH_REDIS_REST_TOKEN"),
enableAutoPipelining: true
}),
new Redis({
url: ensureEnv("US1_UPSTASH_REDIS_REST_URL"),
token: ensureEnv("US1_UPSTASH_REDIS_REST_TOKEN"),
enableAutoPipelining: true
}),
],
limiter,
Expand Down
8 changes: 5 additions & 3 deletions src/lua-scripts/single.ts
Original file line number Diff line number Diff line change
Expand Up @@ -120,17 +120,19 @@ export const tokenBucketLimitScript = `
return {remaining, refilledAt + interval}
`;

export const tokenBucketIdentifierNotFound = -1

export const tokenBucketRemainingTokensScript = `
local key = KEYS[1]
local maxTokens = tonumber(ARGV[1])
local bucket = redis.call("HMGET", key, "tokens")
local bucket = redis.call("HMGET", key, "refilledAt", "tokens")
if bucket[1] == false then
return maxTokens
return {maxTokens, ${tokenBucketIdentifierNotFound}}
end
return tonumber(bucket[1])
return {tonumber(bucket[2]), tonumber(bucket[1])}
`;

export const cachedFixedWindowLimitScript = `
Expand Down
10 changes: 8 additions & 2 deletions src/multi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,10 @@ export class MultiRegionRatelimit extends Ratelimit<MultiRegionContext> {
return accTokens + parsedToken;
}, 0);

return Math.max(0, tokens - usedTokens);
return {
remaining: Math.max(0, tokens - usedTokens),
reset: (bucket + 1) * windowDuration
};
},
async resetTokens(ctx: MultiRegionContext, identifier: string) {
const pattern = [identifier, "*"].join(":");
Expand Down Expand Up @@ -514,7 +517,10 @@ export class MultiRegionRatelimit extends Ratelimit<MultiRegionContext> {
}));

const usedTokens = await Promise.any(dbs.map((s) => s.request));
return Math.max(0, tokens - usedTokens);
return {
remaining: Math.max(0, tokens - usedTokens),
reset: (currentWindow + 1) * windowSize
};
},
async resetTokens(ctx: MultiRegionContext, identifier: string) {
const pattern = [identifier, "*"].join(":");
Expand Down
13 changes: 12 additions & 1 deletion src/ratelimit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,18 @@ export abstract class Ratelimit<TContext extends Context> {
await this.limiter().resetTokens(this.ctx, pattern);
};

public getRemaining = async (identifier: string): Promise<number> => {
/**
* Returns the remaining token count together with a reset timestamps
*
* @param identifier identifir to check
* @returns object with `remaining` and reset fields. `remaining` denotes
* the remaining tokens and reset denotes the timestamp when the
* tokens reset.
*/
public getRemaining = async (identifier: string): Promise<{
remaining: number;
reset: number;
}> => {
const pattern = [this.prefix, identifier].join(":");

return await this.limiter().getRemaining(this.ctx, pattern);
Expand Down
2 changes: 1 addition & 1 deletion src/resetUsedTokens.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ function run<TContext extends Context>(builder: Ratelimit<TContext>) {

// reset tokens
await builder.resetUsedTokens(id);
const remaining = await builder.getRemaining(id);
const {remaining} = await builder.getRemaining(id);
expect(remaining).toBe(limit);

Check failure on line 29 in src/resetUsedTokens.test.ts

View workflow job for this annotation

GitHub Actions / Tests

error: expect(received).toBe(expected)

Expected: 10 Received: 0 at /home/runner/work/ratelimit-js/ratelimit-js/src/resetUsedTokens.test.ts:29:25

Check failure on line 29 in src/resetUsedTokens.test.ts

View workflow job for this annotation

GitHub Actions / Tests

error: expect(received).toBe(expected)

Expected: 10 Received: 0 at /home/runner/work/ratelimit-js/ratelimit-js/src/resetUsedTokens.test.ts:29:25

Check failure on line 29 in src/resetUsedTokens.test.ts

View workflow job for this annotation

GitHub Actions / Tests

error: expect(received).toBe(expected)

Expected: 10 Received: 0 at /home/runner/work/ratelimit-js/ratelimit-js/src/resetUsedTokens.test.ts:29:25

Check failure on line 29 in src/resetUsedTokens.test.ts

View workflow job for this annotation

GitHub Actions / Tests

error: expect(received).toBe(expected)

Expected: 10 Received: 0 at /home/runner/work/ratelimit-js/ratelimit-js/src/resetUsedTokens.test.ts:29:25

Check failure on line 29 in src/resetUsedTokens.test.ts

View workflow job for this annotation

GitHub Actions / Tests

error: expect(received).toBe(expected)

Expected: 10 Received: 0 at /home/runner/work/ratelimit-js/ratelimit-js/src/resetUsedTokens.test.ts:29:25

Check failure on line 29 in src/resetUsedTokens.test.ts

View workflow job for this annotation

GitHub Actions / Tests

error: expect(received).toBe(expected)

Expected: 10 Received: 0 at /home/runner/work/ratelimit-js/ratelimit-js/src/resetUsedTokens.test.ts:29:25

Check failure on line 29 in src/resetUsedTokens.test.ts

View workflow job for this annotation

GitHub Actions / Tests

error: expect(received).toBe(expected)

Expected: 10 Received: 0 at /home/runner/work/ratelimit-js/ratelimit-js/src/resetUsedTokens.test.ts:29:25

Check failure on line 29 in src/resetUsedTokens.test.ts

View workflow job for this annotation

GitHub Actions / Tests

error: expect(received).toBe(expected)

Expected: 10 Received: 0 at /home/runner/work/ratelimit-js/ratelimit-js/src/resetUsedTokens.test.ts:29:25

Check failure on line 29 in src/resetUsedTokens.test.ts

View workflow job for this annotation

GitHub Actions / Tests

error: expect(received).toBe(expected)

Expected: 10 Received: 0 at /home/runner/work/ratelimit-js/ratelimit-js/src/resetUsedTokens.test.ts:29:25
}, 10000);
});
Expand Down
33 changes: 26 additions & 7 deletions src/single.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
fixedWindowRemainingTokensScript,
slidingWindowLimitScript,
slidingWindowRemainingTokensScript,
tokenBucketIdentifierNotFound,
tokenBucketLimitScript,
tokenBucketRemainingTokensScript,
} from "./lua-scripts/single";
Expand Down Expand Up @@ -214,7 +215,10 @@ export class RegionRatelimit extends Ratelimit<RegionContext> {
[null],
) as number;

return Math.max(0, tokens - usedTokens);
return {
remaining: Math.max(0, tokens - usedTokens),
reset: (bucket + 1) * windowDuration
};
},
async resetTokens(ctx: RegionContext, identifier: string) {
const pattern = [identifier, "*"].join(":");
Expand Down Expand Up @@ -322,7 +326,10 @@ export class RegionRatelimit extends Ratelimit<RegionContext> {
[now, windowSize],
) as number;

return Math.max(0, tokens - usedTokens);
return {
remaining: Math.max(0, tokens - usedTokens),
reset: (currentWindow + 1) * windowSize
}
},
async resetTokens(ctx: RegionContext, identifier: string) {
const pattern = [identifier, "*"].join(":");
Expand Down Expand Up @@ -416,15 +423,21 @@ export class RegionRatelimit extends Ratelimit<RegionContext> {
},
async getRemaining(ctx: RegionContext, identifier: string) {

const remainingTokens = await safeEval(
const [remainingTokens, refilledAt] = await safeEval(
ctx,
tokenBucketRemainingTokensScript,
"getRemainingHash",
[identifier],
[maxTokens],
) as number;
) as [number, number];

const freshRefillAt = Date.now() + intervalDuration
const identifierRefillsAt = refilledAt + intervalDuration

return remainingTokens;
return {
remaining: remainingTokens,
reset: refilledAt === tokenBucketIdentifierNotFound ? freshRefillAt : identifierRefillsAt
};
},
async resetTokens(ctx: RegionContext, identifier: string) {
const pattern = identifier;
Expand Down Expand Up @@ -541,7 +554,10 @@ export class RegionRatelimit extends Ratelimit<RegionContext> {
const hit = typeof ctx.cache.get(key) === "number";
if (hit) {
const cachedUsedTokens = ctx.cache.get(key) ?? 0;
return Math.max(0, tokens - cachedUsedTokens);
return {
remaining: Math.max(0, tokens - cachedUsedTokens),
reset: (bucket + 1) * windowDuration
};
}

const usedTokens = await safeEval(
Expand All @@ -551,7 +567,10 @@ export class RegionRatelimit extends Ratelimit<RegionContext> {
[key],
[null],
) as number;
return Math.max(0, tokens - usedTokens);
return {
remaining: Math.max(0, tokens - usedTokens),
reset: (bucket + 1) * windowDuration
};
},
async resetTokens(ctx: RegionContext, identifier: string) {
// Empty the cache
Expand Down
5 changes: 4 additions & 1 deletion src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@ export type Algorithm<TContext> = () => {
cache?: EphemeralCache;
},
) => Promise<RatelimitResponse>;
getRemaining: (ctx: TContext, identifier: string) => Promise<number>;
getRemaining: (ctx: TContext, identifier: string) => Promise<{
remaining: number,
reset: number
}>;
resetTokens: (ctx: TContext, identifier: string) => Promise<void>;
};

Expand Down

0 comments on commit a239d58

Please sign in to comment.