-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement code run run a worker pod client
This commit also contains a mock_server to run a simple web app that uses the worker pods to process commands sent by clients.
- Loading branch information
Showing
5 changed files
with
213 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,6 +33,7 @@ dependencies = [ | |
"pyyaml", | ||
"sqlalchemy", | ||
"astropy", | ||
"websocket-client", | ||
] | ||
#dynamic = ["version"] | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import logging | ||
import sqlalchemy | ||
import yaml | ||
|
||
import websocket | ||
|
||
from . import command | ||
|
||
logger = logging.getLogger("lsst.rubintv.analysis.service.client") | ||
|
||
|
||
yaml_filename = filename = "/Users/fred3m/temp/visitDb/summit.yaml" | ||
schemas = {} | ||
|
||
with open(filename, "r") as file: | ||
schemas["visit"] = yaml.safe_load(file) | ||
|
||
engine = sqlalchemy.create_engine("sqlite:////Users/fred3m/temp/visitDb/visit.db") | ||
|
||
|
||
def on_message(ws, message): | ||
print(f"Received: {message}") | ||
response = command.execute_command(message, schemas["visit"], engine) | ||
ws.send(response) | ||
|
||
|
||
def on_error(ws, error): | ||
print(f"\033[91mError: {error}\033[0m") | ||
|
||
|
||
def on_close(ws, close_status_code, close_msg): | ||
print("\033[93mConnection closed\033[0m") | ||
|
||
|
||
def run_client(address, port): | ||
print(f"\033[92mConnecting to rubinTV at {address}:{port}\033[0m") | ||
# Connect to the WebSocket server | ||
ws = websocket.WebSocketApp(f"ws://{address}:{port}/ws/worker", | ||
on_message=on_message, | ||
on_error=on_error, | ||
on_close=on_close) | ||
ws.run_forever() | ||
ws.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,3 +5,4 @@ pydantic | |
pyyaml | ||
sqlalchemy | ||
astropy | ||
websocket |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
from enum import Enum | ||
import uuid | ||
|
||
import tornado.web | ||
import tornado.httpserver | ||
import tornado.ioloop | ||
import tornado.websocket | ||
|
||
LISTEN_PORT = 2000 | ||
LISTEN_ADDRESS = 'localhost' | ||
|
||
|
||
ansi_colors = { | ||
"black": "30", | ||
"red": "31", | ||
"green": "32", | ||
"yellow": "33", | ||
"blue": "34", | ||
"magenta": "35", | ||
"cyan": "36", | ||
"white": "37", | ||
} | ||
|
||
|
||
def log(message, color, end="\033[31"): | ||
_color = ansi_colors[color] | ||
print(f"\033[{_color}m{message}{end}m") | ||
|
||
|
||
class WorkerPodStatus(Enum): | ||
IDLE = "idle" | ||
BUSY = "busy" | ||
|
||
|
||
class WebSocketHandler(tornado.websocket.WebSocketHandler): | ||
""" | ||
Handler that handles WebSocket connections | ||
""" | ||
|
||
@classmethod | ||
def urls(cls): | ||
return [ | ||
(r'/ws/([^/]+)', cls, {}), # Route/Handler/kwargs | ||
] | ||
|
||
def open(self, type: str) -> None: | ||
""" | ||
Client opens a websocket | ||
""" | ||
self.client_id = str(uuid.uuid4()) | ||
if type == "worker": | ||
workers[self.client_id] = WorkerPod(self.client_id, self) | ||
log(f"New worker {self.client_id} connected. Total workers: {len(workers)}", "blue") | ||
if type == "client": | ||
clients[self.client_id] = self | ||
log(f"New client {self.client_id} connected. Total clients: {len(clients)}", "yellow") | ||
|
||
def on_message(self, message: str) -> None: | ||
""" | ||
Message received from a client or worker. | ||
""" | ||
log(f"Message received from {self.client_id}: {message}", "yellow") | ||
if self.client_id in clients: | ||
client = clients[self.client_id] | ||
|
||
# Find an idle worker | ||
idle_worker = None | ||
for worker in workers.values(): | ||
if worker.status == WorkerPodStatus.IDLE: | ||
idle_worker = worker | ||
break | ||
|
||
if idle_worker is None: | ||
# No idle worker found, add to queue | ||
queue.append(QueueItem(message, client)) | ||
return | ||
idle_worker.process(message, client) | ||
return | ||
|
||
if self.client_id in workers: | ||
worker = workers[self.client_id] | ||
worker.on_finished(message) | ||
|
||
# Check the queue for any outstanding jobs. | ||
if len(queue) > 0: | ||
queue_item = queue.pop(0) | ||
worker.process(queue_item.message, queue_item.client) | ||
return | ||
|
||
def on_close(self) -> None: | ||
""" | ||
Client closes the connection | ||
""" | ||
if self.client_id in clients: | ||
del clients[self.client_id] | ||
log(f"Client disconnected. Active clients: {len(clients)}", "yellow") | ||
if self.client_id in workers: | ||
del workers[self.client_id] | ||
log(f"Worker disconnected. Active workers: {len(workers)}", "blue") | ||
|
||
def check_origin(self, origin): | ||
""" | ||
Override the origin check if needed | ||
""" | ||
return True | ||
|
||
|
||
class WorkerPod: | ||
status: WorkerPodStatus | ||
connected_client: WebSocketHandler | ||
|
||
def __init__(self, id: str, ws: WebSocketHandler): | ||
self.id = id | ||
self.ws = ws | ||
self.status = WorkerPodStatus.IDLE | ||
self.connected_client = None | ||
|
||
def process(self, message: str, connected_client: WebSocketHandler): | ||
self.status = WorkerPodStatus.BUSY | ||
self.connected_client = connected_client | ||
self.ws.write_message(message) | ||
|
||
def on_finished(self, message): | ||
self.connected_client.write_message(message) | ||
self.status = WorkerPodStatus.IDLE | ||
self.connected_client = None | ||
|
||
|
||
class QueueItem: | ||
def __init__(self, message, client): | ||
self.message = message | ||
self.client = client | ||
|
||
|
||
workers: dict[str, WorkerPod] = dict() # Keep track of connected worker pods | ||
clients: dict[str: WebSocketHandler] = dict() # Keep track of connected clients | ||
queue: list[QueueItem] = list() # Queue of messages to be processed | ||
|
||
|
||
def main(): | ||
# Create tornado application and supply URL routes | ||
app = tornado.web.Application(WebSocketHandler.urls()) | ||
|
||
# Setup HTTP Server | ||
http_server = tornado.httpserver.HTTPServer(app) | ||
http_server.listen(LISTEN_PORT, LISTEN_ADDRESS) | ||
|
||
log(f"Listening on address: {LISTEN_ADDRESS}, {LISTEN_PORT}", "green") | ||
|
||
# Start IO/Event loop | ||
tornado.ioloop.IOLoop.instance().start() | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import sys | ||
|
||
from lsst.rubintv.analysis.service import client | ||
|
||
if __name__ == "__main__": | ||
# Run the client and connect to rubinTV via websockets | ||
args = sys.argv[1:] | ||
if len(args) == 2: | ||
client.run_client(args[0], int(args[1])) | ||
elif len(args) != 0: | ||
raise RuntimeError("Usage: python -m lsst.rubintv.analysis.service.client [address] [port]") | ||
else: | ||
client.run_client("localhost", 2000) |