Skip to content

Commit

Permalink
Merge pull request #428 from supertokens/feat/rate-limting
Browse files Browse the repository at this point in the history
feat: Add 429 rate limting from SaaS
  • Loading branch information
rishabhpoddar authored Aug 28, 2023
2 parents a3f023b + 61bb182 commit 8bb55c6
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 6 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [unreleased]



## [0.15.3] - 2023-09-24

- Handle 429 rate limiting from SaaS core instances

## [0.15.2] - 2023-09-23

- Fixed bugs in thirdparty providers: Bitbucket, Boxy-SAML, and Facebook
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@

setup(
name="supertokens_python",
version="0.15.2",
version="0.15.3",
author="SuperTokens",
license="Apache 2.0",
author_email="[email protected]",
Expand Down
3 changes: 2 additions & 1 deletion supertokens_python/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import annotations

SUPPORTED_CDI_VERSIONS = ["3.0"]
VERSION = "0.15.2"
VERSION = "0.15.3"
TELEMETRY = "/telemetry"
USER_COUNT = "/users/count"
USER_DELETE = "/user/remove"
Expand All @@ -29,3 +29,4 @@
API_VERSION_HEADER = "cdi-version"
DASHBOARD_VERSION = "0.7"
HUNDRED_YEARS_IN_MS = 3153600000000
RATE_LIMIT_STATUS_CODE = 429
4 changes: 3 additions & 1 deletion supertokens_python/normalised_url_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def equals(self, other: NormalisedURLPath) -> bool:

def is_a_recipe_path(self) -> bool:
parts = self.__value.split("/")
return parts[1] == "recipe" or parts[2] == "recipe"
return (len(parts) > 1 and parts[1] == "recipe") or (
len(parts) > 2 and parts[2] == "recipe"
)


def normalise_url_path_or_throw_error(input_str: str) -> str:
Expand Down
32 changes: 29 additions & 3 deletions supertokens_python/querier.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
# under the License.
from __future__ import annotations

import asyncio

from json import JSONDecodeError
from os import environ
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional

from httpx import AsyncClient, ConnectTimeout, NetworkError, Response

Expand All @@ -25,6 +27,7 @@
API_VERSION_HEADER,
RID_KEY_HEADER,
SUPPORTED_CDI_VERSIONS,
RATE_LIMIT_STATUS_CODE,
)
from .normalised_url_path import NormalisedURLPath

