diff --git a/package.json b/package.json index 5764867..eed9a83 100644 --- a/package.json +++ b/package.json @@ -12,7 +12,7 @@ "fmt": "bunx @biomejs/biome check --apply ./src" }, "devDependencies": { - "@upstash/redis": "^1.28.3", + "@upstash/redis": "^1.31.5", "bun-types": "latest", "rome": "^11.0.0", "tsup": "^7.2.0", diff --git a/src/deny-list.test.ts b/src/deny-list/deny-list.test.ts similarity index 60% rename from src/deny-list.test.ts rename to src/deny-list/deny-list.test.ts index 5524817..d04c897 100644 --- a/src/deny-list.test.ts +++ b/src/deny-list/deny-list.test.ts @@ -1,8 +1,8 @@ -import { expect, test, describe, afterAll } from "bun:test"; +import { expect, test, describe, afterAll, beforeAll } from "bun:test"; import { Redis } from "@upstash/redis"; -import { Ratelimit } from "./index"; -import { checkDenyListCache, defaultDeniedResponse, resolveResponses } from "./deny-list"; -import { RatelimitResponseType } from "./types"; +import { Ratelimit } from "../index"; +import { checkDenyListCache, defaultDeniedResponse, resolveLimitPayload } from "./deny-list"; +import { DenyListResponse, RatelimitResponseType } from "../types"; test("should get expected response from defaultDeniedResponse", () => { @@ -20,8 +20,18 @@ test("should get expected response from defaultDeniedResponse", () => { }); }); +describe("should resolve ratelimit and deny list response", async () => { + const redis = Redis.fromEnv(); + const prefix = `test-resolve-prefix`; + + let callCount = 0; + const spyRedis = { + multi: () => { + callCount += 1; + return redis.multi(); + } + } -test.only("should override response in resolveResponses correctly", () => { const initialResponse = { success: true, limit: 100, @@ -31,8 +41,7 @@ test.only("should override response in resolveResponses correctly", () => { reason: undefined, deniedValue: undefined }; - - const denyListResponse = "testValue"; + const expectedResponse = { success: false, limit: 100, @@ -40,12 +49,52 @@ test.only("should override response in resolveResponses correctly", () => { reset: 60, pending: Promise.resolve(), reason: "denyList" as RatelimitResponseType, - deniedValue: denyListResponse + deniedValue: "testValue" }; - const response = resolveResponses([initialResponse, denyListResponse]); - expect(response).toEqual(expectedResponse); -}); + test("should update ip deny list when invalidIpDenyList is true", async () => { + let callCount = 0; + const spyRedis = { + multi: () => { + callCount += 1; + return redis.multi(); + } + } + + const denyListResponse: DenyListResponse = { + deniedValue: "testValue", + invalidIpDenyList: true + }; + + const response = resolveLimitPayload(spyRedis as Redis, prefix, [initialResponse, denyListResponse], 8); + await response.pending; + + expect(response).toEqual(expectedResponse); + expect(callCount).toBe(1) // calls multi once to store ips + }); + + test("should update ip deny list when invalidIpDenyList is true", async () => { + + let callCount = 0; + const spyRedis = { + multi: () => { + callCount += 1; + return redis.multi(); + } + } + + const denyListResponse: DenyListResponse = { + deniedValue: "testValue", + invalidIpDenyList: false + }; + + const response = resolveLimitPayload(spyRedis as Redis, prefix, [initialResponse, denyListResponse], 8); + await response.pending; + + expect(response).toEqual(expectedResponse); + expect(callCount).toBe(0) // doesn't call multi to update deny list + }); +}) describe("should reject in deny list", async () => { @@ -53,18 +102,22 @@ describe("should reject in deny list", async () => { const prefix = `test-prefix`; const denyListKey = [prefix, "denyList", "all"].join(":"); - // Insert a value into the deny list - await redis.sadd(denyListKey, "denyIdentifier", "denyIp", "denyAgent", "denyCountry"); const ratelimit = new Ratelimit({ redis, limiter: Ratelimit.tokenBucket(10, "5 s", 10), prefix, - enableProtection: true + enableProtection: true, + denyListThreshold: 8 }); afterAll(async () => { - redis.del(denyListKey) + await redis.del(denyListKey) + }) + + // Insert a value into the deny list + beforeAll(async () => { + await redis.sadd(denyListKey, "denyIdentifier", "denyIp", "denyAgent", "denyCountry"); }) test("should allow with values not in the deny list", async () => { diff --git a/src/deny-list.ts b/src/deny-list/deny-list.ts similarity index 62% rename from src/deny-list.ts rename to src/deny-list/deny-list.ts index 8fc331a..a7520f8 100644 --- a/src/deny-list.ts +++ b/src/deny-list/deny-list.ts @@ -1,6 +1,8 @@ -import { DeniedValue, LimitPayload, Redis } from "./types" -import { RatelimitResponse } from "./types" -import { Cache } from "./cache"; +import { DeniedValue, DenyListResponse, DenyListExtension, LimitPayload, IpDenyListStatusKey } from "../types" +import { RatelimitResponse, Redis } from "../types" +import { Cache } from "../cache"; +import { checkDenyListScript } from "./scripts"; +import { updateIpDenyList } from "./ip-deny-list"; const denyListCache = new Cache(new Map()); @@ -46,21 +48,28 @@ export const checkDenyList = async ( redis: Redis, prefix: string, members: string[] -): Promise => { - const deniedMembers = await redis.smismember( - [prefix, "denyList", "all"].join(":"), +): Promise => { + const [ deniedValues, ipDenyListStatus ] = await redis.eval( + checkDenyListScript, + [ + [prefix, DenyListExtension, "all"].join(":"), + [prefix, IpDenyListStatusKey].join(":"), + ], members - ); + ) as [boolean[], number]; - let deniedMember: DeniedValue = undefined; - deniedMembers.map((memberDenied, index) => { + let deniedValue: DeniedValue = undefined; + deniedValues.map((memberDenied, index) => { if (memberDenied) { blockMember(members[index]) - deniedMember = members[index] + deniedValue = members[index] } }) - return deniedMember; + return { + deniedValue, + invalidIpDenyList: ipDenyListStatus === -2 + }; }; /** @@ -71,15 +80,28 @@ export const checkDenyList = async ( * @param denyListResponse * @returns */ -export const resolveResponses = ( - [ratelimitResponse, denyListResponse]: LimitPayload +export const resolveLimitPayload = ( + redis: Redis, + prefix: string, + [ratelimitResponse, denyListResponse]: LimitPayload, + threshold: number ): RatelimitResponse => { - if (denyListResponse) { + + if (denyListResponse.deniedValue) { ratelimitResponse.success = false; ratelimitResponse.remaining = 0; ratelimitResponse.reason = "denyList"; - ratelimitResponse.deniedValue = denyListResponse + ratelimitResponse.deniedValue = denyListResponse.deniedValue + } + + if (denyListResponse.invalidIpDenyList) { + const updatePromise = updateIpDenyList(redis, prefix, threshold) + ratelimitResponse.pending = Promise.all([ + ratelimitResponse.pending, + updatePromise + ]) } + return ratelimitResponse; }; diff --git a/src/deny-list/index.ts b/src/deny-list/index.ts new file mode 100644 index 0000000..18944c5 --- /dev/null +++ b/src/deny-list/index.ts @@ -0,0 +1 @@ +export * from "./deny-list" \ No newline at end of file diff --git a/src/deny-list/integration.test.ts b/src/deny-list/integration.test.ts new file mode 100644 index 0000000..c7c52f0 --- /dev/null +++ b/src/deny-list/integration.test.ts @@ -0,0 +1,202 @@ +// test ip deny list from the highest level, using Ratelimit +import { expect, test, describe, afterAll, beforeEach } from "bun:test"; +import { Ratelimit } from "../index"; +import { Redis } from "@upstash/redis"; +import { DenyListExtension, IpDenyListKey, IpDenyListStatusKey, RatelimitResponse } from "../types"; +import { disableIpDenyList } from "./ip-deny-list"; + +describe("should reject in deny list", async () => { + + const redis = Redis.fromEnv(); + const prefix = `test-integration-prefix`; + const statusKey = [prefix, IpDenyListStatusKey].join(":") + const allDenyListsKey = [prefix, DenyListExtension, "all"].join(":"); + const ipDenyListsKey = [prefix, DenyListExtension, IpDenyListKey].join(":"); + + const ratelimit = new Ratelimit({ + redis, + limiter: Ratelimit.tokenBucket(10, "5 s", 10), + prefix, + enableProtection: true, + denyListThreshold: 8 + }); + + beforeEach(async () => { + await redis.flushdb() + // adding different values to avoid the deny list cache + await redis.sadd(allDenyListsKey, "foo", "albatros", "penguin"); + }); + + test("should not check deny list when enableProtection: false", async () => { + const ratelimit = new Ratelimit({ + redis, + limiter: Ratelimit.tokenBucket(10, "5 s", 10), + prefix, + enableProtection: false, + denyListThreshold: 8 + }); + + const result = await ratelimit.limit("albatros") + expect(result.success).toBeTrue() + + const [status, statusTTL, allSize, ipListsize] = await Promise.all([ + redis.get(statusKey), + redis.ttl(statusKey), + redis.scard(allDenyListsKey), + redis.scard(ipDenyListsKey), + ]) + + // no status flag + expect(status).toBe(null) + expect(statusTTL).toBe(-2) + expect(allSize).toBe(3) // foo + albatros + penguin + expect(ipListsize).toBe(0) + }) + + test("should create ip denylist when enableProtection: true and not disabled", async () => { + const { pending, success } = await ratelimit.limit("penguin"); + expect(success).toBeFalse() + await pending; + + const [status, statusTTL, allSize, ipListsize] = await Promise.all([ + redis.get(statusKey), + redis.ttl(statusKey), + redis.scard(allDenyListsKey), + redis.scard(ipDenyListsKey), + ]) + + // status flag exists and has ttl + expect(status).toBe("valid") + expect(statusTTL).toBeGreaterThan(1000) + expect(allSize).toBeGreaterThan(0) + expect(ipListsize).toBe(allSize-3) // foo + albatros + penguin + }) + + test("should not create ip denylist when enableProtection: true but flag is disabled", async () => { + await disableIpDenyList(redis, prefix); + const { pending, success } = await ratelimit.limit("test-user-2"); + expect(success).toBeTrue() + await pending; + + const [status, statusTTL, allSize, ipListsize] = await Promise.all([ + redis.get(statusKey), + redis.ttl(statusKey), + redis.scard(allDenyListsKey), + redis.scard(ipDenyListsKey), + ]) + + // no status flag + expect(status).toBe("disabled") + expect(statusTTL).toBe(-1) + expect(allSize).toBe(3) // foo + albatros + penguin + expect(ipListsize).toBe(0) + }) + + test("should observe that ip denylist is deleted after disabling", async () => { + const { pending, success } = await ratelimit.limit("test-user-3"); + expect(success).toBeTrue() + await pending; + + const [status, statusTTL, allSize, ipListsize] = await Promise.all([ + redis.get(statusKey), + redis.ttl(statusKey), + redis.scard(allDenyListsKey), + redis.scard(ipDenyListsKey), + ]) + + // status flag exists and has ttl + expect(status).toBe("valid") + expect(statusTTL).toBeGreaterThan(1000) + expect(allSize).toBeGreaterThan(0) + expect(ipListsize).toBe(allSize-3) // foo + albatros + penguin + + // DISABLE: called from UI + await disableIpDenyList(redis, prefix); + + // call again + const { pending: newPending } = await ratelimit.limit("test-user"); + await newPending; + + const [newStatus, newStatusTTL, newAllSize, newIpListsize] = await Promise.all([ + redis.get(statusKey), + redis.ttl(statusKey), + redis.scard(allDenyListsKey), + redis.scard(ipDenyListsKey), + ]) + + // status flag exists and has ttl + expect(newStatus).toBe("disabled") + expect(newStatusTTL).toBe(-1) + expect(newAllSize).toBe(3) // foo + albatros + penguin + expect(newIpListsize).toBe(0) + }) + + test("should intialize ip list only once when called consecutively", async () => { + + const requests: RatelimitResponse[] = await Promise.all([ + ratelimit.limit("test-user-X"), + ratelimit.limit("test-user-Y") + ]) + + expect(requests[0].success).toBeTrue() + expect(requests[1].success).toBeTrue() + + // wait for both to finish + const result = await Promise.all([ + requests[0].pending, + requests[1].pending + ]) + /** + * Result is like this: + * [ + * undefined, + * [ + * undefined, + * [ 1, 0, 74, 74, 75, "OK" ] + * ] + * ] + * + * the first is essentially: + * >> Promise.resolve() + * + * Second one is + * >> Promise.all([Promise.resolve(), updateIpDenyListPromise]) + * + * This means that even though the requests were consecutive, only one was + * allowed to update to update the ip list! + */ + + // only one undefined + expect(result.filter((value) => value === undefined).length).toBe(1) + + // other response is defined + const definedResponse = result.filter((value) => value !== undefined)[0] as [undefined, any[]] + expect(definedResponse[0]).toBe(undefined) + expect(definedResponse[1].length).toBe(6) + expect(definedResponse[1][1]).toBe(0) // deleting deny list fails because there is none + expect(definedResponse[1][5]).toBe("OK") // setting TTL returns OK + }) + + test("should block ips from ip deny list", async () => { + const { pending, success } = await ratelimit.limit("test-user"); + expect(success).toBeTrue() + await pending; + + const [ip1, ip2] = await redis.srandmember(ipDenyListsKey, 2) as string[] + + const result = await ratelimit.limit("test-user", {ip: ip1}) + expect(result.success).toBeFalse() + expect(result.reason).toBe("denyList") + + await disableIpDenyList(redis, prefix); + + // first one still returns false because it is cached + const newResult = await ratelimit.limit("test-user", {ip: ip1}) + expect(newResult.success).toBeFalse() + expect(newResult.reason).toBe("denyList") + + // other one returns true + const otherResult = await ratelimit.limit("test-user", {ip: ip2}) + expect(otherResult.success).toBeTrue() + }) +}) \ No newline at end of file diff --git a/src/deny-list/ip-deny-list.test.ts b/src/deny-list/ip-deny-list.test.ts new file mode 100644 index 0000000..7ac98a1 --- /dev/null +++ b/src/deny-list/ip-deny-list.test.ts @@ -0,0 +1,180 @@ +import { Redis } from "@upstash/redis"; +import { beforeEach, describe, expect, test } from "bun:test"; +import { checkDenyList } from "./deny-list"; +import { disableIpDenyList, updateIpDenyList } from "./ip-deny-list"; +import { DenyListExtension, IpDenyListKey, IpDenyListStatusKey } from "../types"; + +describe("should update ip deny list status", async () => { + const redis = Redis.fromEnv(); + const prefix = `test-ip-list-prefix`; + const allDenyListsKey = [prefix, DenyListExtension, "all"].join(":"); + const ipDenyListsKey = [prefix, DenyListExtension, IpDenyListKey].join(":"); + const statusKey = [prefix, IpDenyListStatusKey].join(":") + + beforeEach(async () => { + await redis.flushdb() + await redis.sadd( + allDenyListsKey, "foo", "bar") + }); + + test("should return invalidIpDenyList: true when empty", async () => { + const { deniedValue, invalidIpDenyList } = await checkDenyList( + redis, prefix, ["foo", "bar"] + ) + + expect(deniedValue).toBe("bar") + expect(invalidIpDenyList).toBeTrue() + }) + + test("should return invalidIpDenyList: false when disabled", async () => { + await disableIpDenyList(redis, prefix); + const { deniedValue, invalidIpDenyList } = await checkDenyList( + redis, prefix, ["bar", "foo"] + ) + + expect(deniedValue).toBe("foo") + expect(invalidIpDenyList).toBeFalse() + }) + + test("should return invalidIpDenyList: false after updating", async () => { + await updateIpDenyList(redis, prefix, 8); + const { deniedValue, invalidIpDenyList } = await checkDenyList( + redis, prefix, ["whale", "albatros"] + ) + + expect(typeof deniedValue).toBe("undefined") + expect(invalidIpDenyList).toBeFalse() + }) + + test("should return invalidIpDenyList: false after updating + disabling", async () => { + + // initial values + expect(await redis.ttl(statusKey)).toBe(-2) + const initialStatus = await redis.get(statusKey) + expect(initialStatus).toBe(null) + + // UPDATE + await updateIpDenyList(redis, prefix, 8); + const { deniedValue, invalidIpDenyList } = await checkDenyList( + redis, prefix, ["user"] + ) + + expect(typeof deniedValue).toBe("undefined") + expect(invalidIpDenyList).toBeFalse() + // positive tll on the status key + expect(await redis.ttl(statusKey)).toBeGreaterThan(0) + const status = await redis.get(statusKey) + expect(status).toBe("valid") + + // DISABLE + await disableIpDenyList(redis, prefix); + const { + deniedValue: secondDeniedValue, + invalidIpDenyList: secondInvalidIpDenyList + } = await checkDenyList( + redis, prefix, ["foo", "bar"] + ) + + expect(secondDeniedValue).toBe("bar") + expect(secondInvalidIpDenyList).toBeFalse() + // -1 in the status key + expect(await redis.ttl(statusKey)).toBe(-1) + const secondStatus = await redis.get(statusKey) + expect(secondStatus).toBe("disabled") + }) + + test("should handle timeout correctly", async () => { + + await updateIpDenyList(redis, prefix, 8, 5_000); // update with 5 seconds ttl on status flag + const pipeline = redis.multi() + pipeline.smembers(allDenyListsKey) + pipeline.smembers(ipDenyListsKey) + pipeline.get(statusKey) + pipeline.ttl(statusKey) + + const [allValues, ipDenyListValues, status, statusTTL]: [string[], string[], string | null, number] = await pipeline.exec(); + expect(ipDenyListValues.length).toBeGreaterThan(0) + expect(allValues.length).toBe(ipDenyListValues.length + 2) // + 2 for foo and bar + expect(status).toBe("valid") + expect(statusTTL).toBeGreaterThan(2) // ttl is more than 5 seconds + + // wait 6 seconds + await new Promise((r) => setTimeout(r, 6_000)); + + const [newAllValues, newIpDenyListValues, newStatus, newStatusTTL]: [string[], string[], string | null, number] = await pipeline.exec(); + + // deny lists remain as they are + expect(newIpDenyListValues.length).toBeGreaterThan(0) + expect(newAllValues.length).toBe(allValues.length) + expect(newIpDenyListValues.length).toBe(ipDenyListValues.length) + + // status flag is gone + expect(newStatus).toBe(null) + expect(newStatusTTL).toBe(-2) + }, { timeout: 10_000 }) + + test("should overwrite disabled status with updateIpDenyList", async () => { + await disableIpDenyList(redis, prefix); + + const pipeline = redis.multi() + pipeline.smembers(allDenyListsKey) + pipeline.smembers(ipDenyListsKey) + pipeline.get(statusKey) + pipeline.ttl(statusKey) + + const [allValues, ipDenyListValues, status, statusTTL]: [string[], string[], string | null, number] = await pipeline.exec(); + expect(ipDenyListValues.length).toBe(0) + expect(allValues.length).toBe(2) // + 2 for foo and bar + expect(status).toBe("disabled") + expect(statusTTL).toBe(-1) + + // update status: called from UI or from SDK when status key expires + await updateIpDenyList(redis, prefix, 8); + + const [newAllValues, newIpDenyListValues, newStatus, newStatusTTL]: [string[], string[], string | null, number] = await pipeline.exec(); + + // deny lists remain as they are + expect(newIpDenyListValues.length).toBeGreaterThan(0) + expect(newAllValues.length).toBe(newIpDenyListValues.length + 2) + expect(newStatus).toBe("valid") + expect(newStatusTTL).toBeGreaterThan(1000) + }) +}) + +describe("should only allow threshold values from 1 to 8", async () => { + const redis = Redis.fromEnv(); + const prefix = `test-ip-list-prefix`; + + test("should reject string", async () => { + try { + // @ts-expect-error + await updateIpDenyList(redis, prefix, "test") + } catch (error: any) { + expect(error.name).toEqual("ThresholdError") + } + }) + + test("should reject 0", async () => { + try { + await updateIpDenyList(redis, prefix, 0) + } catch (error: any) { + expect(error.name).toEqual("ThresholdError") + } + }) + + test("should reject negative", async () => { + try { + await updateIpDenyList(redis, prefix, -1) + } catch (error: any) { + expect(error.name).toEqual("ThresholdError") + } + }) + + test("should reject 9", async () => { + try { + await updateIpDenyList(redis, prefix, 9) + } catch (error: any) { + expect(error.name).toEqual("ThresholdError") + } + }) +}) \ No newline at end of file diff --git a/src/deny-list/ip-deny-list.ts b/src/deny-list/ip-deny-list.ts new file mode 100644 index 0000000..d64221b --- /dev/null +++ b/src/deny-list/ip-deny-list.ts @@ -0,0 +1,127 @@ +import { DenyListExtension, IpDenyListKey, IpDenyListStatusKey, Redis } from "../types" +import { getIpListTTL } from "./time" + +const baseUrl = "https://raw.githubusercontent.com/stamparm/ipsum/master/levels" + +export class ThresholdError extends Error { + constructor(threshold: number) { + super(`Allowed threshold values are from 1 to 8, 1 and 8 included. Received: ${threshold}`); + this.name = "ThresholdError"; + } +} + +/** + * Fetches the ips from the ipsum.txt at github + * + * In the repo we are using, 30+ ip lists are aggregated. The results are + * stores in text files from 1 to 8. + * https://github.com/stamparm/ipsum/tree/master/levels + * + * X.txt file holds ips which are in at least X of the lists. + * + * @param threshold ips with less than or equal to the threshold are not included + * @returns list of ips + */ +const getIpDenyList = async (threshold: number) => { + if (typeof threshold !== "number" || threshold < 1 || threshold > 8) { + throw new ThresholdError(threshold) + } + + try { + // Fetch data from the URL + const response = await fetch(`${baseUrl}/${threshold}.txt`) + if (!response.ok) { + throw new Error(`Error fetching data: ${response.statusText}`) + } + const data = await response.text() + + // Process the data + const lines = data.split("\n") + return lines.filter((value) => value.length > 0) // remove empty values + } catch (error) { + throw new Error(`Failed to fetch ip deny list: ${error}`) + } +} + +/** + * Gets the list of ips from the github source which are not in the + * deny list already + * + * First, gets the ip list from github using the threshold. Then, calls redis with + * a transaction which does the following: + * - subtract the current ip deny list from all + * - delete current ip deny list + * - recreate ip deny list with the ips from github. Ips already in the users own lists + * are excluded. + * - status key is set to valid with ttl until next 2 AM UTC, which is a bit later than + * when the list is updated on github. + * + * @param redis redis instance + * @param prefix ratelimit prefix + * @param threshold ips with less than or equal to the threshold are not included + * @param ttl time to live in milliseconds for the status flag. Optional. If not + * passed, ttl is infferred from current time. + * @returns list of ips which are not in the deny list + */ +export const updateIpDenyList = async ( + redis: Redis, + prefix: string, + threshold: number, + ttl?: number +) => { + const allIps = await getIpDenyList(threshold) + + const allDenyLists = [prefix, DenyListExtension, "all"].join(":") + const ipDenyList = [prefix, DenyListExtension, IpDenyListKey].join(":") + const statusKey = [prefix, IpDenyListStatusKey].join(":") + + const transaction = redis.multi() + + // remove the old ip deny list from the all set + transaction.sdiffstore(allDenyLists, allDenyLists, ipDenyList) + + // delete the old ip deny list and create new one + transaction.del(ipDenyList) + transaction.sadd(ipDenyList, ...allIps) + + // make all deny list and ip deny list disjoint by removing duplicate + // ones from ip deny list + transaction.sdiffstore(ipDenyList, ipDenyList, allDenyLists) + + // add remaining ips to all list + transaction.sunionstore(allDenyLists, allDenyLists, ipDenyList) + + // set status key with ttl + transaction.set(statusKey, "valid", {px: ttl ?? getIpListTTL()}) + + return await transaction.exec() +} + +/** + * Disables the ip deny list by removing the ip deny list from the all + * set and removing the ip deny list. Also sets the status key to disabled + * with no ttl. + * + * @param redis redis instance + * @param prefix ratelimit prefix + * @returns + */ +export const disableIpDenyList = async (redis: Redis, prefix: string) => { + const allDenyListsKey = [prefix, DenyListExtension, "all"].join(":") + const ipDenyListKey = [prefix, DenyListExtension, IpDenyListKey].join(":") + const statusKey = [prefix, IpDenyListStatusKey].join(":") + + const transaction = redis.multi() + + // remove the old ip deny list from the all set + transaction.sdiffstore(allDenyListsKey, allDenyListsKey, ipDenyListKey) + + // delete the old ip deny list + transaction.del(ipDenyListKey) + + // set to disabled + // this way, the TTL command in checkDenyListScript will return -1. + transaction.set(statusKey, "disabled") + + return await transaction.exec() +} diff --git a/src/deny-list/scripts.test.ts b/src/deny-list/scripts.test.ts new file mode 100644 index 0000000..43ebbdf --- /dev/null +++ b/src/deny-list/scripts.test.ts @@ -0,0 +1,98 @@ +import { Redis } from "@upstash/redis"; +import { beforeAll, beforeEach, describe, expect, test } from "bun:test"; +import { DenyListExtension, IpDenyListStatusKey, IsDenied } from "../types"; +import { checkDenyListScript } from "./scripts"; +import { disableIpDenyList, updateIpDenyList } from "./deny-list-update"; + +describe("should manage state correctly", async () => { + const redis = Redis.fromEnv(); + const prefix = `test-script-prefix`; + + const allDenyListsKey = [prefix, DenyListExtension, "all"].join(":"); + const ipDenyListStatusKey = [prefix, IpDenyListStatusKey].join(":"); + + beforeEach(async () => { + await redis.flushdb() + await redis.sadd( + allDenyListsKey, "foo", "bar") + }); + + test("should return status: -2 initially", async () => { + const [isMember, status] = await redis.eval( + checkDenyListScript, + [allDenyListsKey, ipDenyListStatusKey], + ["whale", "foo", "bar", "zed"] + ) as [IsDenied[], number]; + + expect(isMember).toEqual([0, 1, 1, 0]) + expect(status).toBe(-2) + }) + + test("should return status: -1 when disabled", async () => { + await disableIpDenyList(redis, prefix); + const [isMember, status] = await redis.eval( + checkDenyListScript, + [allDenyListsKey, ipDenyListStatusKey], + ["whale", "foo", "bar", "zed"] + ) as [IsDenied[], number]; + + expect(isMember).toEqual([0, 1, 1, 0]) + expect(status).toBe(-1) + }) + + test("should return status: number after update", async () => { + await updateIpDenyList(redis, prefix, 8); + const [isMember, status] = await redis.eval( + checkDenyListScript, + [allDenyListsKey, ipDenyListStatusKey], + ["foo", "whale", "bar", "zed"] + ) as [IsDenied[], number]; + + expect(isMember).toEqual([1, 0, 1, 0]) + expect(status).toBeGreaterThan(1000) + }) + + test("should return status: -1 after update and disable", async () => { + await updateIpDenyList(redis, prefix, 8); + await disableIpDenyList(redis, prefix); + const [isMember, status] = await redis.eval( + checkDenyListScript, + [allDenyListsKey, ipDenyListStatusKey], + ["foo", "whale", "bar", "zed"] + ) as [IsDenied[], number]; + + expect(isMember).toEqual([1, 0, 1, 0]) + expect(status).toBe(-1) + }) + + test("should only make one of two consecutive requests update deny list", async () => { + + // running the eval script consecutively when the deny list needs + // to be updated. Only one will update the ip list. It will be + // given 30 seconds before its turn expires. Until then, other requests + // will continue using the old ip deny list + const response = await Promise.all([ + redis.eval( + checkDenyListScript, + [allDenyListsKey, ipDenyListStatusKey], + ["foo", "whale", "bar", "zed"] + ) as Promise<[IsDenied[], number]>, + redis.eval( + checkDenyListScript, + [allDenyListsKey, ipDenyListStatusKey], + ["foo", "whale", "bar", "zed"] + ) as Promise<[IsDenied[], number]> + ]); + + // first request is told that there is no valid ip list (ttl: -2), + // hence it will update the ip deny list + expect(response[0]).toEqual([[1, 0, 1, 0], -2]) + + // second request is told that there is already a valid ip list + // with ttl 30. + expect(response[1]).toEqual([[1, 0, 1, 0], 30]) + + const state = await redis.get(ipDenyListStatusKey) + expect(state).toBe("pending") + }) +}) \ No newline at end of file diff --git a/src/deny-list/scripts.ts b/src/deny-list/scripts.ts new file mode 100644 index 0000000..a17c727 --- /dev/null +++ b/src/deny-list/scripts.ts @@ -0,0 +1,26 @@ +export const checkDenyListScript = ` + -- Checks if values provideed in ARGV are present in the deny lists. + -- This is done using the allDenyListsKey below. + + -- Additionally, checks the status of the ip deny list using the + -- ipDenyListStatusKey below. Here are the possible states of the + -- ipDenyListStatusKey key: + -- * status == -1: set to "disabled" with no TTL + -- * status == -2: not set, meaning that is was set before but expired + -- * status > 0: set to "valid", with a TTL + -- + -- In the case of status == -2, we set the status to "pending" with + -- 30 second ttl. During this time, the process which got status == -2 + -- will update the ip deny list. + + local allDenyListsKey = KEYS[1] + local ipDenyListStatusKey = KEYS[2] + + local results = redis.call('SMISMEMBER', allDenyListsKey, unpack(ARGV)) + local status = redis.call('TTL', ipDenyListStatusKey) + if status == -2 then + redis.call('SETEX', ipDenyListStatusKey, 30, "pending") + end + + return { results, status } +` \ No newline at end of file diff --git a/src/deny-list/time.test.ts b/src/deny-list/time.test.ts new file mode 100644 index 0000000..b343346 --- /dev/null +++ b/src/deny-list/time.test.ts @@ -0,0 +1,46 @@ +import { getIpListTTL } from './time'; +import { beforeAll, beforeEach, describe, expect, test } from "bun:test"; + +describe('getIpListTTL', () => { + test('returns correct TTL when it is before 2 AM UTC', () => { + const before2AM = Date.UTC(2024, 5, 12, 1, 0, 0); // June 12, 2024, 1:00 AM UTC + const expectedTTL = 1 * 60 * 60 * 1000; // 1 hour in milliseconds + + expect(getIpListTTL(before2AM)).toBe(expectedTTL); + }); + + test('returns correct TTL when it is exactly 2 AM UTC', () => { + const exactly2AM = Date.UTC(2024, 5, 12, 2, 0, 0); // June 12, 2024, 2:00 AM UTC + const expectedTTL = 24 * 60 * 60 * 1000; // 24 hours in milliseconds + + expect(getIpListTTL(exactly2AM)).toBe(expectedTTL); + }); + + test('returns correct TTL when it is after 2 AM UTC but before the next 2 AM UTC', () => { + const after2AM = Date.UTC(2024, 5, 12, 3, 0, 0); // June 12, 2024, 3:00 AM UTC + const expectedTTL = 23 * 60 * 60 * 1000; // 23 hours in milliseconds + + expect(getIpListTTL(after2AM)).toBe(expectedTTL); + }); + + test('returns correct TTL when it is much later in the day', () => { + const laterInDay = Date.UTC(2024, 5, 12, 20, 0, 0); // June 12, 2024, 8:00 PM UTC + const expectedTTL = 6 * 60 * 60 * 1000; // 6 hours in milliseconds + + expect(getIpListTTL(laterInDay)).toBe(expectedTTL); + }); + + test('returns correct TTL when it is exactly the next day', () => { + const nextDay = Date.UTC(2024, 5, 13, 2, 0, 0); // June 13, 2024, 2:00 AM UTC + const expectedTTL = 24 * 60 * 60 * 1000; // 24 hours in milliseconds + + expect(getIpListTTL(nextDay)).toBe(expectedTTL); + }); + + test('returns correct TTL when no time is provided (uses current time)', () => { + const now = Date.now(); + const expectedTTL = getIpListTTL(now); + + expect(getIpListTTL()).toBe(expectedTTL); + }); +}); diff --git a/src/deny-list/time.ts b/src/deny-list/time.ts new file mode 100644 index 0000000..9c0f2f9 --- /dev/null +++ b/src/deny-list/time.ts @@ -0,0 +1,20 @@ + +// Number of milliseconds in one hour +const MILLISECONDS_IN_HOUR = 60 * 60 * 1000; + +// Number of milliseconds in one day +const MILLISECONDS_IN_DAY = 24 * MILLISECONDS_IN_HOUR; + +// Number of milliseconds from the current time to 2 AM UTC +const MILLISECONDS_TO_2AM = 2 * MILLISECONDS_IN_HOUR; + +export const getIpListTTL = (time?: number) => { + const now = time || Date.now(); + + // Time since the last 2 AM UTC + const timeSinceLast2AM = (now - MILLISECONDS_TO_2AM) % MILLISECONDS_IN_DAY; + + // Remaining time until the next 2 AM UTC + return MILLISECONDS_IN_DAY - timeSinceLast2AM; +} + \ No newline at end of file diff --git a/src/index.ts b/src/index.ts index 0a9cc4f..e1fb872 100644 --- a/src/index.ts +++ b/src/index.ts @@ -5,6 +5,7 @@ import type { MultiRegionRatelimitConfig } from "./multi"; import { RegionRatelimit as Ratelimit } from "./single"; import type { RegionRatelimitConfig as RatelimitConfig } from "./single"; import type { Algorithm } from "./types"; +import * as IpDenyList from "./deny-list/ip-deny-list" export { Ratelimit, @@ -14,4 +15,5 @@ export { type Algorithm, Analytics, type AnalyticsConfig, + IpDenyList }; diff --git a/src/ratelimit.ts b/src/ratelimit.ts index 0dcb679..e1a3934 100644 --- a/src/ratelimit.ts +++ b/src/ratelimit.ts @@ -1,7 +1,7 @@ -import { Analytics, type Geo } from "./analytics"; +import { Analytics } from "./analytics"; import { Cache } from "./cache"; import type { Algorithm, Context, LimitOptions, LimitPayload, RatelimitResponse, Redis } from "./types"; -import { checkDenyList, checkDenyListCache, defaultDeniedResponse, resolveResponses } from "./deny-list"; +import { checkDenyList, checkDenyListCache, defaultDeniedResponse, resolveLimitPayload } from "./deny-list/index"; export class TimeoutError extends Error { constructor() { @@ -73,6 +73,8 @@ export type RatelimitConfig = { * @default false */ enableProtection?: boolean + + denyListThreshold?: number }; /** @@ -105,12 +107,16 @@ export abstract class Ratelimit { protected readonly enableProtection: boolean; + protected readonly denyListThreshold: number + constructor(config: RatelimitConfig) { this.ctx = config.ctx; this.limiter = config.limiter; this.timeout = config.timeout ?? 5000; this.prefix = config.prefix ?? "@upstash/ratelimit"; + this.enableProtection = config.enableProtection ?? false; + this.denyListThreshold = config.denyListThreshold ?? 6; this.primaryRedis = ("redis" in this.ctx) ? this.ctx.redis : this.ctx.regionContexts[0].redis this.analytics = config.analytics @@ -275,23 +281,21 @@ export abstract class Ratelimit { const key = this.getKey(identifier); const definedMembers = this.getDefinedMembers(identifier, req); - const deniedMember = checkDenyListCache(definedMembers) + const deniedValue = checkDenyListCache(definedMembers) let result: LimitPayload; - if (deniedMember) { - result = [defaultDeniedResponse(deniedMember), deniedMember]; + if (deniedValue) { + result = [defaultDeniedResponse(deniedValue), {deniedValue, invalidIpDenyList: false}]; } else { result = await Promise.all([ this.limiter().limit(this.ctx, key, req?.rate), - checkDenyList( - this.primaryRedis, - this.prefix, - definedMembers - ) + this.enableProtection + ? checkDenyList(this.primaryRedis, this.prefix, definedMembers) + : { deniedValue: undefined, invalidIpDenyList: false } ]); } - return resolveResponses(result) + return resolveLimitPayload(this.primaryRedis, this.prefix, result, this.denyListThreshold) }; /** diff --git a/src/single.ts b/src/single.ts index 15add3b..a6ae095 100644 --- a/src/single.ts +++ b/src/single.ts @@ -84,6 +84,11 @@ export type RegionRatelimitConfig = { * @default false */ enableProtection?: boolean + + /** + * @default 6 + */ + denyListThreshold?: number }; /** @@ -119,6 +124,7 @@ export class RegionRatelimit extends Ratelimit { }, ephemeralCache: config.ephemeralCache, enableProtection: config.enableProtection, + denyListThreshold: config.denyListThreshold }); } diff --git a/src/types.ts b/src/types.ts index 0fd795f..482813d 100644 --- a/src/types.ts +++ b/src/types.ts @@ -1,3 +1,4 @@ +import { Pipeline } from "@upstash/redis"; import { Geo } from "./analytics"; /** @@ -87,7 +88,7 @@ export type RatelimitResponse = { /** * The value which was in the deny list if reason: "denyList" */ - deniedValue?: string + deniedValue?: DeniedValue }; export type Algorithm = () => { @@ -106,7 +107,13 @@ export type Algorithm = () => { export type IsDenied = 0 | 1; export type DeniedValue = string | undefined; -export type LimitPayload = [RatelimitResponse, DeniedValue]; +export type DenyListResponse = { deniedValue: DeniedValue, invalidIpDenyList: boolean } + +export const DenyListExtension = "denyList" as const +export const IpDenyListKey = "ipDenyList" as const +export const IpDenyListStatusKey = "ipDenyListStatus" as const + +export type LimitPayload = [RatelimitResponse, DenyListResponse]; export type LimitOptions = { geo?: Geo, rate?: number, @@ -138,4 +145,6 @@ export interface Redis { smismember: ( key: string, members: string[] ) => Promise; + + multi: () => Pipeline }