diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 3fc2f5f..d21becf 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -68,7 +68,7 @@ jobs: cd src uv run manage.py makemigrations --check uv run manage.py migrate - uv run manage.py test tests/ + uv run manage.py test docker: needs: [ lint, unit_test ] diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4d585ab..261d776 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,22 +13,20 @@ repos: args: ["--allow-multiple-documents"] - id: debug-statements - id: trailing-whitespace - exclude: >- - ^.*.md$ + exclude: ^.*.md$ - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.7.0 hooks: - id: ruff - args: [ --fix ] + args: [--fix] - id: ruff-format - # local mypy because of stub dependencies - repo: local hooks: - id: typecheck - name: Typecheck - entry: mypy . - types: [python] + name: Typecheck (uv) + entry: uv run mypy language: system pass_filenames: false + args: ["."] diff --git a/dev-compose.yml b/dev-compose.yml index 03b6a8b..123bcbf 100644 --- a/dev-compose.yml +++ b/dev-compose.yml @@ -10,9 +10,15 @@ services: ports: - "5432:5432" + cache: + image: memcached:latest + ports: + - "11211:11211" + web: depends_on: - - db + - db + - cache build: . env_file: ./.env ports: diff --git a/docker-compose.yml b/docker-compose.yml index cdbacbe..0e7639c 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -8,7 +8,13 @@ services: volumes: - db-data:/var/lib/postgresql/data + cache: + image: memcached:latest + web: + depends_on: + - db + - cache image: unitystation/central-command:latest environment: - DEBUG=0 diff --git a/pyproject.toml b/pyproject.toml index 4f40742..a5a46d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "psycopg2-binary~=2.9.9", "python-dotenv~=0.19.2", "whitenoise~=6.6.0", + "pymemcache>=4.0,<5.0", ] [build-system] diff --git a/src/accounts/api/serializers.py b/src/accounts/api/serializers.py index eba6557..ce42bd7 100644 --- a/src/accounts/api/serializers.py +++ b/src/accounts/api/serializers.py @@ -93,7 +93,7 @@ def validate(self, data): if not account_confirmation.is_token_valid(): raise serializers.ValidationError({"token": "Token is invalid or expired."}) - return {"token": data["token"]} + return {"token": data["token"], "account_confirmation": account_confirmation} class EmailSerializer(serializers.Serializer): diff --git a/src/accounts/api/views.py b/src/accounts/api/views.py index 21aedc7..9a0cfd0 100644 --- a/src/accounts/api/views.py +++ b/src/accounts/api/views.py @@ -324,7 +324,7 @@ def post(self, request): if not serializer.is_valid(): return ErrorResponse(serializer.errors, status.HTTP_400_BAD_REQUEST) - account_confirmation = AccountConfirmation.objects.get(token=serializer.validated_data["token"]) + account_confirmation: AccountConfirmation = serializer.validated_data["account_confirmation"] account = account_confirmation.account account.is_confirmed = True diff --git a/src/baby_serverlist/__init__.py b/src/baby_serverlist/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/baby_serverlist/admin.py b/src/baby_serverlist/admin.py new file mode 100644 index 0000000..fa52808 --- /dev/null +++ b/src/baby_serverlist/admin.py @@ -0,0 +1,10 @@ +from django.contrib import admin + +from .models import BabyServer + + +@admin.register(BabyServer) +class BabyServerAdmin(admin.ModelAdmin): + list_display = ("id", "owner", "whitelisted", "serverlist_token") + search_fields = ("id", "owner__email", "owner__unique_identifier") + list_filter = ("whitelisted",) diff --git a/src/baby_serverlist/api/serializers.py b/src/baby_serverlist/api/serializers.py new file mode 100644 index 0000000..d381fa3 --- /dev/null +++ b/src/baby_serverlist/api/serializers.py @@ -0,0 +1,31 @@ +from rest_framework import serializers + + +class ServerStatusSerializer(serializers.Serializer): + ServerToken = serializers.CharField() + Passworded = serializers.BooleanField() + ServerName = serializers.CharField() + ForkName = serializers.CharField() + BuildVersion = serializers.IntegerField() + CurrentMap = serializers.CharField() + GameMode = serializers.CharField() + IngameTime = serializers.CharField() + RoundTime = serializers.CharField() + PlayerCount = serializers.IntegerField() + PlayerCountMax = serializers.IntegerField() + ServerIP = serializers.CharField() + ServerPort = serializers.IntegerField() + WinDownload = serializers.CharField() + OSXDownload = serializers.CharField() + LinuxDownload = serializers.CharField() + fps = serializers.IntegerField() + GoodFileVersion = serializers.CharField() + + +class ActiveServersSerializer(serializers.Serializer): + CashDateTime = serializers.DateTimeField() + servers = ServerStatusSerializer(many=True) + + +class RegenerateServerlistTokenSerializer(serializers.Serializer): + server_id = serializers.UUIDField() diff --git a/src/baby_serverlist/api/urls.py b/src/baby_serverlist/api/urls.py new file mode 100644 index 0000000..6ebb9d7 --- /dev/null +++ b/src/baby_serverlist/api/urls.py @@ -0,0 +1,18 @@ +from django.urls import path + +from .views import ( + CreateBabyServerView, + ListBabyServersView, + ListOwnedBabyServersView, + PostServerStatusView, + RegenerateServerlistTokenView, +) + +app_name = "baby_serverlist" +urlpatterns = [ + path("status/", PostServerStatusView.as_view(), name="report-status"), + path("servers/create/", CreateBabyServerView.as_view(), name="create"), + path("servers/owned/", ListOwnedBabyServersView.as_view(), name="list-owned"), + path("servers/", ListBabyServersView.as_view(), name="list"), + path("servers/regenerate-token/", RegenerateServerlistTokenView.as_view(), name="regenerate-token"), +] diff --git a/src/baby_serverlist/api/views.py b/src/baby_serverlist/api/views.py new file mode 100644 index 0000000..4b054a6 --- /dev/null +++ b/src/baby_serverlist/api/views.py @@ -0,0 +1,154 @@ +import logging + +from datetime import UTC, datetime +from typing import cast + +from django.core import signing +from rest_framework import status +from rest_framework.generics import GenericAPIView, ListAPIView +from rest_framework.permissions import AllowAny +from rest_framework.response import Response + +from accounts.models import Account +from baby_serverlist.models import SERVERLIST_TOKEN_SALT, BabyServer +from commons.cache import ( + get_many_baby_server_statuses, + set_baby_server_heartbeat, + set_baby_server_status, +) +from commons.error_response import ErrorResponse + +from .serializers import RegenerateServerlistTokenSerializer, ServerStatusSerializer + +logger = logging.getLogger(__name__) + + +class PostServerStatusView(GenericAPIView): + """Accepts signed status payloads from baby servers and stores the latest state in cache.""" + + serializer_class = ServerStatusSerializer + permission_classes = (AllowAny,) + + def post(self, request, *args, **kwargs): + serializer = self.get_serializer(data=request.data) + + if not serializer.is_valid(): + return ErrorResponse(serializer.errors, status.HTTP_400_BAD_REQUEST) + + status_payload = serializer.validated_data + server_token = status_payload.get("ServerToken") + + try: + payload = signing.loads(server_token, salt=SERVERLIST_TOKEN_SALT) + except signing.BadSignature: + return ErrorResponse("Invalid or expired token", status.HTTP_400_BAD_REQUEST) + + try: + baby_server = BabyServer.objects.get(id=payload.get("server_id"), serverlist_token=server_token) + except BabyServer.DoesNotExist: + return ErrorResponse("Invalid or expired token", status.HTTP_400_BAD_REQUEST) + + server_id = str(baby_server.id) + status_without_token = {key: value for key, value in status_payload.items() if key != "ServerToken"} + + set_baby_server_status(server_id, status_without_token) + set_baby_server_heartbeat(server_id, datetime.now(tz=UTC).isoformat()) + + logger.debug("Received server status update for server %s: %s", baby_server.id, status_without_token) + + return Response(status=status.HTTP_200_OK) + + +class CreateBabyServerView(GenericAPIView): + """Creates a new baby server for the authenticated user and returns the freshly minted token.""" + + queryset = BabyServer.objects.all() + + def post(self, request, *args, **kwargs): + user = cast(Account, request.user) + + baby_server = BabyServer.objects.create(owner=user) + + return Response( + { + "id": str(baby_server.id), + "serverlist_token": baby_server.serverlist_token, + "whitelisted": baby_server.whitelisted, + }, + status=status.HTTP_201_CREATED, + ) + + +class ListOwnedBabyServersView(ListAPIView): + """Lists the caller's baby servers with a derived `live` flag based on recent heartbeats.""" + + def get_queryset(self): + user = cast(Account, self.request.user) + return BabyServer.objects.filter(owner=user).only("id", "whitelisted") + + def list(self, request, *args, **kwargs): + queryset = self.get_queryset() + + data = [ + { + "id": str(server.id), + "whitelisted": server.whitelisted, + "live": server.is_live(), + } + for server in queryset + ] + return Response(data, status=status.HTTP_200_OK) + + +class ListBabyServersView(ListAPIView): + """Return cached status payloads for all baby servers that have reported recently.""" + + permission_classes = (AllowAny,) + + def list(self, request, *args, **kwargs): + servers = BabyServer.objects.filter(whitelisted=True) + server_ids = [str(server.id) for server in servers] + status_map = get_many_baby_server_statuses(server_ids) + + data = [ + status_map[server_id] for server_id in sorted(status_map.keys()) if isinstance(status_map[server_id], dict) + ] + + return Response({"servers": data}, status=status.HTTP_200_OK) + + +class RegenerateServerlistTokenView(GenericAPIView): + """Regenerates a server's signed token after validating ownership.""" + + serializer_class = RegenerateServerlistTokenSerializer + queryset = BabyServer.objects.all() + + def post(self, request, *args, **kwargs): + serializer = self.get_serializer(data=request.data) + serializer.is_valid(raise_exception=True) + + server_id = serializer.validated_data["server_id"] + + try: + baby_server: BabyServer = BabyServer.objects.get(id=server_id) + except BabyServer.DoesNotExist: + return ErrorResponse("Baby server not found", status.HTTP_404_NOT_FOUND) + + user = cast(Account, request.user) + + if baby_server.owner != user: + return ErrorResponse( + "You do not have permission to modify this baby server", + status.HTTP_403_FORBIDDEN, + ) + + baby_server.serverlist_token = baby_server.generate_serverlist_token() + baby_server.save(update_fields=["serverlist_token"]) + + return Response( + { + "id": str(baby_server.id), + "serverlist_token": baby_server.serverlist_token, + }, + status=status.HTTP_200_OK, + ) diff --git a/src/baby_serverlist/apps.py b/src/baby_serverlist/apps.py new file mode 100644 index 0000000..5a3579e --- /dev/null +++ b/src/baby_serverlist/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class BabyServerlistConfig(AppConfig): + default_auto_field = "django.db.models.BigAutoField" + name = "baby_serverlist" diff --git a/src/baby_serverlist/migrations/0001_initial.py b/src/baby_serverlist/migrations/0001_initial.py new file mode 100644 index 0000000..5766db9 --- /dev/null +++ b/src/baby_serverlist/migrations/0001_initial.py @@ -0,0 +1,27 @@ +# Generated by Django 3.2.25 on 2025-11-01 05:11 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion +import uuid + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name='BabyServer', + fields=[ + ('id', models.UUIDField(default=uuid.uuid4, editable=False, primary_key=True, serialize=False)), + ('serverlist_token', models.TextField(editable=False, unique=True)), + ('whitelisted', models.BooleanField(default=False)), + ('owner', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='baby_servers', to=settings.AUTH_USER_MODEL)), + ], + ), + ] diff --git a/src/baby_serverlist/migrations/__init__.py b/src/baby_serverlist/migrations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/baby_serverlist/models.py b/src/baby_serverlist/models.py new file mode 100644 index 0000000..1e893a9 --- /dev/null +++ b/src/baby_serverlist/models.py @@ -0,0 +1,53 @@ +from datetime import UTC, datetime, timedelta +from secrets import token_urlsafe +from uuid import uuid4 + +from django.core import signing +from django.db import models + +from accounts.models import Account +from commons.cache import get_baby_server_heartbeat + +SERVERLIST_TOKEN_SALT = "baby_serverlist.serverlist_token" + + +class BabyServer(models.Model): + id = models.UUIDField(primary_key=True, default=uuid4, editable=False) + serverlist_token = models.TextField(unique=True, editable=False) + whitelisted = models.BooleanField(default=False) + owner = models.ForeignKey( + Account, + on_delete=models.CASCADE, + related_name="baby_servers", + ) + objects = models.Manager() + + def __str__(self) -> str: + return f"BabyServer(id={self.id}, owner={self.owner.unique_identifier})" + + def save(self, *args, **kwargs): + if not self.serverlist_token: + self.serverlist_token = self.generate_serverlist_token() + super().save(*args, **kwargs) + + def generate_serverlist_token(self) -> str: + """Create a signed token that uniquely identifies this server and can be validated by clients.""" + payload = { + "server_id": str(self.id), + "owner_id": str(self.owner.unique_identifier), + "nonce": token_urlsafe(16), + } + return signing.dumps(payload, salt=SERVERLIST_TOKEN_SALT) + + def is_live(self) -> bool: + """Return True when the server has reported within the last 12 seconds.""" + heartbeat_iso = get_baby_server_heartbeat(str(self.id)) + if not heartbeat_iso: + return False + try: + heartbeat_time = datetime.fromisoformat(heartbeat_iso) + except ValueError: + return False + if heartbeat_time.tzinfo is None: + heartbeat_time = heartbeat_time.replace(tzinfo=UTC) + return datetime.now(tz=UTC) - heartbeat_time <= timedelta(seconds=12) diff --git a/src/central_command/settings.py b/src/central_command/settings.py index 8fe78b7..91431f4 100644 --- a/src/central_command/settings.py +++ b/src/central_command/settings.py @@ -51,6 +51,7 @@ "post_office", "accounts", "persistence", + "baby_serverlist", "drf_spectacular", ] @@ -154,6 +155,28 @@ } } +MEMCACHED_HOST = os.environ.get("MEMCACHED_HOST", "cache") +MEMCACHED_PORT = os.environ.get("MEMCACHED_PORT", "11211") + +CACHES = { + "default": { + "BACKEND": "django.core.cache.backends.memcached.PyMemcacheCache", + "LOCATION": f"{MEMCACHED_HOST}:{MEMCACHED_PORT}", + } +} + + +if "test" in sys.argv: + DATABASES["default"] = { + "ENGINE": "django.db.backends.sqlite3", + "NAME": ":memory:", + } + + CACHES["default"] = { + "BACKEND": "django.core.cache.backends.locmem.LocMemCache", + "LOCATION": "tests", + } + REST_FRAMEWORK = { "DEFAULT_SCHEMA_CLASS": "drf_spectacular.openapi.AutoSchema", "DEFAULT_AUTHENTICATION_CLASSES": ["knox.auth.TokenAuthentication"], @@ -161,6 +184,14 @@ "DEFAULT_PAGINATION_CLASS": "rest_framework.pagination.PageNumberPagination", "PAGE_SIZE": 10, "EXCEPTION_HANDLER": "commons.error_response.custom_exception_handler", + "DEFAULT_THROTTLE_CLASSES": [ + "rest_framework.throttling.AnonRateThrottle", + "rest_framework.throttling.UserRateThrottle", + ], + "DEFAULT_THROTTLE_RATES": { + "anon": "60/minute", + "user": "120/minute", + }, } SPECTACULAR_SETTINGS = { diff --git a/src/central_command/urls.py b/src/central_command/urls.py index 454b44f..fd3c277 100644 --- a/src/central_command/urls.py +++ b/src/central_command/urls.py @@ -25,4 +25,5 @@ # API REST FRAMEWORK path("accounts/", include("accounts.api.urls", "Accounts API")), path("persistence/", include("persistence.api.urls")), + path("baby-serverlist/", include("baby_serverlist.api.urls")), ] diff --git a/src/commons/cache.py b/src/commons/cache.py new file mode 100644 index 0000000..31e2075 --- /dev/null +++ b/src/commons/cache.py @@ -0,0 +1,57 @@ +from collections.abc import Iterable +from typing import Any + +from django.core.cache import cache + +SERVER_STATUS_KEY_PREFIX = "baby_server_status:" +SERVER_HEARTBEAT_KEY_PREFIX = "baby_server_heartbeat:" + + +def _status_key(server_id: str) -> str: + """Build the cache key used to store a server's status payload.""" + return f"{SERVER_STATUS_KEY_PREFIX}{server_id}" + + +def _heartbeat_key(server_id: str) -> str: + """Build the cache key used to store a server's last heartbeat timestamp.""" + return f"{SERVER_HEARTBEAT_KEY_PREFIX}{server_id}" + + +def set_baby_server_status(server_id: str, status: dict[str, Any]) -> None: + """Persist the latest status payload for a server.""" + cache.set(_status_key(server_id), status) + + +def get_baby_server_status(server_id: str) -> dict[str, Any] | None: + """Fetch the cached status payload for a server, if present.""" + cached = cache.get(_status_key(server_id)) + return cached if isinstance(cached, dict) else None + + +def get_many_baby_server_statuses(server_ids: Iterable[str]) -> dict[str, dict[str, Any]]: + """Fetch cached statuses for a list of servers using a single multi-get call.""" + key_map = {_status_key(server_id): server_id for server_id in server_ids} + raw = cache.get_many(key_map.keys()) + return { + key_map[cache_key]: value for cache_key, value in raw.items() if isinstance(value, dict) and cache_key in key_map + } + + +def set_baby_server_heartbeat(server_id: str, timestamp: str) -> None: + """Persist the last-reported timestamp for a server.""" + cache.set(_heartbeat_key(server_id), timestamp) + + +def get_baby_server_heartbeat(server_id: str) -> str | None: + """Retrieve the cached heartbeat timestamp for a server.""" + cached = cache.get(_heartbeat_key(server_id)) + return cached if isinstance(cached, str) else None + + +def get_many_baby_server_heartbeats(server_ids: Iterable[str]) -> dict[str, str]: + """Fetch heartbeat timestamps for multiple servers in a single call.""" + key_map = {_heartbeat_key(server_id): server_id for server_id in server_ids} + raw = cache.get_many(key_map.keys()) + return { + key_map[cache_key]: value for cache_key, value in raw.items() if isinstance(value, str) and cache_key in key_map + } diff --git a/src/tests/accounts/endpoints/test_confirm_account.py b/src/tests/accounts/endpoints/test_confirm_account.py new file mode 100644 index 0000000..44e07f8 --- /dev/null +++ b/src/tests/accounts/endpoints/test_confirm_account.py @@ -0,0 +1,38 @@ +from django.core.cache import cache +from django.urls import reverse +from rest_framework import status +from rest_framework.test import APITestCase + +from accounts.models import Account, AccountConfirmation + + +class ConfirmAccountTest(APITestCase): + def setUp(self) -> None: + cache.clear() + self.account = Account.objects.create_user( + username="confirmUser", + email="confirm@example.com", + unique_identifier="confirmUser", + ) + self.account.set_password("aValidPss963") + self.account.is_confirmed = False + self.account.save() + + self.confirmation = AccountConfirmation.objects.create(account=self.account, token="confirm-token") # noqa: S106 - test-only credential + self.url = reverse("account:confirm") + + def test_confirm_account_with_valid_token(self) -> None: + response = self.client.post(self.url, {"token": self.confirmation.token}, format="json") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + + self.account.refresh_from_db() + self.assertTrue(self.account.is_confirmed) + self.assertFalse(AccountConfirmation.objects.filter(pk=self.confirmation.pk).exists()) + + def test_confirm_account_with_invalid_token(self) -> None: + response = self.client.post(self.url, {"token": "invalid-token"}, format="json") + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.account.refresh_from_db() + self.assertFalse(self.account.is_confirmed) diff --git a/src/tests/accounts/endpoints/test_login_credentials.py b/src/tests/accounts/endpoints/test_login_credentials.py index e4cbe3b..a483dd9 100644 --- a/src/tests/accounts/endpoints/test_login_credentials.py +++ b/src/tests/accounts/endpoints/test_login_credentials.py @@ -1,3 +1,4 @@ +from django.core.cache import cache from django.urls import reverse from rest_framework import status from rest_framework.test import APITestCase @@ -7,6 +8,7 @@ class LoginCredentialsTest(APITestCase): def setUp(self): + cache.clear() self.url = reverse("account:login-credentials") self.valid_account = Account.objects.create_user( username="validUser", diff --git a/src/tests/accounts/endpoints/test_login_token.py b/src/tests/accounts/endpoints/test_login_token.py index fef0233..3482d1b 100644 --- a/src/tests/accounts/endpoints/test_login_token.py +++ b/src/tests/accounts/endpoints/test_login_token.py @@ -1,3 +1,4 @@ +from django.core.cache import cache from django.urls import reverse from rest_framework import status from rest_framework.test import APITestCase @@ -7,6 +8,7 @@ class LoginTokenTest(APITestCase): def setUp(self): + cache.clear() self.url = reverse("account:login-token") self.valid_account = Account.objects.create_user( username="validUser", diff --git a/src/tests/accounts/endpoints/test_register.py b/src/tests/accounts/endpoints/test_register.py index 9c712fe..796b2c3 100644 --- a/src/tests/accounts/endpoints/test_register.py +++ b/src/tests/accounts/endpoints/test_register.py @@ -1,3 +1,6 @@ +from unittest.mock import patch + +from django.core.cache import cache from django.urls import reverse from rest_framework import status from rest_framework.test import APITestCase @@ -7,6 +10,7 @@ class RegisterTest(APITestCase): def setUp(self): + cache.clear() self.url = reverse("account:register") self.valid_data = { "email": "validUser@valid.com", @@ -101,3 +105,15 @@ def test_confirmed_state_at_registering(self): self.assertEqual(response.status_code, status.HTTP_200_OK) created_account = Account.objects.get(email=self.valid_data["email"]) self.assertFalse(created_account.is_confirmed) + + def test_register_sends_confirmation_email(self): + data = self.valid_data.copy() + data["email"] = "newuser@example.com" + data["unique_identifier"] = "newuser" + data["username"] = "newuser" + + with patch("accounts.models.Account.send_confirmation_mail") as mock_send: + response = self.client.post(self.url, data, format="json") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + mock_send.assert_called_once() diff --git a/src/tests/accounts/endpoints/test_reset_password.py b/src/tests/accounts/endpoints/test_reset_password.py index ffda7e4..870af95 100644 --- a/src/tests/accounts/endpoints/test_reset_password.py +++ b/src/tests/accounts/endpoints/test_reset_password.py @@ -1,3 +1,6 @@ +from unittest.mock import patch + +from django.core.cache import cache from django.urls import reverse from rest_framework import status from rest_framework.test import APITestCase @@ -11,6 +14,7 @@ def get_reset_url(token): class PasswordResetTest(APITestCase): def setUp(self): + cache.clear() self.valid_account = Account.objects.create_user( username="validUser", email="validUser@valid.com", @@ -37,6 +41,16 @@ def test_request_password_reset_with_valid_email(self): password_reset_request = PasswordResetRequestModel.objects.first() self.assertIsNotNone(password_reset_request) + def test_request_password_reset_sends_email(self): + data = {"email": "validUser@valid.com"} + + with patch("accounts.api.views.send_email_with_template") as mock_send: + response = self.client.post(self.url_request, data, format="json") + + self.assertEqual(response.status_code, status.HTTP_200_OK) + mock_send.assert_called_once() + self.assertEqual(mock_send.call_args.kwargs["recipient"], data["email"]) + def test_request_password_reset_with_invalid_email(self): data = {"email": "invalid@mail.com"} diff --git a/src/tests/accounts/endpoints/test_server_verification_token.py b/src/tests/accounts/endpoints/test_server_verification_token.py index 6e9037a..5ef528c 100644 --- a/src/tests/accounts/endpoints/test_server_verification_token.py +++ b/src/tests/accounts/endpoints/test_server_verification_token.py @@ -1,3 +1,4 @@ +from django.core.cache import cache from django.urls import reverse from rest_framework import status from rest_framework.test import APITestCase @@ -7,6 +8,7 @@ class ServerVerificationTokenTest(APITestCase): def setUp(self): + cache.clear() self.valid_account = Account.objects.create_user( username="validUser", email="validUser@valid.com", @@ -57,6 +59,8 @@ def test_verify_with_invalid_token(self): def test_verify_with_invalid_identifier(self): self.client.credentials(HTTP_AUTHORIZATION=f"Token {self.token}") response = self.client.get(self.url_request, format="json") + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertIn("verification_token", response.data) data = {"verification_token": response.data["verification_token"], "unique_identifier": "invalidIdentifier"} response = self.client.post(self.url_verify, data, format="json") diff --git a/src/tests/baby_serverlist/__init__.py b/src/tests/baby_serverlist/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/tests/baby_serverlist/__init__.py @@ -0,0 +1 @@ + diff --git a/src/tests/baby_serverlist/test_api.py b/src/tests/baby_serverlist/test_api.py new file mode 100644 index 0000000..843fff4 --- /dev/null +++ b/src/tests/baby_serverlist/test_api.py @@ -0,0 +1,186 @@ +from datetime import UTC, datetime, timedelta +from uuid import uuid4 + +from django.core.cache import cache +from django.urls import reverse +from rest_framework import status +from rest_framework.test import APITestCase + +from accounts.models import Account +from baby_serverlist.models import BabyServer +from commons.cache import ( + get_baby_server_heartbeat, + get_baby_server_status, + set_baby_server_heartbeat, + set_baby_server_status, +) + + +def _sample_status_payload(server_token: str) -> dict[str, object]: + return { + "ServerToken": server_token, + "Passworded": False, + "ServerName": "Unitystation - Playtest Server", + "ForkName": "UnityStationDevelop", + "BuildVersion": 25103114, + "CurrentMap": "MainStations/SquareStation.json", + "GameMode": "Secret", + "IngameTime": "a lot", + "RoundTime": "0", + "PlayerCount": 0, + "PlayerCountMax": 45, + "ServerIP": "127.0.0.1", + "ServerPort": 7777, + "WinDownload": "https://example.com/win.zip", + "OSXDownload": "https://example.com/osx.zip", + "LinuxDownload": "https://example.com/linux.zip", + "fps": 98, + "GoodFileVersion": "0.31.0", + } + + +class BabyServerAPITests(APITestCase): + def setUp(self) -> None: + self.user = Account.objects.create_user( + email="owner@example.com", + password="password123", # noqa: S106 - test-only credential + unique_identifier="owner", + username="Owner", + ) + self.other_user = Account.objects.create_user( + email="other@example.com", + password="password123", # noqa: S106 - test-only credential + unique_identifier="otheruser", + username="Other", + ) + + def tearDown(self) -> None: + cache.clear() + + def test_create_baby_server_returns_token(self) -> None: + self.client.force_authenticate(self.user) + + response = self.client.post(reverse("baby_serverlist:create")) + + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + payload = response.json() + self.assertIn("serverlist_token", payload) + self.assertTrue(BabyServer.objects.filter(id=payload["id"]).exists()) + + def test_regenerate_token_changes_token_for_owner(self) -> None: + self.client.force_authenticate(self.user) + baby_server = BabyServer.objects.create(owner=self.user) + original_token = baby_server.serverlist_token + + response = self.client.post( + reverse("baby_serverlist:regenerate-token"), + {"server_id": str(baby_server.id)}, + format="json", + ) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + baby_server.refresh_from_db() + self.assertNotEqual(original_token, baby_server.serverlist_token) + + def test_regenerate_token_rejects_non_owner(self) -> None: + baby_server = BabyServer.objects.create(owner=self.user) + self.client.force_authenticate(self.other_user) + + response = self.client.post( + reverse("baby_serverlist:regenerate-token"), + {"server_id": str(baby_server.id)}, + format="json", + ) + + self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) + + def test_regenerate_token_not_found_returns_404(self) -> None: + self.client.force_authenticate(self.user) + + response = self.client.post( + reverse("baby_serverlist:regenerate-token"), + {"server_id": str(uuid4())}, + format="json", + ) + + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + + def test_post_server_status_stores_payload_in_cache(self) -> None: + baby_server = BabyServer.objects.create(owner=self.user, whitelisted=True) + payload = _sample_status_payload(baby_server.serverlist_token) + + response = self.client.post( + reverse("baby_serverlist:report-status"), + payload, + format="json", + ) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + cached_status = get_baby_server_status(str(baby_server.id)) + self.assertIsNotNone(cached_status) + cached_status = cached_status or {} + self.assertNotIn("ServerToken", cached_status) + self.assertEqual(cached_status["ServerName"], payload["ServerName"]) + + cached_heartbeat = get_baby_server_heartbeat(str(baby_server.id)) + self.assertIsNotNone(cached_heartbeat) + + def test_post_server_status_rejects_invalid_token(self) -> None: + baby_server = BabyServer.objects.create(owner=self.user) + payload = _sample_status_payload("invalid-token") + + response = self.client.post( + reverse("baby_serverlist:report-status"), + payload, + format="json", + ) + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + cached_status = get_baby_server_status(str(baby_server.id)) + self.assertIsNone(cached_status) + + def test_list_owned_baby_servers_live_flag(self) -> None: + self.client.force_authenticate(self.user) + baby_server = BabyServer.objects.create(owner=self.user) + + response = self.client.get(reverse("baby_serverlist:list-owned")) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertFalse(response.json()[0]["live"]) + + set_baby_server_heartbeat(str(baby_server.id), datetime.now(tz=UTC).isoformat()) + + response = self.client.get(reverse("baby_serverlist:list-owned")) + self.assertTrue(response.json()[0]["live"]) + + stale_time = datetime.now(tz=UTC) - timedelta(seconds=13) + set_baby_server_heartbeat(str(baby_server.id), stale_time.isoformat()) + + response = self.client.get(reverse("baby_serverlist:list-owned")) + self.assertFalse(response.json()[0]["live"]) + + def test_list_baby_servers_returns_whitelisted_status(self) -> None: + baby_server = BabyServer.objects.create(owner=self.user, whitelisted=True) + status_data = _sample_status_payload(baby_server.serverlist_token) + status_without_token = status_data.copy() + status_without_token.pop("ServerToken") + + set_baby_server_status(str(baby_server.id), status_without_token) + + response = self.client.get(reverse("baby_serverlist:list")) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + payload = response.json() + self.assertIn("servers", payload) + self.assertEqual(len(payload["servers"]), 1) + self.assertEqual(payload["servers"][0]["ServerName"], status_without_token["ServerName"]) + + def test_list_baby_servers_ignores_non_whitelisted(self) -> None: + non_whitelisted = BabyServer.objects.create(owner=self.user, whitelisted=False) + set_baby_server_status(str(non_whitelisted.id), {"ServerName": "Hidden"}) + + response = self.client.get(reverse("baby_serverlist:list")) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.json(), {"servers": []}) diff --git a/uv.lock b/uv.lock index 9fdcec6..b35a0d0 100644 --- a/uv.lock +++ b/uv.lock @@ -49,6 +49,7 @@ dependencies = [ { name = "drf-spectacular" }, { name = "gunicorn" }, { name = "psycopg2-binary" }, + { name = "pymemcache" }, { name = "python-dotenv" }, { name = "whitenoise" }, ] @@ -71,6 +72,7 @@ requires-dist = [ { name = "drf-spectacular", specifier = "~=0.27.1" }, { name = "gunicorn", specifier = "~=20.1.0" }, { name = "psycopg2-binary", specifier = "~=2.9.9" }, + { name = "pymemcache", specifier = ">=4.0,<5.0" }, { name = "python-dotenv", specifier = "~=0.19.2" }, { name = "whitenoise", specifier = "~=6.6.0" }, ] @@ -479,6 +481,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/13/a3/a812df4e2dd5696d1f351d58b8fe16a405b234ad2886a0dab9183fb78109/pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc", size = 117552 }, ] +[[package]] +name = "pymemcache" +version = "4.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d9/b6/4541b664aeaad025dfb8e851dcddf8e25ab22607e674dd2b562ea3e3586f/pymemcache-4.0.0.tar.gz", hash = "sha256:27bf9bd1bbc1e20f83633208620d56de50f14185055e49504f4f5e94e94aff94", size = 70176 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/41/ba/2f7b22d8135b51c4fefb041461f8431e1908778e6539ff5af6eeaaee367a/pymemcache-4.0.0-py2.py3-none-any.whl", hash = "sha256:f507bc20e0dc8d562f8df9d872107a278df049fa496805c1431b926f3ddd0eab", size = 60772 }, +] + [[package]] name = "python-dotenv" version = "0.19.2"