diff --git a/main.py b/main.py index 0c20dffb2..0c5f90592 100755 --- a/main.py +++ b/main.py @@ -10,53 +10,65 @@ import asyncio import logging import os -import platform import signal import sys import time from datetime import datetime +from functools import wraps import humanize from docopt import docopt +from prometheus_client import start_http_server import server +from server import info from server.config import config +from server.control import ControlServer from server.game_service import GameService +from server.health import HealthServer from server.ice_servers.nts import TwilioNTS from server.player_service import PlayerService from server.profiler import Profiler from server.protocol import QDataStreamProtocol, SimpleJsonProtocol +def log_signal(func): + @wraps(func) + def wrapped(sig, frame): + logger.info("Received signal %s", signal.Signals(sig)) + return func(sig, frame) + + return wrapped + + async def main(): global startup_time, shutdown_time - version = os.environ.get("VERSION") or "dev" - python_version = platform.python_version() - logger.info( - "Lobby %s (Python %s) on %s", - version, - python_version, - sys.platform + "Lobby %s (Python %s) on %s named %s", + info.VERSION, + info.PYTHON_VERSION, + sys.platform, + info.CONTAINER_NAME, ) + if config.ENABLE_METRICS: + logger.info("Using prometheus on port: %i", config.METRICS_PORT) + start_http_server(config.METRICS_PORT) + loop = asyncio.get_running_loop() done = loop.create_future() logger.info("Event loop: %s", loop) - def signal_handler(sig: int, _frame): - logger.info( - "Received signal %s, shutting down", - signal.Signals(sig) - ) + @log_signal + def done_handler(sig: int, frame): if not done.done(): done.set_result(0) # Make sure we can shutdown gracefully - signal.signal(signal.SIGTERM, signal_handler) - signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, done_handler) + signal.signal(signal.SIGINT, done_handler) database = server.db.FAFDatabase( host=config.DB_SERVER, @@ -91,19 +103,21 @@ def signal_handler(sig: int, _frame): config.register_callback("PROFILING_DURATION", profiler.refresh) config.register_callback("PROFILING_INTERVAL", profiler.refresh) - await instance.start_services() - - ctrl_server = await server.run_control_server(player_service, game_service) + health_server = HealthServer(instance) + await health_server.run_from_config() + config.register_callback( + "HEALTH_SERVER_PORT", + health_server.run_from_config + ) - async def restart_control_server(): - nonlocal ctrl_server + control_server = ControlServer(instance) + await control_server.run_from_config() + config.register_callback( + "CONTROL_SERVER_PORT", + control_server.run_from_config + ) - await ctrl_server.shutdown() - ctrl_server = await server.run_control_server( - player_service, - game_service - ) - config.register_callback("CONTROL_SERVER_PORT", restart_control_server) + await instance.start_services() PROTO_CLASSES = { QDataStreamProtocol.__name__: QDataStreamProtocol, @@ -135,8 +149,8 @@ async def restart_control_server(): ) server.metrics.info.info({ - "version": version, - "python_version": python_version, + "version": info.VERSION, + "python_version": info.PYTHON_VERSION, "start_time": datetime.utcnow().strftime("%m-%d %H:%M"), "game_uid": str(game_service.game_id_counter) }) @@ -150,12 +164,27 @@ async def restart_control_server(): shutdown_time = time.perf_counter() # Cleanup - await instance.shutdown() - await ctrl_server.shutdown() + await instance.graceful_shutdown() + + drain_task = asyncio.create_task(instance.drain()) - # Close DB connections + @log_signal + def drain_handler(sig: int, frame): + if not drain_task.done(): + drain_task.cancel() + + # Allow us to force shut down by skipping the drain + signal.signal(signal.SIGTERM, drain_handler) + signal.signal(signal.SIGINT, drain_handler) + + await drain_task + await instance.shutdown() + await control_server.shutdown() await database.close() + # Health server should be the last thing to shut down + await health_server.shutdown() + return exit_code @@ -191,7 +220,7 @@ async def restart_control_server(): stop_time = time.perf_counter() logger.info( "Total server uptime: %s", - humanize.naturaldelta(stop_time - startup_time) + humanize.precisedelta(stop_time - startup_time) ) if shutdown_time is not None: diff --git a/minikube-example.yaml b/minikube-example.yaml new file mode 100644 index 000000000..2258f8242 --- /dev/null +++ b/minikube-example.yaml @@ -0,0 +1,83 @@ +apiVersion: v1 +kind: Service +metadata: + name: faf-lobby + labels: + app: faf-lobby +spec: + type: NodePort + selector: + app: faf-lobby + ports: + - port: 8001 + name: qstream + - port: 8002 + name: simplejson +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: faf-lobby +spec: + replicas: 1 + selector: + matchLabels: + app: faf-lobby + template: + metadata: + labels: + app: faf-lobby + spec: + terminationGracePeriodSeconds: 310 + containers: + - name: faf-python-server + image: faf-python-server:graceful + imagePullPolicy: Never + readinessProbe: + httpGet: + path: /ready + port: health + initialDelaySeconds: 4 + periodSeconds: 1 + ports: + - containerPort: 4000 + name: control + - containerPort: 2000 + name: health + - containerPort: 8001 + name: qstream + - containerPort: 8002 + name: simplejson + env: + - name: CONFIGURATION_FILE + value: /config/config.yaml + - name: CONTAINER_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + volumeMounts: + - name: config + mountPath: /config + readOnly: true + volumes: + - name: config + configMap: + name: minikube-dev-config + items: + - key: config.yaml + path: config.yaml +--- +apiVersion: v1 +kind: ConfigMap +metadata: + name: minikube-dev-config +data: + config.yaml: | + LOG_LEVEL: TRACE + USE_POLICY_SERVER: false + QUEUE_POP_TIME_MAX: 30 + SHUTDOWN_GRACE_PERIOD: 300 + SHUTDOWN_KICK_IDLE_PLAYERS: true + + DB_SERVER: host.minikube.internal + MQ_SERVER: host.minikube.internal diff --git a/server/__init__.py b/server/__init__.py index c262075e1..e3837209f 100644 --- a/server/__init__.py +++ b/server/__init__.py @@ -90,15 +90,12 @@ import time from typing import Optional -from prometheus_client import start_http_server - import server.metrics as metrics -from .asyncio_extensions import synchronizedmethod +from .asyncio_extensions import map_suppress, synchronizedmethod from .broadcast_service import BroadcastService from .config import TRACE, config from .configuration_service import ConfigurationService -from .control import run_control_server from .core import Service, create_services from .db import FAFDatabase from .game_service import GameService @@ -138,18 +135,12 @@ "RatingService", "ServerInstance", "ViolationService", - "control", "game_service", "protocol", - "run_control_server", ) logger = logging.getLogger("server") -if config.ENABLE_METRICS: - logger.info("Using prometheus on port: %i", config.METRICS_PORT) - start_http_server(config.METRICS_PORT) - class ServerInstance(object): """ @@ -273,40 +264,86 @@ async def listen( return ctx - async def shutdown(self): - results = await asyncio.gather( - *(ctx.stop() for ctx in self.contexts), - return_exceptions=True - ) - for result, ctx in zip(results, self.contexts): - if isinstance(result, BaseException): - self._logger.exception( - "Unexpected error when stopping context %s", - ctx, - exc_info=result - ) + async def graceful_shutdown(self): + """ + Start a graceful shut down of the server. - results = await asyncio.gather( - *(service.shutdown() for service in self.services.values()), - return_exceptions=True - ) - for result, service in zip(results, self.services.values()): - if isinstance(result, BaseException): - self._logger.error( - "Unexpected error when shutting down service %s", - service - ) + 1. Notify all services of graceful shutdown + """ + self._logger.info("Initiating graceful shutdown") - results = await asyncio.gather( - *(ctx.shutdown() for ctx in self.contexts), - return_exceptions=True + await map_suppress( + lambda service: service.graceful_shutdown(), + self.services.values(), + logger=self._logger, + msg="when starting graceful shutdown of service " ) - for result, ctx in zip(results, self.contexts): - if isinstance(result, BaseException): - self._logger.error( - "Unexpected error when shutting down context %s", - ctx - ) + + async def shutdown(self): + """ + Immediately shutdown the server. + + 1. Stop accepting new connections + 2. Stop all services + 3. Close all existing connections + """ + self._logger.info("Initiating full shutdown") + + await self._stop_contexts() + await self._shutdown_services() + await self._shutdown_contexts() self.contexts.clear() self.started = False + + async def drain(self): + """ + Wait for all games to end. + """ + game_service: GameService = self.services["game_service"] + broadcast_service: BroadcastService = self.services["broadcast_service"] + try: + await asyncio.wait_for( + game_service.drain_games(), + timeout=config.SHUTDOWN_GRACE_PERIOD + ) + except asyncio.CancelledError: + self._logger.debug( + "Stopped waiting for games to end due to forced shutdown" + ) + except asyncio.TimeoutError: + self._logger.warning( + "Graceful shutdown period ended! %s games are still live!", + len(game_service.live_games) + ) + finally: + # The report_dirties loop is responsible for clearing dirty games + # and broadcasting the update messages to players and to RabbitMQ. + # We need to wait here for that loop to complete otherwise it is + # possible for the services to be shut down inbetween clearing the + # games and posting the messages, causing the posts to fail. + await broadcast_service.wait_report_dirtes() + + async def _shutdown_services(self): + await map_suppress( + lambda service: service.shutdown(), + self.services.values(), + logger=self._logger, + msg="when shutting down service " + ) + + async def _stop_contexts(self): + await map_suppress( + lambda ctx: ctx.stop(), + self.contexts, + logger=self._logger, + msg="when stopping context " + ) + + async def _shutdown_contexts(self): + await map_suppress( + lambda ctx: ctx.shutdown(), + self.contexts, + logger=self._logger, + msg="when shutting down context " + ) diff --git a/server/asyncio_extensions.py b/server/asyncio_extensions.py index 2773de53e..3d5915126 100644 --- a/server/asyncio_extensions.py +++ b/server/asyncio_extensions.py @@ -12,8 +12,10 @@ AsyncContextManager, Callable, Coroutine, + Iterable, Optional, Protocol, + TypeVar, cast, overload ) @@ -22,6 +24,7 @@ AsyncFunc = Callable[..., Coroutine[Any, Any, Any]] AsyncDecorator = Callable[[AsyncFunc], AsyncFunc] +T = TypeVar("T") class AsyncLock(Protocol, AsyncContextManager["AsyncLock"]): @@ -30,23 +33,24 @@ async def acquire(self) -> bool: ... def release(self) -> None: ... -async def gather_without_exceptions( - tasks: list[asyncio.Task], - *exceptions: type[BaseException], -) -> list[Any]: - """ - Run coroutines in parallel, raising the first exception that dosen't - match any of the specified exception classes. - """ - results = [] - for fut in asyncio.as_completed(tasks): - try: - results.append(await fut) - except exceptions: - logger.debug( - "Ignoring error in gather_without_exceptions", exc_info=True +async def map_suppress( + func: Callable[[T], Coroutine[Any, Any, Any]], + iterable: Iterable[T], + logger: logging.Logger = logger, + msg: str = "" +): + results = await asyncio.gather( + *(func(item) for item in iterable), + return_exceptions=True + ) + for result, item in zip(results, iterable): + if isinstance(result, BaseException): + logger.exception( + "Unexpected error %s%s", + msg, + item, + exc_info=result ) - return results # Based on python3.8 asyncio.Lock diff --git a/server/broadcast_service.py b/server/broadcast_service.py index 538391eec..6eaf6da2f 100644 --- a/server/broadcast_service.py +++ b/server/broadcast_service.py @@ -1,3 +1,6 @@ +import asyncio + +import humanize from aio_pika import DeliveryMode from .config import config @@ -27,13 +30,14 @@ def __init__( self.message_queue_service = message_queue_service self.game_service = game_service self.player_service = player_service + self._report_dirties_event = None async def initialize(self): # Using a lazy interval timer so that the intervals can be changed # without restarting the server. self._broadcast_dirties_timer = LazyIntervalTimer( lambda: config.DIRTY_REPORT_INTERVAL, - self.report_dirties, + self._monitored_report_dirties, start=True ) self._broadcast_ping_timer = LazyIntervalTimer( @@ -42,6 +46,14 @@ async def initialize(self): start=True ) + async def _monitored_report_dirties(self): + event = asyncio.Event() + self._report_dirties_event = event + try: + await self.report_dirties() + finally: + event.set() + async def report_dirties(self): """ Send updates about any dirty (changed) entities to connected players. @@ -116,3 +128,48 @@ async def report_dirties(self): def broadcast_ping(self): self.server.write_broadcast({"command": "ping"}) + + async def wait_report_dirtes(self): + """ + Wait for the current report_dirties task to complete. + """ + if self._report_dirties_event is None: + return + + await self._report_dirties_event.wait() + + async def graceful_shutdown(self): + if config.SHUTDOWN_KICK_IDLE_PLAYERS: + message = ( + "If you're in a game you can continue to play, otherwise you " + "will be disconnected. If you aren't reconnected automatically " + "please wait a few minutes and try to connect again." + ) + else: + message = ( + "If you're in a game you can continue to play, however, you " + "will not be able to create any new games until the server has " + "been restarted." + ) + + delta = humanize.precisedelta(config.SHUTDOWN_GRACE_PERIOD) + self.server.write_broadcast({ + "command": "notice", + "style": "info", + "text": ( + f"The server will be shutting down for maintenance in {delta}! " + f"{message}" + ) + }) + + async def shutdown(self): + self.server.write_broadcast({ + "command": "notice", + "style": "info", + "text": ( + "The server has been shut down for maintenance " + "but should be back online soon. If you experience any " + "problems, please restart your client.

