From 9a3513d0393fc9638fe372fb3f79c4b701a05d75 Mon Sep 17 00:00:00 2001 From: Fred Moolekamp Date: Tue, 10 Oct 2023 10:40:06 -0400 Subject: [PATCH] Implement code run run a worker pod client This commit also contains a mock_server to run a simple web app that uses the worker pods to process commands sent by clients. --- pyproject.toml | 1 + .../lsst/rubintv/analysis/service/client.py | 43 +++++ requirements.txt | 1 + scripts/mock_server.py | 155 ++++++++++++++++++ scripts/rubinTv_analysis.py | 13 ++ 5 files changed, 213 insertions(+) create mode 100644 python/lsst/rubintv/analysis/service/client.py create mode 100644 scripts/mock_server.py create mode 100644 scripts/rubinTv_analysis.py diff --git a/pyproject.toml b/pyproject.toml index 5a4e43d..69d3b57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "pyyaml", "sqlalchemy", "astropy", + "websocket-client", ] #dynamic = ["version"] diff --git a/python/lsst/rubintv/analysis/service/client.py b/python/lsst/rubintv/analysis/service/client.py new file mode 100644 index 0000000..ffc7529 --- /dev/null +++ b/python/lsst/rubintv/analysis/service/client.py @@ -0,0 +1,43 @@ +import logging +import sqlalchemy +import yaml + +import websocket + +from . import command + +logger = logging.getLogger("lsst.rubintv.analysis.service.client") + + +yaml_filename = filename = "/Users/fred3m/temp/visitDb/summit.yaml" +schemas = {} + +with open(filename, "r") as file: + schemas["visit"] = yaml.safe_load(file) + +engine = sqlalchemy.create_engine("sqlite:////Users/fred3m/temp/visitDb/visit.db") + + +def on_message(ws, message): + print(f"Received: {message}") + response = command.execute_command(message, schemas["visit"], engine) + ws.send(response) + + +def on_error(ws, error): + print(f"\033[91mError: {error}\033[0m") + + +def on_close(ws, close_status_code, close_msg): + print("\033[93mConnection closed\033[0m") + + +def run_client(address, port): + print(f"\033[92mConnecting to rubinTV at {address}:{port}\033[0m") + # Connect to the WebSocket server + ws = websocket.WebSocketApp(f"ws://{address}:{port}/ws/worker", + on_message=on_message, + on_error=on_error, + on_close=on_close) + ws.run_forever() + ws.close() diff --git a/requirements.txt b/requirements.txt index cae54f4..51bdbf2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ pydantic pyyaml sqlalchemy astropy +websocket diff --git a/scripts/mock_server.py b/scripts/mock_server.py new file mode 100644 index 0000000..409e065 --- /dev/null +++ b/scripts/mock_server.py @@ -0,0 +1,155 @@ +from enum import Enum +import uuid + +import tornado.web +import tornado.httpserver +import tornado.ioloop +import tornado.websocket + +LISTEN_PORT = 2000 +LISTEN_ADDRESS = 'localhost' + + +ansi_colors = { + "black": "30", + "red": "31", + "green": "32", + "yellow": "33", + "blue": "34", + "magenta": "35", + "cyan": "36", + "white": "37", +} + + +def log(message, color, end="\033[31"): + _color = ansi_colors[color] + print(f"\033[{_color}m{message}{end}m") + + +class WorkerPodStatus(Enum): + IDLE = "idle" + BUSY = "busy" + + +class WebSocketHandler(tornado.websocket.WebSocketHandler): + """ + Handler that handles WebSocket connections + """ + + @classmethod + def urls(cls): + return [ + (r'/ws/([^/]+)', cls, {}), # Route/Handler/kwargs + ] + + def open(self, type: str) -> None: + """ + Client opens a websocket + """ + self.client_id = str(uuid.uuid4()) + if type == "worker": + workers[self.client_id] = WorkerPod(self.client_id, self) + log(f"New worker {self.client_id} connected. Total workers: {len(workers)}", "blue") + if type == "client": + clients[self.client_id] = self + log(f"New client {self.client_id} connected. Total clients: {len(clients)}", "yellow") + + def on_message(self, message: str) -> None: + """ + Message received from a client or worker. + """ + log(f"Message received from {self.client_id}: {message}", "yellow") + if self.client_id in clients: + client = clients[self.client_id] + + # Find an idle worker + idle_worker = None + for worker in workers.values(): + if worker.status == WorkerPodStatus.IDLE: + idle_worker = worker + break + + if idle_worker is None: + # No idle worker found, add to queue + queue.append(QueueItem(message, client)) + return + idle_worker.process(message, client) + return + + if self.client_id in workers: + worker = workers[self.client_id] + worker.on_finished(message) + + # Check the queue for any outstanding jobs. + if len(queue) > 0: + queue_item = queue.pop(0) + worker.process(queue_item.message, queue_item.client) + return + + def on_close(self) -> None: + """ + Client closes the connection + """ + if self.client_id in clients: + del clients[self.client_id] + log(f"Client disconnected. Active clients: {len(clients)}", "yellow") + if self.client_id in workers: + del workers[self.client_id] + log(f"Worker disconnected. Active workers: {len(workers)}", "blue") + + def check_origin(self, origin): + """ + Override the origin check if needed + """ + return True + + +class WorkerPod: + status: WorkerPodStatus + connected_client: WebSocketHandler + + def __init__(self, id: str, ws: WebSocketHandler): + self.id = id + self.ws = ws + self.status = WorkerPodStatus.IDLE + self.connected_client = None + + def process(self, message: str, connected_client: WebSocketHandler): + self.status = WorkerPodStatus.BUSY + self.connected_client = connected_client + self.ws.write_message(message) + + def on_finished(self, message): + self.connected_client.write_message(message) + self.status = WorkerPodStatus.IDLE + self.connected_client = None + + +class QueueItem: + def __init__(self, message, client): + self.message = message + self.client = client + + +workers: dict[str, WorkerPod] = dict() # Keep track of connected worker pods +clients: dict[str: WebSocketHandler] = dict() # Keep track of connected clients +queue: list[QueueItem] = list() # Queue of messages to be processed + + +def main(): + # Create tornado application and supply URL routes + app = tornado.web.Application(WebSocketHandler.urls()) + + # Setup HTTP Server + http_server = tornado.httpserver.HTTPServer(app) + http_server.listen(LISTEN_PORT, LISTEN_ADDRESS) + + log(f"Listening on address: {LISTEN_ADDRESS}, {LISTEN_PORT}", "green") + + # Start IO/Event loop + tornado.ioloop.IOLoop.instance().start() + + +if __name__ == '__main__': + main() diff --git a/scripts/rubinTv_analysis.py b/scripts/rubinTv_analysis.py new file mode 100644 index 0000000..77f2652 --- /dev/null +++ b/scripts/rubinTv_analysis.py @@ -0,0 +1,13 @@ +import sys + +from lsst.rubintv.analysis.service import client + +if __name__ == "__main__": + # Run the client and connect to rubinTV via websockets + args = sys.argv[1:] + if len(args) == 2: + client.run_client(args[0], int(args[1])) + elif len(args) != 0: + raise RuntimeError("Usage: python -m lsst.rubintv.analysis.service.client [address] [port]") + else: + client.run_client("localhost", 2000)