Skip to content

Commit

Permalink
282 Add excluded_paths to SessionsAuthBackend (#283)
Browse files Browse the repository at this point in the history
* add `excluded_paths` to `SessionsAuthBackend`

* try fixing mypy test

* fix typo
  • Loading branch information
dantownsend authored Apr 10, 2024
1 parent 9cc0966 commit 7dae866
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 26 deletions.
7 changes: 7 additions & 0 deletions docs/source/session_auth/middleware.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ follows:
-------------------------------------------------------------------------------

``excluded_paths``
------------------

This works identically to token auth - see :ref:`excluded_paths`.

-------------------------------------------------------------------------------

Source
------

Expand Down
2 changes: 2 additions & 0 deletions docs/source/token_auth/middleware.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ You'll have to run the migrations for this to work correctly.
``TokenAuthBackend``
--------------------

.. _excluded_paths:

``excluded_paths``
~~~~~~~~~~~~~~~~~~

Expand Down
18 changes: 13 additions & 5 deletions piccolo_api/session_auth/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from piccolo_api.session_auth.tables import SessionsBase
from piccolo_api.shared.auth import UnauthenticatedUser, User
from piccolo_api.shared.auth.excluded_paths import check_excluded_paths


class SessionsAuthBackend(AuthenticationBackend):
Expand All @@ -31,6 +32,7 @@ def __init__(
active_only: bool = True,
increase_expiry: t.Optional[timedelta] = None,
allow_unauthenticated: bool = False,
excluded_paths: t.Optional[t.Sequence[str]] = None,
):
"""
:param auth_table:
Expand All @@ -43,22 +45,26 @@ def __init__(
The name of the session cookie. Override this if it clashes with
other cookies in your application.
:param admin_only:
If True, users which aren't admins will be rejected.
If ``True``, users which aren't admins will be rejected.
:param superuser_only:
If True, users which aren't superusers will be rejected.
If ``True``, users which aren't superusers will be rejected.
:param active_only:
If True, users which aren't active will be rejected.
If ``True``, users which aren't active will be rejected.
:param increase_expiry:
If set, the session expiry will be increased by this amount on each
request, if it's close to expiry. This allows sessions to have a
short expiry date, whilst also providing a good user experience.
:param allow_unauthenticated:
If True, when a matching user session can't be found, the request
If ``True``, when a matching user session can't be found, the request
still continues, but an unauthenticated user is added to the scope.
It's then up to the application's endpoints to check if a user is
authenticated or not using ``request.user.is_authenticated``. If
False, the request is automatically rejected if a user session
``False``, the request is automatically rejected if a user session
can't be found.
:param excluded_paths:
These paths don't require a session cookie - useful if you want to
exclude a few URLs, such as docs.
""" # noqa: E501
super().__init__()
self.auth_table = auth_table
Expand All @@ -69,7 +75,9 @@ def __init__(
self.active_only = active_only
self.increase_expiry = increase_expiry
self.allow_unauthenticated = allow_unauthenticated
self.excluded_paths = excluded_paths or []

@check_excluded_paths
async def authenticate(
self, conn: HTTPConnection
) -> t.Optional[t.Tuple[AuthCredentials, BaseUser]]:
Expand Down
43 changes: 43 additions & 0 deletions piccolo_api/shared/auth/excluded_paths.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from __future__ import annotations

import functools
import typing as t

from starlette.authentication import AuthCredentials, AuthenticationBackend
from starlette.requests import HTTPConnection

from piccolo_api.shared.auth import UnauthenticatedUser


def check_excluded_paths(authenticate_func: t.Callable):

@functools.wraps(authenticate_func)
async def authenticate(self: AuthenticationBackend, conn: HTTPConnection):
conn_path = dict(conn)

excluded_paths = getattr(self, "excluded_paths", None)

if excluded_paths is None:
raise ValueError("excluded_paths isn't defined")

for excluded_path in excluded_paths:
if excluded_path.endswith("*"):
if (
conn_path["raw_path"]
.decode("utf-8")
.startswith(excluded_path.rstrip("*"))
):
return (
AuthCredentials(scopes=[]),
UnauthenticatedUser(),
)
else:
if conn_path["path"] == excluded_path:
return (
AuthCredentials(scopes=[]),
UnauthenticatedUser(),
)

return await authenticate_func(self=self, conn=conn)

return authenticate
27 changes: 6 additions & 21 deletions piccolo_api/token_auth/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
)
from starlette.requests import HTTPConnection

from piccolo_api.shared.auth import UnauthenticatedUser, User
from piccolo_api.shared.auth import User
from piccolo_api.shared.auth.excluded_paths import check_excluded_paths
from piccolo_api.token_auth.tables import TokenAuth


Expand Down Expand Up @@ -90,7 +91,9 @@ def __init__(
:param token_auth_provider:
Used to verify that a token is correct.
:param excluded_paths:
These paths don't require a token.
These paths don't require a token - useful if you want to
exclude a few URLs, such as docs.
"""
super().__init__()
self.token_auth_provider = token_auth_provider
Expand All @@ -104,29 +107,11 @@ def extract_token(self, header: str) -> str:

return token

@check_excluded_paths
async def authenticate(
self, conn: HTTPConnection
) -> t.Optional[t.Tuple[AuthCredentials, BaseUser]]:
auth_header = conn.headers.get("Authorization", None)
conn_path = dict(conn)

for excluded_path in self.excluded_paths:
if excluded_path.endswith("*"):
if (
conn_path["raw_path"]
.decode("utf-8")
.startswith(excluded_path.rstrip("*"))
):
return (
AuthCredentials(scopes=[]),
UnauthenticatedUser(),
)
else:
if conn_path["path"] == excluded_path:
return (
AuthCredentials(scopes=[]),
UnauthenticatedUser(),
)

if not auth_header:
raise AuthenticationError("The Authorization header is missing.")
Expand Down
99 changes: 99 additions & 0 deletions tests/session_auth/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,105 @@ def test_wrong_cookie_value(self):
)


###############################################################################

EXCLUDED_PATHS_APP = Router(
routes=[
Route("/", EchoEndpoint),
Route(
"/foo/",
EchoEndpoint,
),
Route(
"/foo/1/",
EchoEndpoint,
),
Route(
"/bar/",
EchoEndpoint,
),
Route(
"/bar/1/",
EchoEndpoint,
),
]
)


class TestExcludedPaths(SessionTestCase):
"""
Make sure that if `excluded_paths` is set, then the middleware allows the
request to continue without a cookie.
"""

def create_user_and_session(self):
user = BaseUser(
**self.credentials, active=True, admin=True, superuser=True
)
user.save().run_sync()
SessionsBase.create_session_sync(user_id=user.id)

def setUp(self):
super().setUp()

# Add a session to the database to make it more realistic.
self.create_user_and_session()

def test_excluded_paths(self):
"""
Make sure that only the `excluded_paths` are accessible
"""
app = AuthenticationMiddleware(
EXCLUDED_PATHS_APP,
SessionsAuthBackend(
allow_unauthenticated=False,
excluded_paths=["/foo/"],
),
)
client = TestClient(app)

for path in ("/", "/foo/1/", "/bar/", "/bar/1/"):
response = client.get(path)
self.assertEqual(response.status_code, 400)
self.assertEqual(response.content, b"No session cookie found.")

response = client.get("/foo/")
assert response.status_code == 200
self.assertDictEqual(
response.json(),
{"is_unauthenticated_user": True, "is_authenticated": False},
)

def test_excluded_paths_wildcard(self):
"""
Make sure that wildcard paths work correctly.
"""
app = AuthenticationMiddleware(
EXCLUDED_PATHS_APP,
SessionsAuthBackend(
allow_unauthenticated=False,
excluded_paths=["/foo/*"],
),
)
client = TestClient(app)

for path in ("/", "/bar/", "/bar/1/"):
response = client.get(path)
self.assertEqual(response.status_code, 400)
self.assertEqual(response.content, b"No session cookie found.")

for path in ("/foo/", "/foo/1/"):
response = client.get(path)
self.assertEqual(response.status_code, 200)
self.assertDictEqual(
response.json(),
{"is_unauthenticated_user": True, "is_authenticated": False},
)


###############################################################################


class TestHooks(SessionTestCase):
def test_hooks(self):
# TODO Replace these with mocks ...
Expand Down

0 comments on commit 7dae866

Please sign in to comment.