Skip to content

Commit c2c894b

Browse files
authored
Merge pull request #100 from Project-OMOTES/99-allow-a-worker-internal-to-connect-to-multiple-workflow-types
99: Basic functioning of multiple workflow types per worker at once.
2 parents eb38742 + cbeb2c5 commit c2c894b

File tree

1 file changed

+45
-30
lines changed

1 file changed

+45
-30
lines changed

src/omotes_sdk/internal/worker/worker.py

Lines changed: 45 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,22 @@
22
import logging
33
import socket
44
import sys
5-
from typing import Callable, Dict, List, Any, Optional, Tuple
5+
from typing import Any, Callable, Dict, List, Optional, Tuple
66
from uuid import UUID
77

88
import streamcapture
99
from billiard.einfo import ExceptionInfo
10-
from celery import Task as CeleryTask, Celery
10+
from celery import Celery
11+
from celery import Task as CeleryTask
1112
from celery.apps.worker import Worker as CeleryWorker
12-
from kombu import Queue as KombuQueue
1313
from esdl import EnergySystem
14+
from kombu import Queue as KombuQueue
15+
from omotes_sdk_protocol.internal.task_pb2 import TaskProgressUpdate, TaskResult
1416

15-
from omotes_sdk.internal.orchestrator_worker_events.esdl_messages import EsdlMessage
17+
from omotes_sdk.internal.common.broker_interface import BrokerInterface
1618
from omotes_sdk.internal.common.esdl_util import pyesdl_from_string
19+
from omotes_sdk.internal.orchestrator_worker_events.esdl_messages import EsdlMessage
1720
from omotes_sdk.internal.worker.configs import WorkerConfig
18-
from omotes_sdk.internal.common.broker_interface import BrokerInterface
19-
from omotes_sdk_protocol.internal.task_pb2 import (
20-
TaskResult,
21-
TaskProgressUpdate,
22-
)
2321
from omotes_sdk.types import ProtobufDict
2422

