diff --git a/main.py b/main.py
index 0c20dffb2..0c5f90592 100755
--- a/main.py
+++ b/main.py
@@ -10,53 +10,65 @@
import asyncio
import logging
import os
-import platform
import signal
import sys
import time
from datetime import datetime
+from functools import wraps
import humanize
from docopt import docopt
+from prometheus_client import start_http_server
import server
+from server import info
from server.config import config
+from server.control import ControlServer
from server.game_service import GameService
+from server.health import HealthServer
from server.ice_servers.nts import TwilioNTS
from server.player_service import PlayerService
from server.profiler import Profiler
from server.protocol import QDataStreamProtocol, SimpleJsonProtocol
+def log_signal(func):
+ @wraps(func)
+ def wrapped(sig, frame):
+ logger.info("Received signal %s", signal.Signals(sig))
+ return func(sig, frame)
+
+ return wrapped
+
+
async def main():
global startup_time, shutdown_time
- version = os.environ.get("VERSION") or "dev"
- python_version = platform.python_version()
-
logger.info(
- "Lobby %s (Python %s) on %s",
- version,
- python_version,
- sys.platform
+ "Lobby %s (Python %s) on %s named %s",
+ info.VERSION,
+ info.PYTHON_VERSION,
+ sys.platform,
+ info.CONTAINER_NAME,
)
+ if config.ENABLE_METRICS:
+ logger.info("Using prometheus on port: %i", config.METRICS_PORT)
+ start_http_server(config.METRICS_PORT)
+
loop = asyncio.get_running_loop()
done = loop.create_future()
logger.info("Event loop: %s", loop)
- def signal_handler(sig: int, _frame):
- logger.info(
- "Received signal %s, shutting down",
- signal.Signals(sig)
- )
+ @log_signal
+ def done_handler(sig: int, frame):
if not done.done():
done.set_result(0)
# Make sure we can shutdown gracefully
- signal.signal(signal.SIGTERM, signal_handler)
- signal.signal(signal.SIGINT, signal_handler)
+ signal.signal(signal.SIGTERM, done_handler)
+ signal.signal(signal.SIGINT, done_handler)
database = server.db.FAFDatabase(
host=config.DB_SERVER,
@@ -91,19 +103,21 @@ def signal_handler(sig: int, _frame):
config.register_callback("PROFILING_DURATION", profiler.refresh)
config.register_callback("PROFILING_INTERVAL", profiler.refresh)
- await instance.start_services()
-
- ctrl_server = await server.run_control_server(player_service, game_service)
+ health_server = HealthServer(instance)
+ await health_server.run_from_config()
+ config.register_callback(
+ "HEALTH_SERVER_PORT",
+ health_server.run_from_config
+ )
- async def restart_control_server():
- nonlocal ctrl_server
+ control_server = ControlServer(instance)
+ await control_server.run_from_config()
+ config.register_callback(
+ "CONTROL_SERVER_PORT",
+ control_server.run_from_config
+ )
- await ctrl_server.shutdown()
- ctrl_server = await server.run_control_server(
- player_service,
- game_service
- )
- config.register_callback("CONTROL_SERVER_PORT", restart_control_server)
+ await instance.start_services()
PROTO_CLASSES = {
QDataStreamProtocol.__name__: QDataStreamProtocol,
@@ -135,8 +149,8 @@ async def restart_control_server():
)
server.metrics.info.info({
- "version": version,
- "python_version": python_version,
+ "version": info.VERSION,
+ "python_version": info.PYTHON_VERSION,
"start_time": datetime.utcnow().strftime("%m-%d %H:%M"),
"game_uid": str(game_service.game_id_counter)
})
@@ -150,12 +164,27 @@ async def restart_control_server():
shutdown_time = time.perf_counter()
# Cleanup
- await instance.shutdown()
- await ctrl_server.shutdown()
+ await instance.graceful_shutdown()
+
+ drain_task = asyncio.create_task(instance.drain())
- # Close DB connections
+ @log_signal
+ def drain_handler(sig: int, frame):
+ if not drain_task.done():
+ drain_task.cancel()
+
+ # Allow us to force shut down by skipping the drain
+ signal.signal(signal.SIGTERM, drain_handler)
+ signal.signal(signal.SIGINT, drain_handler)
+
+ await drain_task
+ await instance.shutdown()
+ await control_server.shutdown()
await database.close()
+ # Health server should be the last thing to shut down
+ await health_server.shutdown()
+
return exit_code
@@ -191,7 +220,7 @@ async def restart_control_server():
stop_time = time.perf_counter()
logger.info(
"Total server uptime: %s",
- humanize.naturaldelta(stop_time - startup_time)
+ humanize.precisedelta(stop_time - startup_time)
)
if shutdown_time is not None:
diff --git a/minikube-example.yaml b/minikube-example.yaml
new file mode 100644
index 000000000..2258f8242
--- /dev/null
+++ b/minikube-example.yaml
@@ -0,0 +1,83 @@
+apiVersion: v1
+kind: Service
+metadata:
+ name: faf-lobby
+ labels:
+ app: faf-lobby
+spec:
+ type: NodePort
+ selector:
+ app: faf-lobby
+ ports:
+ - port: 8001
+ name: qstream
+ - port: 8002
+ name: simplejson
+---
+apiVersion: apps/v1
+kind: Deployment
+metadata:
+ name: faf-lobby
+spec:
+ replicas: 1
+ selector:
+ matchLabels:
+ app: faf-lobby
+ template:
+ metadata:
+ labels:
+ app: faf-lobby
+ spec:
+ terminationGracePeriodSeconds: 310
+ containers:
+ - name: faf-python-server
+ image: faf-python-server:graceful
+ imagePullPolicy: Never
+ readinessProbe:
+ httpGet:
+ path: /ready
+ port: health
+ initialDelaySeconds: 4
+ periodSeconds: 1
+ ports:
+ - containerPort: 4000
+ name: control
+ - containerPort: 2000
+ name: health
+ - containerPort: 8001
+ name: qstream
+ - containerPort: 8002
+ name: simplejson
+ env:
+ - name: CONFIGURATION_FILE
+ value: /config/config.yaml
+ - name: CONTAINER_NAME
+ valueFrom:
+ fieldRef:
+ fieldPath: metadata.name
+ volumeMounts:
+ - name: config
+ mountPath: /config
+ readOnly: true
+ volumes:
+ - name: config
+ configMap:
+ name: minikube-dev-config
+ items:
+ - key: config.yaml
+ path: config.yaml
+---
+apiVersion: v1
+kind: ConfigMap
+metadata:
+ name: minikube-dev-config
+data:
+ config.yaml: |
+ LOG_LEVEL: TRACE
+ USE_POLICY_SERVER: false
+ QUEUE_POP_TIME_MAX: 30
+ SHUTDOWN_GRACE_PERIOD: 300
+ SHUTDOWN_KICK_IDLE_PLAYERS: true
+
+ DB_SERVER: host.minikube.internal
+ MQ_SERVER: host.minikube.internal
diff --git a/server/__init__.py b/server/__init__.py
index c262075e1..e3837209f 100644
--- a/server/__init__.py
+++ b/server/__init__.py
@@ -90,15 +90,12 @@
import time
from typing import Optional
-from prometheus_client import start_http_server
-
import server.metrics as metrics
-from .asyncio_extensions import synchronizedmethod
+from .asyncio_extensions import map_suppress, synchronizedmethod
from .broadcast_service import BroadcastService
from .config import TRACE, config
from .configuration_service import ConfigurationService
-from .control import run_control_server
from .core import Service, create_services
from .db import FAFDatabase
from .game_service import GameService
@@ -138,18 +135,12 @@
"RatingService",
"ServerInstance",
"ViolationService",
- "control",
"game_service",
"protocol",
- "run_control_server",
)
logger = logging.getLogger("server")
-if config.ENABLE_METRICS:
- logger.info("Using prometheus on port: %i", config.METRICS_PORT)
- start_http_server(config.METRICS_PORT)
-
class ServerInstance(object):
"""
@@ -273,40 +264,86 @@ async def listen(
return ctx
- async def shutdown(self):
- results = await asyncio.gather(
- *(ctx.stop() for ctx in self.contexts),
- return_exceptions=True
- )
- for result, ctx in zip(results, self.contexts):
- if isinstance(result, BaseException):
- self._logger.exception(
- "Unexpected error when stopping context %s",
- ctx,
- exc_info=result
- )
+ async def graceful_shutdown(self):
+ """
+ Start a graceful shut down of the server.
- results = await asyncio.gather(
- *(service.shutdown() for service in self.services.values()),
- return_exceptions=True
- )
- for result, service in zip(results, self.services.values()):
- if isinstance(result, BaseException):
- self._logger.error(
- "Unexpected error when shutting down service %s",
- service
- )
+ 1. Notify all services of graceful shutdown
+ """
+ self._logger.info("Initiating graceful shutdown")
- results = await asyncio.gather(
- *(ctx.shutdown() for ctx in self.contexts),
- return_exceptions=True
+ await map_suppress(
+ lambda service: service.graceful_shutdown(),
+ self.services.values(),
+ logger=self._logger,
+ msg="when starting graceful shutdown of service "
)
- for result, ctx in zip(results, self.contexts):
- if isinstance(result, BaseException):
- self._logger.error(
- "Unexpected error when shutting down context %s",
- ctx
- )
+
+ async def shutdown(self):
+ """
+ Immediately shutdown the server.
+
+ 1. Stop accepting new connections
+ 2. Stop all services
+ 3. Close all existing connections
+ """
+ self._logger.info("Initiating full shutdown")
+
+ await self._stop_contexts()
+ await self._shutdown_services()
+ await self._shutdown_contexts()
self.contexts.clear()
self.started = False
+
+ async def drain(self):
+ """
+ Wait for all games to end.
+ """
+ game_service: GameService = self.services["game_service"]
+ broadcast_service: BroadcastService = self.services["broadcast_service"]
+ try:
+ await asyncio.wait_for(
+ game_service.drain_games(),
+ timeout=config.SHUTDOWN_GRACE_PERIOD
+ )
+ except asyncio.CancelledError:
+ self._logger.debug(
+ "Stopped waiting for games to end due to forced shutdown"
+ )
+ except asyncio.TimeoutError:
+ self._logger.warning(
+ "Graceful shutdown period ended! %s games are still live!",
+ len(game_service.live_games)
+ )
+ finally:
+ # The report_dirties loop is responsible for clearing dirty games
+ # and broadcasting the update messages to players and to RabbitMQ.
+ # We need to wait here for that loop to complete otherwise it is
+ # possible for the services to be shut down inbetween clearing the
+ # games and posting the messages, causing the posts to fail.
+ await broadcast_service.wait_report_dirtes()
+
+ async def _shutdown_services(self):
+ await map_suppress(
+ lambda service: service.shutdown(),
+ self.services.values(),
+ logger=self._logger,
+ msg="when shutting down service "
+ )
+
+ async def _stop_contexts(self):
+ await map_suppress(
+ lambda ctx: ctx.stop(),
+ self.contexts,
+ logger=self._logger,
+ msg="when stopping context "
+ )
+
+ async def _shutdown_contexts(self):
+ await map_suppress(
+ lambda ctx: ctx.shutdown(),
+ self.contexts,
+ logger=self._logger,
+ msg="when shutting down context "
+ )
diff --git a/server/asyncio_extensions.py b/server/asyncio_extensions.py
index 2773de53e..3d5915126 100644
--- a/server/asyncio_extensions.py
+++ b/server/asyncio_extensions.py
@@ -12,8 +12,10 @@
AsyncContextManager,
Callable,
Coroutine,
+ Iterable,
Optional,
Protocol,
+ TypeVar,
cast,
overload
)
@@ -22,6 +24,7 @@
AsyncFunc = Callable[..., Coroutine[Any, Any, Any]]
AsyncDecorator = Callable[[AsyncFunc], AsyncFunc]
+T = TypeVar("T")
class AsyncLock(Protocol, AsyncContextManager["AsyncLock"]):
@@ -30,23 +33,24 @@ async def acquire(self) -> bool: ...
def release(self) -> None: ...
-async def gather_without_exceptions(
- tasks: list[asyncio.Task],
- *exceptions: type[BaseException],
-) -> list[Any]:
- """
- Run coroutines in parallel, raising the first exception that dosen't
- match any of the specified exception classes.
- """
- results = []
- for fut in asyncio.as_completed(tasks):
- try:
- results.append(await fut)
- except exceptions:
- logger.debug(
- "Ignoring error in gather_without_exceptions", exc_info=True
+async def map_suppress(
+ func: Callable[[T], Coroutine[Any, Any, Any]],
+ iterable: Iterable[T],
+ logger: logging.Logger = logger,
+ msg: str = ""
+):
+ results = await asyncio.gather(
+ *(func(item) for item in iterable),
+ return_exceptions=True
+ )
+ for result, item in zip(results, iterable):
+ if isinstance(result, BaseException):
+ logger.exception(
+ "Unexpected error %s%s",
+ msg,
+ item,
+ exc_info=result
)
- return results
# Based on python3.8 asyncio.Lock
diff --git a/server/broadcast_service.py b/server/broadcast_service.py
index 538391eec..6eaf6da2f 100644
--- a/server/broadcast_service.py
+++ b/server/broadcast_service.py
@@ -1,3 +1,6 @@
+import asyncio
+
+import humanize
from aio_pika import DeliveryMode
from .config import config
@@ -27,13 +30,14 @@ def __init__(
self.message_queue_service = message_queue_service
self.game_service = game_service
self.player_service = player_service
+ self._report_dirties_event = None
async def initialize(self):
# Using a lazy interval timer so that the intervals can be changed
# without restarting the server.
self._broadcast_dirties_timer = LazyIntervalTimer(
lambda: config.DIRTY_REPORT_INTERVAL,
- self.report_dirties,
+ self._monitored_report_dirties,
start=True
)
self._broadcast_ping_timer = LazyIntervalTimer(
@@ -42,6 +46,14 @@ async def initialize(self):
start=True
)
+ async def _monitored_report_dirties(self):
+ event = asyncio.Event()
+ self._report_dirties_event = event
+ try:
+ await self.report_dirties()
+ finally:
+ event.set()
+
async def report_dirties(self):
"""
Send updates about any dirty (changed) entities to connected players.
@@ -116,3 +128,48 @@ async def report_dirties(self):
def broadcast_ping(self):
self.server.write_broadcast({"command": "ping"})
+
+ async def wait_report_dirtes(self):
+ """
+ Wait for the current report_dirties task to complete.
+ """
+ if self._report_dirties_event is None:
+ return
+
+ await self._report_dirties_event.wait()
+
+ async def graceful_shutdown(self):
+ if config.SHUTDOWN_KICK_IDLE_PLAYERS:
+ message = (
+ "If you're in a game you can continue to play, otherwise you "
+ "will be disconnected. If you aren't reconnected automatically "
+ "please wait a few minutes and try to connect again."
+ )
+ else:
+ message = (
+ "If you're in a game you can continue to play, however, you "
+ "will not be able to create any new games until the server has "
+ "been restarted."
+ )
+
+ delta = humanize.precisedelta(config.SHUTDOWN_GRACE_PERIOD)
+ self.server.write_broadcast({
+ "command": "notice",
+ "style": "info",
+ "text": (
+ f"The server will be shutting down for maintenance in {delta}! "
+ f"{message}"
+ )
+ })
+
+ async def shutdown(self):
+ self.server.write_broadcast({
+ "command": "notice",
+ "style": "info",
+ "text": (
+ "The server has been shut down for maintenance "
+ "but should be back online soon. If you experience any "
+ "problems, please restart your client.
"
+ "We apologize for this interruption."
+ )
+ })
diff --git a/server/config.py b/server/config.py
index b0c901e22..aade3ee29 100644
--- a/server/config.py
+++ b/server/config.py
@@ -67,8 +67,15 @@ def __init__(self):
self.DIRTY_REPORT_INTERVAL = 1
self.PING_INTERVAL = 45
+ # How many seconds to wait for games to end before doing a hard shutdown.
+ # If using kubernetes, you must set terminationGracePeriodSeconds
+ # on the pod to be larger than this value. With docker compose, use
+ # --timeout (-t) to set a longer timeout.
+ self.SHUTDOWN_GRACE_PERIOD = 30 * 60
+ self.SHUTDOWN_KICK_IDLE_PLAYERS = False
self.CONTROL_SERVER_PORT = 4000
+ self.HEALTH_SERVER_PORT = 2000
self.METRICS_PORT = 8011
self.ENABLE_METRICS = False
@@ -97,10 +104,9 @@ def __init__(self):
self.FAF_POLICY_SERVER_BASE_URL = "http://faf-policy-server"
self.USE_POLICY_SERVER = True
- self.FORCE_STEAM_LINK_AFTER_DATE = 1536105599 # 5 september 2018 by default
- self.FORCE_STEAM_LINK = False
-
self.ALLOW_PASSWORD_LOGIN = True
+ # How many seconds a connection has to authenticate before being killed
+ self.LOGIN_TIMEOUT = 5 * 60
self.NEWBIE_BASE_MEAN = 500
self.NEWBIE_MIN_GAMES = 10
@@ -171,10 +177,14 @@ def refresh(self) -> None:
with open(config_file) as f:
new_values.update(yaml.safe_load(f))
except FileNotFoundError:
- self._logger.info("No configuration file found at %s", config_file)
+ self._logger.warning(
+ "No configuration file found at %s",
+ config_file
+ )
except TypeError:
self._logger.info(
- "Configuration file at %s appears to be empty", config_file
+ "Configuration file at %s appears to be empty",
+ config_file
)
triggered_callback_keys = tuple(
diff --git a/server/control.py b/server/control.py
index caed69319..328b95603 100644
--- a/server/control.py
+++ b/server/control.py
@@ -3,73 +3,66 @@
"""
import socket
-from json import dumps
from aiohttp import web
from .config import config
from .decorators import with_logger
-from .game_service import GameService
-from .player_service import PlayerService
@with_logger
class ControlServer:
def __init__(
self,
- game_service: GameService,
- player_service: PlayerService,
- host: str,
- port: int
+ lobby_server: "ServerInstance",
):
- self.game_service = game_service
- self.player_service = player_service
- self.host = host
- self.port = port
+ self.lobby_server = lobby_server
+ self.game_service = lobby_server.services["game_service"]
+ self.player_service = lobby_server.services["player_service"]
+ self.host = None
+ self.port = None
self.app = web.Application()
self.runner = web.AppRunner(self.app)
self.app.add_routes([
web.get("/games", self.games),
- web.get("/players", self.players)
+ web.get("/players", self.players),
])
- async def start(self) -> None:
+ async def run_from_config(self) -> None:
+ """
+ Initialize the http control server
+ """
+ host = socket.gethostbyname(socket.gethostname())
+ port = config.CONTROL_SERVER_PORT
+
+ await self.shutdown()
+ await self.start(host, port)
+
+ async def start(self, host: str, port: int) -> None:
+ self.host = host
+ self.port = port
await self.runner.setup()
- self.site = web.TCPSite(self.runner, self.host, self.port)
+ self.site = web.TCPSite(self.runner, host, port)
await self.site.start()
self._logger.info(
- "Control server listening on http://%s:%s", self.host, self.port
+ "Control server listening on http://%s:%s", host, port
)
async def shutdown(self) -> None:
await self.runner.cleanup()
+ self.host = None
+ self.port = None
- async def games(self, request):
- body = dumps(to_dict_list(self.game_service.all_games))
- return web.Response(body=body.encode(), content_type="application/json")
-
- async def players(self, request):
- body = dumps(to_dict_list(self.player_service.all_players))
- return web.Response(body=body.encode(), content_type="application/json")
-
-
-async def run_control_server(
- player_service: PlayerService,
- game_service: GameService
-) -> ControlServer:
- """
- Initialize the http control server
- """
- host = socket.gethostbyname(socket.gethostname())
- port = config.CONTROL_SERVER_PORT
-
- ctrl_server = ControlServer(game_service, player_service, host, port)
- await ctrl_server.start()
-
- return ctrl_server
-
+ async def games(self, request) -> web.Response:
+ return web.json_response([
+ game.to_dict()
+ for game in self.game_service.all_games
+ ])
-def to_dict_list(list_):
- return list(map(lambda p: p.to_dict(), list_))
+ async def players(self, request) -> web.Response:
+ return web.json_response([
+ player.to_dict()
+ for player in self.player_service.all_players
+ ])
diff --git a/server/core/service.py b/server/core/service.py
index e2bae1bc3..ed5ecaa0d 100644
--- a/server/core/service.py
+++ b/server/core/service.py
@@ -30,6 +30,17 @@ async def initialize(self) -> None:
"""
pass # pragma: no cover
+ async def graceful_shutdown(self) -> None:
+ """
+ Called once after the graceful shutdown period is initiated.
+
+ This signals that the service should stop accepting new events but
+ continue to wait for existing ones to complete normally. The hook
+ funciton `shutdown` will be called after the grace period has ended to
+ fully shutdown the service.
+ """
+ pass # pragma: no cover
+
async def shutdown(self) -> None:
"""
Called once after the server received the shutdown signal.
diff --git a/server/exceptions.py b/server/exceptions.py
index b2102d9bb..5c5a103d6 100644
--- a/server/exceptions.py
+++ b/server/exceptions.py
@@ -57,3 +57,9 @@ def __init__(self, message, method, *args, **kwargs):
super().__init__(*args, **kwargs)
self.message = message
self.method = method
+
+
+class DisabledError(Exception):
+ """
+ The operation is disabled due to an impending server shutdown.
+ """
diff --git a/server/game_service.py b/server/game_service.py
index d0bd44328..1e44bfa42 100644
--- a/server/game_service.py
+++ b/server/game_service.py
@@ -2,6 +2,7 @@
Manages the lifecycle of active games
"""
+import asyncio
from collections import Counter
from typing import Optional, Union, ValuesView
@@ -15,6 +16,7 @@
from .db import FAFDatabase
from .db.models import game_featuredMods
from .decorators import with_logger
+from .exceptions import DisabledError
from .games import (
CustomGame,
FeaturedMod,
@@ -52,6 +54,8 @@ def __init__(
self._rating_service = rating_service
self._message_queue_service = message_queue_service
self.game_id_counter = 0
+ self._allow_new_games = False
+ self._drain_event = None
# Populated below in really_update_static_ish_data.
self.featured_mods = dict()
@@ -68,6 +72,7 @@ async def initialize(self) -> None:
self._update_cron = aiocron.crontab(
"*/10 * * * *", func=self.update_data
)
+ self._allow_new_games = True
async def initialise_game_counter(self):
async with self._db.acquire() as conn:
@@ -153,6 +158,9 @@ def create_game(
"""
Main entrypoint for creating new games
"""
+ if not self._allow_new_games:
+ raise DisabledError()
+
game_id = self.create_uid()
game_args = {
"database": self._db,
@@ -207,10 +215,17 @@ def update_active_game_metrics(self):
rating_type_counter[(rating_type, state)]
)
+ @property
+ def all_games(self) -> ValuesView[Game]:
+ return self._games.values()
+
@property
def live_games(self) -> list[Game]:
- return [game for game in self._games.values()
- if game.state is GameState.LIVE]
+ return [
+ game
+ for game in self.all_games
+ if game.state is GameState.LIVE
+ ]
@property
def open_games(self) -> list[Game]:
@@ -225,22 +240,32 @@ def open_games(self) -> list[Game]:
The client ignores everything "closed". This property fetches all such not-closed games.
"""
- return [game for game in self._games.values()
- if game.state is GameState.LOBBY or game.state is GameState.LIVE]
-
- @property
- def all_games(self) -> ValuesView[Game]:
- return self._games.values()
+ return [
+ game
+ for game in self.all_games
+ if game.state in (GameState.LOBBY, GameState.LIVE)
+ ]
@property
def pending_games(self) -> list[Game]:
- return [game for game in self._games.values()
- if game.state is GameState.LOBBY or game.state is GameState.INITIALIZING]
+ return [
+ game
+ for game in self.all_games
+ if game.state in (GameState.LOBBY, GameState.INITIALIZING)
+ ]
def remove_game(self, game: Game):
if game.id in self._games:
+ self._logger.debug("Removing game %s", game)
del self._games[game.id]
+ if (
+ self._drain_event is not None
+ and not self._drain_event.is_set()
+ and not self._games
+ ):
+ self._drain_event.set()
+
def __getitem__(self, item: int) -> Game:
return self._games[item]
@@ -262,3 +287,31 @@ async def publish_game_results(self, game_results: EndedGameInfo):
metrics.rated_games.labels(game_results.rating_type).inc()
# TODO: Remove when rating service starts listening to message queue
await self._rating_service.enqueue(result_dict)
+
+ async def drain_games(self):
+ """
+ Wait for all games to finish.
+ """
+ if not self._games:
+ return
+
+ if not self._drain_event:
+ self._drain_event = asyncio.Event()
+
+ await self._drain_event.wait()
+
+ async def graceful_shutdown(self):
+ self._allow_new_games = False
+
+ await self.close_lobby_games()
+
+ async def close_lobby_games(self):
+ self._logger.info("Closing all games currently in lobby")
+ for game in self.pending_games:
+ for game_connection in list(game.connections):
+ # Tell the client to kill the FA process
+ game_connection.player.write_message({
+ "command": "notice",
+ "style": "kill"
+ })
+ await game_connection.abort()
diff --git a/server/games/game.py b/server/games/game.py
index f72874c02..e09f8a3a2 100644
--- a/server/games/game.py
+++ b/server/games/game.py
@@ -168,7 +168,6 @@ def players(self) -> list[Player]:
Depending on the state, it is either:
- (LOBBY) The currently connected players
- (LIVE) Players who participated in the game
- - Empty list
"""
if self.state is GameState.LOBBY:
return self.get_connected_players()
diff --git a/server/health.py b/server/health.py
new file mode 100644
index 000000000..617d33593
--- /dev/null
+++ b/server/health.py
@@ -0,0 +1,65 @@
+"""
+Kubernetes compatible HTTP health check server.
+"""
+
+import http
+import socket
+
+from aiohttp import web
+
+from .config import config
+from .decorators import with_logger
+
+
+@with_logger
+class HealthServer:
+ def __init__(
+ self,
+ lobby_server: "ServerInstance",
+ ):
+ self.lobby_server = lobby_server
+ self.host = None
+ self.port = None
+
+ self.app = web.Application()
+ self.runner = web.AppRunner(self.app, access_log=None)
+
+ self.app.add_routes([
+ web.get("/ready", self.ready)
+ ])
+
+ async def run_from_config(self) -> None:
+ """
+ Initialize the http health server
+ """
+ host = socket.gethostbyname(socket.gethostname())
+ port = config.HEALTH_SERVER_PORT
+
+ await self.shutdown()
+ await self.start(host, port)
+
+ async def start(self, host: str, port: int) -> None:
+ self.host = host
+ self.port = port
+ await self.runner.setup()
+ self.site = web.TCPSite(self.runner, host, port)
+ await self.site.start()
+ self._logger.info(
+ "Health server listening on http://%s:%s", host, port
+ )
+
+ async def shutdown(self) -> None:
+ await self.runner.cleanup()
+ self.host = None
+ self.port = None
+
+ async def ready(self, request):
+ code_map = {
+ True: http.HTTPStatus.OK.value,
+ False: http.HTTPStatus.SERVICE_UNAVAILABLE.value
+ }
+
+ return web.Response(
+ status=code_map[self.lobby_server.started],
+ content_type="text/plain"
+ )
diff --git a/server/info.py b/server/info.py
new file mode 100644
index 000000000..5c7767a11
--- /dev/null
+++ b/server/info.py
@@ -0,0 +1,12 @@
+"""
+Static meta information about the container/process
+"""
+
+import os
+import platform
+
+PYTHON_VERSION = platform.python_version()
+
+# Environment variables
+VERSION = os.getenv("VERSION") or "dev"
+CONTAINER_NAME = os.getenv("CONTAINER_NAME") or "faf-python-server"
diff --git a/server/ladder_service/ladder_service.py b/server/ladder_service/ladder_service.py
index 0365552e5..e7046493d 100644
--- a/server/ladder_service/ladder_service.py
+++ b/server/ladder_service/ladder_service.py
@@ -34,6 +34,7 @@
matchmaker_queue_map_pool
)
from server.decorators import with_logger
+from server.exceptions import DisabledError
from server.game_service import GameService
from server.games import InitMode, LadderGame
from server.games.ladder_game import GameClosedError
@@ -70,6 +71,7 @@ def __init__(
self.violation_service = violation_service
self._searches: dict[Player, dict[str, Search]] = defaultdict(dict)
+ self._allow_new_searches = True
async def initialize(self) -> None:
await self.update_data()
@@ -274,6 +276,9 @@ def start_search(
queue_name: str,
on_matched: OnMatchedCallback = lambda _1, _2: None
):
+ if not self._allow_new_searches:
+ raise DisabledError()
+
timeouts = self.violation_service.get_violations(players)
if timeouts:
self._logger.debug("timeouts: %s", timeouts)
@@ -715,10 +720,16 @@ def on_connection_lost(self, conn: "LobbyConnection") -> None:
if player in self._informed_players:
self._informed_players.remove(player)
- async def shutdown(self):
+ async def graceful_shutdown(self):
+ self._allow_new_searches = False
+
for queue in self.queues.values():
queue.shutdown()
+ for player, searches in self._searches.items():
+ for queue_name in list(searches.keys()):
+ self._cancel_search(player, queue_name)
+
class NotConnectedError(asyncio.TimeoutError):
def __init__(self, players: list[Player]):
diff --git a/server/ladder_service/violation_service.py b/server/ladder_service/violation_service.py
index 2ef1a9030..f954abe29 100644
--- a/server/ladder_service/violation_service.py
+++ b/server/ladder_service/violation_service.py
@@ -91,7 +91,7 @@ def register_violations(self, players: list[Player]):
})
extra_text = ""
if violation.count > 1:
- delta_text = humanize.naturaldelta(
+ delta_text = humanize.precisedelta(
violation.get_ban_expiration() - now
)
extra_text = f" You can queue again in {delta_text}"
diff --git a/server/lobbyconnection.py b/server/lobbyconnection.py
index b2fd49aa8..e4c9c0e80 100644
--- a/server/lobbyconnection.py
+++ b/server/lobbyconnection.py
@@ -30,7 +30,12 @@
)
from .db.models import login as t_login
from .decorators import timed, with_logger
-from .exceptions import AuthenticationError, BanError, ClientError
+from .exceptions import (
+ AuthenticationError,
+ BanError,
+ ClientError,
+ DisabledError
+)
from .factions import Faction
from .game_service import GameService
from .gameconnection import GameConnection
@@ -92,6 +97,7 @@ def __init__(
self.user_agent = None
self.version = None
+ self._timeout_task = None
self._attempted_connectivity_test = False
self._logger.debug("LobbyConnection initialized for '%s'", self.session)
@@ -110,8 +116,15 @@ def get_user_identifier(self) -> str:
async def on_connection_made(self, protocol: Protocol, peername: Address):
self.protocol = protocol
self.peer_address = peername
+ self._timeout_task = asyncio.create_task(self.timeout_login())
metrics.server_connections.inc()
+ async def timeout_login(self):
+ with contextlib.suppress(asyncio.CancelledError):
+ await asyncio.sleep(config.LOGIN_TIMEOUT)
+ if not self._authenticated:
+ await self.abort("Client took too long to log in.")
+
async def abort(self, logspam=""):
self._authenticated = False
@@ -167,37 +180,45 @@ async def on_message_received(self, message):
handler = getattr(self, f"command_{cmd}")
await handler(message)
- except AuthenticationError as ex:
- metrics.user_logins.labels("failure", ex.method).inc()
+ except AuthenticationError as e:
+ metrics.user_logins.labels("failure", e.method).inc()
await self.send({
"command": "authentication_failed",
- "text": ex.message
+ "text": e.message
})
- except BanError as ex:
+ except BanError as e:
await self.send({
"command": "notice",
"style": "error",
- "text": ex.message()
+ "text": e.message()
})
- await self.abort(ex.message())
- except ClientError as ex:
- self._logger.warning("Client error: %s", ex.message)
+ await self.abort(e.message())
+ except ClientError as e:
+ self._logger.warning("Client error: %s", e.message)
await self.send({
"command": "notice",
"style": "error",
- "text": ex.message
+ "text": e.message
})
- if not ex.recoverable:
- await self.abort(ex.message)
- except (KeyError, ValueError) as ex:
- self._logger.exception(ex)
+ if not e.recoverable:
+ await self.abort(e.message)
+ except (KeyError, ValueError) as e:
+ self._logger.exception(e)
await self.abort(f"Garbage command: {message}")
except ConnectionError as e:
# Propagate connection errors to the ServerContext error handler.
raise e
- except Exception as ex: # pragma: no cover
+ except DisabledError:
+ # TODO: Respond with correlation uid for original message
+ await self.send({"command": "disabled", "request": cmd})
+ self._logger.info(
+ "Ignoring disabled command for %s: %s",
+ self.get_user_identifier(),
+ cmd
+ )
+ except Exception as e: # pragma: no cover
await self.send({"command": "invalid"})
- self._logger.exception(ex)
+ self._logger.exception(e)
await self.abort("Error processing command")
async def command_ping(self, msg):
@@ -240,7 +261,11 @@ async def command_coop_list(self, message):
async def command_matchmaker_info(self, message):
await self.send({
"command": "matchmaker_info",
- "queues": [queue.to_dict() for queue in self.ladder_service.queues.values()]
+ "queues": [
+ queue.to_dict()
+ for queue in self.ladder_service.queues.values()
+ if queue.is_running
+ ]
})
async def send_game_list(self):
@@ -319,11 +344,10 @@ async def command_admin(self, message):
"Administrative action: %s closed game for %s",
self.player, player
)
- with contextlib.suppress(DisconnectedError):
- await player.send_message({
- "command": "notice",
- "style": "kill",
- })
+ player.write_message({
+ "command": "notice",
+ "style": "kill",
+ })
elif action == "closelobby":
if await self.player_service.has_permission_role(
@@ -1231,6 +1255,9 @@ async def nop(*args, **kwargs):
return
self.send = nop
+ if self._timeout_task and not self._timeout_task.done():
+ self._timeout_task.cancel()
+
if self.game_connection:
self._logger.debug(
"Lost lobby connection killing game connection for player %s",
diff --git a/server/matchmaker/matchmaker_queue.py b/server/matchmaker/matchmaker_queue.py
index 5785d0958..778407981 100644
--- a/server/matchmaker/matchmaker_queue.py
+++ b/server/matchmaker/matchmaker_queue.py
@@ -70,6 +70,10 @@ def __init__(
self.matchmaker = TeamMatchMaker()
+ @property
+ def is_running(self) -> bool:
+ return self._is_running
+
def add_map_pool(
self,
map_pool: MapPool,
@@ -102,7 +106,7 @@ async def queue_pop_timer(self) -> None:
in the queue.
"""
self._logger.debug("MatchmakerQueue initialized for %s", self.name)
- while self._is_running:
+ while self.is_running:
try:
await self.timer.next_pop()
@@ -268,6 +272,7 @@ def match(self, s1: Search, s2: Search) -> bool:
def shutdown(self):
self._is_running = False
+ self.timer.cancel()
def to_dict(self):
"""
@@ -278,7 +283,10 @@ def to_dict(self):
"queue_pop_time": datetime.fromtimestamp(
self.timer.next_queue_pop, timezone.utc
).isoformat(),
- "queue_pop_time_delta": self.timer.next_queue_pop - time.time(),
+ "queue_pop_time_delta": round(
+ self.timer.next_queue_pop - time.time(),
+ ndigits=2
+ ),
"num_players": self.num_players,
"boundary_80s": [search.boundary_80 for search in self._queue.keys()],
"boundary_75s": [search.boundary_75 for search in self._queue.keys()],
diff --git a/server/matchmaker/pop_timer.py b/server/matchmaker/pop_timer.py
index c61bce84f..5425707f9 100644
--- a/server/matchmaker/pop_timer.py
+++ b/server/matchmaker/pop_timer.py
@@ -34,6 +34,7 @@ def __init__(self, queue: "MatchmakerQueue"):
self._last_queue_pop = time()
# Optimistically schedule first pop for half of the max pop time
self.next_queue_pop = self._last_queue_pop + (config.QUEUE_POP_TIME_MAX / 2)
+ self._wait_task = None
async def next_pop(self):
""" Wait for the timer to pop. """
@@ -41,7 +42,10 @@ async def next_pop(self):
time_remaining = self.next_queue_pop - time()
self._logger.info("Next %s wave happening in %is", self.queue.name, time_remaining)
metrics.matchmaker_queue_pop.labels(self.queue.name).set(int(time_remaining))
- await asyncio.sleep(time_remaining)
+
+ self._wait_task = asyncio.create_task(asyncio.sleep(time_remaining))
+ await self._wait_task
+
num_players = self.queue.num_players
metrics.matchmaker_players.labels(self.queue.name).set(num_players)
@@ -81,3 +85,7 @@ def time_until_next_pop(self, num_queued: int, time_queued: float) -> float:
)
return config.QUEUE_POP_TIME_MAX
return next_pop_time
+
+ def cancel(self):
+ if self._wait_task:
+ self._wait_task.cancel()
diff --git a/server/player_service.py b/server/player_service.py
index 172808e5b..158b204b0 100644
--- a/server/player_service.py
+++ b/server/player_service.py
@@ -2,7 +2,7 @@
Manages connected and authenticated players
"""
-import contextlib
+import asyncio
from typing import Optional, ValuesView
import aiocron
@@ -10,10 +10,12 @@
from trueskill import Rating
import server.metrics as metrics
+from server.config import config
from server.db import FAFDatabase
from server.decorators import with_logger
-from server.players import Player
+from server.players import Player, PlayerState
from server.rating import RatingType
+from server.timing import at_interval
from .core import Service
from .db.models import (
@@ -265,16 +267,20 @@ async def update_data(self):
)
self.uniqueid_exempt = frozenset(map(lambda x: x[0], result))
- async def shutdown(self):
- for player in self:
- if player.lobby_connection is not None:
- with contextlib.suppress(Exception):
- player.lobby_connection.write_warning(
- "The server has been shut down for maintenance, "
- "but should be back online soon. If you experience any "
- "problems, please restart your client.
"
- "We apologize for this interruption."
- )
+ async def kick_idle_players(self):
+ for fut in asyncio.as_completed([
+ player.lobby_connection.abort("Graceful shutdown.")
+ for player in self.all_players
+ if player.state == PlayerState.IDLE
+ if player.lobby_connection is not None
+ ]):
+ try:
+ await fut
+ except Exception:
+ self._logger.debug(
+ "Error while aborting connection",
+ exc_info=True
+ )
def on_connection_lost(self, conn: "LobbyConnection") -> None:
if not conn.player:
@@ -288,3 +294,7 @@ def on_connection_lost(self, conn: "LobbyConnection") -> None:
conn.player.login,
conn.session
)
+
+ async def graceful_shutdown(self):
+ if config.SHUTDOWN_KICK_IDLE_PLAYERS:
+ self._kick_idle_task = at_interval(1, self.kick_idle_players)
diff --git a/server/servercontext.py b/server/servercontext.py
index 3ac302864..d6feb95af 100644
--- a/server/servercontext.py
+++ b/server/servercontext.py
@@ -41,6 +41,7 @@ def __init__(
super().__init__()
self.name = name
self._server = None
+ self._drain_event = None
self._connection_factory = connection_factory
self._services = services
self.connections: dict[LobbyConnection, Protocol] = {}
@@ -116,6 +117,18 @@ async def stop(self):
self._server.close()
await self._server.wait_closed()
+ async def drain_connections(self):
+ """
+ Wait for all connections to terminate.
+ """
+ if not self.connections:
+ return
+
+ if not self._drain_event:
+ self._drain_event = asyncio.Event()
+
+ await self._drain_event.wait()
+
def write_broadcast(self, message, validate_fn=lambda _: True):
self.write_broadcast_raw(
self.protocol_class.encode_message(message),
@@ -233,6 +246,14 @@ async def handle_client_connected(
self.name,
connection.get_user_identifier()
)
+
+ if (
+ self._drain_event is not None
+ and not self._drain_event.is_set()
+ and not self.connections
+ ):
+ self._drain_event.set()
+
metrics.user_connections.labels(
connection.user_agent,
connection.version
diff --git a/tests/conftest.py b/tests/conftest.py
index 4f8e44b9c..e057c0c73 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -16,6 +16,8 @@
import hypothesis
import pytest
+from server import ServerInstance
+from server.broadcast_service import BroadcastService
from server.config import TRACE, config
from server.db import FAFDatabase
from server.game_service import GameService
@@ -309,6 +311,18 @@ async def player_service(database):
return player_service
+@pytest.fixture
+async def broadcast_service(game_service, message_queue_service, player_service):
+ broadcast_service = BroadcastService(
+ mock.create_autospec(ServerInstance),
+ message_queue_service,
+ game_service,
+ player_service
+ )
+ await broadcast_service.initialize()
+ return broadcast_service
+
+
@pytest.fixture
async def rating_service(database, player_service, message_queue_service):
service = RatingService(database, player_service, message_queue_service)
diff --git a/tests/data/test_conf.yaml b/tests/data/test_conf.yaml
index a81e8bef3..ce082aa6e 100644
--- a/tests/data/test_conf.yaml
+++ b/tests/data/test_conf.yaml
@@ -23,9 +23,6 @@ WWW_URL: "https://www.faforever.com"
CONTENT_URL: "http://content.faforever.com"
FAF_POLICY_SERVER_BASE_URL: "http://faf-policy-server"
-FORCE_STEAM_LINK_AFTER_DATE: 1536105599
-FORCE_STEAM_LINK: false
-
NEWBIE_BASE_MEAN: 500
NEWBIE_MIN_GAMES: 10
TOP_PLAYER_MIN_RATING: 1600
diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py
index ef0d17cf7..ffeaea92d 100644
--- a/tests/integration_tests/conftest.py
+++ b/tests/integration_tests/conftest.py
@@ -26,6 +26,7 @@
from server.config import config
from server.control import ControlServer
from server.db.models import login
+from server.health import HealthServer
from server.protocol import Protocol, QDataStreamProtocol, SimpleJsonProtocol
from server.servercontext import ServerContext
from tests.utils import exhaust_callbacks
@@ -170,7 +171,7 @@ async def make_lobby_server(config):
context.__connected_client_protos = []
player_service.is_uniqueid_exempt = lambda id: True
- return contexts
+ return instance, contexts
mock_policy = mock.patch(
"server.lobbyconnection.config.FAF_POLICY_SERVER_BASE_URL",
@@ -190,7 +191,7 @@ async def make_lobby_server(config):
@pytest.fixture
-async def lobby_contexts(lobby_server_factory):
+async def lobby_setup(lobby_server_factory):
return await lobby_server_factory({
"qstream": {
"ADDRESS": "127.0.0.1",
@@ -206,7 +207,7 @@ async def lobby_contexts(lobby_server_factory):
@pytest.fixture
-async def lobby_contexts_proxy(lobby_server_factory):
+async def lobby_setup_proxy(lobby_server_factory):
return await lobby_server_factory({
"qstream": {
"ADDRESS": "127.0.0.1",
@@ -223,8 +224,28 @@ async def lobby_contexts_proxy(lobby_server_factory):
})
+@pytest.fixture
+def lobby_instance(lobby_setup):
+ instance, _ = lobby_setup
+ return instance
+
+
+@pytest.fixture
+def lobby_contexts(lobby_setup):
+ _, contexts = lobby_setup
+ return contexts
+
+
+@pytest.fixture
+def lobby_contexts_proxy(lobby_setup_proxy):
+ _, contexts = lobby_setup_proxy
+ return contexts
+
+
+# TODO: This fixture is poorly named since it returns a ServerContext, however,
+# it is used in almost every tests, so renaming it is a large task.
@pytest.fixture(params=("qstream", "json"))
-def lobby_server(request, lobby_contexts):
+def lobby_server(request, lobby_contexts) -> ServerContext:
yield lobby_contexts[request.param]
@@ -234,14 +255,25 @@ def lobby_server_proxy(request, lobby_contexts_proxy):
@pytest.fixture
-async def control_server(player_service, game_service):
- server = ControlServer(
- game_service,
- player_service,
+async def control_server(lobby_instance):
+ server = ControlServer(lobby_instance)
+ await server.start(
"127.0.0.1",
config.CONTROL_SERVER_PORT
)
- await server.start()
+
+ yield server
+
+ await server.shutdown()
+
+
+@pytest.fixture
+async def health_server(lobby_instance):
+ server = HealthServer(lobby_instance)
+ await server.start(
+ "127.0.0.1",
+ config.HEALTH_SERVER_PORT
+ )
yield server
@@ -466,7 +498,7 @@ async def connect_and_sign_in(
credentials,
lobby_server: ServerContext,
address: Optional[tuple[str, int]] = None
-):
+) -> tuple[int, int, Protocol]:
proto = await connect_client(lobby_server, address)
session = await get_session(proto)
await perform_login(proto, credentials)
diff --git a/tests/integration_tests/test_game.py b/tests/integration_tests/test_game.py
index 8fd4f5ae8..929af3155 100644
--- a/tests/integration_tests/test_game.py
+++ b/tests/integration_tests/test_game.py
@@ -36,6 +36,7 @@ async def host_game(
game_id = int(msg["uid"])
await open_fa(proto)
+ await read_until_command(proto, "HostGame", target="game")
return game_id
diff --git a/tests/integration_tests/test_health_server.py b/tests/integration_tests/test_health_server.py
new file mode 100644
index 000000000..e352ec559
--- /dev/null
+++ b/tests/integration_tests/test_health_server.py
@@ -0,0 +1,16 @@
+import aiohttp
+
+from tests.utils import fast_forward
+
+
+@fast_forward(2)
+async def test_ready(health_server, lobby_instance):
+ url = f"http://{health_server.host}:{health_server.port}/ready"
+ async with aiohttp.ClientSession() as session:
+ async with session.get(url) as resp:
+ assert resp.status == 200
+
+ await lobby_instance.shutdown()
+
+ async with session.get(url) as resp:
+ assert resp.status == 503
diff --git a/tests/integration_tests/test_server.py b/tests/integration_tests/test_server.py
index 334d22673..160df9bc3 100644
--- a/tests/integration_tests/test_server.py
+++ b/tests/integration_tests/test_server.py
@@ -5,6 +5,7 @@
import pytest
from sqlalchemy import and_, select
+from server.config import config
from server.db.models import avatars, avatars_list, ban
from server.protocol import DisconnectedError
from tests.utils import fast_forward
@@ -18,7 +19,13 @@
read_until,
read_until_command
)
-from .test_game import host_game, join_game, open_fa, send_player_options
+from .test_game import (
+ host_game,
+ join_game,
+ open_fa,
+ send_player_options,
+ setup_game_1v1
+)
@fast_forward(10)
@@ -51,6 +58,41 @@ async def test_server_proxy_mode_direct(lobby_server_proxy, caplog):
assert "this may indicate a misconfiguration" in caplog.text
+@fast_forward(10)
+async def test_login_timeout(lobby_server, monkeypatch, caplog):
+ monkeypatch.setattr(config, "LOGIN_TIMEOUT", 5)
+
+ proto = await connect_client(lobby_server)
+
+ await asyncio.sleep(5)
+
+ with pytest.raises(DisconnectedError), caplog.at_level("TRACE"):
+ await proto.read_message()
+
+ assert "Client took too long to log in" in caplog.text
+
+
+@fast_forward(10)
+async def test_disconnect_before_login_timeout(
+ lobby_server,
+ monkeypatch,
+ caplog
+):
+ monkeypatch.setattr(config, "LOGIN_TIMEOUT", 5)
+
+ proto = await connect_client(lobby_server)
+
+ await asyncio.sleep(1)
+
+ with pytest.raises(DisconnectedError), caplog.at_level("TRACE"):
+ await proto.close()
+ await asyncio.sleep(4)
+ await proto.read_message()
+
+ assert "Client disconnected" in caplog.text
+ assert "Client took too long to log in" not in caplog.text
+
+
async def test_server_deprecated_client(lobby_server):
proto = await connect_client(lobby_server)
@@ -128,6 +170,177 @@ async def test_ping_message(lobby_server):
await read_until_command(proto, "ping", timeout=46)
+@fast_forward(10)
+async def test_graceful_shutdown(
+ lobby_instance,
+ lobby_server,
+ tmp_user,
+ monkeypatch
+):
+ _, _, proto = await connect_and_sign_in(
+ await tmp_user("Player"),
+ lobby_server
+ )
+ await read_until_command(proto, "game_info")
+
+ await proto.send_message({
+ "command": "matchmaker_info"
+ })
+ msg = await read_until_command(proto, "matchmaker_info", timeout=5)
+ assert msg["queues"]
+
+ await proto.send_message({
+ "command": "game_matchmaking",
+ "state": "start",
+ "queue_name": "ladder1v1"
+ })
+ await read_until_command(proto, "search_info", state="start", timeout=5)
+
+ await lobby_instance.graceful_shutdown()
+
+ # First we get the notice message
+ msg = await read_until_command(proto, "notice", timeout=5)
+ assert "The server will be shutting down" in msg["text"]
+
+ # Then the queues are shut down
+ await read_until_command(proto, "search_info", state="stop", timeout=5)
+
+ # Matchmaker info should not include any queues
+ await proto.send_message({
+ "command": "matchmaker_info"
+ })
+ msg = await read_until_command(proto, "matchmaker_info", timeout=5)
+ assert msg == {
+ "command": "matchmaker_info",
+ "queues": []
+ }
+
+ # Hosting custom games should be disabled
+ await proto.send_message({
+ "command": "game_host",
+ "visibility": "public",
+ })
+ msg = await read_until_command(proto, "disabled", timeout=5)
+ assert msg == {
+ "command": "disabled",
+ "request": "game_host"
+ }
+
+ # Joining matchmaker queues should be disabled
+ await proto.send_message({
+ "command": "game_host",
+ "visibility": "public",
+ })
+ msg = await read_until_command(proto, "disabled", timeout=5)
+ assert msg == {
+ "command": "disabled",
+ "request": "game_host"
+ }
+
+
+@fast_forward(10)
+async def test_graceful_shutdown_kick(
+ lobby_instance,
+ lobby_server,
+ tmp_user,
+ monkeypatch
+):
+ monkeypatch.setattr(config, "SHUTDOWN_KICK_IDLE_PLAYERS", True)
+
+ async def test_connected(proto) -> bool:
+ try:
+ await proto.send_message({"command": "ping"})
+ await read_until_command(proto, "pong", timeout=5)
+ return True
+ except (DisconnectedError, ConnectionError, asyncio.TimeoutError):
+ return False
+
+ _, _, idle_proto = await connect_and_sign_in(
+ await tmp_user("Idle"),
+ lobby_server
+ )
+ _, _, host_proto = await connect_and_sign_in(
+ await tmp_user("Host"),
+ lobby_server
+ )
+ player1_id, _, player1_proto = await connect_and_sign_in(
+ await tmp_user("Player"),
+ lobby_server
+ )
+ player2_id, _, player2_proto = await connect_and_sign_in(
+ await tmp_user("Player"),
+ lobby_server
+ )
+ protos = (
+ idle_proto,
+ host_proto,
+ player1_proto,
+ player2_proto
+ )
+ for proto in protos:
+ await read_until_command(proto, "game_info")
+
+ # Host is in lobby, not playing a game
+ await host_game(host_proto, visibility="public")
+
+ # Player1 and Player2 are playing a game
+ await setup_game_1v1(
+ player1_proto,
+ player1_id,
+ player2_proto,
+ player2_id
+ )
+
+ await lobby_instance.graceful_shutdown()
+
+ # Check that everyone is notified
+ for proto in protos:
+ msg = await read_until_command(proto, "notice")
+ assert "The server will be shutting down" in msg["text"]
+
+ # Players in lobby should be told to close their games
+ await read_until_command(host_proto, "notice", style="kill")
+
+ # Idle players are kicked every 1 second
+ await asyncio.sleep(1)
+
+ assert await test_connected(idle_proto) is False
+ assert await test_connected(host_proto) is False
+
+ assert await test_connected(player1_proto) is True
+ assert await test_connected(player2_proto) is True
+
+
+@fast_forward(60)
+async def test_drain(
+ lobby_instance,
+ lobby_server,
+ tmp_user,
+ monkeypatch,
+ caplog,
+):
+ monkeypatch.setattr(config, "SHUTDOWN_GRACE_PERIOD", 10)
+
+ player1_id, _, player1_proto = await connect_and_sign_in(
+ await tmp_user("Player"),
+ lobby_server
+ )
+ player2_id, _, player2_proto = await connect_and_sign_in(
+ await tmp_user("Player"),
+ lobby_server
+ )
+ await setup_game_1v1(
+ player1_proto,
+ player1_id,
+ player2_proto,
+ player2_id
+ )
+
+ await lobby_instance.drain()
+
+ assert "Graceful shutdown period ended! 1 games are still live!" in caplog.messages
+
+
@fast_forward(5)
async def test_player_info_broadcast(lobby_server):
p1 = await connect_client(lobby_server)
diff --git a/tests/integration_tests/test_servercontext.py b/tests/integration_tests/test_servercontext.py
index 935e9caa0..c09192ab8 100644
--- a/tests/integration_tests/test_servercontext.py
+++ b/tests/integration_tests/test_servercontext.py
@@ -9,7 +9,7 @@
from server.core import Service
from server.lobbyconnection import LobbyConnection
from server.protocol import DisconnectedError, QDataStreamProtocol
-from tests.utils import exhaust_callbacks
+from tests.utils import exhaust_callbacks, fast_forward
class MockConnection:
@@ -141,3 +141,22 @@ async def test_unexpected_exception_in_connection_lost(context, caplog):
await asyncio.sleep(0.1)
assert "Unexpected exception in on_connection_lost" in caplog.text
+
+
+@fast_forward(20)
+async def test_drain_connections(context):
+ srv, ctx = context
+ _, writer = await asyncio.open_connection(*srv.sockets[0].getsockname())
+
+ with pytest.raises(asyncio.TimeoutError):
+ await asyncio.wait_for(
+ ctx.drain_connections(),
+ timeout=10
+ )
+
+ writer.close()
+
+ await asyncio.wait_for(
+ ctx.drain_connections(),
+ timeout=3
+ )
diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py
index 0c8217a61..7d65cdbb7 100644
--- a/tests/unit_tests/conftest.py
+++ b/tests/unit_tests/conftest.py
@@ -127,7 +127,10 @@ def game_stats_service():
def add_connected_player(game: Game, player):
game.game_service.player_service[player.id] = player
- gc = make_mock_game_connection(state=GameConnectionState.CONNECTED_TO_HOST, player=player)
+ gc = make_mock_game_connection(
+ state=GameConnectionState.CONNECTED_TO_HOST,
+ player=player
+ )
game.set_player_option(player.id, "Army", 0)
game.set_player_option(player.id, "StartSpot", 0)
game.set_player_option(player.id, "Team", 0)
diff --git a/tests/unit_tests/test_asyncio_extensions.py b/tests/unit_tests/test_asyncio_extensions.py
index 67aa7dde6..62c70df26 100644
--- a/tests/unit_tests/test_asyncio_extensions.py
+++ b/tests/unit_tests/test_asyncio_extensions.py
@@ -5,7 +5,7 @@
from server.asyncio_extensions import (
SpinLock,
- gather_without_exceptions,
+ map_suppress,
synchronized,
synchronizedmethod
)
@@ -16,41 +16,37 @@ class CustomError(Exception):
pass
-async def raises_connection_error():
- raise ConnectionError("Test ConnectionError")
+async def test_map_suppress(caplog):
+ obj1 = mock.AsyncMock()
+ obj2 = mock.AsyncMock()
+ obj2.test.side_effect = CustomError("Test Exception")
+ obj2.__str__.side_effect = lambda: "TestObject"
+ with caplog.at_level("TRACE"):
+ await map_suppress(
+ lambda x: x.test(),
+ [obj1, obj2]
+ )
-async def raises_connection_reset_error():
- raise ConnectionResetError("Test ConnectionResetError")
+ obj1.test.assert_called_once()
+ obj2.test.assert_called_once()
+ assert "Unexpected error TestObject" in caplog.messages
-async def raises_custom_error():
- raise CustomError("Test Exception")
+async def test_map_suppress_message(caplog):
+ obj1 = mock.AsyncMock()
+ obj1.test.side_effect = CustomError("Test Exception")
+ obj1.__str__.side_effect = lambda: "TestObject"
+ with caplog.at_level("TRACE"):
+ await map_suppress(
+ lambda x: x.test(),
+ [obj1],
+ msg="when testing "
+ )
-async def test_gather_without_exceptions():
- completes_correctly = mock.AsyncMock()
-
- with pytest.raises(CustomError):
- await gather_without_exceptions([
- raises_connection_error(),
- raises_custom_error(),
- completes_correctly()
- ], ConnectionError)
-
- completes_correctly.assert_called_once()
-
-
-async def test_gather_without_exceptions_subclass():
- completes_correctly = mock.AsyncMock()
-
- await gather_without_exceptions([
- raises_connection_error(),
- raises_connection_reset_error(),
- completes_correctly()
- ], ConnectionError)
-
- completes_correctly.assert_called_once()
+ obj1.test.assert_called_once()
+ assert "Unexpected error when testing TestObject" in caplog.messages
@fast_forward(15)
diff --git a/tests/unit_tests/test_broadcast_service.py b/tests/unit_tests/test_broadcast_service.py
new file mode 100644
index 000000000..fc2fa5814
--- /dev/null
+++ b/tests/unit_tests/test_broadcast_service.py
@@ -0,0 +1,22 @@
+import asyncio
+
+
+async def test_broadcast_shutdown(broadcast_service):
+ await broadcast_service.shutdown()
+
+ broadcast_service.server.write_broadcast.assert_called_once()
+
+
+async def test_broadcast_ping(broadcast_service):
+ broadcast_service.broadcast_ping()
+
+ broadcast_service.server.write_broadcast.assert_called_once_with(
+ {"command": "ping"}
+ )
+
+
+async def test_wait_report_dirties(broadcast_service):
+ await asyncio.wait_for(
+ broadcast_service.wait_report_dirtes(),
+ timeout=1
+ )
diff --git a/tests/unit_tests/test_games_service.py b/tests/unit_tests/test_games_service.py
index 6d3525362..55fd74674 100644
--- a/tests/unit_tests/test_games_service.py
+++ b/tests/unit_tests/test_games_service.py
@@ -1,6 +1,19 @@
+import asyncio
+
+import pytest
+
from server.db.models import game_stats
-from server.games import CustomGame, Game, LadderGame, VisibilityState
+from server.exceptions import DisabledError
+from server.games import (
+ CustomGame,
+ Game,
+ GameState,
+ LadderGame,
+ VisibilityState
+)
from server.players import PlayerState
+from tests.unit_tests.conftest import add_connected_player
+from tests.utils import fast_forward
async def test_initialization(game_service):
@@ -18,6 +31,34 @@ async def test_initialize_game_counter_empty(game_service, database):
assert game_service.game_id_counter == 0
+async def test_graceful_shutdown(game_service):
+ await game_service.graceful_shutdown()
+
+ with pytest.raises(DisabledError):
+ game_service.create_game(
+ game_mode="faf",
+ )
+
+
+@fast_forward(2)
+async def test_drain_games(game_service):
+ game = game_service.create_game(
+ game_mode="faf",
+ name="TestGame"
+ )
+
+ with pytest.raises(asyncio.TimeoutError):
+ await asyncio.wait_for(game_service.drain_games(), 1)
+
+ game_service.remove_game(game)
+
+ await asyncio.wait_for(game_service.drain_games(), 1)
+ assert len(game_service.all_games) == 0
+
+ # Calling drain_games again should return immediately
+ await asyncio.wait_for(game_service.drain_games(), 1)
+
+
async def test_create_game(players, game_service):
players.hosting.state = PlayerState.IDLE
game = game_service.create_game(
@@ -46,6 +87,7 @@ async def test_all_games(players, game_service):
password=None
)
assert game in game_service.pending_games
+ assert game in game_service.all_games
assert isinstance(game, CustomGame)
@@ -75,3 +117,21 @@ async def test_create_game_other_gamemode(players, game_service):
assert game in game_service.pop_dirty_games()
assert isinstance(game, Game)
assert game.game_mode == "labwars"
+
+
+async def test_close_lobby_games(players, game_service):
+ game = game_service.create_game(
+ visibility=VisibilityState.PUBLIC,
+ game_mode="faf",
+ host=players.hosting,
+ name="Test",
+ mapname="SCMP_007",
+ password=None
+ )
+ game.state = GameState.LOBBY
+ conn = add_connected_player(game, players.hosting)
+ assert conn in game.connections
+
+ await game_service.close_lobby_games()
+
+ conn.abort.assert_called_once()
diff --git a/tests/unit_tests/test_ladder_service.py b/tests/unit_tests/test_ladder_service.py
index b48ead13f..b62b063a3 100644
--- a/tests/unit_tests/test_ladder_service.py
+++ b/tests/unit_tests/test_ladder_service.py
@@ -7,6 +7,7 @@
from server import LadderService
from server.db.models import matchmaker_queue, matchmaker_queue_map_pool
+from server.exceptions import DisabledError
from server.games import LadderGame
from server.games.ladder_game import GameClosedError
from server.ladder_service import game_name
@@ -1072,3 +1073,73 @@ async def test_write_rating_progress_other_rating(
ladder_service.write_rating_progress(player, RatingType.GLOBAL)
player.write_message.assert_not_called()
+
+
+async def test_graceful_shutdown_disables_searching(
+ ladder_service: LadderService,
+ player_factory
+):
+ p1 = player_factory(
+ "Dostya",
+ player_id=1,
+ ladder_rating=(1000, 10)
+ )
+
+ await ladder_service.graceful_shutdown()
+
+ with pytest.raises(DisabledError):
+ ladder_service.start_search([p1], "ladder1v1")
+
+
+async def test_graceful_shutdown_clears_queues(
+ ladder_service: LadderService,
+ player_factory
+):
+ p1 = player_factory(
+ "Dostya",
+ player_id=1,
+ ladder_rating=(1000, 10)
+ )
+ p1.write_message = mock.Mock()
+
+ p2 = player_factory(
+ "QAI",
+ player_id=2,
+ ladder_rating=(2350, 125),
+ )
+ p2.write_message = mock.Mock()
+
+ p3 = player_factory(
+ "Brackman",
+ player_id=3,
+ ladder_rating=(1000, 10)
+ )
+ p3.write_message = mock.Mock()
+
+ ladder_service.start_search([p1], "ladder1v1")
+ p1.write_message.reset_mock()
+
+ ladder_service.start_search([p2, p3], "tmm2v2")
+ p2.write_message.reset_mock()
+ p3.write_message.reset_mock()
+
+ await ladder_service.graceful_shutdown()
+
+ p1.write_message.assert_called_once_with({
+ "command": "search_info",
+ "state": "stop",
+ "queue_name": "ladder1v1"
+ })
+ p2.write_message.assert_called_once_with({
+ "command": "search_info",
+ "state": "stop",
+ "queue_name": "tmm2v2"
+ })
+ p3.write_message.assert_called_once_with({
+ "command": "search_info",
+ "state": "stop",
+ "queue_name": "tmm2v2"
+ })
+
+ assert ladder_service.queues["ladder1v1"]._is_running is False
+ assert ladder_service.queues["tmm2v2"]._is_running is False
diff --git a/tests/unit_tests/test_lobbyconnection.py b/tests/unit_tests/test_lobbyconnection.py
index b60250538..0de4aa751 100644
--- a/tests/unit_tests/test_lobbyconnection.py
+++ b/tests/unit_tests/test_lobbyconnection.py
@@ -585,7 +585,7 @@ async def test_command_admin_closeFA(lobbyconnection, player_factory):
"user_id": tuna.id
})
- tuna.lobby_connection.send.assert_any_call({
+ tuna.lobby_connection.write.assert_any_call({
"command": "notice",
"style": "kill",
})
diff --git a/tests/unit_tests/test_player_service.py b/tests/unit_tests/test_player_service.py
index 47ee1b583..1dcf6cf3b 100644
--- a/tests/unit_tests/test_player_service.py
+++ b/tests/unit_tests/test_player_service.py
@@ -1,7 +1,5 @@
from unittest import mock
-from server.lobbyconnection import LobbyConnection
-from server.protocol import DisconnectedError
from server.rating import RatingType
@@ -112,27 +110,3 @@ async def test_mark_dirty(player_factory, player_service):
async def test_update_data(player_service):
await player_service.update_data()
assert player_service.is_uniqueid_exempt(1) is True
-
-
-async def test_broadcast_shutdown(player_factory, player_service):
- player = player_factory()
- lconn = mock.create_autospec(LobbyConnection)
- player.lobby_connection = lconn
- player_service[0] = player
-
- await player_service.shutdown()
-
- player.lobby_connection.write_warning.assert_called_once()
-
-
-async def test_broadcast_shutdown_error(player_factory, player_service):
- player = player_factory()
- lconn = mock.create_autospec(LobbyConnection)
- lconn.write_warning.side_effect = DisconnectedError
- player.lobby_connection = lconn
-
- player_service[0] = player
-
- await player_service.shutdown()
-
- player.lobby_connection.write_warning.assert_called_once()