From 7c4409e3580cc27796ab6b528645f41533971bdf Mon Sep 17 00:00:00 2001 From: Fred Moolekamp Date: Mon, 23 Oct 2023 16:36:04 -0400 Subject: [PATCH] Responses to reviewer comments (rebase before merging) --- .../lsst/rubintv/analysis/service/__init__.py | 2 +- .../lsst/rubintv/analysis/service/client.py | 90 +++++++------ .../lsst/rubintv/analysis/service/command.py | 21 +-- .../lsst/rubintv/analysis/service/database.py | 2 +- python/lsst/rubintv/analysis/service/query.py | 29 ++-- python/lsst/rubintv/analysis/service/utils.py | 40 ++++++ scripts/mock_server.py | 125 +++++++++--------- scripts/rubintv_worker.py | 5 +- 8 files changed, 177 insertions(+), 137 deletions(-) create mode 100644 python/lsst/rubintv/analysis/service/utils.py diff --git a/python/lsst/rubintv/analysis/service/__init__.py b/python/lsst/rubintv/analysis/service/__init__.py index e144f74..ae91fbd 100644 --- a/python/lsst/rubintv/analysis/service/__init__.py +++ b/python/lsst/rubintv/analysis/service/__init__.py @@ -1 +1 @@ -from . import command, database, query +from . import command, database, query, utils diff --git a/python/lsst/rubintv/analysis/service/client.py b/python/lsst/rubintv/analysis/service/client.py index 9504b0b..38fba82 100644 --- a/python/lsst/rubintv/analysis/service/client.py +++ b/python/lsst/rubintv/analysis/service/client.py @@ -26,56 +26,64 @@ from websocket import WebSocketApp from .command import DatabaseConnection, execute_command +from .utils import printc, Colors logger = logging.getLogger("lsst.rubintv.analysis.service.client") -def on_error(ws: WebSocketApp, error: str) -> None: - """Error received from the server.""" - print(f"\033[91mError: {error}\033[0m") +class Worker: + def __init__(self, address: str, port: int, connection_info: dict[str, dict]): + self._address = address + self._port = port + self._connection_info = connection_info + def on_error(self, ws: WebSocketApp, error: str) -> None: + """Error received from the server.""" + printc(f"Error: {error}", color=Colors.BRIGHT_RED) -def on_close(ws: WebSocketApp, close_status_code: str, close_msg: str) -> None: - """Connection closed by the server.""" - print("\033[93mConnection closed\033[0m") + def on_close(self, ws: WebSocketApp, close_status_code: str, close_msg: str) -> None: + """Connection closed by the server.""" + printc("Connection closed", Colors.BRIGHT_YELLOW) + def run(self) -> None: + """Run the worker and connect to the rubinTV server. -def run_worker(address: str, port: int, connection_info: dict[str, dict]) -> None: - """Run the worker and connect to the rubinTV server. + Parameters + ---------- + address : + Address of the rubinTV web app. + port : + Port of the rubinTV web app websockets. + connection_info : + Connections . + """ + # Load the database connection information + databases: dict[str, DatabaseConnection] = {} - Parameters - ---------- - address : - Address of the rubinTV web app. - port : - Port of the rubinTV web app websockets. - connection_info : - Connections . - """ - # Load the database connection information - databases: dict[str, DatabaseConnection] = {} + for name, info in self._connection_info["databases"].items(): + with open(info["schema"], "r") as file: + engine = sqlalchemy.create_engine(info["url"]) + schema = yaml.safe_load(file) + databases[name] = DatabaseConnection(schema=schema, engine=engine) - for name, info in connection_info["databases"].items(): - with open(info["schema"], "r") as file: - engine = sqlalchemy.create_engine(info["url"]) - schema = yaml.safe_load(file) - databases[name] = DatabaseConnection(schema=schema, engine=engine) + # Load the Butler (if one is available) + butler: Butler | None = None + if "butler" in self._connection_info: + repo = self._connection_info["butler"].pop("repo") + butler = Butler(repo, **self._connection_info["butler"]) - # Load the Butler (if one is available) - butler: Butler | None = None - if "butler" in connection_info: - repo = connection_info["butler"].pop("repo") - butler = Butler(repo, **connection_info["butler"]) + def on_message(ws: WebSocketApp, message: str) -> None: + """Message received from the server.""" + response = execute_command(message, databases, butler) + ws.send(response) - def on_message(ws: WebSocketApp, message: str) -> None: - """Message received from the server.""" - response = execute_command(message, databases, butler) - ws.send(response) - - print(f"\033[92mConnecting to rubinTV at {address}:{port}\033[0m") - # Connect to the WebSocket server - ws = WebSocketApp( - f"ws://{address}:{port}/ws/worker", on_message=on_message, on_error=on_error, on_close=on_close - ) - ws.run_forever() - ws.close() + printc(f"Connecting to rubinTV at {self._address}:{self._port}", Colors.BRIGHT_GREEN) + # Connect to the WebSocket server + ws = WebSocketApp( + f"ws://{self._address}:{self._port}/ws/worker", + on_message=on_message, + on_error=self.on_error, + on_close=self.on_close, + ) + ws.run_forever() + ws.close() diff --git a/python/lsst/rubintv/analysis/service/command.py b/python/lsst/rubintv/analysis/service/command.py index c90e3ab..776ea2b 100644 --- a/python/lsst/rubintv/analysis/service/command.py +++ b/python/lsst/rubintv/analysis/service/command.py @@ -141,6 +141,7 @@ class BaseCommand(ABC): This should be unique for each command. """ + command_registry = {} result: dict | None = None response_type: str @@ -153,7 +154,7 @@ def build_contents(self, databases: dict[str, DatabaseConnection], butler: Butle databases : The database connections. butler : - A conencted Butler. + A connected Butler. Returns ------- @@ -187,11 +188,7 @@ def to_json(self): @classmethod def register(cls, name: str): """Register a command.""" - command_registry[name] = cls - - -# Registry of all commands -command_registry = {} + BaseCommand.command_registry[name] = cls def execute_command(command_str: str, databases: dict[str, DatabaseConnection], butler: Butler | None) -> str: @@ -212,7 +209,7 @@ def execute_command(command_str: str, databases: dict[str, DatabaseConnection], databases : The database connections. butler : - A conencted Butler. + A connected Butler. """ try: command_dict = json.loads(command_str) @@ -226,15 +223,11 @@ def execute_command(command_str: str, databases: dict[str, DatabaseConnection], if "name" not in command_dict.keys(): raise CommandParsingError("No command 'name' given") - if command_dict["name"] not in command_registry.keys(): + if command_dict["name"] not in BaseCommand.command_registry.keys(): raise CommandParsingError(f"Unrecognized command '{command_dict['name']}'") - if "parameters" in command_dict: - parameters = command_dict["parameters"] - else: - parameters = {} - - command = command_registry[command_dict["name"]](**parameters) + parameters = command_dict.get("parameters", {}) + command = BaseCommand.command_registry[command_dict["name"]](**parameters) except Exception as err: logging.exception("Error parsing command.") diff --git a/python/lsst/rubintv/analysis/service/database.py b/python/lsst/rubintv/analysis/service/database.py index 6ba8851..eb06f48 100644 --- a/python/lsst/rubintv/analysis/service/database.py +++ b/python/lsst/rubintv/analysis/service/database.py @@ -219,7 +219,7 @@ def build_contents(self, databases: dict[str, DatabaseConnection], butler: Butle engine=database.engine, ) - if len(data) == 0: + if not data: # There is no column data to return content: dict = { "columns": columns, diff --git a/python/lsst/rubintv/analysis/service/query.py b/python/lsst/rubintv/analysis/service/query.py index 49d4fa4..4f6c0bf 100644 --- a/python/lsst/rubintv/analysis/service/query.py +++ b/python/lsst/rubintv/analysis/service/query.py @@ -123,23 +123,24 @@ class ParentQuery(Query): """ def __init__(self, children: list[Query], operator: str): - self.children = children - self.operator = operator + self._children = children + self._operator = operator def __call__(self, table: sqlalchemy.Table) -> sqlalchemy.sql.elements.BooleanClauseList: - child_results = [child(table) for child in self.children] + child_results = [child(table) for child in self._children] try: - if self.operator == "AND": - return sqlalchemy.and_(*child_results) - if self.operator == "OR": - return sqlalchemy.or_(*child_results) - if self.operator == "NOT": - return sqlalchemy.not_(*child_results) - if self.operator == "XOR": - return sqlalchemy.and_( - sqlalchemy.or_(*child_results), - sqlalchemy.not_(sqlalchemy.and_(*child_results)), - ) + match self._operator: + case "AND": + return sqlalchemy.and_(*child_results) + case "OR": + return sqlalchemy.or_(*child_results) + case "NOT": + return sqlalchemy.not_(*child_results) + case "XOR": + return sqlalchemy.and_( + sqlalchemy.or_(*child_results), + sqlalchemy.not_(sqlalchemy.and_(*child_results)), + ) except Exception: raise QueryError("Error applying a boolean query statement.") diff --git a/python/lsst/rubintv/analysis/service/utils.py b/python/lsst/rubintv/analysis/service/utils.py new file mode 100644 index 0000000..806215b --- /dev/null +++ b/python/lsst/rubintv/analysis/service/utils.py @@ -0,0 +1,40 @@ +from enum import Enum + + +# ANSI color codes for printing to the terminal +class Colors(Enum): + RESET = 0 + BLACK = 30 + RED = 31 + GREEN = 32 + YELLOW = 33 + BLUE = 34 + MAGENTA = 35 + CYAN = 36 + WHITE = 37 + DEFAULT = 39 + BRIGHT_BLACK = 90 + BRIGHT_RED = 91 + BRIGHT_GREEN = 92 + BRIGHT_YELLOW = 93 + BRIGHT_BLUE = 94 + BRIGHT_MAGENTA = 95 + BRIGHT_CYAN = 96 + BRIGHT_WHITE = 97 + + +def printc(message: str, color: Colors, end_color: Colors = Colors.RESET): + """Print a message to the terminal in color. + + After printing reset the color by default. + + Parameters + ---------- + message : + The message to print. + color : + The color to print the message in. + end : + The color future messages should be printed in. + """ + print(f"\033[{color.value}m{message}\033[{end_color.value}m") diff --git a/scripts/mock_server.py b/scripts/mock_server.py index 814103e..eb3797c 100644 --- a/scripts/mock_server.py +++ b/scripts/mock_server.py @@ -19,6 +19,8 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +from __future__ import annotations + import uuid from dataclasses import dataclass from enum import Enum @@ -28,40 +30,13 @@ import tornado.web import tornado.websocket +from lsst.rubintv.analysis.service.utils import printc, Colors + # Default port and address to listen on LISTEN_PORT = 2000 LISTEN_ADDRESS = "localhost" -# ANSI color codes for printing to the terminal -ansi_colors = { - "black": "30", - "red": "31", - "green": "32", - "yellow": "33", - "blue": "34", - "magenta": "35", - "cyan": "36", - "white": "37", -} - - -def log(message, color, end="\033[31"): - """Print a message to the terminal in color. - - Parameters - ---------- - message : - The message to print. - color : - The color to print the message in. - end : - The color future messages should be printed in. - """ - _color = ansi_colors[color] - print(f"\033[{_color}m{message}{end}m") - - class WorkerPodStatus(Enum): """Status of a worker pod.""" @@ -74,6 +49,10 @@ class WebSocketHandler(tornado.websocket.WebSocketHandler): Handler that handles WebSocket connections """ + workers: dict[str, WorkerPod] = dict() # Keep track of connected worker pods + clients: dict[str, WebSocketHandler] = dict() # Keep track of connected clients + queue: list[QueueItem] = list() # Queue of messages to be processed + @classmethod def urls(cls) -> list[tuple[str, type[tornado.web.RequestHandler], dict[str, str]]]: """url to handle websocket connections. @@ -85,7 +64,7 @@ def urls(cls) -> list[tuple[str, type[tornado.web.RequestHandler], dict[str, str (r"/ws/([^/]+)", cls, {}), # Route/Handler/kwargs ] - def open(self, type: str) -> None: + def open(self, client_type: str) -> None: """ Client opens a websocket @@ -95,12 +74,20 @@ def open(self, type: str) -> None: The type of client that is connecting. """ self.client_id = str(uuid.uuid4()) - if type == "worker": - workers[self.client_id] = WorkerPod(self.client_id, self) - log(f"New worker {self.client_id} connected. Total workers: {len(workers)}", "blue") - if type == "client": - clients[self.client_id] = self - log(f"New client {self.client_id} connected. Total clients: {len(clients)}", "yellow") + if client_type == "worker": + WebSocketHandler.workers[self.client_id] = WorkerPod(self.client_id, self) + printc( + f"New worker {self.client_id} connected. Total workers: {len(WebSocketHandler.workers)}", + Colors.BLUE, + Colors.RED, + ) + if client_type == "client": + WebSocketHandler.clients[self.client_id] = self + printc( + f"New client {self.client_id} connected. Total clients: {len(WebSocketHandler.clients)}", + Colors.YELLOW, + Colors.RED, + ) def on_message(self, message: str) -> None: """ @@ -111,35 +98,36 @@ def on_message(self, message: str) -> None: message : The message received from the client or worker. """ - if self.client_id in clients: - log(f"Message received from {self.client_id}", "yellow") - client = clients[self.client_id] + if self.client_id in WebSocketHandler.clients: + printc(f"Message received from {self.client_id}", Colors.YELLOW, Colors.RED) + client = WebSocketHandler.clients[self.client_id] # Find an idle worker idle_worker = None - for worker in workers.values(): + for worker in WebSocketHandler.workers.values(): if worker.status == WorkerPodStatus.IDLE: idle_worker = worker break if idle_worker is None: # No idle worker found, add to queue - queue.append(QueueItem(message, client)) + WebSocketHandler.queue.append(QueueItem(message, client)) return idle_worker.process(message, client) return - if self.client_id in workers: - worker = workers[self.client_id] + if self.client_id in WebSocketHandler.workers: + worker = WebSocketHandler.workers[self.client_id] worker.on_finished(message) - log( + printc( f"Message received from worker {self.client_id}. New status {worker.status}", - "blue", + Colors.BLUE, + Colors.RED, ) # Check the queue for any outstanding jobs. - if len(queue) > 0: - queue_item = queue.pop(0) + if len(WebSocketHandler.queue) > 0: + queue_item = WebSocketHandler.queue.pop(0) worker.process(queue_item.message, queue_item.client) return @@ -147,16 +135,24 @@ def on_close(self) -> None: """ Client closes the connection """ - if self.client_id in clients: - del clients[self.client_id] - log(f"Client disconnected. Active clients: {len(clients)}", "yellow") - for worker in workers.values(): + if self.client_id in WebSocketHandler.clients: + del WebSocketHandler.clients[self.client_id] + printc( + f"Client disconnected. Active clients: {len(WebSocketHandler.clients)}", + Colors.YELLOW, + Colors.RED, + ) + for worker in WebSocketHandler.workers.values(): if worker.connected_client == self: worker.on_finished("Client disconnected") break - if self.client_id in workers: - del workers[self.client_id] - log(f"Worker disconnected. Active workers: {len(workers)}", "blue") + if self.client_id in WebSocketHandler.workers: + del WebSocketHandler.workers[self.client_id] + printc( + f"Worker disconnected. Active workers: {len(WebSocketHandler.workers)}", + Colors.BLUE, + Colors.RED, + ) def check_origin(self, origin): """ @@ -183,8 +179,8 @@ class WorkerPod: status: WorkerPodStatus connected_client: WebSocketHandler | None - def __init__(self, id: str, ws: WebSocketHandler): - self.id = id + def __init__(self, wid: str, ws: WebSocketHandler): + self.wid = wid self.ws = ws self.status = WorkerPodStatus.IDLE self.connected_client = None @@ -201,7 +197,11 @@ def process(self, message: str, connected_client: WebSocketHandler): """ self.status = WorkerPodStatus.BUSY self.connected_client = connected_client - log(f"Worker {self.id} processing message from client {connected_client.client_id}", "blue") + printc( + f"Worker {self.wid} processing message from client {connected_client.client_id}", + Colors.BLUE, + Colors.RED, + ) # Send the job to the worker pod self.ws.write_message(message) @@ -215,7 +215,9 @@ def on_finished(self, message): # Send the reply to the client that made the request. self.connected_client.write_message(message) else: - log(f"Worker {self.id} finished processing, but no client was connected.", "red") + printc( + f"Worker {self.wid} finished processing, but no client was connected.", Colors.RED, Colors.RED + ) self.status = WorkerPodStatus.IDLE self.connected_client = None @@ -236,11 +238,6 @@ class QueueItem: client: WebSocketHandler -workers: dict[str, WorkerPod] = dict() # Keep track of connected worker pods -clients: dict[str, WebSocketHandler] = dict() # Keep track of connected clients -queue: list[QueueItem] = list() # Queue of messages to be processed - - def main(): # Create tornado application and supply URL routes app = tornado.web.Application(WebSocketHandler.urls()) # type: ignore @@ -249,7 +246,7 @@ def main(): http_server = tornado.httpserver.HTTPServer(app) http_server.listen(LISTEN_PORT, LISTEN_ADDRESS) - log(f"Listening on address: {LISTEN_ADDRESS}, {LISTEN_PORT}", "green") + printc(f"Listening on address: {LISTEN_ADDRESS}, {LISTEN_PORT}", Colors.GREEN, Colors.RED) # Start IO/Event loop tornado.ioloop.IOLoop.instance().start() diff --git a/scripts/rubintv_worker.py b/scripts/rubintv_worker.py index ad03071..97152bf 100644 --- a/scripts/rubintv_worker.py +++ b/scripts/rubintv_worker.py @@ -24,7 +24,7 @@ import pathlib import yaml -from lsst.rubintv.analysis.service.client import run_worker +from lsst.rubintv.analysis.service.client import Worker default_config = os.path.join(pathlib.Path(__file__).parent.absolute(), "config.yaml") @@ -47,7 +47,8 @@ def main(): config = yaml.safe_load(file) # Run the client and connect to rubinTV via websockets - run_worker(args.address, args.port, config) + worker = Worker(args.address, args.port, config) + worker.run() if __name__ == "__main__":