Skip to content

Commit

Permalink
python: Fixed taskhandler and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dbrasseur-aneo committed Sep 27, 2023
1 parent 77bf4d3 commit 4c404f4
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 213 deletions.
122 changes: 29 additions & 93 deletions packages/python/src/armonik/worker/taskhandler.py
Original file line number Diff line number Diff line change
@@ -1,83 +1,35 @@
from __future__ import annotations
import os
from typing import Optional, Dict, List, Tuple, Union, cast

from ..common import TaskOptions, TaskDefinition, Task
from ..protogen.common.agent_common_pb2 import Result, CreateTaskRequest, CreateResultsMetaDataRequest, CreateResultsMetaDataResponse
from ..protogen.common.objects_pb2 import TaskRequest, InitKeyedDataStream, DataChunk, InitTaskRequest, TaskRequestHeader, Configuration
from ..protogen.common.agent_common_pb2 import CreateTaskRequest, CreateResultsMetaDataRequest, CreateResultsMetaDataResponse, NotifyResultDataRequest
from ..protogen.common.objects_pb2 import TaskRequest, DataChunk, InitTaskRequest, TaskRequestHeader, Configuration
from ..protogen.worker.agent_service_pb2_grpc import AgentStub
from ..protogen.common.worker_common_pb2 import ProcessRequest


class TaskHandler:
def __init__(self, request_iterator, agent_client):
self.request_iterator = request_iterator
def __init__(self, request: ProcessRequest, agent_client: AgentStub):
self._client: AgentStub = agent_client
self.payload = bytearray()
self.session_id: Optional[str] = None
self.task_id: Optional[str] = None
self.task_options: Optional[TaskOptions] = None
self.token: Optional[str] = None
self.expected_results: List[str] = []
self.data_dependencies: Dict[str, bytearray] = {}
self.configuration: Optional[Configuration] = None

@classmethod
def create(cls, request_iterator, agent_client) -> "TaskHandler":
output = cls(request_iterator, agent_client)
output.init()
return output

def init(self):
current = next(self.request_iterator, None)
if current is None:
raise ValueError("Request stream ended unexpectedly")

if current.compute.WhichOneof("type") != "init_request":
raise ValueError("Expected a Compute request type with InitRequest to start the stream.")

init_request = current.compute.init_request
self.session_id = init_request.session_id
self.task_id = init_request.task_id
self.task_options = TaskOptions.from_message(init_request.task_options)
self.expected_results = list(init_request.expected_output_keys)
self.configuration = init_request.configuration
self.token = current.communication_token

datachunk = init_request.payload
self.payload.extend(datachunk.data)
while not datachunk.data_complete:
current = next(self.request_iterator, None)
if current is None:
raise ValueError("Request stream ended unexpectedly")
if current.compute.WhichOneof("type") != "payload":
raise ValueError("Expected a Compute request type with Payload to continue the stream.")
datachunk = current.compute.payload
self.payload.extend(datachunk.data)

while True:
current = next(self.request_iterator, None)
if current is None:
raise ValueError("Request stream ended unexpectedly")
if current.compute.WhichOneof("type") != "init_data":
raise ValueError("Expected a Compute request type with InitData to continue the stream.")
init_data = current.compute.init_data
if not (init_data.key is None or init_data.key == ""):
chunk = bytearray()
while True:
current = next(self.request_iterator, None)
if current is None:
raise ValueError("Request stream ended unexpectedly")
if current.compute.WhichOneof("type") != "data":
raise ValueError("Expected a Compute request type with Data to continue the stream.")
datachunk = current.compute.data
if datachunk.WhichOneof("type") == "data":
chunk.extend(datachunk.data)
elif datachunk.WhichOneof("type") is None or datachunk.WhichOneof("type") == "":
raise ValueError("Expected a Compute request type with Datachunk Payload to continue the stream.")
elif datachunk.WhichOneof("type") == "data_complete":
break
self.data_dependencies[init_data.key] = chunk
else:
break
self.session_id: str = request.session_id
self.task_id: str = request.task_id
self.task_options: TaskOptions = TaskOptions.from_message(request.task_options)
self.token: str = request.communication_token
self.expected_results: List[str] = list(request.expected_output_keys)
self.configuration: Configuration = request.configuration
self.payload_id: str = request.payload_id
self.data_folder: str = request.data_folder

