Skip to content

Commit

Permalink
fixup! Statistics: Count servers and clients
Browse files Browse the repository at this point in the history
  • Loading branch information
roekatz committed May 26, 2024
1 parent 50d237d commit 74b3050
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 16 deletions.
9 changes: 7 additions & 2 deletions packages/opal-server/opal_server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,16 @@ class OpalServerConfig(Confi):
"__opal_stats_state_sync",
description="The topic other servers with statistics provide their state to a waking-up server",
)
STATISTICS_SERVER_HELLO_CHANNEL = confi.str(
"STATISTICS_SERVER_HELLO_CHANNEL",
STATISTICS_SERVER_KEEPALIVE_CHANNEL = confi.str(
"STATISTICS_SERVER_KEEPALIVE_CHANNEL",
"__opal_stats_server_hello",
description="The topic workers use to signal they exist and are alive",
)
STATISTICS_SERVER_KEEPALIVE_TIMEOUT = confi.str(
"STATISTICS_SERVER_KEEPALIVE_TIMEOUT",
20,
description="Timeout for forgetting a server from which a keep-alive haven't been seen (keep-alive frequency would be half of this value)",
)

# Data updates
ALL_DATA_TOPIC = confi.str(
Expand Down
2 changes: 2 additions & 0 deletions packages/opal-server/opal_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,8 @@ async def stop_server_background_tasks(self):
tasks.append(asyncio.create_task(self.publisher.stop()))
if self.broadcast_keepalive is not None:
tasks.append(asyncio.create_task(self.broadcast_keepalive.stop()))
if self.opal_statistics is not None:
tasks.append(asyncio.create_task(self.opal_statistics.stop()))

try:
await asyncio.gather(*tasks)
Expand Down
72 changes: 58 additions & 14 deletions packages/opal-server/opal_server/statistics.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import asyncio
import os
from datetime import datetime
from importlib.metadata import version as module_version
from random import uniform
from typing import Any, Dict, List, Optional, Set
from uuid import uuid4

import opal_server
import pydantic
from fastapi import APIRouter, HTTPException, status
from fastapi_websocket_pubsub.event_notifier import Subscription, TopicList
from fastapi_websocket_pubsub.pub_sub_server import PubSubEndpoint
from opal_common.async_utils import TasksPool
from opal_common.config import opal_common_config
from opal_common.logger import get_logger
from opal_common.topics.publisher import PeriodicPublisher
from opal_server.config import opal_server_config
from pydantic import BaseModel, Field

Expand Down Expand Up @@ -49,7 +52,7 @@ class SyncResponse(BaseModel):
rpc_id_to_client_id: Dict[str, str]


class ServerHello(BaseModel):
class ServerKeepalive(BaseModel):
worker_id: str


Expand All @@ -72,6 +75,10 @@ class OpalStatistics:
def __init__(self, endpoint):
self._endpoint: PubSubEndpoint = endpoint
self._uptime = datetime.utcnow()
self._opal_version = module_version(opal_server.__name__)
self._workers_count = (lambda envar: int(envar) if envar.isdigit() else 1)(
os.environ.get("UVICORN_NUM_WORKERS", "1")
)

# helps us realize when another server already responded to a sync request
self._worker_id = uuid4().hex
Expand All @@ -91,6 +98,8 @@ def __init__(self, endpoint):
self._synced_after_wakeup = asyncio.Event()
self._received_sync_messages: Set[str] = set()
self._publish_tasks = TasksPool()
self._seen_servers: Dict[str, datetime] = {}
self._periodic_keepalive_task: asyncio.Task | None = None

@property
def state(self) -> ServerStats:
Expand All @@ -99,9 +108,40 @@ def state(self) -> ServerStats:
@property
def stat_counts(self) -> StatCounts:
return StatCounts(
clients=len(self._state.clients), servers=len(self._state.servers)
clients=len(self._state.clients),
servers=len(self._state.servers) / self._workers_count,
)

async def _expire_old_servers(self):
async with self._lock:
now = datetime.utcnow()
still_alive = {}
for server_id, last_seen in self._seen_servers.items():
if (
now - last_seen
).total_seconds() < opal_server_config.STATISTICS_SERVER_KEEPALIVE_TIMEOUT:
still_alive[server_id] = last_seen
self._seen_servers = still_alive
self._state.servers = {self._worker_id} | set(self._seen_servers.keys())

async def _periodic_server_keepalive(self):
while True:
try:
await self._expire_old_servers()
self._publish(
opal_server_config.STATISTICS_SERVER_KEEPALIVE_CHANNEL,
ServerKeepalive(worker_id=self._worker_id).dict(),
)
await asyncio.sleep(
opal_server_config.STATISTICS_SERVER_KEEPALIVE_TIMEOUT / 2
)
except asyncio.CancelledError:
logger.debug("Statistics: periodic server keepalive cancelled")
return
except Exception as e:
logger.exception("Statistics: periodic server keepalive failed")
logger.exception("Statistics: periodic server keepalive failed")

def _publish(self, channel: str, message: Any):
self._publish_tasks.add_task(self._endpoint.publish([channel], message))

Expand All @@ -117,8 +157,8 @@ async def run(self):
self._receive_other_worker_synced_state,
)
await self._endpoint.subscribe(
[opal_server_config.STATISTICS_SERVER_HELLO_CHANNEL],
self._receive_other_worker_hello_message,
[opal_server_config.STATISTICS_SERVER_KEEPALIVE_CHANNEL],
self._receive_other_worker_keepalive_message,
)
await self._endpoint.subscribe(
[opal_common_config.STATISTICS_ADD_CLIENT_CHANNEL], self._add_client
Expand All @@ -140,6 +180,15 @@ async def run(self):
opal_server_config.STATISTICS_WAKEUP_CHANNEL,
SyncRequest(requesting_worker_id=self._worker_id).dict(),
)
self._periodic_keepalive_task = asyncio.create_task(
self._periodic_server_keepalive()
)

async def stop(self):
if self._periodic_keepalive_task:
self._periodic_keepalive_task.cancel()
await self._periodic_keepalive_task
self._periodic_keepalive_task = None

async def _sync_remove_client(self, subscription: Subscription, rpc_id: str):
"""helper function to recall remove client in all servers.
Expand Down Expand Up @@ -178,13 +227,6 @@ async def _receive_other_worker_wakeup_message(

logger.debug(f"received stats wakeup message: {request.requesting_worker_id}")

# Use worker wakeup to reset everyone's "servers" state
self._state.servers = {self._worker_id, request.requesting_worker_id}
self._publish(
opal_server_config.STATISTICS_SERVER_HELLO_CHANNEL,
ServerHello(worker_id=self._worker_id).dict(),
)

if len(self._state.clients):
# wait random time in order to reduce the number of messages sent by all the other opal servers
await asyncio.sleep(uniform(MIN_TIME_TO_WAIT, MAX_TIME_TO_WAIT))
Expand Down Expand Up @@ -230,10 +272,12 @@ async def _receive_other_worker_synced_state(
self._rpc_id_to_client_id = response.rpc_id_to_client_id
self._synced_after_wakeup.set()

async def _receive_other_worker_hello_message(
self, subscription: Subscription, hello_message: dict
async def _receive_other_worker_keepalive_message(
self, subscription: Subscription, keepalive_message: dict
):
self._state.servers.add(hello_message["worker_id"])
async with self._lock:
self._seen_servers[keepalive_message["worker_id"]] = datetime.now()
self._state.servers.add(keepalive_message["worker_id"])

async def _add_client(self, subscription: Subscription, stats_message: dict):
"""add client record to statistics state.
Expand Down

0 comments on commit 74b3050

Please sign in to comment.