Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support to multiple authentication #901

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 54 additions & 3 deletions integration_tests/base_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pathlib
from collections import defaultdict
from typing import Optional
from base64 import b64encode

from robyn import (
Request,
Expand All @@ -13,7 +14,7 @@
serve_html,
WebSocketConnector,
)
from robyn.authentication import AuthenticationHandler, BearerGetter, Identity
from robyn.authentication import AuthenticationHandler, BearerGetter, Identity, BasicGetter
from robyn.robyn import Headers
from robyn.templating import JinjaTemplate

Expand Down Expand Up @@ -794,6 +795,34 @@ async def async_auth(request: Request):
return "authenticated"


@app.get("/sync/auth/basic", auth_required=True, auth_middleware_name="basic")
def sync_auth_basic(request: Request):
assert request.identity is not None
assert request.identity.claims == {"key": "value"}
return "authenticated"


@app.get("/async/auth/basic", auth_required=True, auth_middleware_name="basic")
async def async_auth_basic(request: Request):
assert request.identity is not None
assert request.identity.claims == {"key": "value"}
return "authenticated"


@app.get("/sync/auth/bearer-2", auth_required=True, auth_middleware_name="bearer-2")
def sync_auth_bearer_2(request: Request):
assert request.identity is not None
assert request.identity.claims == {"key": "value"}
return "authenticated"


@app.get("/async/auth/bearer-2", auth_required=True, auth_middleware_name="bearer-2")
async def async_auth_bearer_2(request: Request):
assert request.identity is not None
assert request.identity.claims == {"key": "value"}
return "authenticated"


# ===== Main =====


Expand Down Expand Up @@ -845,7 +874,7 @@ def main():
app.include_router(sub_router)
app.include_router(di_subrouter)

class BasicAuthHandler(AuthenticationHandler):
class BearerAuthHandler(AuthenticationHandler):
def authenticate(self, request: Request) -> Optional[Identity]:
token = self.token_getter.get_token(request)
if token is not None:
Expand All @@ -855,7 +884,29 @@ def authenticate(self, request: Request) -> Optional[Identity]:
return Identity(claims={"key": "value"})
return None

app.configure_authentication(BasicAuthHandler(token_getter=BearerGetter()))
class OtherBearerAuthHandler(AuthenticationHandler):
def authenticate(self, request: Request) -> Optional[Identity]:
token = self.token_getter.get_token(request)
if token is not None:
# Useless but we call the set_token method for testing purposes
self.token_getter.set_token(request, token)
if token == "valid-2":
return Identity(claims={"key": "value"})
return None

class BasicAuthHandler(AuthenticationHandler):
def authenticate(self, request: Request) -> Optional[Identity]:
username, password = self.token_getter.get_credentials(request)
if username is not None and password is not None:
# Useless but we call the set_token method for testing purposes
self.token_getter.set_token(request, b64encode(f"{username}:{password}".encode()).decode())
if username == "valid" and password == "valid":
return Identity(claims={"key": "value"})
return None

app.configure_authentication(BasicAuthHandler(token_getter=BasicGetter(), name="basic"))
app.configure_authentication(BearerAuthHandler(token_getter=BearerGetter(), name="bearer", default=True))
app.configure_authentication(OtherBearerAuthHandler(token_getter=BearerGetter(), name="bearer-2"))
app.start(port=8080, _check_port=False)


Expand Down
79 changes: 79 additions & 0 deletions integration_tests/test_authentication.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from base64 import b64encode

from integration_tests.helpers.http_methods_helpers import get

Expand Down Expand Up @@ -40,3 +41,81 @@ def test_invalid_authentication_no_token(session, function_type: str):
r = get(f"/{function_type}/auth", should_check_response=False)
assert r.status_code == 401
assert r.headers.get("WWW-Authenticate") == "BearerGetter"


@pytest.mark.benchmark
@pytest.mark.parametrize("function_type", ["sync", "async"])
def test_valid_authentication_bearer_2(session, function_type: str):
r = get(f"/{function_type}/auth/bearer-2", headers={"Authorization": "Bearer valid-2"})
assert r.text == "authenticated"


@pytest.mark.benchmark
@pytest.mark.parametrize("function_type", ["sync", "async"])
def test_invalid_authentication_token_bearer_2(session, function_type: str):
r = get(
f"/{function_type}/auth/bearer-2",
headers={"Authorization": "Bearer invalid"},
should_check_response=False,
)
assert r.status_code == 401
assert r.headers.get("WWW-Authenticate") == "BearerGetter"


@pytest.mark.benchmark
@pytest.mark.parametrize("function_type", ["sync", "async"])
def test_invalid_authentication_header_bearer_2(session, function_type: str):
r = get(
f"/{function_type}/auth/bearer-2",
headers={"Authorization": "Bear valid-2"},
should_check_response=False,
)
assert r.status_code == 401
assert r.headers.get("WWW-Authenticate") == "BearerGetter"