" + "We apologize for this interruption." + ) + }) diff --git a/server/config.py b/server/config.py index b0c901e22..aade3ee29 100644 --- a/server/config.py +++ b/server/config.py @@ -67,8 +67,15 @@ def __init__(self): self.DIRTY_REPORT_INTERVAL = 1 self.PING_INTERVAL = 45 + # How many seconds to wait for games to end before doing a hard shutdown. + # If using kubernetes, you must set terminationGracePeriodSeconds + # on the pod to be larger than this value. With docker compose, use + # --timeout (-t) to set a longer timeout. + self.SHUTDOWN_GRACE_PERIOD = 30 * 60 + self.SHUTDOWN_KICK_IDLE_PLAYERS = False self.CONTROL_SERVER_PORT = 4000 + self.HEALTH_SERVER_PORT = 2000 self.METRICS_PORT = 8011 self.ENABLE_METRICS = False @@ -97,10 +104,9 @@ def __init__(self): self.FAF_POLICY_SERVER_BASE_URL = "http://faf-policy-server" self.USE_POLICY_SERVER = True - self.FORCE_STEAM_LINK_AFTER_DATE = 1536105599 # 5 september 2018 by default - self.FORCE_STEAM_LINK = False - self.ALLOW_PASSWORD_LOGIN = True + # How many seconds a connection has to authenticate before being killed + self.LOGIN_TIMEOUT = 5 * 60 self.NEWBIE_BASE_MEAN = 500 self.NEWBIE_MIN_GAMES = 10 @@ -171,10 +177,14 @@ def refresh(self) -> None: with open(config_file) as f: new_values.update(yaml.safe_load(f)) except FileNotFoundError: - self._logger.info("No configuration file found at %s", config_file) + self._logger.warning( + "No configuration file found at %s", + config_file + ) except TypeError: self._logger.info( - "Configuration file at %s appears to be empty", config_file + "Configuration file at %s appears to be empty", + config_file ) triggered_callback_keys = tuple( diff --git a/server/control.py b/server/control.py index caed69319..328b95603 100644 --- a/server/control.py +++ b/server/control.py @@ -3,73 +3,66 @@ """ import socket -from json import dumps from aiohttp import web from .config import config from .decorators import with_logger -from .game_service import GameService -from .player_service import PlayerService @with_logger class ControlServer: def __init__( self, - game_service: GameService, - player_service: PlayerService, - host: str, - port: int + lobby_server: "ServerInstance", ): - self.game_service = game_service - self.player_service = player_service - self.host = host - self.port = port + self.lobby_server = lobby_server + self.game_service = lobby_server.services["game_service"] + self.player_service = lobby_server.services["player_service"] + self.host = None + self.port = None self.app = web.Application() self.runner = web.AppRunner(self.app) self.app.add_routes([ web.get("/games", self.games), - web.get("/players", self.players) + web.get("/players", self.players), ]) - async def start(self) -> None: + async def run_from_config(self) -> None: + """ + Initialize the http control server + """ + host = socket.gethostbyname(socket.gethostname()) + port = config.CONTROL_SERVER_PORT + + await self.shutdown() + await self.start(host, port) + + async def start(self, host: str, port: int) -> None: + self.host = host + self.port = port await self.runner.setup() - self.site = web.TCPSite(self.runner, self.host, self.port) + self.site = web.TCPSite(self.runner, host, port) await self.site.start() self._logger.info( - "Control server listening on http://%s:%s", self.host, self.port + "Control server listening on http://%s:%s", host, port ) async def shutdown(self) -> None: await self.runner.cleanup() + self.host = None + self.port = None - async def games(self, request): - body = dumps(to_dict_list(self.game_service.all_games)) - return web.Response(body=body.encode(), content_type="application/json") - - async def players(self, request): - body = dumps(to_dict_list(self.player_service.all_players)) - return web.Response(body=body.encode(), content_type="application/json") - - -async def run_control_server( - player_service: PlayerService, - game_service: GameService -) -> ControlServer: - """ - Initialize the http control server - """ - host = socket.gethostbyname(socket.gethostname()) - port = config.CONTROL_SERVER_PORT - - ctrl_server = ControlServer(game_service, player_service, host, port) - await ctrl_server.start() - - return ctrl_server - + async def games(self, request) -> web.Response: + return web.json_response([ + game.to_dict() + for game in self.game_service.all_games + ]) -def to_dict_list(list_): - return list(map(lambda p: p.to_dict(), list_)) + async def players(self, request) -> web.Response: + return web.json_response([ + player.to_dict() + for player in self.player_service.all_players + ]) diff --git a/server/core/service.py b/server/core/service.py index e2bae1bc3..ed5ecaa0d 100644 --- a/server/core/service.py +++ b/server/core/service.py @@ -30,6 +30,17 @@ async def initialize(self) -> None: """ pass # pragma: no cover + async def graceful_shutdown(self) -> None: + """ + Called once after the graceful shutdown period is initiated. + + This signals that the service should stop accepting new events but + continue to wait for existing ones to complete normally. The hook + funciton `shutdown` will be called after the grace period has ended to + fully shutdown the service. + """ + pass # pragma: no cover + async def shutdown(self) -> None: """ Called once after the server received the shutdown signal. diff --git a/server/exceptions.py b/server/exceptions.py index b2102d9bb..5c5a103d6 100644 --- a/server/exceptions.py +++ b/server/exceptions.py @@ -57,3 +57,9 @@ def __init__(self, message, method, *args, **kwargs): super().__init__(*args, **kwargs) self.message = message self.method = method + + +class DisabledError(Exception): + """ + The operation is disabled due to an impending server shutdown. + """ diff --git a/server/game_service.py b/server/game_service.py index d0bd44328..1e44bfa42 100644 --- a/server/game_service.py +++ b/server/game_service.py @@ -2,6 +2,7 @@ Manages the lifecycle of active games """ +import asyncio from collections import Counter from typing import Optional, Union, ValuesView @@ -15,6 +16,7 @@ from .db import FAFDatabase from .db.models import game_featuredMods from .decorators import with_logger +from .exceptions import DisabledError from .games import ( CustomGame, FeaturedMod, @@ -52,6 +54,8 @@ def __init__( self._rating_service = rating_service self._message_queue_service = message_queue_service self.game_id_counter = 0 + self._allow_new_games = False + self._drain_event = None # Populated below in really_update_static_ish_data. self.featured_mods = dict() @@ -68,6 +72,7 @@ async def initialize(self) -> None: self._update_cron = aiocron.crontab( "*/10 * * * *", func=self.update_data ) + self._allow_new_games = True async def initialise_game_counter(self): async with self._db.acquire() as conn: @@ -153,6 +158,9 @@ def create_game( """ Main entrypoint for creating new games """ + if not self._allow_new_games: + raise DisabledError() + game_id = self.create_uid() game_args = { "database": self._db, @@ -207,10 +215,17 @@ def update_active_game_metrics(self): rating_type_counter[(rating_type, state)] ) + @property + def all_games(self) -> ValuesView[Game]: + return self._games.values() + @property def live_games(self) -> list[Game]: - return [game for game in self._games.values() - if game.state is GameState.LIVE] + return [ + game + for game in self.all_games + if game.state is GameState.LIVE + ] @property def open_games(self) -> list[Game]: @@ -225,22 +240,32 @@ def open_games(self) -> list[Game]: The client ignores everything "closed". This property fetches all such not-closed games. """ - return [game for game in self._games.values() - if game.state is GameState.LOBBY or game.state is GameState.LIVE] - - @property - def all_games(self) -> ValuesView[Game]: - return self._games.values() + return [ + game + for game in self.all_games + if game.state in (GameState.LOBBY, GameState.LIVE) + ] @property def pending_games(self) -> list[Game]: - return [game for game in self._games.values() - if game.state is GameState.LOBBY or game.state is GameState.INITIALIZING] + return [ + game + for game in self.all_games + if game.state in (GameState.LOBBY, GameState.INITIALIZING) + ] def remove_game(self, game: Game): if game.id in self._games: + self._logger.debug("Removing game %s", game) del self._games[game.id] + if ( + self._drain_event is not None + and not self._drain_event.is_set() + and not self._games + ): + self._drain_event.set() + def __getitem__(self, item: int) -> Game: return self._games[item] @@ -262,3 +287,31 @@ async def publish_game_results(self, game_results: EndedGameInfo): metrics.rated_games.labels(game_results.rating_type).inc() # TODO: Remove when rating service starts listening to message queue await self._rating_service.enqueue(result_dict) + + async def drain_games(self): + """ + Wait for all games to finish. + """ + if not self._games: + return + + if not self._drain_event: + self._drain_event = asyncio.Event() + + await self._drain_event.wait() + + async def graceful_shutdown(self): + self._allow_new_games = False + + await self.close_lobby_games() + + async def close_lobby_games(self): + self._logger.info("Closing all games currently in lobby") + for game in self.pending_games: + for game_connection in list(game.connections): + # Tell the client to kill the FA process + game_connection.player.write_message({ + "command": "notice", + "style": "kill" + }) + await game_connection.abort() diff --git a/server/games/game.py b/server/games/game.py index f72874c02..e09f8a3a2 100644 --- a/server/games/game.py +++ b/server/games/game.py @@ -168,7 +168,6 @@ def players(self) -> list[Player]: Depending on the state, it is either: - (LOBBY) The currently connected players - (LIVE) Players who participated in the game - - Empty list """ if self.state is GameState.LOBBY: return self.get_connected_players() diff --git a/server/health.py b/server/health.py new file mode 100644 index 000000000..617d33593 --- /dev/null +++ b/server/health.py @@ -0,0 +1,65 @@ +""" +Kubernetes compatible HTTP health check server. +""" + +import http +import socket + +from aiohttp import web + +from .config import config +from .decorators import with_logger + + +@with_logger +class HealthServer: + def __init__( + self, + lobby_server: "ServerInstance", + ): + self.lobby_server = lobby_server + self.host = None + self.port = None + + self.app = web.Application() + self.runner = web.AppRunner(self.app, access_log=None) + + self.app.add_routes([ + web.get("/ready", self.ready) + ]) + + async def run_from_config(self) -> None: + """ + Initialize the http health server + """ + host = socket.gethostbyname(socket.gethostname()) + port = config.HEALTH_SERVER_PORT + + await self.shutdown() + await self.start(host, port) + + async def start(self, host: str, port: int) -> None: + self.host = host + self.port = port + await self.runner.setup() + self.site = web.TCPSite(self.runner, host, port) + await self.site.start() + self._logger.info( + "Health server listening on http://%s:%s", host, port + ) + + async def shutdown(self) -> None: + await self.runner.cleanup() + self.host = None + self.port = None + + async def ready(self, request): + code_map = { + True: http.HTTPStatus.OK.value, + False: http.HTTPStatus.SERVICE_UNAVAILABLE.value + } + + return web.Response( + status=code_map[self.lobby_server.started], + content_type="text/plain" + ) diff --git a/server/info.py b/server/info.py new file mode 100644 index 000000000..5c7767a11 --- /dev/null +++ b/server/info.py @@ -0,0 +1,12 @@ +""" +Static meta information about the container/process +""" + +import os +import platform + +PYTHON_VERSION = platform.python_version() + +# Environment variables +VERSION = os.getenv("VERSION") or "dev" +CONTAINER_NAME = os.getenv("CONTAINER_NAME") or "faf-python-server" diff --git a/server/ladder_service/ladder_service.py b/server/ladder_service/ladder_service.py index 0365552e5..e7046493d 100644 --- a/server/ladder_service/ladder_service.py +++ b/server/ladder_service/ladder_service.py @@ -34,6 +34,7 @@ matchmaker_queue_map_pool ) from server.decorators import with_logger +from server.exceptions import DisabledError from server.game_service import GameService from server.games import InitMode, LadderGame from server.games.ladder_game import GameClosedError @@ -70,6 +71,7 @@ def __init__( self.violation_service = violation_service self._searches: dict[Player, dict[str, Search]] = defaultdict(dict) + self._allow_new_searches = True async def initialize(self) -> None: await self.update_data() @@ -274,6 +276,9 @@ def start_search( queue_name: str, on_matched: OnMatchedCallback = lambda _1, _2: None ): + if not self._allow_new_searches: + raise DisabledError() + timeouts = self.violation_service.get_violations(players) if timeouts: self._logger.debug("timeouts: %s", timeouts) @@ -715,10 +720,16 @@ def on_connection_lost(self, conn: "LobbyConnection") -> None: if player in self._informed_players: self._informed_players.remove(player) - async def shutdown(self): + async def graceful_shutdown(self): + self._allow_new_searches = False + for queue in self.queues.values(): queue.shutdown() + for player, searches in self._searches.items(): + for queue_name in list(searches.keys()): + self._cancel_search(player, queue_name) + class NotConnectedError(asyncio.TimeoutError): def __init__(self, players: list[Player]): diff --git a/server/ladder_service/violation_service.py b/server/ladder_service/violation_service.py index 2ef1a9030..f954abe29 100644 --- a/server/ladder_service/violation_service.py +++ b/server/ladder_service/violation_service.py @@ -91,7 +91,7 @@ def register_violations(self, players: list[Player]): }) extra_text = "" if violation.count > 1: - delta_text = humanize.naturaldelta( + delta_text = humanize.precisedelta( violation.get_ban_expiration() - now ) extra_text = f" You can queue again in {delta_text}" diff --git a/server/lobbyconnection.py b/server/lobbyconnection.py index b2fd49aa8..e4c9c0e80 100644 --- a/server/lobbyconnection.py +++ b/server/lobbyconnection.py @@ -30,7 +30,12 @@ ) from .db.models import login as t_login from .decorators import timed, with_logger -from .exceptions import AuthenticationError, BanError, ClientError +from .exceptions import ( + AuthenticationError, + BanError, + ClientError, + DisabledError +) from .factions import Faction from .game_service import GameService from .gameconnection import GameConnection @@ -92,6 +97,7 @@ def __init__( self.user_agent = None self.version = None + self._timeout_task = None self._attempted_connectivity_test = False self._logger.debug("LobbyConnection initialized for '%s'", self.session) @@ -110,8 +116,15 @@ def get_user_identifier(self) -> str: async def on_connection_made(self, protocol: Protocol, peername: Address): self.protocol = protocol self.peer_address = peername + self._timeout_task = asyncio.create_task(self.timeout_login()) metrics.server_connections.inc() + async def timeout_login(self): + with contextlib.suppress(asyncio.CancelledError): + await asyncio.sleep(config.LOGIN_TIMEOUT) + if not self._authenticated: + await self.abort("Client took too long to log in.") + async def abort(self, logspam=""): self._authenticated = False @@ -167,37 +180,45 @@ async def on_message_received(self, message): handler = getattr(self, f"command_{cmd}") await handler(message) - except AuthenticationError as ex: - metrics.user_logins.labels("failure", ex.method).inc() + except AuthenticationError as e: + metrics.user_logins.labels("failure", e.method).inc() await self.send({ "command": "authentication_failed", - "text": ex.message + "text": e.message }) - except BanError as ex: + except BanError as e: await self.send({ "command": "notice", "style": "error", - "text": ex.message() + "text": e.message() }) - await self.abort(ex.message()) - except ClientError as ex: - self._logger.warning("Client error: %s", ex.message) + await self.abort(e.message()) + except ClientError as e: + self._logger.warning("Client error: %s", e.message) await self.send({ "command": "notice", "style": "error", - "text": ex.message + "text": e.message }) - if not ex.recoverable: - await self.abort(ex.message) - except (KeyError, ValueError) as ex: - self._logger.exception(ex) + if not e.recoverable: + await self.abort(e.message) + except (KeyError, ValueError) as e: + self._logger.exception(e) await self.abort(f"Garbage command: {message}") except ConnectionError as e: # Propagate connection errors to the ServerContext error handler. raise e - except Exception as ex: # pragma: no cover + except DisabledError: + # TODO: Respond with correlation uid for original message + await self.send({"command": "disabled", "request": cmd}) + self._logger.info( + "Ignoring disabled command for %s: %s", + self.get_user_identifier(), + cmd + ) + except Exception as e: # pragma: no cover await self.send({"command": "invalid"}) - self._logger.exception(ex) + self._logger.exception(e) await self.abort("Error processing command") async def command_ping(self, msg): @@ -240,7 +261,11 @@ async def command_coop_list(self, message): async def command_matchmaker_info(self, message): await self.send({ "command": "matchmaker_info", - "queues": [queue.to_dict() for queue in self.ladder_service.queues.values()] + "queues": [ + queue.to_dict() + for queue in self.ladder_service.queues.values() + if queue.is_running + ] }) async def send_game_list(self): @@ -319,11 +344,10 @@ async def command_admin(self, message): "Administrative action: %s closed game for %s", self.player, player ) - with contextlib.suppress(DisconnectedError): - await player.send_message({ - "command": "notice", - "style": "kill", - }) + player.write_message({ + "command": "notice", + "style": "kill", + }) elif action == "closelobby": if await self.player_service.has_permission_role( @@ -1231,6 +1255,9 @@ async def nop(*args, **kwargs): return self.send = nop + if self._timeout_task and not self._timeout_task.done(): + self._timeout_task.cancel() + if self.game_connection: self._logger.debug( "Lost lobby connection killing game connection for player %s", diff --git a/server/matchmaker/matchmaker_queue.py b/server/matchmaker/matchmaker_queue.py index 5785d0958..778407981 100644 --- a/server/matchmaker/matchmaker_queue.py +++ b/server/matchmaker/matchmaker_queue.py @@ -70,6 +70,10 @@ def __init__( self.matchmaker = TeamMatchMaker() + @property + def is_running(self) -> bool: + return self._is_running + def add_map_pool( self, map_pool: MapPool, @@ -102,7 +106,7 @@ async def queue_pop_timer(self) -> None: in the queue. """ self._logger.debug("MatchmakerQueue initialized for %s", self.name) - while self._is_running: + while self.is_running: try: await self.timer.next_pop() @@ -268,6 +272,7 @@ def match(self, s1: Search, s2: Search) -> bool: def shutdown(self): self._is_running = False + self.timer.cancel() def to_dict(self): """ @@ -278,7 +283,10 @@ def to_dict(self): "queue_pop_time": datetime.fromtimestamp( self.timer.next_queue_pop, timezone.utc ).isoformat(), - "queue_pop_time_delta": self.timer.next_queue_pop - time.time(), + "queue_pop_time_delta": round( + self.timer.next_queue_pop - time.time(), + ndigits=2 + ), "num_players": self.num_players, "boundary_80s": [search.boundary_80 for search in self._queue.keys()], "boundary_75s": [search.boundary_75 for search in self._queue.keys()], diff --git a/server/matchmaker/pop_timer.py b/server/matchmaker/pop_timer.py index c61bce84f..5425707f9 100644 --- a/server/matchmaker/pop_timer.py +++ b/server/matchmaker/pop_timer.py @@ -34,6 +34,7 @@ def __init__(self, queue: "MatchmakerQueue"): self._last_queue_pop = time() # Optimistically schedule first pop for half of the max pop time self.next_queue_pop = self._last_queue_pop + (config.QUEUE_POP_TIME_MAX / 2) + self._wait_task = None async def next_pop(self): """ Wait for the timer to pop. """ @@ -41,7 +42,10 @@ async def next_pop(self): time_remaining = self.next_queue_pop - time() self._logger.info("Next %s wave happening in %is", self.queue.name, time_remaining) metrics.matchmaker_queue_pop.labels(self.queue.name).set(int(time_remaining)) - await asyncio.sleep(time_remaining) + + self._wait_task = asyncio.create_task(asyncio.sleep(time_remaining)) + await self._wait_task + num_players = self.queue.num_players metrics.matchmaker_players.labels(self.queue.name).set(num_players) @@ -81,3 +85,7 @@ def time_until_next_pop(self, num_queued: int, time_queued: float) -> float: ) return config.QUEUE_POP_TIME_MAX return next_pop_time + + def cancel(self): + if self._wait_task: + self._wait_task.cancel() diff --git a/server/player_service.py b/server/player_service.py index 172808e5b..158b204b0 100644 --- a/server/player_service.py +++ b/server/player_service.py @@ -2,7 +2,7 @@ Manages connected and authenticated players """ -import contextlib +import asyncio from typing import Optional, ValuesView import aiocron @@ -10,10 +10,12 @@ from trueskill import Rating import server.metrics as metrics +from server.config import config from server.db import FAFDatabase from server.decorators import with_logger -from server.players import Player +from server.players import Player, PlayerState from server.rating import RatingType +from server.timing import at_interval from .core import Service from .db.models import ( @@ -265,16 +267,20 @@ async def update_data(self): ) self.uniqueid_exempt = frozenset(map(lambda x: x[0], result)) - async def shutdown(self): - for player in self: - if player.lobby_connection is not None: - with contextlib.suppress(Exception): - player.lobby_connection.write_warning( - "The server has been shut down for maintenance, " - "but should be back online soon. If you experience any " - "problems, please restart your client.