# TODO: Lazy load
with open(os.path.join(self.data_folder, self.payload_id), "rb") as f:
self.payload = f.read()

# TODO: Lazy load
self.data_dependencies: Dict[str, bytes] = {}
for dd in request.data_dependencies:
with open(os.path.join(self.data_folder, dd), "rb") as f:
self.data_dependencies[dd] = f.read()

def create_tasks(self, tasks: List[TaskDefinition], task_options: Optional[TaskOptions] = None) -> Tuple[List[Task], List[str]]:
"""Create new tasks for ArmoniK
Expand Down Expand Up @@ -122,29 +74,13 @@ def send_result(self, key: str, data: Union[bytes, bytearray]) -> None:
key: Result key
data: Result data
"""
def result_stream():
res = Result(communication_token=self.token, init=InitKeyedDataStream(key=key))
assert self.configuration is not None
yield res
start = 0
data_len = len(data)
while start < data_len:
chunksize = min(self.configuration.data_chunk_max_size, data_len - start)
res = Result(communication_token=self.token, data=DataChunk(data=data[start:start + chunksize]))
yield res
start += chunksize
res = Result(communication_token=self.token, data=DataChunk(data_complete=True))
yield res
res = Result(communication_token=self.token, init=InitKeyedDataStream(last_result=True))
yield res

result_reply = self._client.SendResult(result_stream())
if result_reply.WhichOneof("type") == "error":
raise Exception(f"Cannot send result id={key}")

def get_results_ids(self, names : List[str]) -> Dict[str, str]:
return {r.name : r.result_id for r in cast(CreateResultsMetaDataResponse, self._client.CreateResultsMetaData(CreateResultsMetaDataRequest(results=[CreateResultsMetaDataRequest.ResultCreate(name = n) for n in names], session_id=self.session_id, communication_token=self.token))).results}
with open(os.path.join(self.data_folder, key), "wb") as f:
f.write(data)

self._client.NotifyResultData(NotifyResultDataRequest(ids=[NotifyResultDataRequest.ResultIdentifier(session_id=self.session_id, result_id=key)], communication_token=self.token))

def get_results_ids(self, names: List[str]) -> Dict[str, str]:
return {r.name: r.result_id for r in cast(CreateResultsMetaDataResponse, self._client.CreateResultsMetaData(CreateResultsMetaDataRequest(results=[CreateResultsMetaDataRequest.ResultCreate(name=n) for n in names], session_id=self.session_id, communication_token=self.token))).results}


def _to_request_stream_internal(request, communication_token, is_last, chunk_max_size):
Expand Down
8 changes: 4 additions & 4 deletions packages/python/src/armonik/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .seqlogger import ClefLogger
from ..common import Output, HealthCheckStatus
from ..protogen.common.objects_pb2 import Empty
from ..protogen.common.worker_common_pb2 import ProcessReply, HealthCheckReply
from ..protogen.common.worker_common_pb2 import ProcessReply, ProcessRequest, HealthCheckReply
from ..protogen.worker.agent_service_pb2_grpc import AgentStub
from ..protogen.worker.worker_service_pb2_grpc import WorkerServicer, add_WorkerServicer_to_server
from .taskhandler import TaskHandler
Expand Down Expand Up @@ -46,11 +46,11 @@ def start(self, endpoint: str):
server.start()
server.wait_for_termination()

def Process(self, request_iterator, context) -> Union[ProcessReply, None]:
def Process(self, request: ProcessRequest, context) -> Union[ProcessReply, None]:
try:
self._logger.debug("Received task")
task_handler = TaskHandler.create(request_iterator, self._client)
return ProcessReply(communication_token=task_handler.token, output=self.processing_function(task_handler).to_message())
task_handler = TaskHandler(request, self._client)
return ProcessReply(output=self.processing_function(task_handler).to_message())
except Exception as e:
self._logger.exception(f"Failed task {''.join(traceback.format_exception(type(e) ,e, e.__traceback__))}", exc_info=e)

Expand Down
143 changes: 32 additions & 111 deletions packages/python/tests/taskhandler_test.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,35 @@
#!/usr/bin/env python3
import os

import pytest
from typing import Iterator
from .common import DummyChannel

from armonik.common import TaskDefinition