@pytest.mark.benchmark
@pytest.mark.parametrize("function_type", ["sync", "async"])
def test_invalid_authentication_no_token_bearer_2(session, function_type: str):
r = get(f"/{function_type}/auth/bearer-2", should_check_response=False)
assert r.status_code == 401
assert r.headers.get("WWW-Authenticate") == "BearerGetter"


@pytest.mark.benchmark
@pytest.mark.parametrize("function_type", ["sync", "async"])
def test_valid_authentication_basic(session, function_type: str):
r = get(f"/{function_type}/auth/basic", headers={"Authorization": f"Basic {b64encode('valid:valid'.encode()).decode()}"})
assert r.text == "authenticated"


@pytest.mark.benchmark
@pytest.mark.parametrize("function_type", ["sync", "async"])
def test_invalid_authentication_token_basic(session, function_type: str):
r = get(
f"/{function_type}/auth/basic",
headers={"Authorization": "Basic invalid"},
should_check_response=False,
)
assert r.status_code == 401
assert r.headers.get("WWW-Authenticate") == "BasicGetter"


@pytest.mark.benchmark
@pytest.mark.parametrize("function_type", ["sync", "async"])
def test_invalid_authentication_header_basic(session, function_type: str):
r = get(
f"/{function_type}/auth/basic",
headers={"Authorization": "Bear valid-2"},
should_check_response=False,
)
assert r.status_code == 401
assert r.headers.get("WWW-Authenticate") == "BasicGetter"


@pytest.mark.benchmark
@pytest.mark.parametrize("function_type", ["sync", "async"])
def test_invalid_authentication_no_token_basic(session, function_type: str):
r = get(f"/{function_type}/auth/basic", should_check_response=False)
assert r.status_code == 401
assert r.headers.get("WWW-Authenticate") == "BasicGetter"
68 changes: 47 additions & 21 deletions robyn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
self.directories: List[Directory] = []
self.event_handlers = {}
self.exception_handler: Optional[Callable] = None
self.authentication_handler: Optional[AuthenticationHandler] = None
self.authentication_handler: List[AuthenticationHandler] = []

def _handle_dev_mode(self):
cli_dev_mode = self.config.dev # --dev
Expand All @@ -84,13 +84,36 @@ def _handle_dev_mode(self):
logger.error("Ignoring ROBYN_DEV_MODE environment variable. Dev mode is not supported in the python wrapper.")
raise SystemExit("Dev mode is not supported in the python wrapper. Please use the Robyn CLI. e.g. python3 -m robyn app.py")

def auth_handler_configured(self):
handler_count = len(self.authentication_handler)
if handler_count == 0:
return
if handler_count == 1:
self.authentication_handler[0].default = True
return

default_handlers = [handler for handler in self.authentication_handler if handler.default]

if len(default_handlers) == 0:
raise ValueError(
"Multiple authentication handlers are configured, but none is set as the default. "
"Please set one of the authentication handlers as the default."
)

if len(default_handlers) > 1:
raise ValueError(
"Multiple authentication handlers are configured with more than one default. "
"Please ensure only one authentication handler is set as the default."
)

def add_route(
self,
route_type: Union[HttpMethod, str],
endpoint: str,
handler: Callable,
is_const: bool = False,
auth_required: bool = False,
auth_middleware_name: Optional[str] = None,
):
"""
Connect a URI to a handler
Expand All @@ -100,14 +123,15 @@ def add_route(
:param handler function: represents the sync or async function passed as a handler for the route
:param is_const bool: represents if the handler is a const function or not
:param auth_required bool: represents if the route needs authentication or not
:param auth_middleware_name str: represents auth handler name for the route
"""

""" We will add the status code here only
"""
injected_dependencies = self.dependencies.get_dependency_map(self)

if auth_required:
self.middleware_router.add_auth_middleware(endpoint)(handler)
self.middleware_router.add_auth_middleware(endpoint, auth_middleware_name)(handler)