" - "We apologize for this interruption." - ) + async def kick_idle_players(self): + for fut in asyncio.as_completed([ + player.lobby_connection.abort("Graceful shutdown.") + for player in self.all_players + if player.state == PlayerState.IDLE + if player.lobby_connection is not None + ]): + try: + await fut + except Exception: + self._logger.debug( + "Error while aborting connection", + exc_info=True + ) def on_connection_lost(self, conn: "LobbyConnection") -> None: if not conn.player: @@ -288,3 +294,7 @@ def on_connection_lost(self, conn: "LobbyConnection") -> None: conn.player.login, conn.session ) + + async def graceful_shutdown(self): + if config.SHUTDOWN_KICK_IDLE_PLAYERS: + self._kick_idle_task = at_interval(1, self.kick_idle_players) diff --git a/server/servercontext.py b/server/servercontext.py index 3ac302864..d6feb95af 100644 --- a/server/servercontext.py +++ b/server/servercontext.py @@ -41,6 +41,7 @@ def __init__( super().__init__() self.name = name self._server = None + self._drain_event = None self._connection_factory = connection_factory self._services = services self.connections: dict[LobbyConnection, Protocol] = {} @@ -116,6 +117,18 @@ async def stop(self): self._server.close() await self._server.wait_closed() + async def drain_connections(self): + """ + Wait for all connections to terminate. + """ + if not self.connections: + return + + if not self._drain_event: + self._drain_event = asyncio.Event() + + await self._drain_event.wait() + def write_broadcast(self, message, validate_fn=lambda _: True): self.write_broadcast_raw( self.protocol_class.encode_message(message), @@ -233,6 +246,14 @@ async def handle_client_connected( self.name, connection.get_user_identifier() ) + + if ( + self._drain_event is not None + and not self._drain_event.is_set() + and not self.connections + ): + self._drain_event.set() + metrics.user_connections.labels( connection.user_agent, connection.version diff --git a/tests/conftest.py b/tests/conftest.py index 4f8e44b9c..e057c0c73 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,6 +16,8 @@ import hypothesis import pytest +from server import ServerInstance +from server.broadcast_service import BroadcastService from server.config import TRACE, config from server.db import FAFDatabase from server.game_service import GameService @@ -309,6 +311,18 @@ async def player_service(database): return player_service +@pytest.fixture +async def broadcast_service(game_service, message_queue_service, player_service): + broadcast_service = BroadcastService( + mock.create_autospec(ServerInstance), + message_queue_service, + game_service, + player_service + ) + await broadcast_service.initialize() + return broadcast_service + + @pytest.fixture async def rating_service(database, player_service, message_queue_service): service = RatingService(database, player_service, message_queue_service) diff --git a/tests/data/test_conf.yaml b/tests/data/test_conf.yaml index a81e8bef3..ce082aa6e 100644 --- a/tests/data/test_conf.yaml +++ b/tests/data/test_conf.yaml @@ -23,9 +23,6 @@ WWW_URL: "https://www.faforever.com" CONTENT_URL: "http://content.faforever.com" FAF_POLICY_SERVER_BASE_URL: "http://faf-policy-server" -FORCE_STEAM_LINK_AFTER_DATE: 1536105599 -FORCE_STEAM_LINK: false - NEWBIE_BASE_MEAN: 500 NEWBIE_MIN_GAMES: 10 TOP_PLAYER_MIN_RATING: 1600 diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index ef0d17cf7..ffeaea92d 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -26,6 +26,7 @@ from server.config import config from server.control import ControlServer from server.db.models import login +from server.health import HealthServer from server.protocol import Protocol, QDataStreamProtocol, SimpleJsonProtocol from server.servercontext import ServerContext from tests.utils import exhaust_callbacks @@ -170,7 +171,7 @@ async def make_lobby_server(config): context.__connected_client_protos = [] player_service.is_uniqueid_exempt = lambda id: True - return contexts + return instance, contexts mock_policy = mock.patch( "server.lobbyconnection.config.FAF_POLICY_SERVER_BASE_URL", @@ -190,7 +191,7 @@ async def make_lobby_server(config): @pytest.fixture -async def lobby_contexts(lobby_server_factory): +async def lobby_setup(lobby_server_factory): return await lobby_server_factory({ "qstream": { "ADDRESS": "127.0.0.1", @@ -206,7 +207,7 @@ async def lobby_contexts(lobby_server_factory): @pytest.fixture -async def lobby_contexts_proxy(lobby_server_factory): +async def lobby_setup_proxy(lobby_server_factory): return await lobby_server_factory({ "qstream": { "ADDRESS": "127.0.0.1", @@ -223,8 +224,28 @@ async def lobby_contexts_proxy(lobby_server_factory): }) +@pytest.fixture +def lobby_instance(lobby_setup): + instance, _ = lobby_setup + return instance + + +@pytest.fixture +def lobby_contexts(lobby_setup): + _, contexts = lobby_setup + return contexts + + +@pytest.fixture +def lobby_contexts_proxy(lobby_setup_proxy): + _, contexts = lobby_setup_proxy + return contexts + + +# TODO: This fixture is poorly named since it returns a ServerContext, however, +# it is used in almost every tests, so renaming it is a large task. @pytest.fixture(params=("qstream", "json")) -def lobby_server(request, lobby_contexts): +def lobby_server(request, lobby_contexts) -> ServerContext: yield lobby_contexts[request.param] @@ -234,14 +255,25 @@ def lobby_server_proxy(request, lobby_contexts_proxy): @pytest.fixture -async def control_server(player_service, game_service): - server = ControlServer( - game_service, - player_service, +async def control_server(lobby_instance): + server = ControlServer(lobby_instance) + await server.start( "127.0.0.1", config.CONTROL_SERVER_PORT ) - await server.start() + + yield server + + await server.shutdown() + + +@pytest.fixture +async def health_server(lobby_instance): + server = HealthServer(lobby_instance) + await server.start( + "127.0.0.1", + config.HEALTH_SERVER_PORT + ) yield server @@ -466,7 +498,7 @@ async def connect_and_sign_in( credentials, lobby_server: ServerContext, address: Optional[tuple[str, int]] = None -): +) -> tuple[int, int, Protocol]: proto = await connect_client(lobby_server, address) session = await get_session(proto) await perform_login(proto, credentials) diff --git a/tests/integration_tests/test_game.py b/tests/integration_tests/test_game.py index 8fd4f5ae8..929af3155 100644 --- a/tests/integration_tests/test_game.py +++ b/tests/integration_tests/test_game.py @@ -36,6 +36,7 @@ async def host_game( game_id = int(msg["uid"]) await open_fa(proto) + await read_until_command(proto, "HostGame", target="game") return game_id diff --git a/tests/integration_tests/test_health_server.py b/tests/integration_tests/test_health_server.py new file mode 100644 index 000000000..e352ec559 --- /dev/null +++ b/tests/integration_tests/test_health_server.py @@ -0,0 +1,16 @@ +import aiohttp + +from tests.utils import fast_forward + + +@fast_forward(2) +async def test_ready(health_server, lobby_instance): + url = f"http://{health_server.host}:{health_server.port}/ready" + async with aiohttp.ClientSession() as session: + async with session.get(url) as resp: + assert resp.status == 200 + + await lobby_instance.shutdown() + + async with session.get(url) as resp: + assert resp.status == 503 diff --git a/tests/integration_tests/test_server.py b/tests/integration_tests/test_server.py index 334d22673..160df9bc3 100644 --- a/tests/integration_tests/test_server.py +++ b/tests/integration_tests/test_server.py @@ -5,6 +5,7 @@ import pytest from sqlalchemy import and_, select +from server.config import config from server.db.models import avatars, avatars_list, ban from server.protocol import DisconnectedError from tests.utils import fast_forward @@ -18,7 +19,13 @@ read_until, read_until_command ) -from .test_game import host_game, join_game, open_fa, send_player_options +from .test_game import ( + host_game, + join_game, + open_fa, + send_player_options, + setup_game_1v1 +) @fast_forward(10) @@ -51,6 +58,41 @@ async def test_server_proxy_mode_direct(lobby_server_proxy, caplog): assert "this may indicate a misconfiguration" in caplog.text +@fast_forward(10) +async def test_login_timeout(lobby_server, monkeypatch, caplog): + monkeypatch.setattr(config, "LOGIN_TIMEOUT", 5) + + proto = await connect_client(lobby_server) + + await asyncio.sleep(5) + + with pytest.raises(DisconnectedError), caplog.at_level("TRACE"): + await proto.read_message() + + assert "Client took too long to log in" in caplog.text + + +@fast_forward(10) +async def test_disconnect_before_login_timeout( + lobby_server, + monkeypatch, + caplog +): + monkeypatch.setattr(config, "LOGIN_TIMEOUT", 5) + + proto = await connect_client(lobby_server) + + await asyncio.sleep(1) + + with pytest.raises(DisconnectedError), caplog.at_level("TRACE"): + await proto.close() + await asyncio.sleep(4) + await proto.read_message() + + assert "Client disconnected" in caplog.text + assert "Client took too long to log in" not in caplog.text + + async def test_server_deprecated_client(lobby_server): proto = await connect_client(lobby_server) @@ -128,6 +170,177 @@ async def test_ping_message(lobby_server): await read_until_command(proto, "ping", timeout=46) +@fast_forward(10) +async def test_graceful_shutdown( + lobby_instance, + lobby_server, + tmp_user, + monkeypatch +): + _, _, proto = await connect_and_sign_in( + await tmp_user("Player"), + lobby_server + ) + await read_until_command(proto, "game_info") + + await proto.send_message({ + "command": "matchmaker_info" + }) + msg = await read_until_command(proto, "matchmaker_info", timeout=5) + assert msg["queues"] + + await proto.send_message({ + "command": "game_matchmaking", + "state": "start", + "queue_name": "ladder1v1" + }) + await read_until_command(proto, "search_info", state="start", timeout=5) + + await lobby_instance.graceful_shutdown() + + # First we get the notice message + msg = await read_until_command(proto, "notice", timeout=5) + assert "The server will be shutting down" in msg["text"] + + # Then the queues are shut down + await read_until_command(proto, "search_info", state="stop", timeout=5) + + # Matchmaker info should not include any queues + await proto.send_message({ + "command": "matchmaker_info" + }) + msg = await read_until_command(proto, "matchmaker_info", timeout=5) + assert msg == { + "command": "matchmaker_info", + "queues": [] + } + + # Hosting custom games should be disabled + await proto.send_message({ + "command": "game_host", + "visibility": "public", + }) + msg = await read_until_command(proto, "disabled", timeout=5) + assert msg == { + "command": "disabled", + "request": "game_host" + } + + # Joining matchmaker queues should be disabled + await proto.send_message({ + "command": "game_host", + "visibility": "public", + }) + msg = await read_until_command(proto, "disabled", timeout=5) + assert msg == { + "command": "disabled", + "request": "game_host" + } + + +@fast_forward(10) +async def test_graceful_shutdown_kick( + lobby_instance, + lobby_server, + tmp_user, + monkeypatch +): + monkeypatch.setattr(config, "SHUTDOWN_KICK_IDLE_PLAYERS", True) + + async def test_connected(proto) -> bool: + try: + await proto.send_message({"command": "ping"}) + await read_until_command(proto, "pong", timeout=5) + return True + except (DisconnectedError, ConnectionError, asyncio.TimeoutError): + return False + + _, _, idle_proto = await connect_and_sign_in( + await tmp_user("Idle"), + lobby_server + ) + _, _, host_proto = await connect_and_sign_in( + await tmp_user("Host"), + lobby_server + ) + player1_id, _, player1_proto = await connect_and_sign_in( + await tmp_user("Player"), + lobby_server + ) + player2_id, _, player2_proto = await connect_and_sign_in( + await tmp_user("Player"), + lobby_server + ) + protos = ( + idle_proto, + host_proto, + player1_proto, + player2_proto + ) + for proto in protos: + await read_until_command(proto, "game_info") + + # Host is in lobby, not playing a game + await host_game(host_proto, visibility="public") + + # Player1 and Player2 are playing a game + await setup_game_1v1( + player1_proto, + player1_id, + player2_proto, + player2_id + ) + + await lobby_instance.graceful_shutdown() + + # Check that everyone is notified + for proto in protos: + msg = await read_until_command(proto, "notice") + assert "The server will be shutting down" in msg["text"] + + # Players in lobby should be told to close their games + await read_until_command(host_proto, "notice", style="kill") + + # Idle players are kicked every 1 second + await asyncio.sleep(1) + + assert await test_connected(idle_proto) is False + assert await test_connected(host_proto) is False + + assert await test_connected(player1_proto) is True + assert await test_connected(player2_proto) is True + + +@fast_forward(60) +async def test_drain( + lobby_instance, + lobby_server, + tmp_user, + monkeypatch, + caplog, +): + monkeypatch.setattr(config, "SHUTDOWN_GRACE_PERIOD", 10) + + player1_id, _, player1_proto = await connect_and_sign_in( + await tmp_user("Player"), + lobby_server + ) + player2_id, _, player2_proto = await connect_and_sign_in( + await tmp_user("Player"), + lobby_server + ) + await setup_game_1v1( + player1_proto, + player1_id, + player2_proto, + player2_id + ) + + await lobby_instance.drain() + + assert "Graceful shutdown period ended! 1 games are still live!" in caplog.messages + + @fast_forward(5) async def test_player_info_broadcast(lobby_server): p1 = await connect_client(lobby_server) diff --git a/tests/integration_tests/test_servercontext.py b/tests/integration_tests/test_servercontext.py index 935e9caa0..c09192ab8 100644 --- a/tests/integration_tests/test_servercontext.py +++ b/tests/integration_tests/test_servercontext.py @@ -9,7 +9,7 @@ from server.core import Service from server.lobbyconnection import LobbyConnection from server.protocol import DisconnectedError, QDataStreamProtocol -from tests.utils import exhaust_callbacks +from tests.utils import exhaust_callbacks, fast_forward class MockConnection: @@ -141,3 +141,22 @@ async def test_unexpected_exception_in_connection_lost(context, caplog): await asyncio.sleep(0.1) assert "Unexpected exception in on_connection_lost" in caplog.text + + +@fast_forward(20) +async def test_drain_connections(context): + srv, ctx = context + _, writer = await asyncio.open_connection(*srv.sockets[0].getsockname()) + + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for( + ctx.drain_connections(), + timeout=10 + ) + + writer.close() + + await asyncio.wait_for( + ctx.drain_connections(), + timeout=3 + ) diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py index 0c8217a61..7d65cdbb7 100644 --- a/tests/unit_tests/conftest.py +++ b/tests/unit_tests/conftest.py @@ -127,7 +127,10 @@ def game_stats_service(): def add_connected_player(game: Game, player): game.game_service.player_service[player.id] = player - gc = make_mock_game_connection(state=GameConnectionState.CONNECTED_TO_HOST, player=player) + gc = make_mock_game_connection( + state=GameConnectionState.CONNECTED_TO_HOST, + player=player + ) game.set_player_option(player.id, "Army", 0) game.set_player_option(player.id, "StartSpot", 0) game.set_player_option(player.id, "Team", 0) diff --git a/tests/unit_tests/test_asyncio_extensions.py b/tests/unit_tests/test_asyncio_extensions.py index 67aa7dde6..62c70df26 100644 --- a/tests/unit_tests/test_asyncio_extensions.py +++ b/tests/unit_tests/test_asyncio_extensions.py @@ -5,7 +5,7 @@ from server.asyncio_extensions import ( SpinLock, - gather_without_exceptions, + map_suppress, synchronized, synchronizedmethod ) @@ -16,41 +16,37 @@ class CustomError(Exception): pass -async def raises_connection_error(): - raise ConnectionError("Test ConnectionError") +async def test_map_suppress(caplog): + obj1 = mock.AsyncMock() + obj2 = mock.AsyncMock() + obj2.test.side_effect = CustomError("Test Exception") + obj2.__str__.side_effect = lambda: "TestObject" + with caplog.at_level("TRACE"): + await map_suppress( + lambda x: x.test(), + [obj1, obj2] + ) -async def raises_connection_reset_error(): - raise ConnectionResetError("Test ConnectionResetError") + obj1.test.assert_called_once() + obj2.test.assert_called_once() + assert "Unexpected error TestObject" in caplog.messages -async def raises_custom_error(): - raise CustomError("Test Exception") +async def test_map_suppress_message(caplog): + obj1 = mock.AsyncMock() + obj1.test.side_effect = CustomError("Test Exception") + obj1.__str__.side_effect = lambda: "TestObject" + with caplog.at_level("TRACE"): + await map_suppress( + lambda x: x.test(), + [obj1], + msg="when testing " + ) -async def test_gather_without_exceptions(): - completes_correctly = mock.AsyncMock() - - with pytest.raises(CustomError): - await gather_without_exceptions([ - raises_connection_error(), - raises_custom_error(), - completes_correctly() - ], ConnectionError) - - completes_correctly.assert_called_once() - - -async def test_gather_without_exceptions_subclass(): - completes_correctly = mock.AsyncMock() - - await gather_without_exceptions([ - raises_connection_error(), - raises_connection_reset_error(), - completes_correctly() - ], ConnectionError) - - completes_correctly.assert_called_once() + obj1.test.assert_called_once() + assert "Unexpected error when testing TestObject" in caplog.messages @fast_forward(15) diff --git a/tests/unit_tests/test_broadcast_service.py b/tests/unit_tests/test_broadcast_service.py new file mode 100644 index 000000000..fc2fa5814 --- /dev/null +++ b/tests/unit_tests/test_broadcast_service.py @@ -0,0 +1,22 @@ +import asyncio + + +async def test_broadcast_shutdown(broadcast_service): + await broadcast_service.shutdown() + + broadcast_service.server.write_broadcast.assert_called_once() + + +async def test_broadcast_ping(broadcast_service): + broadcast_service.broadcast_ping() + + broadcast_service.server.write_broadcast.assert_called_once_with( + {"command": "ping"} + ) + + +async def test_wait_report_dirties(broadcast_service): + await asyncio.wait_for( + broadcast_service.wait_report_dirtes(), + timeout=1 + ) diff --git a/tests/unit_tests/test_games_service.py b/tests/unit_tests/test_games_service.py index 6d3525362..55fd74674 100644 --- a/tests/unit_tests/test_games_service.py +++ b/tests/unit_tests/test_games_service.py @@ -1,6 +1,19 @@ +import asyncio + +import pytest + from server.db.models import game_stats -from server.games import CustomGame, Game, LadderGame, VisibilityState +from server.exceptions import DisabledError +from server.games import ( + CustomGame, + Game, + GameState, + LadderGame, + VisibilityState +) from server.players import PlayerState +from tests.unit_tests.conftest import add_connected_player +from tests.utils import fast_forward async def test_initialization(game_service): @@ -18,6 +31,34 @@ async def test_initialize_game_counter_empty(game_service, database): assert game_service.game_id_counter == 0 +async def test_graceful_shutdown(game_service): + await game_service.graceful_shutdown() + + with pytest.raises(DisabledError): + game_service.create_game( + game_mode="faf", + ) + + +@fast_forward(2) +async def test_drain_games(game_service): + game = game_service.create_game( + game_mode="faf", + name="TestGame" + ) + + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(game_service.drain_games(), 1) + + game_service.remove_game(game) + + await asyncio.wait_for(game_service.drain_games(), 1) + assert len(game_service.all_games) == 0 + + # Calling drain_games again should return immediately + await asyncio.wait_for(game_service.drain_games(), 1) + + async def test_create_game(players, game_service): players.hosting.state = PlayerState.IDLE game = game_service.create_game( @@ -46,6 +87,7 @@ async def test_all_games(players, game_service): password=None ) assert game in game_service.pending_games + assert game in game_service.all_games assert isinstance(game, CustomGame) @@ -75,3 +117,21 @@ async def test_create_game_other_gamemode(players, game_service): assert game in game_service.pop_dirty_games() assert isinstance(game, Game) assert game.game_mode == "labwars" + + +async def test_close_lobby_games(players, game_service): + game = game_service.create_game( + visibility=VisibilityState.PUBLIC, + game_mode="faf", + host=players.hosting, + name="Test", + mapname="SCMP_007", + password=None + ) + game.state = GameState.LOBBY + conn = add_connected_player(game, players.hosting) + assert conn in game.connections + + await game_service.close_lobby_games() + + conn.abort.assert_called_once() diff --git a/tests/unit_tests/test_ladder_service.py b/tests/unit_tests/test_ladder_service.py index b48ead13f..b62b063a3 100644 --- a/tests/unit_tests/test_ladder_service.py +++ b/tests/unit_tests/test_ladder_service.py @@ -7,6 +7,7 @@ from server import LadderService from server.db.models import matchmaker_queue, matchmaker_queue_map_pool +from server.exceptions import DisabledError from server.games import LadderGame from server.games.ladder_game import GameClosedError from server.ladder_service import game_name @@ -1072,3 +1073,73 @@ async def test_write_rating_progress_other_rating( ladder_service.write_rating_progress(player, RatingType.GLOBAL) player.write_message.assert_not_called() + + +async def test_graceful_shutdown_disables_searching( + ladder_service: LadderService, + player_factory +): + p1 = player_factory( + "Dostya", + player_id=1, + ladder_rating=(1000, 10) + ) + + await ladder_service.graceful_shutdown() + + with pytest.raises(DisabledError): + ladder_service.start_search([p1], "ladder1v1") + + +async def test_graceful_shutdown_clears_queues( + ladder_service: LadderService, + player_factory +): + p1 = player_factory( + "Dostya", + player_id=1, + ladder_rating=(1000, 10) + ) + p1.write_message = mock.Mock() + + p2 = player_factory( + "QAI", + player_id=2, + ladder_rating=(2350, 125), + ) + p2.write_message = mock.Mock() + + p3 = player_factory( + "Brackman", + player_id=3, + ladder_rating=(1000, 10) + ) + p3.write_message = mock.Mock() + + ladder_service.start_search([p1], "ladder1v1") + p1.write_message.reset_mock() + + ladder_service.start_search([p2, p3], "tmm2v2") + p2.write_message.reset_mock() + p3.write_message.reset_mock() + + await ladder_service.graceful_shutdown() + + p1.write_message.assert_called_once_with({ + "command": "search_info", + "state": "stop", + "queue_name": "ladder1v1" + }) + p2.write_message.assert_called_once_with({ + "command": "search_info", + "state": "stop", + "queue_name": "tmm2v2" + }) + p3.write_message.assert_called_once_with({ + "command": "search_info", + "state": "stop", + "queue_name": "tmm2v2" + }) + + assert ladder_service.queues["ladder1v1"]._is_running is False + assert ladder_service.queues["tmm2v2"]._is_running is False diff --git a/tests/unit_tests/test_lobbyconnection.py b/tests/unit_tests/test_lobbyconnection.py index b60250538..0de4aa751 100644 --- a/tests/unit_tests/test_lobbyconnection.py +++ b/tests/unit_tests/test_lobbyconnection.py @@ -585,7 +585,7 @@ async def test_command_admin_closeFA(lobbyconnection, player_factory): "user_id": tuna.id }) - tuna.lobby_connection.send.assert_any_call({ + tuna.lobby_connection.write.assert_any_call({ "command": "notice", "style": "kill", }) diff --git a/tests/unit_tests/test_player_service.py b/tests/unit_tests/test_player_service.py index 47ee1b583..1dcf6cf3b 100644 --- a/tests/unit_tests/test_player_service.py +++ b/tests/unit_tests/test_player_service.py @@ -1,7 +1,5 @@ from unittest import mock -from server.lobbyconnection import LobbyConnection -from server.protocol import DisconnectedError from server.rating import RatingType @@ -112,27 +110,3 @@ async def test_mark_dirty(player_factory, player_service): async def test_update_data(player_service): await player_service.update_data() assert player_service.is_uniqueid_exempt(1) is True - - -async def test_broadcast_shutdown(player_factory, player_service): - player = player_factory() - lconn = mock.create_autospec(LobbyConnection) - player.lobby_connection = lconn - player_service[0] = player - - await player_service.shutdown() - - player.lobby_connection.write_warning.assert_called_once() - - -async def test_broadcast_shutdown_error(player_factory, player_service): - player = player_factory() - lconn = mock.create_autospec(LobbyConnection) - lconn.write_warning.side_effect = DisconnectedError - player.lobby_connection = lconn - - player_service[0] = player - - await player_service.shutdown() - - player.lobby_connection.write_warning.assert_called_once()