2523
logger = logging.getLogger("omotes_sdk_internal")
@@ -59,7 +57,7 @@ def send_start(self) -> None:
5957
TaskProgressUpdate(
6058
job_id=str(self.job_id),
6159
celery_task_id=self.task.request.id,
62-
celery_task_type=WORKER_TASK_TYPE,
60+
celery_task_type=self.task.name,
6361
status=TaskProgressUpdate.START,
6462
message="Started job at worker.",
6563
).SerializeToString(),
@@ -84,7 +82,7 @@ def update_progress(self, fraction: float, message: str) -> None:
8482
TaskProgressUpdate(
8583
job_id=str(self.job_id),
8684
celery_task_id=self.task.request.id,
87-
celery_task_type=WORKER_TASK_TYPE,
85+
celery_task_type=self.task.name,
8886
progress=float(fraction),
8987
message=message,
9088
).SerializeToString(),
@@ -165,7 +163,7 @@ def after_return(
165163
result_message = TaskResult(
166164
job_id=str(job_id),
167165
celery_task_id=self.request.id,
168-
celery_task_type=WORKER_TASK_TYPE,
166+
celery_task_type=self.name,
169167
result_type=TaskResult.ResultType.ERROR,
170168
output_esdl="",
171169
logs=logs,
@@ -181,7 +179,7 @@ def after_return(
181179
result_message = TaskResult(
182180
job_id=str(job_id),
183181
celery_task_id=self.request.id,
184-
celery_task_type=WORKER_TASK_TYPE,
182+
celery_task_type=self.name,
185183
result_type=TaskResult.ResultType.SUCCEEDED,
186184
output_esdl=self.output_esdl,
187185
logs=logs,
@@ -258,17 +256,21 @@ def wrapped_worker_task(
258256
:param params_dict: job, non-ESDL, parameters.
259257
"""
260258
logger.info("Worker started new task %s with reference %s", job_id, job_reference)
261-
task_util = TaskUtil(job_id, task, task.broker_if)
259+
task_util = TaskUtil(
260+
job_id,
261+
task,
262+
task.broker_if,
263+
)
262264
task_util.send_start()
263265
output_esdl, esdl_messages = WORKER_TASK_FUNCTION(
264-
input_esdl, params_dict, task_util.update_progress
266+
input_esdl, params_dict, task_util.update_progress, task.name
265267
)
266268

267269
if output_esdl:
268270
input_esh = pyesdl_from_string(input_esdl)
269271
input_energy_system: EnergySystem = input_esh.energy_system
270272
if job_reference is None:
271-
new_name = f"{input_energy_system.name}_{WORKER_TASK_TYPE}"
273+
new_name = f"{input_energy_system.name}_{task.name}"
272274
elif job_reference == "":
273275
new_name = f"{input_energy_system.name}"
274276
else:
@@ -316,9 +318,21 @@ def start(self) -> None:
316318
)
317319

318320
# Config of celery app
319-
self.celery_app.conf.task_queues = [KombuQueue(
320-
WORKER_TASK_TYPE, routing_key=WORKER_TASK_TYPE, queue_arguments={"x-max-priority": 10}
321-
)] # Tell the worker to listen to a specific queue for 1 workflow type.
321+
queues = []
322+
for worker_task_type in WORKER_TASK_TYPES:
323+
logger.info("Starting Worker to work on task %s", worker_task_type)
324+
queues.append(
325+
KombuQueue(
326+
worker_task_type,
327+
routing_key=worker_task_type,
328+
queue_arguments={"x-max-priority": 10},
329+
)
330+
)
331+
self.celery_app.task(
332+
wrapped_worker_task, base=WorkerTask, name=worker_task_type, bind=True
333+
)
334+
335+
self.celery_app.conf.task_queues = queues
322336
self.celery_app.conf.task_acks_late = True
323337
self.celery_app.conf.task_reject_on_worker_lost = True
324338
self.celery_app.conf.task_acks_on_failure_or_timeout = False
@@ -331,9 +345,6 @@ def start(self) -> None:
331345
self.celery_app.conf.worker_hijack_root_logger = False
332346
self.celery_app.conf.worker_redirect_stdouts = False
333347

334-
self.celery_app.task(wrapped_worker_task, base=WorkerTask, name=WORKER_TASK_TYPE, bind=True)
335-
336-
logger.info("Starting Worker to work on task %s", WORKER_TASK_TYPE)
337348
logger.info(
338349
"Connected to broker rabbitmq (%s:%s/%s) as %s",
339350
rabbitmq_config.host,
@@ -343,7 +354,7 @@ def start(self) -> None:
343354
)
344355

345356
self.celery_worker = self.celery_app.Worker(
346-
hostname=f"worker-{WORKER_TASK_TYPE}@{socket.gethostname()}",
357+
hostname=f"worker-{'_'.join(WORKER_TASK_TYPES)}@{socket.gethostname()}",
347358
loglevel=logging.getLevelName(self.config.log_level),
348359
autoscale=(1, 1),
349360
)
@@ -353,7 +364,7 @@ def start(self) -> None:
353364

354365
UpdateProgressHandler = Callable[[float, str], None]
355366
WorkerTaskF = Callable[
356-
[str, ProtobufDict, UpdateProgressHandler],
367+
[str, ProtobufDict, UpdateProgressHandler, str],
357368
Tuple[
358369
Optional[str],
359370
List[EsdlMessage],
@@ -362,21 +373,25 @@ def start(self) -> None:
362373

363374
WORKER: Worker = None # type: ignore [assignment] # noqa
364375
WORKER_TASK_FUNCTION: WorkerTaskF = None # type: ignore [assignment] # noqa
365-
WORKER_TASK_TYPE: str = None # type: ignore [assignment] # noqa
376+
WORKER_TASK_TYPES: list[str] = None # type: ignore [assignment] # noqa
366377

367378

368379
def initialize_worker(
369-
task_type: str,
380+
task_types: list[str],
370381
task_function: WorkerTaskF,
371382
) -> None:
372383
"""Initialize and run the `Worker`.
373384
374-
:param task_type: Technical name of the task. Needs to be equal to the name of the celery task
375-
to which the orchestrator forwards the task.
385+
:param task_types: Technical name of the tasks. Needs to be equal to the name of the celery task
386+
to which the orchestrator forwards the task. May connect to one or more tasks.
376387
:param task_function: Function which performs the Celery task.
377388
"""
378-
global WORKER_TASK_FUNCTION, WORKER_TASK_TYPE, WORKER
379-
WORKER_TASK_TYPE = task_type
389+
global WORKER_TASK_FUNCTION, WORKER_TASK_TYPES, WORKER
390+
WORKER_TASK_TYPES = task_types
391+
if len(WORKER_TASK_TYPES) < 1:
392+
raise RuntimeError(
393+
f"Should connect to one or more worker task types. Only found {len(WORKER_TASK_TYPES)}"
394+
)
380395
WORKER_TASK_FUNCTION = task_function
381396
WORKER = Worker()
382397
WORKER.start()

0 commit comments

Comments
 (0)