22import logging
33import socket
44import sys
5- from typing import Callable , Dict , List , Any , Optional , Tuple
5+ from typing import Any , Callable , Dict , List , Optional , Tuple
66from uuid import UUID
77
88import streamcapture
99from billiard .einfo import ExceptionInfo
10- from celery import Task as CeleryTask , Celery
10+ from celery import Celery
11+ from celery import Task as CeleryTask
1112from celery .apps .worker import Worker as CeleryWorker
12- from kombu import Queue as KombuQueue
1313from esdl import EnergySystem
14+ from kombu import Queue as KombuQueue
15+ from omotes_sdk_protocol .internal .task_pb2 import TaskProgressUpdate , TaskResult
1416
15- from omotes_sdk .internal .orchestrator_worker_events . esdl_messages import EsdlMessage
17+ from omotes_sdk .internal .common . broker_interface import BrokerInterface
1618from omotes_sdk .internal .common .esdl_util import pyesdl_from_string
19+ from omotes_sdk .internal .orchestrator_worker_events .esdl_messages import EsdlMessage
1720from omotes_sdk .internal .worker .configs import WorkerConfig
18- from omotes_sdk .internal .common .broker_interface import BrokerInterface
19- from omotes_sdk_protocol .internal .task_pb2 import (
20- TaskResult ,
21- TaskProgressUpdate ,
22- )
2321from omotes_sdk .types import ProtobufDict
2422
2523logger = logging .getLogger ("omotes_sdk_internal" )
@@ -59,7 +57,7 @@ def send_start(self) -> None:
5957 TaskProgressUpdate (
6058 job_id = str (self .job_id ),
6159 celery_task_id = self .task .request .id ,
62- celery_task_type = WORKER_TASK_TYPE ,
60+ celery_task_type = self . task . name ,
6361 status = TaskProgressUpdate .START ,
6462 message = "Started job at worker." ,
6563 ).SerializeToString (),
@@ -84,7 +82,7 @@ def update_progress(self, fraction: float, message: str) -> None:
8482 TaskProgressUpdate (
8583 job_id = str (self .job_id ),
8684 celery_task_id = self .task .request .id ,
87- celery_task_type = WORKER_TASK_TYPE ,
85+ celery_task_type = self . task . name ,
8886 progress = float (fraction ),
8987 message = message ,
9088 ).SerializeToString (),
@@ -165,7 +163,7 @@ def after_return(
165163 result_message = TaskResult (
166164 job_id = str (job_id ),
167165 celery_task_id = self .request .id ,
168- celery_task_type = WORKER_TASK_TYPE ,
166+ celery_task_type = self . name ,
169167 result_type = TaskResult .ResultType .ERROR ,
170168 output_esdl = "" ,
171169 logs = logs ,
@@ -181,7 +179,7 @@ def after_return(
181179 result_message = TaskResult (
182180 job_id = str (job_id ),
183181 celery_task_id = self .request .id ,
184- celery_task_type = WORKER_TASK_TYPE ,
182+ celery_task_type = self . name ,
185183 result_type = TaskResult .ResultType .SUCCEEDED ,
186184 output_esdl = self .output_esdl ,
187185 logs = logs ,
@@ -258,17 +256,21 @@ def wrapped_worker_task(
258256 :param params_dict: job, non-ESDL, parameters.
259257 """
260258 logger .info ("Worker started new task %s with reference %s" , job_id , job_reference )
261- task_util = TaskUtil (job_id , task , task .broker_if )
259+ task_util = TaskUtil (
260+ job_id ,
261+ task ,
262+ task .broker_if ,
263+ )
262264 task_util .send_start ()
263265 output_esdl , esdl_messages = WORKER_TASK_FUNCTION (
264- input_esdl , params_dict , task_util .update_progress
266+ input_esdl , params_dict , task_util .update_progress , task . name
265267 )
266268
267269 if output_esdl :
268270 input_esh = pyesdl_from_string (input_esdl )
269271 input_energy_system : EnergySystem = input_esh .energy_system
270272 if job_reference is None :
271- new_name = f"{ input_energy_system .name } _{ WORKER_TASK_TYPE } "
273+ new_name = f"{ input_energy_system .name } _{ task . name } "
272274 elif job_reference == "" :
273275 new_name = f"{ input_energy_system .name } "
274276 else :
@@ -316,9 +318,21 @@ def start(self) -> None:
316318 )
317319
318320 # Config of celery app
319- self .celery_app .conf .task_queues = [KombuQueue (
320- WORKER_TASK_TYPE , routing_key = WORKER_TASK_TYPE , queue_arguments = {"x-max-priority" : 10 }
321- )] # Tell the worker to listen to a specific queue for 1 workflow type.
321+ queues = []
322+ for worker_task_type in WORKER_TASK_TYPES :
323+ logger .info ("Starting Worker to work on task %s" , worker_task_type )
324+ queues .append (
325+ KombuQueue (
326+ worker_task_type ,
327+ routing_key = worker_task_type ,
328+ queue_arguments = {"x-max-priority" : 10 },
329+ )
330+ )
331+ self .celery_app .task (
332+ wrapped_worker_task , base = WorkerTask , name = worker_task_type , bind = True
333+ )
334+
335+ self .celery_app .conf .task_queues = queues
322336 self .celery_app .conf .task_acks_late = True
323337 self .celery_app .conf .task_reject_on_worker_lost = True
324338 self .celery_app .conf .task_acks_on_failure_or_timeout = False
@@ -331,9 +345,6 @@ def start(self) -> None:
331345 self .celery_app .conf .worker_hijack_root_logger = False
332346 self .celery_app .conf .worker_redirect_stdouts = False
333347
334- self .celery_app .task (wrapped_worker_task , base = WorkerTask , name = WORKER_TASK_TYPE , bind = True )
335-
336- logger .info ("Starting Worker to work on task %s" , WORKER_TASK_TYPE )
337348 logger .info (
338349 "Connected to broker rabbitmq (%s:%s/%s) as %s" ,
339350 rabbitmq_config .host ,
@@ -343,7 +354,7 @@ def start(self) -> None:
343354 )
344355
345356 self .celery_worker = self .celery_app .Worker (
346- hostname = f"worker-{ WORKER_TASK_TYPE } @{ socket .gethostname ()} " ,
357+ hostname = f"worker-{ '_' . join ( WORKER_TASK_TYPES ) } @{ socket .gethostname ()} " ,
347358 loglevel = logging .getLevelName (self .config .log_level ),
348359 autoscale = (1 , 1 ),
349360 )
@@ -353,7 +364,7 @@ def start(self) -> None:
353364
354365UpdateProgressHandler = Callable [[float , str ], None ]
355366WorkerTaskF = Callable [
356- [str , ProtobufDict , UpdateProgressHandler ],
367+ [str , ProtobufDict , UpdateProgressHandler , str ],
357368 Tuple [
358369 Optional [str ],
359370 List [EsdlMessage ],
@@ -362,21 +373,25 @@ def start(self) -> None:
362373
363374WORKER : Worker = None # type: ignore [assignment] # noqa
364375WORKER_TASK_FUNCTION : WorkerTaskF = None # type: ignore [assignment] # noqa
365- WORKER_TASK_TYPE : str = None # type: ignore [assignment] # noqa
376+ WORKER_TASK_TYPES : list [ str ] = None # type: ignore [assignment] # noqa
366377
367378
368379def initialize_worker (
369- task_type : str ,
380+ task_types : list [ str ] ,
370381 task_function : WorkerTaskF ,
371382) -> None :
372383 """Initialize and run the `Worker`.
373384
374- :param task_type : Technical name of the task . Needs to be equal to the name of the celery task
375- to which the orchestrator forwards the task.
385+ :param task_types : Technical name of the tasks . Needs to be equal to the name of the celery task
386+ to which the orchestrator forwards the task. May connect to one or more tasks.
376387 :param task_function: Function which performs the Celery task.
377388 """
378- global WORKER_TASK_FUNCTION , WORKER_TASK_TYPE , WORKER
379- WORKER_TASK_TYPE = task_type
389+ global WORKER_TASK_FUNCTION , WORKER_TASK_TYPES , WORKER
390+ WORKER_TASK_TYPES = task_types
391+ if len (WORKER_TASK_TYPES ) < 1 :
392+ raise RuntimeError (
393+ f"Should connect to one or more worker task types. Only found { len (WORKER_TASK_TYPES )} "
394+ )
380395 WORKER_TASK_FUNCTION = task_function
381396 WORKER = Worker ()
382397 WORKER .start ()
0 commit comments