if isinstance(route_type, str):
http_methods = {
Expand Down Expand Up @@ -226,6 +250,8 @@ def start(self, host: str = "127.0.0.1", port: int = 8080, _check_port: bool = T
port = int(os.getenv("ROBYN_PORT", port))
open_browser = bool(os.getenv("ROBYN_BROWSER_OPEN", self.config.open_browser))

self.auth_handler_configured()

if _check_port:
while self.is_port_in_use(port):
logger.error("Port %s is already in use. Please use a different port.", port)
Expand Down Expand Up @@ -302,111 +328,111 @@ def inner(handler):

return inner

def get(self, endpoint: str, const: bool = False, auth_required: bool = False):
def get(self, endpoint: str, const: bool = False, auth_required: bool = False, auth_middleware_name: Optional[str] = None):
"""
The @app.get decorator to add a route with the GET method

:param endpoint str: endpoint to server the route
"""

def inner(handler):
return self.add_route(HttpMethod.GET, endpoint, handler, const, auth_required)
return self.add_route(HttpMethod.GET, endpoint, handler, const, auth_required, auth_middleware_name=auth_middleware_name)

return inner

def post(self, endpoint: str, auth_required: bool = False):
def post(self, endpoint: str, auth_required: bool = False, auth_middleware_name: Optional[str] = None):
"""
The @app.post decorator to add a route with POST method

:param endpoint str: endpoint to server the route
"""

def inner(handler):
return self.add_route(HttpMethod.POST, endpoint, handler, auth_required=auth_required)
return self.add_route(HttpMethod.POST, endpoint, handler, auth_required=auth_required, auth_middleware_name=auth_middleware_name)

return inner

def put(self, endpoint: str, auth_required: bool = False):
def put(self, endpoint: str, auth_required: bool = False, auth_middleware_name: Optional[str] = None):
"""
The @app.put decorator to add a get route with PUT method

:param endpoint str: endpoint to server the route
"""

def inner(handler):
return self.add_route(HttpMethod.PUT, endpoint, handler, auth_required=auth_required)
return self.add_route(HttpMethod.PUT, endpoint, handler, auth_required=auth_required, auth_middleware_name=auth_middleware_name)

return inner

def delete(self, endpoint: str, auth_required: bool = False):
def delete(self, endpoint: str, auth_required: bool = False, auth_middleware_name: Optional[str] = None):
"""
The @app.delete decorator to add a route with DELETE method

:param endpoint str: endpoint to server the route
"""

def inner(handler):
return self.add_route(HttpMethod.DELETE, endpoint, handler, auth_required=auth_required)
return self.add_route(HttpMethod.DELETE, endpoint, handler, auth_required=auth_required, auth_middleware_name=auth_middleware_name)

return inner

def patch(self, endpoint: str, auth_required: bool = False):
def patch(self, endpoint: str, auth_required: bool = False, auth_middleware_name: Optional[str] = None):
"""
The @app.patch decorator to add a route with PATCH method

:param endpoint [str]: [endpoint to server the route]
"""

def inner(handler):
return self.add_route(HttpMethod.PATCH, endpoint, handler, auth_required=auth_required)
return self.add_route(HttpMethod.PATCH, endpoint, handler, auth_required=auth_required, auth_middleware_name=auth_middleware_name)

return inner

def head(self, endpoint: str, auth_required: bool = False):
def head(self, endpoint: str, auth_required: bool = False, auth_middleware_name: Optional[str] = None):
"""
The @app.head decorator to add a route with HEAD method

:param endpoint str: endpoint to server the route
"""

def inner(handler):
return self.add_route(HttpMethod.HEAD, endpoint, handler, auth_required=auth_required)
return self.add_route(HttpMethod.HEAD, endpoint, handler, auth_required=auth_required, auth_middleware_name=auth_middleware_name)

return inner

def options(self, endpoint: str, auth_required: bool = False):
def options(self, endpoint: str, auth_required: bool = False, auth_middleware_name: Optional[str] = None):
"""
The @app.options decorator to add a route with OPTIONS method

:param endpoint str: endpoint to server the route
"""

def inner(handler):
return self.add_route(HttpMethod.OPTIONS, endpoint, handler, auth_required=auth_required)
return self.add_route(HttpMethod.OPTIONS, endpoint, handler, auth_required=auth_required, auth_middleware_name=auth_middleware_name)

return inner

def connect(self, endpoint: str, auth_required: bool = False):
def connect(self, endpoint: str, auth_required: bool = False, auth_middleware_name: Optional[str] = None):
"""
The @app.connect decorator to add a route with CONNECT method

:param endpoint str: endpoint to server the route
"""

def inner(handler):
return self.add_route(HttpMethod.CONNECT, endpoint, handler, auth_required=auth_required)
return self.add_route(HttpMethod.CONNECT, endpoint, handler, auth_required=auth_required, auth_middleware_name=auth_middleware_name)

return inner

def trace(self, endpoint: str, auth_required: bool = False):
def trace(self, endpoint: str, auth_required: bool = False, auth_middleware_name: Optional[str] = None):
"""
The @app.trace decorator to add a route with TRACE method

:param endpoint str: endpoint to server the route
"""

def inner(handler):
return self.add_route(HttpMethod.TRACE, endpoint, handler, auth_required=auth_required)
return self.add_route(HttpMethod.TRACE, endpoint, handler, auth_required=auth_required, auth_middleware_name=auth_middleware_name)

return inner

Expand Down Expand Up @@ -434,7 +460,7 @@ def configure_authentication(self, authentication_handler: AuthenticationHandler

:param authentication_handler: the instance of a class inheriting the AuthenticationHandler base class
"""
self.authentication_handler = authentication_handler
self.authentication_handler.append(authentication_handler)
self.middleware_router.set_authentication_handler(authentication_handler)


Expand Down
Loading
Loading