diff --git a/packages/opal-common/opal_common/async_utils.py b/packages/opal-common/opal_common/async_utils.py index a2df90c69..b9714d70a 100644 --- a/packages/opal-common/opal_common/async_utils.py +++ b/packages/opal-common/opal_common/async_utils.py @@ -97,13 +97,23 @@ def __init__(self): self._tasks: List[asyncio.Task] = [] def _cleanup_task(self, done_task): - self._tasks.remove(done_task) + try: + self._tasks.remove(done_task) + except KeyError: + ... def add_task(self, f): t = asyncio.create_task(f) self._tasks.append(t) t.add_done_callback(self._cleanup_task) + async def join(self, cancel=False): + if cancel: + for t in self._tasks: + t.cancel() + await asyncio.gather(*self._tasks, return_exceptions=True) + self._tasks.clear() + async def repeated_call( func: Coroutine, diff --git a/packages/opal-common/opal_common/topics/publisher.py b/packages/opal-common/opal_common/topics/publisher.py index b7b75a24f..4c18c2091 100644 --- a/packages/opal-common/opal_common/topics/publisher.py +++ b/packages/opal-common/opal_common/topics/publisher.py @@ -2,7 +2,8 @@ from typing import Any, Optional, Set from ddtrace import tracer -from fastapi_websocket_pubsub import PubSubClient, PubSubEndpoint, Topic, TopicList +from fastapi_websocket_pubsub import PubSubEndpoint, Topic, TopicList +from opal_common.async_utils import TasksPool from opal_common.logger import logger @@ -12,8 +13,7 @@ class TopicPublisher: def __init__(self): """inits the publisher's asyncio tasks list.""" - self._tasks: Set[asyncio.Task] = set() - self._tasks_lock = asyncio.Lock() + self._pool = TasksPool() async def publish(self, topics: TopicList, data: Any = None): raise NotImplementedError() @@ -29,95 +29,10 @@ def start(self): """starts the publisher.""" logger.debug("started topic publisher") - async def _add_task(self, task: asyncio.Task): - async with self._tasks_lock: - self._tasks.add(task) - task.add_done_callback(self._cleanup_task) - - async def wait(self): - async with self._tasks_lock: - await asyncio.gather(*self._tasks, return_exceptions=True) - self._tasks.clear() - async def stop(self): """stops the publisher (cancels any running publishing tasks)""" logger.debug("stopping topic publisher") - await self.wait() - - def _cleanup_task(self, task: asyncio.Task): - try: - self._tasks.remove(task) - except KeyError: - ... - - -class PeriodicPublisher: - """Wrapper for a task that publishes to topic on fixed interval - periodically.""" - - def __init__( - self, - publisher: TopicPublisher, - time_interval: int, - topic: Topic, - message: Any = None, - task_name: str = "periodic publish task", - ): - """inits the publisher. - - Args: - publisher (TopicPublisher): can publish messages on the pub/sub channel - interval (int): the time interval between publishing consecutive messages - topic (Topic): the topic to publish on - message (Any): the message to publish - """ - self._publisher = publisher - self._interval = time_interval - self._topic = topic - self._message = message - self._task_name = task_name - self._task: Optional[asyncio.Task] = None - - async def __aenter__(self): - self.start() - return self - - async def __aexit__(self, exc_type, exc, tb): - await self.stop() - - def start(self): - """starts the periodic publisher task.""" - if self._task is not None: - logger.warning(f"{self._task_name} already started") - return - - logger.info( - f"started {self._task_name}: topic is '{self._topic}', interval is {self._interval} seconds" - ) - self._task = asyncio.create_task(self._publish_task()) - - async def stop(self): - """stops the publisher (cancels any running publishing tasks)""" - if self._task is not None: - self._task.cancel() - try: - await self._task - except asyncio.CancelledError: - pass - self._task = None - logger.info(f"cancelled {self._task_name} to topic: {self._topic}") - - async def wait_until_done(self): - await self._task - - async def _publish_task(self): - while True: - await asyncio.sleep(self._interval) - logger.info( - f"{self._task_name}: publishing message on topic '{self._topic}', next publish is scheduled in {self._interval} seconds" - ) - async with self._publisher: - await self._publisher.publish(topics=[self._topic], data=self._message) + await self._pool.join() class ServerSideTopicPublisher(TopicPublisher): @@ -132,77 +47,5 @@ def __init__(self, endpoint: PubSubEndpoint): self._endpoint = endpoint super().__init__() - async def _publish_impl(self, topics: TopicList, data: Any = None): - with tracer.trace("topic_publisher.publish", resource=str(topics)): - await self._endpoint.publish(topics=topics, data=data) - async def publish(self, topics: TopicList, data: Any = None): await self._add_task(asyncio.create_task(self._publish_impl(topics, data))) - - -class ClientSideTopicPublisher(TopicPublisher): - """A simple wrapper around a PubSubClient that exposes publish(). - - Provides start() and stop() shortcuts that helps treat this client - as a separate "process" or task that runs in the background. - """ - - def __init__(self, client: PubSubClient, server_uri: str): - """inits the publisher. - - Args: - client (PubSubClient): a configured not-yet-started pub sub client - server_uri (str): the URI of the pub sub server we publish to - """ - self._client = client - self._server_uri = server_uri - super().__init__() - - def start(self): - """starts the pub/sub client as a background asyncio task. - - the client will attempt to connect to the pubsub server until - successful. - """ - super().start() - self._client.start_client(f"{self._server_uri}") - - async def stop(self): - """stops the pubsub client, and cancels any publishing tasks.""" - await self._client.disconnect() - await super().stop() - - async def wait_until_done(self): - """When the publisher is a used as a context manager, this method waits - until the client is done (i.e: terminated) to prevent exiting the - context.""" - return await self._client.wait_until_done() - - async def publish(self, topics: TopicList, data: Any = None): - """publish a message by launching a background task on the event loop. - - Args: - topics (TopicList): a list of topics to publish the message to - data (Any): optional data to publish as part of the message - """ - await self._add_task( - asyncio.create_task(self._publish(topics=topics, data=data)) - ) - - async def _publish(self, topics: TopicList, data: Any = None) -> bool: - """Do not trigger directly, must be triggered via publish() in order to - run as a monitored background asyncio task.""" - await self._client.wait_until_ready() - logger.info("Publishing to topics: {topics}", topics=topics) - return await self._client.publish(topics, data) - - -class ScopedServerSideTopicPublisher(ServerSideTopicPublisher): - def __init__(self, endpoint: PubSubEndpoint, scope_id: str): - super().__init__(endpoint) - self._scope_id = scope_id - - async def publish(self, topics: TopicList, data: Any = None): - scoped_topics = [f"{self._scope_id}:{topic}" for topic in topics] - logger.info("Publishing to topics: {topics}", topics=scoped_topics) - await super().publish(scoped_topics, data) diff --git a/packages/opal-server/opal_server/data/data_update_publisher.py b/packages/opal-server/opal_server/data/data_update_publisher.py index 64bc32bbe..25a41369b 100644 --- a/packages/opal-server/opal_server/data/data_update_publisher.py +++ b/packages/opal-server/opal_server/data/data_update_publisher.py @@ -2,22 +2,18 @@ import os from typing import List -from fastapi_utils.tasks import repeat_every from opal_common.logger import logger -from opal_common.schemas.data import ( - DataSourceEntryWithPollingInterval, - DataUpdate, - ServerDataSourceConfig, -) -from opal_common.topics.publisher import TopicPublisher +from opal_common.schemas.data import DataUpdate +from opal_server.pubsub import PubSub +from opal_server.scopes.scoped_pubsub import ScopedPubSub TOPIC_DELIMITER = "/" PREFIX_DELIMITER = ":" class DataUpdatePublisher: - def __init__(self, publisher: TopicPublisher) -> None: - self._publisher = publisher + def __init__(self, pubsub: PubSub | ScopedPubSub) -> None: + self._pubsub = pubsub @staticmethod def get_topic_combos(topic: str) -> List[str]: @@ -108,6 +104,4 @@ async def publish_data_updates(self, update: DataUpdate): entries=logged_entries, ) - await self._publisher.publish( - list(all_topic_combos), update.dict(by_alias=True) - ) + await self._pubsub.publish(list(all_topic_combos), update.dict(by_alias=True)) diff --git a/packages/opal-server/opal_server/policy/watcher/callbacks.py b/packages/opal-server/opal_server/policy/watcher/callbacks.py index 1b5f65590..c0b168afc 100644 --- a/packages/opal-server/opal_server/policy/watcher/callbacks.py +++ b/packages/opal-server/opal_server/policy/watcher/callbacks.py @@ -16,8 +16,8 @@ PolicyUpdateMessage, PolicyUpdateMessageNotification, ) -from opal_common.topics.publisher import TopicPublisher from opal_common.topics.utils import policy_topics +from opal_server.pubsub import PubSub async def create_update_all_directories_in_repo( @@ -104,7 +104,7 @@ def is_path_affected(path: Path) -> bool: async def publish_changed_directories( old_commit: Commit, new_commit: Commit, - publisher: TopicPublisher, + pubsub: PubSub, file_extensions: Optional[List[str]] = None, bundle_ignore: Optional[List[str]] = None, ): @@ -116,7 +116,4 @@ async def publish_changed_directories( ) if notification: - async with publisher: - await publisher.publish( - topics=notification.topics, data=notification.update.dict() - ) + await pubsub.publish_sync(notification.topics, notification.update.dict()) diff --git a/packages/opal-server/opal_server/policy/watcher/factory.py b/packages/opal-server/opal_server/policy/watcher/factory.py index 6d94d6fc4..985b6a61f 100644 --- a/packages/opal-server/opal_server/policy/watcher/factory.py +++ b/packages/opal-server/opal_server/policy/watcher/factory.py @@ -1,22 +1,20 @@ from functools import partial from typing import Any, List, Optional -from fastapi_websocket_pubsub.pub_sub_server import PubSubEndpoint from opal_common.confi.confi import load_conf_if_none from opal_common.git_utils.repo_cloner import RepoClonePathFinder from opal_common.logger import logger from opal_common.sources.api_policy_source import ApiPolicySource from opal_common.sources.git_policy_source import GitPolicySource -from opal_common.topics.publisher import TopicPublisher from opal_server.config import PolicySourceTypes, opal_server_config from opal_server.policy.watcher.callbacks import publish_changed_directories from opal_server.policy.watcher.task import BasePolicyWatcherTask, PolicyWatcherTask +from opal_server.pubsub import PubSub from opal_server.scopes.task import ScopesPolicyWatcherTask def setup_watcher_task( - publisher: TopicPublisher, - pubsub_endpoint: PubSubEndpoint, + pubsub: PubSub, source_type: str = None, remote_source_url: str = None, clone_path_finder: RepoClonePathFinder = None, @@ -35,7 +33,7 @@ def setup_watcher_task( vars Load all the defaults from config if called without params. Args: - publisher(TopicPublisher): server side publisher to publish changes in policy + pubsub(PubSub): server side pubsub client to publish changes in policy source_type(str): policy source type, can be Git / Api to opa bundle server remote_source_url(str): the base address to request the policy from clone_path_finder(RepoClonePathFinder): from which the local dir path for the repo clone would be retrieved @@ -50,7 +48,7 @@ def setup_watcher_task( bundle_ignore(list(str), optional): list of glob paths to use for excluding files from bundle default is OPA_BUNDLE_IGNORE """ if opal_server_config.SCOPES: - return ScopesPolicyWatcherTask(pubsub_endpoint) + return ScopesPolicyWatcherTask(pubsub) # load defaults source_type = load_conf_if_none(source_type, opal_server_config.POLICY_SOURCE_TYPE) @@ -135,9 +133,9 @@ def setup_watcher_task( watcher.add_on_new_policy_callback( partial( publish_changed_directories, - publisher=publisher, + pubsub=pubsub, file_extensions=extensions, bundle_ignore=bundle_ignore, ) ) - return PolicyWatcherTask(watcher, pubsub_endpoint) + return PolicyWatcherTask(watcher, pubsub) diff --git a/packages/opal-server/opal_server/policy/watcher/task.py b/packages/opal-server/opal_server/policy/watcher/task.py index a2ba57558..7099d2295 100644 --- a/packages/opal-server/opal_server/policy/watcher/task.py +++ b/packages/opal-server/opal_server/policy/watcher/task.py @@ -1,10 +1,10 @@ import asyncio import os import signal -from typing import Any, Coroutine, List, Optional +from typing import Any, List, Optional from fastapi_websocket_pubsub import Topic -from fastapi_websocket_pubsub.pub_sub_server import PubSubEndpoint +from fastapi_websocket_pubsub.pub_sub_server import PubSub from opal_common.logger import logger from opal_common.sources.base_policy_source import BasePolicySource from opal_server.config import opal_server_config @@ -13,21 +13,15 @@ class BasePolicyWatcherTask: """Manages the asyncio tasks of the policy watcher.""" - def __init__(self, pubsub_endpoint: PubSubEndpoint): + def __init__(self, pubsub: PubSub): self._tasks: List[asyncio.Task] = [] self._should_stop: Optional[asyncio.Event] = None - self._pubsub_endpoint = pubsub_endpoint + self._pubsub = pubsub self._webhook_tasks: List[asyncio.Task] = [] - async def __aenter__(self): - await self.start() - return self - - async def __aexit__(self, exc_type, exc, tb): - await self.stop() - async def _on_webhook(self, topic: Topic, data: Any): logger.info(f"Webhook listener triggered ({len(self._webhook_tasks)})") + # TODO: Use TasksPool for task in self._webhook_tasks: if task.done(): # Clean references to finished tasks @@ -39,33 +33,20 @@ async def _listen_to_webhook_notifications(self): # Webhook api route can be hit randomly in all workers, so it publishes a message to the webhook topic. # This listener, running in the leader's context, would actually trigger the repo pull - async def _subscribe_internal(): - logger.info( - "listening on webhook topic: '{topic}'", - topic=opal_server_config.POLICY_REPO_WEBHOOK_TOPIC, - ) - await self._pubsub_endpoint.subscribe( - [opal_server_config.POLICY_REPO_WEBHOOK_TOPIC], - self._on_webhook, - ) - - if self._pubsub_endpoint.broadcaster is not None: - async with self._pubsub_endpoint.broadcaster.get_listening_context(): - await _subscribe_internal() - await self._pubsub_endpoint.broadcaster.get_reader_task() - - # Stop the watcher if broadcaster disconnects - self.signal_stop() - else: - # If no broadcaster is configured, just subscribe, no need to wait on anything - await _subscribe_internal() + logger.info( + "listening on webhook topic: '{topic}'", + topic=opal_server_config.POLICY_REPO_WEBHOOK_TOPIC, + ) + await self._pubsub.subscribe( + [opal_server_config.POLICY_REPO_WEBHOOK_TOPIC], + self._on_webhook, + ) async def start(self): """starts the policy watcher and registers a failure callback to terminate gracefully.""" logger.info("Launching policy watcher") - self._tasks.append(asyncio.create_task(self._listen_to_webhook_notifications())) - self._init_should_stop() + await self._listen_to_webhook_notifications() async def stop(self): """stops all policy watcher tasks.""" @@ -80,29 +61,11 @@ async def trigger(self, topic: Topic, data: Any): pull)""" raise NotImplementedError() - def wait_until_should_stop(self) -> Coroutine: - """waits until self.signal_stop() is called on the watcher. - - allows us to keep the repo watcher context alive until signalled - to stop from outside. - """ - self._init_should_stop() - return self._should_stop.wait() - - def signal_stop(self): - """signal the repo watcher it should stop.""" - self._init_should_stop() - self._should_stop.set() - - def _init_should_stop(self): - if self._should_stop is None: - self._should_stop = asyncio.Event() - async def _fail(self, exc: Exception): """called when the watcher fails, and stops all tasks gracefully.""" logger.error("policy watcher failed with exception: {err}", err=repr(exc)) - self.signal_stop() # trigger uvicorn graceful shutdown + # TODO: Seriously? os.kill(os.getpid(), signal.SIGTERM) diff --git a/packages/opal-server/opal_server/policy/webhook/api.py b/packages/opal-server/opal_server/policy/webhook/api.py index c19595ad2..780c53deb 100644 --- a/packages/opal-server/opal_server/policy/webhook/api.py +++ b/packages/opal-server/opal_server/policy/webhook/api.py @@ -2,7 +2,6 @@ from urllib.parse import SplitResult, urlparse from fastapi import APIRouter, Depends, Request, status -from fastapi_websocket_pubsub.pub_sub_server import PubSubEndpoint from opal_common.authentication.deps import JWTAuthenticator from opal_common.logger import logger from opal_common.schemas.webhook import GitWebhookRequestParams @@ -12,11 +11,10 @@ extracted_git_changes, validate_git_secret_or_throw, ) +from opal_server.pubsub import PubSub -def init_git_webhook_router( - pubsub_endpoint: PubSubEndpoint, authenticator: JWTAuthenticator -): +def init_git_webhook_router(pubsub: PubSub, authenticator: JWTAuthenticator): async def dummy_affected_repo_urls(request: Request) -> List[str]: return [] @@ -32,7 +30,7 @@ async def dummy_affected_repo_urls(request: Request) -> List[str]: [Depends(route_dependency)], Depends(func_dependency), source_type, - pubsub_endpoint.publish, + pubsub.publish_sync, ) diff --git a/packages/opal-server/opal_server/publisher.py b/packages/opal-server/opal_server/publisher.py index 7d22fd86c..806e535cd 100644 --- a/packages/opal-server/opal_server/publisher.py +++ b/packages/opal-server/opal_server/publisher.py @@ -1,35 +1,88 @@ -from fastapi_websocket_pubsub import PubSubClient, Topic -from opal_common.confi.confi import load_conf_if_none -from opal_common.topics.publisher import ( - ClientSideTopicPublisher, - PeriodicPublisher, - ServerSideTopicPublisher, - TopicPublisher, -) -from opal_common.utils import get_authorization_header -from opal_server.config import opal_server_config - - -def setup_publisher_task( - server_uri: str = None, - server_token: str = None, -) -> TopicPublisher: - server_uri = load_conf_if_none( - server_uri, - opal_server_config.OPAL_WS_LOCAL_URL, - ) - server_token = load_conf_if_none( - server_token, - opal_server_config.OPAL_WS_TOKEN, - ) - return ClientSideTopicPublisher( - client=PubSubClient(extra_headers=[get_authorization_header(server_token)]), - server_uri=server_uri, - ) +import asyncio +from typing import Any, Optional, Set + +from fastapi_websocket_pubsub import Topic, TopicList +from opal_common.logger import logger +from opal_server.pubsub import PubSub + + +class PeriodicPublisher: + """Wrapper for a task that publishes to topic on fixed interval + periodically.""" + + def __init__( + self, + pubsub: PubSub, + time_interval: int, + topic: Topic, + message: Any = None, + task_name: str = "periodic publish task", + ): + """inits the publisher. + + Args: + publisher (TopicPublisher): can publish messages on the pub/sub channel + time_interval (int): the time interval between publishing consecutive messages + topic (Topic): the topic to publish on + message (Any): the message to publish + """ + self._pubsub = pubsub + self._interval = time_interval + self._topic = topic + self._message = message + self._task_name = task_name + self._task: Optional[asyncio.Task] = None + + async def __aenter__(self): + self.start() + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.stop() + + def start(self): + """starts the periodic publisher task.""" + if self._task is not None: + logger.warning(f"{self._task_name} already started") + return + + logger.info( + f"started {self._task_name}: topic is '{self._topic}', interval is {self._interval} seconds" + ) + self._task = asyncio.create_task(self._publish_task()) + + async def stop(self): + """stops the publisher (cancels any running publishing tasks)""" + if self._task is not None: + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + self._task = None + logger.info(f"cancelled {self._task_name} to topic: {self._topic}") + + async def _publish_task(self): + while True: + await asyncio.sleep(self._interval) + logger.info( + f"{self._task_name}: publishing message on topic '{self._topic}', next publish is scheduled in {self._interval} seconds" + ) + try: + await self._pubsub.publish_sync([self._topic], self._message) + except asyncio.CancelledError: + logger.debug( + f"{self._task_name} for topic '{self._topic}' was cancelled" + ) + break + except Exception as e: + logger.error( + f"failed to publish periodic message on topic '{self._topic}': {e}" + ) def setup_broadcaster_keepalive_task( - publisher: ServerSideTopicPublisher, + pubsub: PubSub, time_interval: int, topic: Topic = "__broadcast_session_keepalive__", ) -> PeriodicPublisher: @@ -37,5 +90,5 @@ def setup_broadcaster_keepalive_task( broadcast channel, so that the session to the backbone won't become idle and close on the backbone end.""" return PeriodicPublisher( - publisher, time_interval, topic, task_name="broadcaster keepalive task" + pubsub, time_interval, topic, task_name="broadcaster keepalive task" ) diff --git a/packages/opal-server/opal_server/pubsub.py b/packages/opal-server/opal_server/pubsub.py index 26d47c422..5a59d5289 100644 --- a/packages/opal-server/opal_server/pubsub.py +++ b/packages/opal-server/opal_server/pubsub.py @@ -1,10 +1,12 @@ +import asyncio import time from contextlib import contextmanager from contextvars import ContextVar from threading import Lock -from typing import Dict, Generator, List, Optional, Set, Tuple, Union, cast -from uuid import UUID, uuid4 +from typing import Any, Coroutine, Dict, Generator, Optional, Set, Union +from uuid import uuid4 +from ddtrace import tracer from fastapi import APIRouter, Depends, WebSocket from fastapi_websocket_pubsub import ( ALL_TOPICS, @@ -21,14 +23,15 @@ WebSocketRpcEventNotifier, ) from fastapi_websocket_rpc import RpcChannel +from opal_common.async_utils import TasksPool from opal_common.authentication.deps import WebsocketJWTAuthenticator from opal_common.authentication.signer import JWTSigner from opal_common.authentication.types import JWTClaims from opal_common.authentication.verifier import Unauthorized -from opal_common.confi.confi import load_conf_if_none from opal_common.config import opal_common_config from opal_common.logger import logger from opal_server.config import opal_server_config +from opal_server.publisher import setup_broadcaster_keepalive_task from pydantic import BaseModel from starlette.datastructures import QueryParams @@ -121,15 +124,17 @@ class PubSub: """Wrapper for the Pub/Sub channel used for both policy and data updates.""" - def __init__(self, signer: JWTSigner, broadcaster_uri: str = None): + def __init__( + self, + signer: JWTSigner, + broadcaster_uri: str = None, + disconnect_callback: Coroutine = None, + ): """ Args: broadcaster_uri (str, optional): Which server/medium should the PubSub use for broadcasting. Defaults to BROADCAST_URI. None means no broadcasting. """ - broadcaster_uri = load_conf_if_none( - broadcaster_uri, opal_server_config.BROADCAST_URI - ) self.pubsub_router = APIRouter() self.api_router = APIRouter() # Pub/Sub Internals @@ -138,8 +143,8 @@ def __init__(self, signer: JWTSigner, broadcaster_uri: str = None): self.client_tracker = ClientTracker() self.notifier.register_subscribe_event(self.client_tracker.on_subscribe) self.notifier.register_unsubscribe_event(self.client_tracker.on_unsubscribe) + self._publish_pool = TasksPool() - self.broadcaster = None if broadcaster_uri is not None: logger.info(f"Initializing broadcaster for server<->server communication") self.broadcaster = EventBroadcaster( @@ -147,8 +152,19 @@ def __init__(self, signer: JWTSigner, broadcaster_uri: str = None): notifier=self.notifier, channel=opal_server_config.BROADCAST_CHANNEL_NAME, ) + if opal_server_config.BROADCAST_KEEPALIVE_INTERVAL > 0: + self.broadcast_keepalive = setup_broadcaster_keepalive_task( + self, + time_interval=opal_server_config.BROADCAST_KEEPALIVE_INTERVAL, + topic=opal_server_config.BROADCAST_KEEPALIVE_TOPIC, + ) + else: logger.info("Pub/Sub broadcaster is off") + self.broadcaster = None + self.broadcast_keepalive = None + + self._wait_for_broadcaster_closed: Optional[asyncio.Task] = None # The server endpoint self.endpoint = PubSubEndpoint( @@ -202,6 +218,49 @@ async def websocket_rpc_endpoint( finally: await websocket.close() + async def start(self): + if self.broadcaster is not None: + await self.broadcaster.connect() + # TODO: That's still not good + self._wait_for_broadcaster_closed = asyncio.create_task( + self.wait_until_done() + ) + if self.broadcast_keepalive is not None: + self.broadcast_keepalive.start() + + async def stop(self): + stop_tasks = [self._publish_pool.join()] + if self.broadcast_keepalive is not None: + stop_tasks.append(self.broadcast_keepalive.stop()) + if self.broadcaster is not None: + self._wait_for_broadcaster_closed.cancel() + stop_tasks.append(self._wait_for_broadcaster_closed) + + # TODO: return_exceptions? + await asyncio.gather(*stop_tasks, return_exceptions=True) + if self.broadcaster is not None: + await self.broadcaster.close() + self.broadcaster = None + + async def wait_until_done(self, done_callback: Optional[Coroutine] = None): + if self.broadcaster is not None: + await self.broadcaster.wait_until_done() + if done_callback: + # TODO: Improve because main loop also wait for broadcaster to be done + await done_callback + + async def publish_sync(self, topics: TopicList, data: Any = None): + with tracer.trace("topic_publisher.publish", resource=str(topics)): + await self.endpoint.publish(topics=topics, data=data) + + async def publish(self, topics: TopicList, data: Any = None): + await self._publish_pool.add_task(self.publish_sync(topics, data)) + + async def subscribe( + self, topics: Union[TopicList, ALL_TOPICS], callback: EventCallback + ) -> list[Subscription]: + return await self.endpoint.subscribe(topics, callback) + @staticmethod async def _verify_permitted_topics( topics: Union[TopicList, ALL_TOPICS], channel: RpcChannel diff --git a/packages/opal-server/opal_server/scopes/api.py b/packages/opal-server/opal_server/scopes/api.py index 95181866a..1ee3b0173 100644 --- a/packages/opal-server/opal_server/scopes/api.py +++ b/packages/opal-server/opal_server/scopes/api.py @@ -13,7 +13,6 @@ status, ) from fastapi.responses import RedirectResponse -from fastapi_websocket_pubsub import PubSubEndpoint from git import InvalidGitRepositoryError from opal_common.async_utils import run_sync from opal_common.authentication.authz import ( @@ -31,19 +30,17 @@ DataUpdate, ServerDataSourceConfig, ) -from opal_common.schemas.policy import PolicyBundle, PolicyUpdateMessageNotification +from opal_common.schemas.policy import PolicyBundle from opal_common.schemas.policy_source import GitPolicyScopeSource, SSHAuthData from opal_common.schemas.scopes import Scope from opal_common.schemas.security import PeerType -from opal_common.topics.publisher import ( - ScopedServerSideTopicPublisher, - ServerSideTopicPublisher, -) from opal_common.urls import set_url_query_param from opal_server.config import opal_server_config from opal_server.data.data_update_publisher import DataUpdatePublisher from opal_server.git_fetcher import GitPolicyFetcher +from opal_server.pubsub import PubSub from opal_server.scopes.scope_repository import ScopeNotFoundError, ScopeRepository +from opal_server.scopes.scoped_pubsub import ScopedPubSub def verify_private_key(private_key: str, key_format: EncryptionKeyFormat) -> bool: @@ -79,7 +76,7 @@ def verify_private_key_or_throw(scope_in: Scope): def init_scope_router( scopes: ScopeRepository, authenticator: JWTAuthenticator, - pubsub_endpoint: PubSubEndpoint, + pubsub: PubSub, ): router = APIRouter() @@ -117,7 +114,7 @@ async def put_scope( logger.info(f"Sync scope: {scope_in.scope_id}{force_fetch_str}") # All server replicas (leaders) should sync the scope. - await pubsub_endpoint.publish( + await pubsub.publish_sync( opal_server_config.POLICY_REPO_WEBHOOK_TOPIC, {"scope_id": scope_in.scope_id, "force_fetch": force_fetch}, ) @@ -203,7 +200,7 @@ async def refresh_scope( force_fetch = hinted_hash is None # All server replicas (leaders) should sync the scope. - await pubsub_endpoint.publish( + await pubsub.publish_sync( opal_server_config.POLICY_REPO_WEBHOOK_TOPIC, { "scope_id": scope_id, @@ -229,7 +226,7 @@ async def sync_all_scopes(claims: JWTClaims = Depends(authenticator)): raise # All server replicas (leaders) should sync all scopes. - await pubsub_endpoint.publish(opal_server_config.POLICY_REPO_WEBHOOK_TOPIC) + await pubsub.publish_sync(opal_server_config.POLICY_REPO_WEBHOOK_TOPIC) return Response(status_code=status.HTTP_200_OK) @@ -350,7 +347,7 @@ async def publish_data_update_event( entry.topics = [f"data:{topic}" for topic in entry.topics] await DataUpdatePublisher( - ScopedServerSideTopicPublisher(pubsub_endpoint, scope_id) + ScopedPubSub(pubsub, scope_id) ).publish_data_updates(update) except Unauthorized as ex: logger.error(f"Unauthorized to publish update: {repr(ex)}") diff --git a/packages/opal-server/opal_server/scopes/scoped_pubsub.py b/packages/opal-server/opal_server/scopes/scoped_pubsub.py new file mode 100644 index 000000000..ce55e062b --- /dev/null +++ b/packages/opal-server/opal_server/scopes/scoped_pubsub.py @@ -0,0 +1,22 @@ +from typing import Any + +from fastapi_websocket_pubsub import TopicList +from opal_common.logger import logger +from opal_server.pubsub import PubSub + + +class ScopedPubSub: + def __init__(self, pubsub: PubSub, scope_id: str): + self._pubsub = pubsub + self._scope_id = scope_id + + def scope_topics(self, topics: TopicList) -> TopicList: + topics = [f"{self._scope_id}:{topic}" for topic in topics] + logger.debug("Publishing to topics: {topics}", topics=topics) + return topics + + async def publish(self, topics: TopicList, data: Any = None): + await self._pubsub.publish(self.scope_topics(topics), data) + + async def publish_sync(self, topics: TopicList, data: Any = None): + await self._pubsub.publish_sync(self.scope_topics(topics), data) diff --git a/packages/opal-server/opal_server/scopes/service.py b/packages/opal-server/opal_server/scopes/service.py index f0104e7bf..fad2a5b5e 100644 --- a/packages/opal-server/opal_server/scopes/service.py +++ b/packages/opal-server/opal_server/scopes/service.py @@ -6,18 +6,18 @@ import git from ddtrace import tracer -from fastapi_websocket_pubsub import PubSubEndpoint from opal_common.git_utils.commit_viewer import VersionedFile from opal_common.logger import logger from opal_common.schemas.policy import PolicyUpdateMessageNotification from opal_common.schemas.policy_source import GitPolicyScopeSource -from opal_common.topics.publisher import ScopedServerSideTopicPublisher from opal_server.git_fetcher import GitPolicyFetcher, PolicyFetcherCallbacks from opal_server.policy.watcher.callbacks import ( create_policy_update, create_update_all_directories_in_repo, ) +from opal_server.pubsub import PubSub from opal_server.scopes.scope_repository import Scope, ScopeRepository +from opal_server.scopes.scoped_pubsub import ScopedPubSub def is_rego_source_file( @@ -41,12 +41,12 @@ def __init__( base_dir: Path, scope_id: str, source: GitPolicyScopeSource, - pubsub_endpoint: PubSubEndpoint, + pubsub: PubSub, ): self._scope_repo_dir = GitPolicyFetcher.repo_clone_path(base_dir, source) self._scope_id = scope_id self._source = source - self._pubsub_endpoint = pubsub_endpoint + self._pubsub = pubsub async def on_update(self, previous_head: str, head: str): if previous_head == head: @@ -93,10 +93,9 @@ async def trigger_notification(self, notification: PolicyUpdateMessageNotificati logger.info( f"Triggering policy update for scope {self._scope_id}: {notification.dict()}" ) - async with ScopedServerSideTopicPublisher( - self._pubsub_endpoint, self._scope_id - ) as publisher: - await publisher.publish(notification.topics, notification.update) + await ScopedPubSub(self._pubsub, self._scope_id).publish_sync( + notification.topics, notification.update + ) class ScopesService: @@ -104,11 +103,11 @@ def __init__( self, base_dir: Path, scopes: ScopeRepository, - pubsub_endpoint: PubSubEndpoint, + pubsub: Optional[PubSub], ): self._base_dir = base_dir self._scopes = scopes - self._pubsub_endpoint = pubsub_endpoint + self._pubsub = pubsub async def sync_scope( self, @@ -139,7 +138,7 @@ async def sync_scope( base_dir=self._base_dir, scope_id=scope.scope_id, source=source, - pubsub_endpoint=self._pubsub_endpoint, + pubsub=self._pubsub, ) fetcher = GitPolicyFetcher( diff --git a/packages/opal-server/opal_server/scopes/task.py b/packages/opal-server/opal_server/scopes/task.py index 83b2b10f0..868fd0617 100644 --- a/packages/opal-server/opal_server/scopes/task.py +++ b/packages/opal-server/opal_server/scopes/task.py @@ -19,7 +19,7 @@ def __init__(self, *args, **kwargs): self._service = ScopesService( base_dir=Path(opal_server_config.BASE_DIR), scopes=ScopeRepository(RedisDB(opal_server_config.REDIS_URL)), - pubsub_endpoint=self._pubsub_endpoint, + pubsub=self._pubsub, ) async def start(self): @@ -78,7 +78,7 @@ def preload_scopes(): service = ScopesService( base_dir=Path(opal_server_config.BASE_DIR), scopes=ScopeRepository(RedisDB(opal_server_config.REDIS_URL)), - pubsub_endpoint=None, + pubsub=None, ) asyncio.run(service.sync_scopes(notify_on_changes=False)) diff --git a/packages/opal-server/opal_server/server.py b/packages/opal-server/opal_server/server.py index 34d9905c3..1e45e5354 100644 --- a/packages/opal-server/opal_server/server.py +++ b/packages/opal-server/opal_server/server.py @@ -2,12 +2,9 @@ import os import signal import sys -import traceback -from functools import partial from typing import List, Optional from fastapi import Depends, FastAPI -from fastapi_websocket_pubsub.event_broadcaster import EventBroadcasterContextManager from opal_common.authentication.deps import JWTAuthenticator, StaticBearerAuthenticator from opal_common.authentication.signer import JWTSigner from opal_common.confi.confi import load_conf_if_none @@ -17,11 +14,6 @@ from opal_common.monitoring import apm, metrics from opal_common.schemas.data import ServerDataSourceConfig from opal_common.synchronization.named_lock import NamedLock -from opal_common.topics.publisher import ( - PeriodicPublisher, - ServerSideTopicPublisher, - TopicPublisher, -) from opal_server.config import opal_server_config from opal_server.data.api import init_data_updates_router from opal_server.data.data_update_publisher import DataUpdatePublisher @@ -30,7 +22,6 @@ from opal_server.policy.watcher.factory import setup_watcher_task from opal_server.policy.watcher.task import PolicyWatcherTask from opal_server.policy.webhook.api import init_git_webhook_router -from opal_server.publisher import setup_broadcaster_keepalive_task from opal_server.pubsub import PubSub from opal_server.redis_utils import RedisDB from opal_server.scopes.api import init_scope_router @@ -59,7 +50,7 @@ def __init__( """ Args: policy_remote_url (str, optional): the url of the repo watched by policy watcher. - init_publisher (bool, optional): whether or not to launch a publisher pub/sub client. + init_publisher (bool, optional): whether to launch a publisher pub/sub client. this publisher is used by the server processes to publish data to the client. data_sources_config (ServerDataSourceConfig, optional): base data configuration, that opal clients should get the data from. @@ -143,39 +134,18 @@ def __init__( else: self.jwks_endpoint = None - self.pubsub = PubSub(signer=self.signer, broadcaster_uri=broadcaster_uri) - - self.publisher: Optional[TopicPublisher] = None - self.broadcast_keepalive: Optional[PeriodicPublisher] = None - if init_publisher: - self.publisher = ServerSideTopicPublisher(self.pubsub.endpoint) - - if ( - opal_server_config.BROADCAST_KEEPALIVE_INTERVAL > 0 - and self.broadcaster_uri is not None - ): - self.broadcast_keepalive = setup_broadcaster_keepalive_task( - self.publisher, - time_interval=opal_server_config.BROADCAST_KEEPALIVE_INTERVAL, - topic=opal_server_config.BROADCAST_KEEPALIVE_TOPIC, - ) + self.pubsub = PubSub( + signer=self.signer, + broadcaster_uri=broadcaster_uri, + disconnect_callback=self._graceful_shutdown(), + ) + # TODO: Ignore init_publisher if opal_common_config.STATISTICS_ENABLED: - self.opal_statistics = OpalStatistics(self.pubsub.endpoint) + self.opal_statistics = OpalStatistics(self.pubsub) else: self.opal_statistics = None - # if stats are enabled, the server workers must be listening on the broadcast - # channel for their own synchronization, not just for their clients. therefore - # we need a "global" listening context - self.broadcast_listening_context: Optional[ - EventBroadcasterContextManager - ] = None - if self.broadcaster_uri is not None and opal_common_config.STATISTICS_ENABLED: - self.broadcast_listening_context = ( - self.pubsub.endpoint.broadcaster.get_listening_context() - ) - self.watcher: PolicyWatcherTask = None self.leadership_lock: Optional[NamedLock] = None @@ -184,6 +154,8 @@ def __init__( self._scopes = ScopeRepository(self._redis_db) logger.info("OPAL Scopes: server is connected to scopes repository") + self._leadership_task: Optional[asyncio.Task] = None + # init fastapi app self.app: FastAPI = self._init_fast_api_app() @@ -221,9 +193,9 @@ def _configure_api_routes(self, app: FastAPI): """mounts the api routes on the app object.""" authenticator = JWTAuthenticator(self.signer) - data_update_publisher: Optional[DataUpdatePublisher] = None - if self.publisher is not None: - data_update_publisher = DataUpdatePublisher(self.publisher) + data_update_publisher: Optional[DataUpdatePublisher] = DataUpdatePublisher( + self.pubsub + ) # Init api routers with required dependencies data_updates_router = init_data_updates_router( @@ -264,7 +236,7 @@ def _configure_api_routes(self, app: FastAPI): if opal_server_config.SCOPES: app.include_router( - init_scope_router(self._scopes, authenticator, self.pubsub.endpoint), + init_scope_router(self._scopes, authenticator, self.pubsub), tags=["Scopes"], prefix="/scopes", ) @@ -294,12 +266,9 @@ async def startup_event(): logger.info("*** OPAL Server Startup ***") try: - self._task = asyncio.create_task(self.start_server_background_tasks()) - + await self.start_server_background_tasks() except Exception: - logger.critical("Exception while starting OPAL") - traceback.print_exc() - + logger.exception("Exception while starting OPAL") sys.exit(1) @app.on_event("shutdown") @@ -309,6 +278,26 @@ async def shutdown_event(): return app + async def _wait_for_leadership(self): + # We want only one worker to run repo watchers + # (otherwise for each new commit, we will publish multiple updates via pub/sub). + # leadership is determined by the first worker to obtain a lock + self.leadership_lock = NamedLock(opal_server_config.LEADER_LOCK_FILE_PATH) + await self.leadership_lock.acquire() + # only one worker gets here, the others block. in case the leader worker + # is terminated, another one will obtain the lock and become leader. + logger.info( + "leadership lock acquired, leader pid: {pid}", + pid=os.getpid(), + ) + + if opal_server_config.SCOPES: + await load_scopes(self._scopes) + + if self._init_policy_watcher: + self.watcher = setup_watcher_task(self.pubsub) + await self.watcher.start() + async def start_server_background_tasks(self): """starts the background processes (as asyncio tasks) if such are configured. @@ -319,85 +308,36 @@ async def start_server_background_tasks(self): only the leader worker (first to obtain leadership lock) will start these tasks: - (repo) watcher: monitors the policy git repository for changes. """ - if self.publisher is not None: - async with self.publisher: - if self.opal_statistics is not None: - if self.broadcast_listening_context is not None: - logger.info( - "listening on broadcast channel for statistics events..." - ) - await self.broadcast_listening_context.__aenter__() - # if the broadcast channel is closed, we want to restart worker process because statistics can't be reliable anymore - self.broadcast_listening_context._event_broadcaster.get_reader_task().add_done_callback( - lambda _: self._graceful_shutdown() - ) - asyncio.create_task(self.opal_statistics.run()) - self.pubsub.endpoint.notifier.register_unsubscribe_event( - self.opal_statistics.remove_client - ) - - # We want only one worker to run repo watchers - # (otherwise for each new commit, we will publish multiple updates via pub/sub). - # leadership is determined by the first worker to obtain a lock - self.leadership_lock = NamedLock( - opal_server_config.LEADER_LOCK_FILE_PATH - ) - async with self.leadership_lock: - # only one worker gets here, the others block. in case the leader worker - # is terminated, another one will obtain the lock and become leader. - logger.info( - "leadership lock acquired, leader pid: {pid}", - pid=os.getpid(), - ) - - if opal_server_config.SCOPES: - await load_scopes(self._scopes) - - if self.broadcast_keepalive is not None: - self.broadcast_keepalive.start() - if not self._init_policy_watcher: - # Wait on keepalive instead to keep leadership lock acquired - await self.broadcast_keepalive.wait_until_done() - - if self._init_policy_watcher: - self.watcher = setup_watcher_task( - self.publisher, self.pubsub.endpoint - ) - # running the watcher, and waiting until it stops (until self.watcher.signal_stop() is called) - async with self.watcher: - await self.watcher.wait_until_should_stop() - - # Worker should restart when watcher stops - self._graceful_shutdown() - - if ( - self.opal_statistics is not None - and self.broadcast_listening_context is not None - ): - await self.broadcast_listening_context.__aexit__() - logger.info( - "stopped listening for statistics events on the broadcast channel" - ) + await self.pubsub.start() + + if self.opal_statistics is not None: + await self.opal_statistics.start() + + self._leadership_task = asyncio.create_task(self._wait_for_leadership()) async def stop_server_background_tasks(self): logger.info("stopping background tasks...") tasks: List[asyncio.Task] = [] - if self.watcher is not None: - tasks.append(asyncio.create_task(self.watcher.stop())) - if self.publisher is not None: - tasks.append(asyncio.create_task(self.publisher.stop())) - if self.broadcast_keepalive is not None: - tasks.append(asyncio.create_task(self.broadcast_keepalive.stop())) if self.opal_statistics is not None: tasks.append(asyncio.create_task(self.opal_statistics.stop())) + if self.watcher is not None: + tasks.append(asyncio.create_task(self.watcher.stop())) + if self._leadership_task is not None: + self._leadership_task.cancel() + tasks.append(self._leadership_task) + tasks.append(asyncio.create_task(self.pubsub.stop())) try: await asyncio.gather(*tasks) except Exception: logger.exception("exception while shutting down background tasks") - def _graceful_shutdown(self): + if self.leadership_lock.is_locked: + await self.leadership_lock.release() + + @staticmethod + async def _graceful_shutdown(): logger.info("Trigger worker graceful shutdown") os.kill(os.getpid(), signal.SIGTERM) diff --git a/packages/opal-server/opal_server/statistics.py b/packages/opal-server/opal_server/statistics.py index 14ea97f0a..b88ff38da 100644 --- a/packages/opal-server/opal_server/statistics.py +++ b/packages/opal-server/opal_server/statistics.py @@ -10,12 +10,11 @@ import pydantic from fastapi import APIRouter, HTTPException, status from fastapi_websocket_pubsub.event_notifier import Subscription, TopicList -from fastapi_websocket_pubsub.pub_sub_server import PubSubEndpoint from opal_common.async_utils import TasksPool from opal_common.config import opal_common_config from opal_common.logger import get_logger -from opal_common.topics.publisher import PeriodicPublisher from opal_server.config import opal_server_config +from opal_server.pubsub import PubSub from pydantic import BaseModel, Field @@ -75,8 +74,8 @@ class OpalStatistics: The pub/sub server endpoint that allows us to subscribe to the stats channel on the server side """ - def __init__(self, endpoint): - self._endpoint: PubSubEndpoint = endpoint + def __init__(self, pubsub): + self._pubsub: PubSub = pubsub self._uptime = datetime.utcnow() self._workers_count = (lambda envar: int(envar) if envar.isdigit() else 1)( os.environ.get("UVICORN_NUM_WORKERS", "1") @@ -135,7 +134,7 @@ async def _periodic_server_keepalive(self): while True: try: await self._expire_old_servers() - self._publish( + await self._publish( opal_server_config.STATISTICS_SERVER_KEEPALIVE_CHANNEL, ServerKeepalive(worker_id=self._worker_id).dict(), ) @@ -147,43 +146,45 @@ async def _periodic_server_keepalive(self): return except Exception as e: logger.exception("Statistics: periodic server keepalive failed") - logger.exception("Statistics: periodic server keepalive failed") - def _publish(self, channel: str, message: Any): - self._publish_tasks.add_task(self._endpoint.publish([channel], message)) + async def _publish(self, channel: str, message: Any): + await self._pubsub.publish([channel], message) - async def run(self): + async def start(self): """subscribe to two channels to be able to sync add and delete of clients.""" - await self._endpoint.subscribe( + await self._pubsub.subscribe( [opal_server_config.STATISTICS_WAKEUP_CHANNEL], self._receive_other_worker_wakeup_message, ) - await self._endpoint.subscribe( + await self._pubsub.subscribe( [opal_server_config.STATISTICS_STATE_SYNC_CHANNEL], self._receive_other_worker_synced_state, ) - await self._endpoint.subscribe( + await self._pubsub.subscribe( [opal_server_config.STATISTICS_SERVER_KEEPALIVE_CHANNEL], self._receive_other_worker_keepalive_message, ) - await self._endpoint.subscribe( + await self._pubsub.subscribe( [opal_common_config.STATISTICS_ADD_CLIENT_CHANNEL], self._add_client ) - await self._endpoint.subscribe( + await self._pubsub.subscribe( [opal_common_config.STATISTICS_REMOVE_CLIENT_CHANNEL], self._sync_remove_client, ) + self._pubsub.endpoint.notifier.register_unsubscribe_event( + self.remove_client + ) # TODO: Fix that # wait before publishing the wakeup message, due to the fact we are # counting on the broadcaster to listen and to replicate the message # to the other workers / server nodes in the networks. # However, since broadcaster is using asyncio.create_task(), there is a # race condition that is mitigated by this asyncio.sleep() call. - await asyncio.sleep(SLEEP_TIME_FOR_BROADCASTER_READER_TO_START) + # await asyncio.sleep(SLEEP_TIME_FOR_BROADCASTER_READER_TO_START) # Let all the other opal servers know that new opal server started logger.info(f"sending stats wakeup message: {self._worker_id}") - self._publish( + await self._publish( opal_server_config.STATISTICS_WAKEUP_CHANNEL, SyncRequest(requesting_worker_id=self._worker_id).dict(), ) @@ -242,7 +243,7 @@ async def _receive_other_worker_wakeup_message( logger.info( f"[{request.requesting_worker_id}] respond with my own stats" ) - self._publish( + await self._publish( opal_server_config.STATISTICS_STATE_SYNC_CHANNEL, SyncResponse( requesting_worker_id=request.requesting_worker_id, @@ -363,7 +364,7 @@ async def remove_client(self, rpc_id: str, topics: TopicList, publish=True): "Publish rpc_id={rpc_id} to be removed from statistics", rpc_id=rpc_id, ) - self._publish( + await self._publish( opal_common_config.STATISTICS_REMOVE_CLIENT_CHANNEL, rpc_id, )