From 4694748c3cf634a8088d75ca1c0a3fffc772ef70 Mon Sep 17 00:00:00 2001 From: Corentin de Boisset Date: Sun, 14 Jan 2024 22:02:01 +0100 Subject: [PATCH] feat: add some python unit tests --- docker_build/entrypoint.sh | 2 +- pyneutrino/commands/wait_db.py | 2 +- pyneutrino/hooks/csrf.py | 6 ++++++ pyneutrino/web/auth/session.py | 11 +++++++--- tests/_utils.py | 15 ++++++++++++++ tests/conftest.py | 1 + tests/test_session.py | 37 ++++++++++++++++++++++++++++++++++ tests/test_wait_db.py | 20 ++++++++++++++++++ 8 files changed, 89 insertions(+), 5 deletions(-) create mode 100644 tests/_utils.py create mode 100644 tests/test_session.py create mode 100644 tests/test_wait_db.py diff --git a/docker_build/entrypoint.sh b/docker_build/entrypoint.sh index a2a0318..03b2ddb 100644 --- a/docker_build/entrypoint.sh +++ b/docker_build/entrypoint.sh @@ -4,7 +4,7 @@ set -o pipefail if [ "$1" = "/neutrino/.venv/bin/python" ] && [ "$2" = "-m" ] && [ "$3" = "gunicorn" ]; then # If the main command is to start gunicorn, we automatically run the migrations - /neutrino/.venv/bin/python -m flask --app pyneutrino test wait-db --timeout 10 + /neutrino/.venv/bin/python -m flask --app pyneutrino wait-db --timeout 10 [ $? -eq 0 ] && /neutrino/.venv/bin/python -m alembic upgrade heads fi diff --git a/pyneutrino/commands/wait_db.py b/pyneutrino/commands/wait_db.py index b1e1fed..809c317 100644 --- a/pyneutrino/commands/wait_db.py +++ b/pyneutrino/commands/wait_db.py @@ -6,7 +6,7 @@ from pyneutrino.db import db from sqlalchemy.exc import OperationalError -WaitDbBp = Blueprint("wait_db", __name__, cli_group="test") +WaitDbBp = Blueprint("wait_db", __name__, cli_group=None) @WaitDbBp.cli.command("wait-db") diff --git a/pyneutrino/hooks/csrf.py b/pyneutrino/hooks/csrf.py index 2f519ef..4c5019e 100644 --- a/pyneutrino/hooks/csrf.py +++ b/pyneutrino/hooks/csrf.py @@ -19,6 +19,9 @@ def check_token(header_value: str, expected_value: str): :param header_value: The value of the CSRF token sent by the user in a header. :param expected_value: The used to serialize the token """ + if current_app.config.get("DISABLE_CSRF", False): + return + s = URLSafeSerializer(current_app.config["SECRET_KEY"]) try: @@ -34,6 +37,9 @@ def check_token(header_value: str, expected_value: str): @CsrfBp.before_app_request def check_csrf_token(): + if current_app.config.get("DISABLE_CSRF", False): + return + # Skip CSRF token validation on non-mutating requests if request.method in ("GET", "HEAD"): return diff --git a/pyneutrino/web/auth/session.py b/pyneutrino/web/auth/session.py index 36ce490..0ea9366 100644 --- a/pyneutrino/web/auth/session.py +++ b/pyneutrino/web/auth/session.py @@ -1,6 +1,6 @@ from flask import Blueprint, request, session from argon2 import PasswordHasher -from pyneutrino.services import validate_schema, login_user, serialize +from pyneutrino.services import validate_schema, login_user, serialize, authguard from pyneutrino.db import db, UserAccount from werkzeug.exceptions import Unauthorized from sqlalchemy.exc import NoResultFound @@ -36,8 +36,12 @@ def login_route(): if (not session.new) and ("user_id" in session) and (session["user_id"] != user.id): session.clear() - ph = PasswordHasher() - if not ph.verify(user.password_hash, json_body["password"]): + try: + ph = PasswordHasher() + # If the password is not valid, argon2 raises an exception + ph.verify(user.password_hash, json_body["password"]) + except BaseException: + # TODO: improve logging if the hash or another error occurs session.clear() raise Unauthorized() @@ -49,6 +53,7 @@ def login_route(): @SessionBp.route("/logout", methods=["POST"]) +@authguard def logout_route(): session.clear() diff --git a/tests/_utils.py b/tests/_utils.py new file mode 100644 index 0000000..eb74ac2 --- /dev/null +++ b/tests/_utils.py @@ -0,0 +1,15 @@ +from flask import Flask + + +def create_basic_user(app: Flask): + client = app.test_client() + client.post( + "/api/auth/register/new-account", + json={ + "email": "supertest@gmail.com", + "username": "super_username", + "password": "123secret", + "public_key": "PGP_PUBLIC_KEY", + "private_key": "PGP_SECRET_KEY", + }, + ) diff --git a/tests/conftest.py b/tests/conftest.py index 8d6a624..3a75360 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,6 +27,7 @@ def app(alembic_runner): { "TESTING": True, "SQLALCHEMY_DATABASE_URI": SQLALCHEMY_DATABASE_URL, + "DISABLE_CSRF": True, } ) diff --git a/tests/test_session.py b/tests/test_session.py new file mode 100644 index 0000000..d6855b3 --- /dev/null +++ b/tests/test_session.py @@ -0,0 +1,37 @@ +from flask import Flask +from ._utils import create_basic_user + + +def test_valid_login(app: Flask): + create_basic_user(app) + client = app.test_client() + res = client.post("/api/auth/session/login", json={"email": "supertest@gmail.com", "password": "123secret"}) + assert res.json["id"] is not None + + headers = " ".join(res.headers.getlist("Set-Cookie")) + + # Check that a session coookie is created + assert "session=" in headers + + +def test_invalid_login(app: Flask): + create_basic_user(app) + client = app.test_client() + res = client.post("/api/auth/session/login", json={"email": "supertest@gmail.com", "password": "wrong"}) + assert res.status_code == 401 + + +def test_logout(app: Flask): + create_basic_user(app) + client = app.test_client() + + # Check we cannot logout without being authenticated + res = client.post("/api/auth/session/logout") + assert res.status_code == 401 + + # Login and Logout + client.post("/api/auth/session/login", json={"email": "supertest@gmail.com", "password": "123secret"}) + res = client.post("/api/auth/session/logout") + assert res.status_code == 204 + print(" ".join(res.headers.getlist("Set-Cookie"))) + assert "session=;" in " ".join(res.headers.getlist("Set-Cookie")) diff --git a/tests/test_wait_db.py b/tests/test_wait_db.py new file mode 100644 index 0000000..9de0f89 --- /dev/null +++ b/tests/test_wait_db.py @@ -0,0 +1,20 @@ +from flask import Flask +from pyneutrino import create_app + + +def test_wait_db_command_success(app: Flask): + runner = app.test_cli_runner() + result = runner.invoke(args=["wait-db", "--timeout", "1"]) + assert "The database is available" in result.output + + +def test_wait_db_command_failure(): + app = create_app( + { + "TESTING": True, + "SQLALCHEMY_DATABASE_URI": "postgresql://127.0.0.2:5432/nope?connect_timeout=1", + } + ) + runner = app.test_cli_runner() + result = runner.invoke(args=["wait-db", "--timeout", "0"]) + assert "Failed to connect to the database" in result.output