from .common import DummyChannel
from armonik.worker import TaskHandler
from armonik.protogen.worker.agent_service_pb2_grpc import AgentStub
from armonik.protogen.common.agent_common_pb2 import CreateTaskRequest, CreateTaskReply, Result, ResultReply
from armonik.protogen.common.agent_common_pb2 import CreateTaskRequest, CreateTaskReply, NotifyResultDataRequest, NotifyResultDataResponse
from armonik.protogen.common.worker_common_pb2 import ProcessRequest
from armonik.protogen.common.objects_pb2 import Configuration, DataChunk
from armonik.protogen.common.objects_pb2 import Configuration
import logging

logging.basicConfig()
logging.getLogger().setLevel(logging.INFO)

data_folder = os.getcwd()


@pytest.fixture(autouse=True, scope="session")
def setup_teardown():
with open(os.path.join(data_folder, "payloadid"), "wb") as f:
f.write("payload".encode())
with open(os.path.join(data_folder, "ddid"), "wb") as f:
f.write("dd".encode())
yield
os.remove(os.path.join(data_folder, "payloadid"))
os.remove(os.path.join(data_folder, "ddid"))


class DummyAgent(AgentStub):

Expand All @@ -29,94 +46,18 @@ def CreateTask(self, request_iterator: Iterator[CreateTaskRequest]) -> CreateTas
task_info=CreateTaskReply.TaskInfo(task_id="TaskId", expected_output_keys=["EOK"],
data_dependencies=["DD"]))]))

def SendResult(self, request_iterator: Iterator[Result]) -> ResultReply:
self.send_result_task_message = [r for r in request_iterator]
return ResultReply()


class Reqs:
InitData1 = ProcessRequest(communication_token="Token",
compute=ProcessRequest.ComputeRequest(
init_data=ProcessRequest.ComputeRequest.InitData(key="DataKey1")))
InitData2 = ProcessRequest(communication_token="Token",
compute=ProcessRequest.ComputeRequest(
init_data=ProcessRequest.ComputeRequest.InitData(key="DataKey2")))
LastDataTrue = ProcessRequest(communication_token="Token",
compute=ProcessRequest.ComputeRequest(
init_data=ProcessRequest.ComputeRequest.InitData(last_data=True)))
LastDataFalse = ProcessRequest(communication_token="Token",
compute=ProcessRequest.ComputeRequest(
init_data=ProcessRequest.ComputeRequest.InitData(last_data=False)))
InitRequestPayload = ProcessRequest(communication_token="Token",
compute=ProcessRequest.ComputeRequest(
init_request=ProcessRequest.ComputeRequest.InitRequest(
payload=DataChunk(data="test".encode("utf-8")),
configuration=Configuration(data_chunk_max_size=100),
expected_output_keys=["EOK"], session_id="SessionId",
task_id="TaskId")))
InitRequestEmptyPayload = ProcessRequest(communication_token="Token",
compute=ProcessRequest.ComputeRequest(
init_request=ProcessRequest.ComputeRequest.InitRequest(
configuration=Configuration(data_chunk_max_size=100),
expected_output_keys=["EOK"], session_id="SessionId",
task_id="TaskId")))
Payload1 = ProcessRequest(communication_token="Token",
compute=ProcessRequest.ComputeRequest(
payload=DataChunk(data="Payload1".encode("utf-8"))))
Payload2 = ProcessRequest(communication_token="Token",
compute=ProcessRequest.ComputeRequest(
payload=DataChunk(data="Payload2".encode("utf-8"))))
PayloadComplete = ProcessRequest(communication_token="Token",
compute=ProcessRequest.ComputeRequest(payload=DataChunk(data_complete=True)))
DataComplete = ProcessRequest(communication_token="Token",
compute=ProcessRequest.ComputeRequest(data=DataChunk(data_complete=True)))
Data1 = ProcessRequest(communication_token="Token",
compute=ProcessRequest.ComputeRequest(
data=DataChunk(data="Data1".encode("utf-8"))))
Data2 = ProcessRequest(communication_token="Token",
compute=ProcessRequest.ComputeRequest(
data=DataChunk(data="Data2".encode("utf-8"))))