Expand Down Expand Up @@ -222,6 +225,7 @@ async def __send_request_helper(
method: str,
http_function: Callable[[str], Awaitable[Response]],
no_of_tries: int,
retry_info_map: Optional[Dict[str, int]] = None,
) -> Any:
if no_of_tries == 0:
raise_general_exception("No SuperTokens core available to query")
Expand All @@ -238,6 +242,14 @@ async def __send_request_helper(
Querier.__last_tried_index %= len(self.__hosts)
url = current_host + path.get_as_string_dangerous()

max_retries = 5

if retry_info_map is None:
retry_info_map = {}

if retry_info_map.get(url) is None:
retry_info_map[url] = max_retries

ProcessState.get_instance().add_state(
AllowedProcessStates.CALLING_SERVICE_IN_REQUEST_HELPER
)
Expand All @@ -247,6 +259,20 @@ async def __send_request_helper(
):
Querier.__hosts_alive_for_testing.add(current_host)

if response.status_code == RATE_LIMIT_STATUS_CODE:
retries_left = retry_info_map[url]

if retries_left > 0:
retry_info_map[url] = retries_left - 1

attempts_made = max_retries - retries_left
delay = (10 + attempts_made * 250) / 1000

await asyncio.sleep(delay)
return await self.__send_request_helper(
path, method, http_function, no_of_tries, retry_info_map
)

if is_4xx_error(response.status_code) or is_5xx_error(response.status_code): # type: ignore
raise_general_exception(
"SuperTokens core threw an error for a "
Expand All @@ -264,9 +290,9 @@ async def __send_request_helper(
except JSONDecodeError:
return response.text

except (ConnectionError, NetworkError, ConnectTimeout):
except (ConnectionError, NetworkError, ConnectTimeout) as _:
return await self.__send_request_helper(
path, method, http_function, no_of_tries - 1
path, method, http_function, no_of_tries - 1, retry_info_map
)
except Exception as e:
raise_general_exception(e)
150 changes: 150 additions & 0 deletions tests/test_querier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved.
#
# This software is licensed under the Apache License, Version 2.0 (the
# "License") as published by the Apache Software Foundation.
#
# You may not use this file except in compliance with the License. You may
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from pytest import mark
from supertokens_python.recipe import (
session,
emailpassword,
emailverification,
dashboard,
)
import asyncio
import respx
import httpx
from supertokens_python import init, SupertokensConfig
from supertokens_python.querier import Querier, NormalisedURLPath

from tests.utils import get_st_init_args
from tests.utils import (
setup_function,
teardown_function,
start_st,
)

_ = setup_function
_ = teardown_function

pytestmark = mark.asyncio
respx_mock = respx.MockRouter


async def test_network_call_is_retried_as_expected():
# Test that network call is retried properly
# Test that rate limiting errors are thrown back to the user
args = get_st_init_args(
[
session.init(),
emailpassword.init(),
emailverification.init(mode="OPTIONAL"),
dashboard.init(),
]
)
args["supertokens_config"] = SupertokensConfig("http://localhost:6789")
init(**args) # type: ignore
start_st()

Querier.api_version = "3.0"
q = Querier.get_instance()

api2_call_count = 0

def api2_side_effect(_: httpx.Request):
nonlocal api2_call_count
api2_call_count += 1

if api2_call_count == 3:
return httpx.Response(200)

return httpx.Response(429, json={})

with respx_mock() as mocker:
api1 = mocker.get("http://localhost:6789/api1").mock(
httpx.Response(429, json={"status": "RATE_ERROR"})
)
api2 = mocker.get("http://localhost:6789/api2").mock(
side_effect=api2_side_effect
)
api3 = mocker.get("http://localhost:6789/api3").mock(httpx.Response(200))

try:
await q.send_get_request(NormalisedURLPath("/api1"), {})
except Exception as e:
if "with status code: 429" in str(
e
) and 'message: {"status": "RATE_ERROR"}' in str(e):
pass
else:
raise e

await q.send_get_request(NormalisedURLPath("/api2"), {})
await q.send_get_request(NormalisedURLPath("/api3"), {})

# 1 initial request + 5 retries
assert api1.call_count == 6
# 2 403 and 1 200
assert api2.call_count == 3
# 200 in the first attempt
assert api3.call_count == 1


async def test_parallel_calls_have_independent_counters():
args = get_st_init_args(
[
session.init(),
emailpassword.init(),
emailverification.init(mode="OPTIONAL"),
dashboard.init(),
]
)
init(**args) # type: ignore
start_st()

Querier.api_version = "3.0"
q = Querier.get_instance()

call_count1 = 0
call_count2 = 0

def api_side_effect(r: httpx.Request):
nonlocal call_count1, call_count2

id_ = int(r.url.params.get("id"))
if id_ == 1:
call_count1 += 1
elif id_ == 2:
call_count2 += 1

return httpx.Response(429, json={})

with respx_mock() as mocker:
api = mocker.get("http://localhost:3567/api").mock(side_effect=api_side_effect)

async def call_api(id_: int):
try:
await q.send_get_request(NormalisedURLPath("/api"), {"id": id_})
except Exception as e:
if "with status code: 429" in str(e):
pass
else:
raise e

_ = await asyncio.gather(
call_api(1),
call_api(2),
)

# 1 initial request + 5 retries
assert call_count1 == 6
assert call_count2 == 6

assert api.call_count == 12

0 comments on commit 8bb55c6

Please sign in to comment.