diff --git a/server/lobbyconnection.py b/server/lobbyconnection.py index ad87aeca7..c50c10964 100644 --- a/server/lobbyconnection.py +++ b/server/lobbyconnection.py @@ -14,7 +14,7 @@ import aiohttp from sqlalchemy import and_, func, select -from sqlalchemy.exc import DBAPIError +from sqlalchemy.exc import DBAPIError, OperationalError import server.metrics as metrics from server.db import FAFDatabase @@ -216,6 +216,16 @@ async def on_message_received(self, message): self.get_user_identifier(), cmd ) + except OperationalError: + # When the database goes down, SqlAlchemy will throw an OperationalError + self._logger.error("Encountered OperationalError on message received. This could indicate DB is down.") + await self.send({ + "command": "notice", + "style": "error", + "text": "Unable to connect to database. Please try again later." + }) + # Make sure to abort here to avoid a thundering herd problem. + await self.abort("Error connecting to database") except Exception as e: # pragma: no cover await self.send({"command": "invalid"}) self._logger.exception(e) diff --git a/tests/unit_tests/test_lobbyconnection.py b/tests/unit_tests/test_lobbyconnection.py index b3977531f..248102f7c 100644 --- a/tests/unit_tests/test_lobbyconnection.py +++ b/tests/unit_tests/test_lobbyconnection.py @@ -5,6 +5,7 @@ import pytest from aiohttp import web from sqlalchemy import and_, select +from sqlalchemy.exc import OperationalError from server.config import config from server.db.models import ban, friends_and_foes @@ -159,6 +160,28 @@ async def test_bad_command_calls_abort(lobbyconnection): lobbyconnection.abort.assert_called_once_with("Error processing command") +async def test_database_outage_error_responds_cleanly(lobbyconnection): + lobbyconnection.abort = mock.AsyncMock() + lobbyconnection.check_policy_conformity = mock.AsyncMock(return_value=True) + lobbyconnection.send = mock.AsyncMock() + + def mock_ensure_authenticated(cmd): + raise OperationalError(statement="", params=[], orig=None) + lobbyconnection.ensure_authenticated = mock_ensure_authenticated + await lobbyconnection.on_message_received({ + "command": "hello", + "login": "test", + "password": sha256(b"test_password").hexdigest(), + "unique_id": "blah" + }) + lobbyconnection.send.assert_called_once_with({ + "command": "notice", + "style": "error", + "text": "Unable to connect to database. Please try again later." + }) + lobbyconnection.abort.assert_called_once_with("Error connecting to database") + + async def test_command_pong_does_nothing(lobbyconnection): lobbyconnection.send = mock.AsyncMock()