should_throw_cases = [
[],
[Reqs.InitData1],
[Reqs.InitData2],
[Reqs.LastDataTrue],
[Reqs.LastDataFalse],
[Reqs.InitRequestPayload],
[Reqs.DataComplete],
[Reqs.InitRequestEmptyPayload],
[Reqs.InitRequestPayload, Reqs.PayloadComplete, Reqs.InitData1, Reqs.Data1, Reqs.LastDataTrue],
[Reqs.InitRequestPayload, Reqs.InitData1, Reqs.Data1, Reqs.DataComplete, Reqs.LastDataTrue],
[Reqs.InitRequestPayload, Reqs.PayloadComplete, Reqs.Data1, Reqs.DataComplete, Reqs.LastDataTrue],
]

should_succeed_cases = [
[Reqs.InitRequestPayload, Reqs.Payload1, Reqs.Payload2, Reqs.PayloadComplete, Reqs.InitData1, Reqs.Data1,
Reqs.Data2, Reqs.DataComplete, Reqs.InitData2, Reqs.Data1, Reqs.Data2, Reqs.Data2, Reqs.Data2,
Reqs.DataComplete, Reqs.LastDataTrue],
[Reqs.InitRequestPayload, Reqs.Payload1, Reqs.PayloadComplete, Reqs.InitData1, Reqs.Data1, Reqs.DataComplete,
Reqs.LastDataTrue],
[Reqs.InitRequestPayload, Reqs.PayloadComplete, Reqs.InitData1, Reqs.Data1, Reqs.DataComplete, Reqs.LastDataTrue],
]


def get_cases(list_requests):
for r in list_requests:
yield iter(r)


@pytest.mark.parametrize("requests", get_cases(should_throw_cases))
def test_taskhandler_create_should_throw(requests: Iterator[ProcessRequest]):
with pytest.raises(ValueError):
TaskHandler.create(requests, DummyAgent(DummyChannel()))


@pytest.mark.parametrize("requests", get_cases(should_succeed_cases))
def test_taskhandler_create_should_succeed(requests: Iterator[ProcessRequest]):
def NotifyResultData(self, request: NotifyResultDataRequest) -> NotifyResultDataResponse:
self.send_result_task_message.append(request)
return NotifyResultDataResponse(result_ids=[i.result_id for i in request.ids])


should_succeed_case = ProcessRequest(communication_token="token", session_id="sessionid", task_id="taskid", expected_output_keys=["resultid"], payload_id="payloadid", data_dependencies=["ddid"], data_folder=data_folder, configuration=Configuration(data_chunk_max_size=8000))


@pytest.mark.parametrize("requests", [should_succeed_case])
def test_taskhandler_create_should_succeed(requests: ProcessRequest):
agent = DummyAgent(DummyChannel())
task_handler = TaskHandler.create(requests, agent)
task_handler = TaskHandler(requests, agent)
assert task_handler.token is not None and len(task_handler.token) > 0
assert len(task_handler.payload) > 0
assert task_handler.session_id is not None and len(task_handler.session_id) > 0
Expand All @@ -125,28 +66,8 @@ def test_taskhandler_create_should_succeed(requests: Iterator[ProcessRequest]):

def test_taskhandler_data_are_correct():
agent = DummyAgent(DummyChannel())
task_handler = TaskHandler.create(iter(should_succeed_cases[0]), agent)
task_handler = TaskHandler(should_succeed_case, agent)
assert len(task_handler.payload) > 0
assert task_handler.payload.decode('utf-8') == "testPayload1Payload2"
assert len(task_handler.data_dependencies) == 2
assert task_handler.data_dependencies["DataKey1"].decode('utf-8') == "Data1Data2"
assert task_handler.data_dependencies["DataKey2"].decode('utf-8') == "Data1Data2Data2Data2"
assert task_handler.task_id == "TaskId"
assert task_handler.session_id == "SessionId"
assert task_handler.token == "Token"

task_handler.send_result("test", "TestData".encode("utf-8"))

results = agent.send_result_task_message
assert len(results) == 4
assert results[0].WhichOneof("type") == "init"
assert results[0].init.key == "test"
assert results[1].WhichOneof("type") == "data"
assert results[1].data.data == "TestData".encode("utf-8")
assert results[2].WhichOneof("type") == "data"
assert results[2].data.data_complete
assert results[3].WhichOneof("type") == "init"
assert results[3].init.last_result

task_handler.create_tasks([TaskDefinition("Payload".encode("utf-8"), ["EOK"], ["DD"])])

Expand Down
Loading

0 comments on commit 4c404f4

Please sign in to comment.