diff --git a/packages/python/src/armonik/worker/taskhandler.py b/packages/python/src/armonik/worker/taskhandler.py index 9d8ade962..49eeb8ff3 100644 --- a/packages/python/src/armonik/worker/taskhandler.py +++ b/packages/python/src/armonik/worker/taskhandler.py @@ -1,83 +1,35 @@ from __future__ import annotations +import os from typing import Optional, Dict, List, Tuple, Union, cast from ..common import TaskOptions, TaskDefinition, Task -from ..protogen.common.agent_common_pb2 import Result, CreateTaskRequest, CreateResultsMetaDataRequest, CreateResultsMetaDataResponse -from ..protogen.common.objects_pb2 import TaskRequest, InitKeyedDataStream, DataChunk, InitTaskRequest, TaskRequestHeader, Configuration +from ..protogen.common.agent_common_pb2 import CreateTaskRequest, CreateResultsMetaDataRequest, CreateResultsMetaDataResponse, NotifyResultDataRequest +from ..protogen.common.objects_pb2 import TaskRequest, DataChunk, InitTaskRequest, TaskRequestHeader, Configuration from ..protogen.worker.agent_service_pb2_grpc import AgentStub +from ..protogen.common.worker_common_pb2 import ProcessRequest class TaskHandler: - def __init__(self, request_iterator, agent_client): - self.request_iterator = request_iterator + def __init__(self, request: ProcessRequest, agent_client: AgentStub): self._client: AgentStub = agent_client - self.payload = bytearray() - self.session_id: Optional[str] = None - self.task_id: Optional[str] = None - self.task_options: Optional[TaskOptions] = None - self.token: Optional[str] = None - self.expected_results: List[str] = [] - self.data_dependencies: Dict[str, bytearray] = {} - self.configuration: Optional[Configuration] = None - - @classmethod - def create(cls, request_iterator, agent_client) -> "TaskHandler": - output = cls(request_iterator, agent_client) - output.init() - return output - - def init(self): - current = next(self.request_iterator, None) - if current is None: - raise ValueError("Request stream ended unexpectedly") - - if current.compute.WhichOneof("type") != "init_request": - raise ValueError("Expected a Compute request type with InitRequest to start the stream.") - - init_request = current.compute.init_request - self.session_id = init_request.session_id - self.task_id = init_request.task_id - self.task_options = TaskOptions.from_message(init_request.task_options) - self.expected_results = list(init_request.expected_output_keys) - self.configuration = init_request.configuration - self.token = current.communication_token - - datachunk = init_request.payload - self.payload.extend(datachunk.data) - while not datachunk.data_complete: - current = next(self.request_iterator, None) - if current is None: - raise ValueError("Request stream ended unexpectedly") - if current.compute.WhichOneof("type") != "payload": - raise ValueError("Expected a Compute request type with Payload to continue the stream.") - datachunk = current.compute.payload - self.payload.extend(datachunk.data) - - while True: - current = next(self.request_iterator, None) - if current is None: - raise ValueError("Request stream ended unexpectedly") - if current.compute.WhichOneof("type") != "init_data": - raise ValueError("Expected a Compute request type with InitData to continue the stream.") - init_data = current.compute.init_data - if not (init_data.key is None or init_data.key == ""): - chunk = bytearray() - while True: - current = next(self.request_iterator, None) - if current is None: - raise ValueError("Request stream ended unexpectedly") - if current.compute.WhichOneof("type") != "data": - raise ValueError("Expected a Compute request type with Data to continue the stream.") - datachunk = current.compute.data - if datachunk.WhichOneof("type") == "data": - chunk.extend(datachunk.data) - elif datachunk.WhichOneof("type") is None or datachunk.WhichOneof("type") == "": - raise ValueError("Expected a Compute request type with Datachunk Payload to continue the stream.") - elif datachunk.WhichOneof("type") == "data_complete": - break - self.data_dependencies[init_data.key] = chunk - else: - break + self.session_id: str = request.session_id + self.task_id: str = request.task_id + self.task_options: TaskOptions = TaskOptions.from_message(request.task_options) + self.token: str = request.communication_token + self.expected_results: List[str] = list(request.expected_output_keys) + self.configuration: Configuration = request.configuration + self.payload_id: str = request.payload_id + self.data_folder: str = request.data_folder + + # TODO: Lazy load + with open(os.path.join(self.data_folder, self.payload_id), "rb") as f: + self.payload = f.read() + + # TODO: Lazy load + self.data_dependencies: Dict[str, bytes] = {} + for dd in request.data_dependencies: + with open(os.path.join(self.data_folder, dd), "rb") as f: + self.data_dependencies[dd] = f.read() def create_tasks(self, tasks: List[TaskDefinition], task_options: Optional[TaskOptions] = None) -> Tuple[List[Task], List[str]]: """Create new tasks for ArmoniK @@ -122,29 +74,13 @@ def send_result(self, key: str, data: Union[bytes, bytearray]) -> None: key: Result key data: Result data """ - def result_stream(): - res = Result(communication_token=self.token, init=InitKeyedDataStream(key=key)) - assert self.configuration is not None - yield res - start = 0 - data_len = len(data) - while start < data_len: - chunksize = min(self.configuration.data_chunk_max_size, data_len - start) - res = Result(communication_token=self.token, data=DataChunk(data=data[start:start + chunksize])) - yield res - start += chunksize - res = Result(communication_token=self.token, data=DataChunk(data_complete=True)) - yield res - res = Result(communication_token=self.token, init=InitKeyedDataStream(last_result=True)) - yield res - - result_reply = self._client.SendResult(result_stream()) - if result_reply.WhichOneof("type") == "error": - raise Exception(f"Cannot send result id={key}") - - def get_results_ids(self, names : List[str]) -> Dict[str, str]: - return {r.name : r.result_id for r in cast(CreateResultsMetaDataResponse, self._client.CreateResultsMetaData(CreateResultsMetaDataRequest(results=[CreateResultsMetaDataRequest.ResultCreate(name = n) for n in names], session_id=self.session_id, communication_token=self.token))).results} + with open(os.path.join(self.data_folder, key), "wb") as f: + f.write(data) + self._client.NotifyResultData(NotifyResultDataRequest(ids=[NotifyResultDataRequest.ResultIdentifier(session_id=self.session_id, result_id=key)], communication_token=self.token)) + + def get_results_ids(self, names: List[str]) -> Dict[str, str]: + return {r.name: r.result_id for r in cast(CreateResultsMetaDataResponse, self._client.CreateResultsMetaData(CreateResultsMetaDataRequest(results=[CreateResultsMetaDataRequest.ResultCreate(name=n) for n in names], session_id=self.session_id, communication_token=self.token))).results} def _to_request_stream_internal(request, communication_token, is_last, chunk_max_size): diff --git a/packages/python/src/armonik/worker/worker.py b/packages/python/src/armonik/worker/worker.py index 19db04f38..2de34aad1 100644 --- a/packages/python/src/armonik/worker/worker.py +++ b/packages/python/src/armonik/worker/worker.py @@ -9,7 +9,7 @@ from .seqlogger import ClefLogger from ..common import Output, HealthCheckStatus from ..protogen.common.objects_pb2 import Empty -from ..protogen.common.worker_common_pb2 import ProcessReply, HealthCheckReply +from ..protogen.common.worker_common_pb2 import ProcessReply, ProcessRequest, HealthCheckReply from ..protogen.worker.agent_service_pb2_grpc import AgentStub from ..protogen.worker.worker_service_pb2_grpc import WorkerServicer, add_WorkerServicer_to_server from .taskhandler import TaskHandler @@ -46,11 +46,11 @@ def start(self, endpoint: str): server.start() server.wait_for_termination() - def Process(self, request_iterator, context) -> Union[ProcessReply, None]: + def Process(self, request: ProcessRequest, context) -> Union[ProcessReply, None]: try: self._logger.debug("Received task") - task_handler = TaskHandler.create(request_iterator, self._client) - return ProcessReply(communication_token=task_handler.token, output=self.processing_function(task_handler).to_message()) + task_handler = TaskHandler(request, self._client) + return ProcessReply(output=self.processing_function(task_handler).to_message()) except Exception as e: self._logger.exception(f"Failed task {''.join(traceback.format_exception(type(e) ,e, e.__traceback__))}", exc_info=e) diff --git a/packages/python/tests/taskhandler_test.py b/packages/python/tests/taskhandler_test.py index 88cbc3809..e4f3c181c 100644 --- a/packages/python/tests/taskhandler_test.py +++ b/packages/python/tests/taskhandler_test.py @@ -1,18 +1,35 @@ #!/usr/bin/env python3 +import os + import pytest from typing import Iterator -from .common import DummyChannel + from armonik.common import TaskDefinition + +from .common import DummyChannel from armonik.worker import TaskHandler from armonik.protogen.worker.agent_service_pb2_grpc import AgentStub -from armonik.protogen.common.agent_common_pb2 import CreateTaskRequest, CreateTaskReply, Result, ResultReply +from armonik.protogen.common.agent_common_pb2 import CreateTaskRequest, CreateTaskReply, NotifyResultDataRequest, NotifyResultDataResponse from armonik.protogen.common.worker_common_pb2 import ProcessRequest -from armonik.protogen.common.objects_pb2 import Configuration, DataChunk +from armonik.protogen.common.objects_pb2 import Configuration import logging logging.basicConfig() logging.getLogger().setLevel(logging.INFO) +data_folder = os.getcwd() + + +@pytest.fixture(autouse=True, scope="session") +def setup_teardown(): + with open(os.path.join(data_folder, "payloadid"), "wb") as f: + f.write("payload".encode()) + with open(os.path.join(data_folder, "ddid"), "wb") as f: + f.write("dd".encode()) + yield + os.remove(os.path.join(data_folder, "payloadid")) + os.remove(os.path.join(data_folder, "ddid")) + class DummyAgent(AgentStub): @@ -29,94 +46,18 @@ def CreateTask(self, request_iterator: Iterator[CreateTaskRequest]) -> CreateTas task_info=CreateTaskReply.TaskInfo(task_id="TaskId", expected_output_keys=["EOK"], data_dependencies=["DD"]))])) - def SendResult(self, request_iterator: Iterator[Result]) -> ResultReply: - self.send_result_task_message = [r for r in request_iterator] - return ResultReply() - - -class Reqs: - InitData1 = ProcessRequest(communication_token="Token", - compute=ProcessRequest.ComputeRequest( - init_data=ProcessRequest.ComputeRequest.InitData(key="DataKey1"))) - InitData2 = ProcessRequest(communication_token="Token", - compute=ProcessRequest.ComputeRequest( - init_data=ProcessRequest.ComputeRequest.InitData(key="DataKey2"))) - LastDataTrue = ProcessRequest(communication_token="Token", - compute=ProcessRequest.ComputeRequest( - init_data=ProcessRequest.ComputeRequest.InitData(last_data=True))) - LastDataFalse = ProcessRequest(communication_token="Token", - compute=ProcessRequest.ComputeRequest( - init_data=ProcessRequest.ComputeRequest.InitData(last_data=False))) - InitRequestPayload = ProcessRequest(communication_token="Token", - compute=ProcessRequest.ComputeRequest( - init_request=ProcessRequest.ComputeRequest.InitRequest( - payload=DataChunk(data="test".encode("utf-8")), - configuration=Configuration(data_chunk_max_size=100), - expected_output_keys=["EOK"], session_id="SessionId", - task_id="TaskId"))) - InitRequestEmptyPayload = ProcessRequest(communication_token="Token", - compute=ProcessRequest.ComputeRequest( - init_request=ProcessRequest.ComputeRequest.InitRequest( - configuration=Configuration(data_chunk_max_size=100), - expected_output_keys=["EOK"], session_id="SessionId", - task_id="TaskId"))) - Payload1 = ProcessRequest(communication_token="Token", - compute=ProcessRequest.ComputeRequest( - payload=DataChunk(data="Payload1".encode("utf-8")))) - Payload2 = ProcessRequest(communication_token="Token", - compute=ProcessRequest.ComputeRequest( - payload=DataChunk(data="Payload2".encode("utf-8")))) - PayloadComplete = ProcessRequest(communication_token="Token", - compute=ProcessRequest.ComputeRequest(payload=DataChunk(data_complete=True))) - DataComplete = ProcessRequest(communication_token="Token", - compute=ProcessRequest.ComputeRequest(data=DataChunk(data_complete=True))) - Data1 = ProcessRequest(communication_token="Token", - compute=ProcessRequest.ComputeRequest( - data=DataChunk(data="Data1".encode("utf-8")))) - Data2 = ProcessRequest(communication_token="Token", - compute=ProcessRequest.ComputeRequest( - data=DataChunk(data="Data2".encode("utf-8")))) - - -should_throw_cases = [ - [], - [Reqs.InitData1], - [Reqs.InitData2], - [Reqs.LastDataTrue], - [Reqs.LastDataFalse], - [Reqs.InitRequestPayload], - [Reqs.DataComplete], - [Reqs.InitRequestEmptyPayload], - [Reqs.InitRequestPayload, Reqs.PayloadComplete, Reqs.InitData1, Reqs.Data1, Reqs.LastDataTrue], - [Reqs.InitRequestPayload, Reqs.InitData1, Reqs.Data1, Reqs.DataComplete, Reqs.LastDataTrue], - [Reqs.InitRequestPayload, Reqs.PayloadComplete, Reqs.Data1, Reqs.DataComplete, Reqs.LastDataTrue], -] - -should_succeed_cases = [ - [Reqs.InitRequestPayload, Reqs.Payload1, Reqs.Payload2, Reqs.PayloadComplete, Reqs.InitData1, Reqs.Data1, - Reqs.Data2, Reqs.DataComplete, Reqs.InitData2, Reqs.Data1, Reqs.Data2, Reqs.Data2, Reqs.Data2, - Reqs.DataComplete, Reqs.LastDataTrue], - [Reqs.InitRequestPayload, Reqs.Payload1, Reqs.PayloadComplete, Reqs.InitData1, Reqs.Data1, Reqs.DataComplete, - Reqs.LastDataTrue], - [Reqs.InitRequestPayload, Reqs.PayloadComplete, Reqs.InitData1, Reqs.Data1, Reqs.DataComplete, Reqs.LastDataTrue], -] - - -def get_cases(list_requests): - for r in list_requests: - yield iter(r) - - -@pytest.mark.parametrize("requests", get_cases(should_throw_cases)) -def test_taskhandler_create_should_throw(requests: Iterator[ProcessRequest]): - with pytest.raises(ValueError): - TaskHandler.create(requests, DummyAgent(DummyChannel())) - - -@pytest.mark.parametrize("requests", get_cases(should_succeed_cases)) -def test_taskhandler_create_should_succeed(requests: Iterator[ProcessRequest]): + def NotifyResultData(self, request: NotifyResultDataRequest) -> NotifyResultDataResponse: + self.send_result_task_message.append(request) + return NotifyResultDataResponse(result_ids=[i.result_id for i in request.ids]) + + +should_succeed_case = ProcessRequest(communication_token="token", session_id="sessionid", task_id="taskid", expected_output_keys=["resultid"], payload_id="payloadid", data_dependencies=["ddid"], data_folder=data_folder, configuration=Configuration(data_chunk_max_size=8000)) + + +@pytest.mark.parametrize("requests", [should_succeed_case]) +def test_taskhandler_create_should_succeed(requests: ProcessRequest): agent = DummyAgent(DummyChannel()) - task_handler = TaskHandler.create(requests, agent) + task_handler = TaskHandler(requests, agent) assert task_handler.token is not None and len(task_handler.token) > 0 assert len(task_handler.payload) > 0 assert task_handler.session_id is not None and len(task_handler.session_id) > 0 @@ -125,28 +66,8 @@ def test_taskhandler_create_should_succeed(requests: Iterator[ProcessRequest]): def test_taskhandler_data_are_correct(): agent = DummyAgent(DummyChannel()) - task_handler = TaskHandler.create(iter(should_succeed_cases[0]), agent) + task_handler = TaskHandler(should_succeed_case, agent) assert len(task_handler.payload) > 0 - assert task_handler.payload.decode('utf-8') == "testPayload1Payload2" - assert len(task_handler.data_dependencies) == 2 - assert task_handler.data_dependencies["DataKey1"].decode('utf-8') == "Data1Data2" - assert task_handler.data_dependencies["DataKey2"].decode('utf-8') == "Data1Data2Data2Data2" - assert task_handler.task_id == "TaskId" - assert task_handler.session_id == "SessionId" - assert task_handler.token == "Token" - - task_handler.send_result("test", "TestData".encode("utf-8")) - - results = agent.send_result_task_message - assert len(results) == 4 - assert results[0].WhichOneof("type") == "init" - assert results[0].init.key == "test" - assert results[1].WhichOneof("type") == "data" - assert results[1].data.data == "TestData".encode("utf-8") - assert results[2].WhichOneof("type") == "data" - assert results[2].data.data_complete - assert results[3].WhichOneof("type") == "init" - assert results[3].init.last_result task_handler.create_tasks([TaskDefinition("Payload".encode("utf-8"), ["EOK"], ["DD"])]) diff --git a/packages/python/tests/worker_test.py b/packages/python/tests/worker_test.py index 1e4604a31..032c406ee 100644 --- a/packages/python/tests/worker_test.py +++ b/packages/python/tests/worker_test.py @@ -1,9 +1,11 @@ #!/usr/bin/env python3 import logging - +import os +import pytest from armonik.worker import ArmoniKWorker, TaskHandler, ClefLogger from armonik.common import Output -from .taskhandler_test import should_succeed_cases +from .taskhandler_test import should_succeed_case, data_folder, DummyAgent +from .common import DummyChannel from armonik.protogen.common.objects_pb2 import Empty import grpc @@ -20,10 +22,22 @@ def return_error(_: TaskHandler) -> Output: return Output("TestError") +def return_and_send(th: TaskHandler) -> Output: + th.send_result(th.expected_results[0], b"result") + return Output() + + +@pytest.fixture(autouse=True, scope="function") +def remove_result(): + yield + if os.path.exists(os.path.join(data_folder, "resultid")): + os.remove(os.path.join(data_folder, "resultid")) + + def test_do_nothing_worker(): with grpc.insecure_channel("unix:///tmp/agent.sock") as agent_channel: worker = ArmoniKWorker(agent_channel, do_nothing, logger=ClefLogger("TestLogger", level=logging.CRITICAL)) - reply = worker.Process(iter(should_succeed_cases[0]), None) + reply = worker.Process(should_succeed_case, None) assert Output(reply.output.error.details if reply.output.WhichOneof("type") == "error" else None).success worker.HealthCheck(Empty(), None) @@ -31,15 +45,29 @@ def test_do_nothing_worker(): def test_worker_should_return_none(): with grpc.insecure_channel("unix:///tmp/agent.sock") as agent_channel: worker = ArmoniKWorker(agent_channel, throw_error, logger=ClefLogger("TestLogger", level=logging.CRITICAL)) - reply = worker.Process(iter(should_succeed_cases[0]), None) + reply = worker.Process(should_succeed_case, None) assert reply is None def test_worker_should_error(): with grpc.insecure_channel("unix:///tmp/agent.sock") as agent_channel: worker = ArmoniKWorker(agent_channel, return_error, logger=ClefLogger("TestLogger", level=logging.CRITICAL)) - reply = worker.Process(iter(should_succeed_cases[0]), None) + reply = worker.Process(should_succeed_case, None) output = Output(reply.output.error.details if reply.output.WhichOneof("type") == "error" else None) assert not output.success assert output.error == "TestError" + +def test_worker_should_write_result(): + with grpc.insecure_channel("unix:///tmp/agent.sock") as agent_channel: + worker = ArmoniKWorker(agent_channel, return_and_send, logger=ClefLogger("TestLogger", level=logging.DEBUG)) + worker._client = DummyAgent(DummyChannel()) + reply = worker.Process(should_succeed_case, None) + assert reply is not None + output = Output(reply.output.error.details if reply.output.WhichOneof("type") == "error" else None) + assert output.success + assert os.path.exists(os.path.join(data_folder, should_succeed_case.expected_output_keys[0])) + with open(os.path.join(data_folder, should_succeed_case.expected_output_keys[0]), "rb") as f: + value = f.read() + assert len(value) > 0 +