diff --git a/packages/cpp/ArmoniK.Api.Common/header/utils/string_utils.h b/packages/cpp/ArmoniK.Api.Common/header/utils/string_utils.h new file mode 100644 index 000000000..607b4131b --- /dev/null +++ b/packages/cpp/ArmoniK.Api.Common/header/utils/string_utils.h @@ -0,0 +1,62 @@ +#pragma once + +#include +#include +#include +#include + +namespace armonik { +namespace api { +namespace common { +namespace utils { +// trim from start (in place) +static inline void ltrim(std::string &s) { + s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) { return !std::isspace(ch); })); +} + +// trim from end (in place) +static inline void rtrim(std::string &s) { + s.erase(std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) { return !std::isspace(ch); }).base(), s.end()); +} + +// trim from both ends (in place) +static inline void trim(std::string &s) { + rtrim(s); + ltrim(s); +} + +// trim from start (copying) +static inline std::string ltrim_copy(std::string s) { + ltrim(s); + return s; +} + +// trim from end (copying) +static inline std::string rtrim_copy(std::string s) { + rtrim(s); + return s; +} + +// trim from both ends (copying) +static inline std::string trim_copy(std::string s) { + trim(s); + return s; +} + +inline std::string pathJoin(const std::string &p1, const std::string &p2) { +#ifdef _WIN32 + constexpr char sep = '\\'; +#else + constexpr char sep = '/'; +#endif + std::string tmp = trim_copy(p1); + + if (tmp[tmp.length() - 1] != sep) { + tmp += sep; + } + return tmp + trim_copy(p2); +} +} // namespace utils +} // namespace common +} // namespace api +} // namespace armonik diff --git a/packages/cpp/ArmoniK.Api.Worker.Tests/source/main.cpp b/packages/cpp/ArmoniK.Api.Worker.Tests/source/main.cpp index 39558d0e7..a64e51227 100644 --- a/packages/cpp/ArmoniK.Api.Worker.Tests/source/main.cpp +++ b/packages/cpp/ArmoniK.Api.Worker.Tests/source/main.cpp @@ -36,10 +36,7 @@ class TestWorker : public armonik::api::worker::ArmoniKWorker { try { if (!taskHandler.getExpectedResults().empty()) { - auto res = taskHandler.send_result(taskHandler.getExpectedResults()[0], taskHandler.getPayload()).get(); - if (res.has_error()) { - throw armonik::api::common::exceptions::ArmoniKApiException(res.error()); - } + taskHandler.send_result(taskHandler.getExpectedResults()[0], taskHandler.getPayload()).get(); } } catch (const std::exception &e) { diff --git a/packages/cpp/ArmoniK.Api.Worker/header/Worker/ArmoniKWorker.h b/packages/cpp/ArmoniK.Api.Worker/header/Worker/ArmoniKWorker.h index ff06921bd..18e9d1893 100644 --- a/packages/cpp/ArmoniK.Api.Worker/header/Worker/ArmoniKWorker.h +++ b/packages/cpp/ArmoniK.Api.Worker/header/Worker/ArmoniKWorker.h @@ -35,15 +35,14 @@ class ArmoniKWorker : public armonik::api::grpc::v1::worker::Worker::Service { * @brief Implements the Process method of the Worker service. * * @param context The ServerContext object. - * @param reader The request iterator + * @param request The Process request * @param response The ProcessReply object. * * @return The status of the method. */ - [[maybe_unused]] ::grpc::Status - Process(::grpc::ServerContext *context, - ::grpc::ServerReader<::armonik::api::grpc::v1::worker::ProcessRequest> *reader, - ::armonik::api::grpc::v1::worker::ProcessReply *response) override; + ::grpc::Status Process(::grpc::ServerContext *context, + const ::armonik::api::grpc::v1::worker::ProcessRequest *request, + ::armonik::api::grpc::v1::worker::ProcessReply *response) override; /** * @brief Function which does the actual work diff --git a/packages/cpp/ArmoniK.Api.Worker/header/Worker/TaskHandler.h b/packages/cpp/ArmoniK.Api.Worker/header/Worker/TaskHandler.h index e50d753ae..886df1aa3 100644 --- a/packages/cpp/ArmoniK.Api.Worker/header/Worker/TaskHandler.h +++ b/packages/cpp/ArmoniK.Api.Worker/header/Worker/TaskHandler.h @@ -22,7 +22,7 @@ class TaskHandler { private: armonik::api::grpc::v1::agent::Agent::Stub &stub_; - ::grpc::ServerReader &request_iterator_; + const armonik::api::grpc::v1::worker::ProcessRequest &request_; std::string session_id_; std::string task_id_; armonik::api::grpc::v1::TaskOptions task_options_; @@ -31,22 +31,17 @@ class TaskHandler { std::map data_dependencies_; std::string token_; armonik::api::grpc::v1::Configuration config_; + std::string data_folder_; public: /** * @brief Construct a new Task Handler object * * @param client the agent client - * @param request_iterator The request iterator + * @param request The process request */ TaskHandler(armonik::api::grpc::v1::agent::Agent::Stub &client, - ::grpc::ServerReader &request_iterator); - - /** - * @brief Initialise the task handler - * - */ - void init(); + const armonik::api::grpc::v1::worker::ProcessRequest &request); /** * @brief Create a task_chunk_stream. @@ -89,7 +84,7 @@ class TaskHandler { * @param data The result data * @return A future containing a vector of ResultReply */ - std::future send_result(std::string key, absl::string_view data); + std::future send_result(std::string key, absl::string_view data); /** * @brief Get the result ids object diff --git a/packages/cpp/ArmoniK.Api.Worker/source/Worker/ArmoniKWorker.cpp b/packages/cpp/ArmoniK.Api.Worker/source/Worker/ArmoniKWorker.cpp index 411fb7892..5a4e5c90a 100644 --- a/packages/cpp/ArmoniK.Api.Worker/source/Worker/ArmoniKWorker.cpp +++ b/packages/cpp/ArmoniK.Api.Worker/source/Worker/ArmoniKWorker.cpp @@ -47,32 +47,37 @@ armonik::api::worker::ArmoniKWorker::ArmoniKWorker(std::unique_ptr *reader, +armonik::api::worker::ArmoniKWorker::Process(::grpc::ServerContext *context, + const ::armonik::api::grpc::v1::worker::ProcessRequest *request, ::armonik::api::grpc::v1::worker::ProcessReply *response) { - + (void)context; logger_.debug("Receive new request From C++ Worker"); - TaskHandler task_handler(*agent_, *reader); - - task_handler.init(); try { - ProcessStatus status = Execute(task_handler); - - logger_.debug("Finish call C++"); - - armonik::api::grpc::v1::Output output; - if (status.ok()) { - *output.mutable_ok() = armonik::api::grpc::v1::Empty(); - } else { - output.mutable_error()->set_details(std::move(status).details()); + TaskHandler task_handler(*agent_, *request); + try { + ProcessStatus status = Execute(task_handler); + + logger_.debug("Finish call C++"); + + armonik::api::grpc::v1::Output output; + if (status.ok()) { + *output.mutable_ok() = armonik::api::grpc::v1::Empty(); + } else { + output.mutable_error()->set_details(std::move(status).details()); + } + *response->mutable_output() = std::move(output); + } catch (const std::exception &e) { + logger_.error("Error processing task : {what}", {{"what", e.what()}}); + std::stringstream ss; + ss << "Error processing task : " << e.what(); + return {::grpc::StatusCode::UNAVAILABLE, ss.str(), e.what()}; } - *response->mutable_output() = std::move(output); } catch (const std::exception &e) { - logger_.error("Error processing task : {what}", {{"what", e.what()}}); + logger_.error("Error in the request handling : {what}", {{"what", e.what()}}); std::stringstream ss; - ss << "Error processing task : " << e.what(); - return {::grpc::StatusCode::UNAVAILABLE, ss.str(), e.what()}; + ss << "Error in the request handling : " << e.what(); + return {::grpc::StatusCode::INVALID_ARGUMENT, ss.str(), e.what()}; } return ::grpc::Status::OK; diff --git a/packages/cpp/ArmoniK.Api.Worker/source/Worker/TaskHandler.cpp b/packages/cpp/ArmoniK.Api.Worker/source/Worker/TaskHandler.cpp index 8b2820d92..a1837f2ab 100644 --- a/packages/cpp/ArmoniK.Api.Worker/source/Worker/TaskHandler.cpp +++ b/packages/cpp/ArmoniK.Api.Worker/source/Worker/TaskHandler.cpp @@ -1,5 +1,7 @@ #include "Worker/TaskHandler.h" #include "exceptions/ArmoniKApiException.h" +#include "utils/string_utils.h" +#include #include #include #include @@ -28,89 +30,29 @@ using ::grpc::Status; * @param client the agent client * @param request_iterator The request iterator */ -armonik::api::worker::TaskHandler::TaskHandler(Agent::Stub &client, - ::grpc::ServerReader &request_iterator) - : stub_(client), request_iterator_(request_iterator) {} - -/** - * @brief Initialise the task handler - * - */ -void armonik::api::worker::TaskHandler::init() { - ProcessRequest Request; - if (!request_iterator_.Read(&Request)) { - throw std::runtime_error("Request stream ended unexpectedly."); +armonik::api::worker::TaskHandler::TaskHandler(Agent::Stub &client, const ProcessRequest &request) + : stub_(client), request_(request) { + token_ = request_.communication_token(); + session_id_ = request_.session_id(); + task_id_ = request_.task_id(); + task_options_ = request_.task_options(); + const std::string payload_id = request_.payload_id(); + data_folder_ = request_.data_folder(); + std::ostringstream string_stream(std::ios::binary); + string_stream + << std::ifstream(armonik::api::common::utils::pathJoin(data_folder_, payload_id), std::fstream::binary).rdbuf(); + payload_ = string_stream.str(); + string_stream.clear(); + config_ = request_.configuration(); + expected_result_.assign(request_.expected_output_keys().begin(), request_.expected_output_keys().end()); + + for (auto &&dd : request_.data_dependencies()) { + // TODO Replace with lazy loading via a custom std::map (to not break compatibility) + string_stream + << std::ifstream(armonik::api::common::utils::pathJoin(data_folder_, dd), std::fstream::binary).rdbuf(); + data_dependencies_[dd] = string_stream.str(); + string_stream.clear(); } - - if (Request.compute().type_case() != armonik::api::grpc::v1::worker::ProcessRequest_ComputeRequest::kInitRequest) { - throw std::runtime_error("Expected a Compute request type with InitRequest to start the stream."); - } - auto *init_request = Request.mutable_compute()->mutable_init_request(); - session_id_ = init_request->session_id(); - task_id_ = init_request->task_id(); - task_options_ = init_request->task_options(); - expected_result_.assign(std::make_move_iterator(init_request->mutable_expected_output_keys()->begin()), - std::make_move_iterator(init_request->mutable_expected_output_keys()->end())); - token_ = Request.communication_token(); - config_ = std::move(*init_request->mutable_configuration()); - - auto *datachunk = &init_request->payload(); - assert(payload_.empty()); - payload_.append(datachunk->data()); - - while (!datachunk->data_complete()) { - if (!request_iterator_.Read(&Request)) { - throw std::runtime_error("Request stream ended unexpectedly."); - } - if (Request.compute().type_case() != armonik::api::grpc::v1::worker::ProcessRequest_ComputeRequest::kPayload) { - throw std::runtime_error("Expected a Compute request type with Payload to continue the stream."); - } - - datachunk = &Request.compute().payload(); - if (datachunk->type_case() == armonik::api::grpc::v1::DataChunk::kData) { - payload_.append(datachunk->data()); - } else if (datachunk->type_case() == armonik::api::grpc::v1::DataChunk::TYPE_NOT_SET) { - throw std::runtime_error("Expected a Compute request type with a DataChunk Payload to continue the stream."); - } else if (datachunk->type_case() == armonik::api::grpc::v1::DataChunk::kDataComplete) { - break; - } - } - - armonik::api::grpc::v1::worker::ProcessRequest_ComputeRequest::InitData *init_data; - - do { - if (!request_iterator_.Read(&Request)) { - throw std::runtime_error("Request stream ended unexpectedly."); - } - if (Request.compute().type_case() != armonik::api::grpc::v1::worker::ProcessRequest_ComputeRequest::kInitData) { - throw std::runtime_error("Expected a Compute request type with InitData to continue the stream."); - } - - init_data = Request.mutable_compute()->mutable_init_data(); - if (init_data->type_case() == armonik::api::grpc::v1::worker::ProcessRequest_ComputeRequest_InitData::kKey) { - const std::string &key = init_data->key(); - std::string data_dep; - while (true) { - ProcessRequest dep_request; - if (!request_iterator_.Read(&dep_request)) { - throw std::runtime_error("Request stream ended unexpectedly."); - } - if (dep_request.compute().type_case() != armonik::api::grpc::v1::worker::ProcessRequest_ComputeRequest::kData) { - throw std::runtime_error("Expected a Compute request type with Data to continue the stream."); - } - - auto chunk = dep_request.compute().data(); - if (chunk.type_case() == armonik::api::grpc::v1::DataChunk::kData) { - data_dep.append(chunk.data()); - } else if (datachunk->type_case() == armonik::api::grpc::v1::DataChunk::TYPE_NOT_SET) { - throw std::runtime_error("Expected a Compute request type with a DataChunk Payload to continue the stream."); - } else if (datachunk->type_case() == armonik::api::grpc::v1::DataChunk::kDataComplete) { - break; - } - } - data_dependencies_[key] = data_dep; - } - } while (init_data->type_case() == armonik::api::grpc::v1::worker::ProcessRequest_ComputeRequest_InitData::kKey); } /** @@ -273,44 +215,24 @@ armonik::api::worker::TaskHandler::create_tasks_async(TaskOptions task_options, * @param data The result data * @return A future containing a vector of ResultReply */ -std::future -armonik::api::worker::TaskHandler::send_result(std::string key, absl::string_view data) { +std::future armonik::api::worker::TaskHandler::send_result(std::string key, absl::string_view data) { return std::async(std::launch::async, [this, key = std::move(key), data]() mutable { - ::grpc::ClientContext context_client_writer; - - armonik::api::grpc::v1::agent::ResultReply reply; - - size_t max_chunk = config_.data_chunk_max_size(); - const size_t data_size = data.size(); - size_t start = 0; + ::grpc::ClientContext context; - auto stream = stub_.SendResult(&context_client_writer, &reply); + std::ofstream output(armonik::api::common::utils::pathJoin(data_folder_, key), + std::fstream::binary | std::fstream::trunc); + output << data; + output.close(); - armonik::api::grpc::v1::agent::Result init_msg; - init_msg.mutable_init()->set_key(std::move(key)); - init_msg.set_communication_token(token_); + armonik::api::grpc::v1::agent::NotifyResultDataResponse reply; + armonik::api::grpc::v1::agent::NotifyResultDataRequest request; + request.set_communication_token(token_); + armonik::api::grpc::v1::agent::NotifyResultDataRequest::ResultIdentifier result_id; + result_id.set_session_id(session_id_); + result_id.set_result_id(key); + *(request.mutable_ids()->Add()) = result_id; - stream->Write(init_msg); - - while (start < data_size) { - size_t chunkSize = std::min(max_chunk, data_size - start); - - armonik::api::grpc::v1::agent::Result msg; - msg.set_communication_token(token_); - msg.mutable_data()->mutable_data()->assign(data.data() + start, chunkSize); - - stream->Write(msg); - - start += chunkSize; - } - - armonik::api::grpc::v1::agent::Result end_msg; - end_msg.set_communication_token(token_); - end_msg.mutable_data()->set_data_complete(true); - stream->Write(end_msg); - - stream->WritesDone(); - ::grpc::Status status = stream->Finish(); + auto status = stub_.NotifyResultData(&context, request, &reply); if (!status.ok()) { std::stringstream message; @@ -318,7 +240,10 @@ armonik::api::worker::TaskHandler::send_result(std::string key, absl::string_vie << ". details: " << status.error_details() << std::endl; throw armonik::api::common::exceptions::ArmoniKApiException(message.str()); } - return reply; + + if (reply.result_ids_size() != 1) { + throw armonik::api::common::exceptions::ArmoniKApiException("Received erroneous reply for send data"); + } }); } diff --git a/packages/python/src/armonik/worker/taskhandler.py b/packages/python/src/armonik/worker/taskhandler.py index 9d8ade962..49eeb8ff3 100644 --- a/packages/python/src/armonik/worker/taskhandler.py +++ b/packages/python/src/armonik/worker/taskhandler.py @@ -1,83 +1,35 @@ from __future__ import annotations +import os from typing import Optional, Dict, List, Tuple, Union, cast from ..common import TaskOptions, TaskDefinition, Task -from ..protogen.common.agent_common_pb2 import Result, CreateTaskRequest, CreateResultsMetaDataRequest, CreateResultsMetaDataResponse -from ..protogen.common.objects_pb2 import TaskRequest, InitKeyedDataStream, DataChunk, InitTaskRequest, TaskRequestHeader, Configuration +from ..protogen.common.agent_common_pb2 import CreateTaskRequest, CreateResultsMetaDataRequest, CreateResultsMetaDataResponse, NotifyResultDataRequest +from ..protogen.common.objects_pb2 import TaskRequest, DataChunk, InitTaskRequest, TaskRequestHeader, Configuration from ..protogen.worker.agent_service_pb2_grpc import AgentStub +from ..protogen.common.worker_common_pb2 import ProcessRequest class TaskHandler: - def __init__(self, request_iterator, agent_client): - self.request_iterator = request_iterator + def __init__(self, request: ProcessRequest, agent_client: AgentStub): self._client: AgentStub = agent_client - self.payload = bytearray() - self.session_id: Optional[str] = None - self.task_id: Optional[str] = None - self.task_options: Optional[TaskOptions] = None - self.token: Optional[str] = None - self.expected_results: List[str] = [] - self.data_dependencies: Dict[str, bytearray] = {} - self.configuration: Optional[Configuration] = None - - @classmethod - def create(cls, request_iterator, agent_client) -> "TaskHandler": - output = cls(request_iterator, agent_client) - output.init() - return output - - def init(self): - current = next(self.request_iterator, None) - if current is None: - raise ValueError("Request stream ended unexpectedly") - - if current.compute.WhichOneof("type") != "init_request": - raise ValueError("Expected a Compute request type with InitRequest to start the stream.") - - init_request = current.compute.init_request - self.session_id = init_request.session_id - self.task_id = init_request.task_id - self.task_options = TaskOptions.from_message(init_request.task_options) - self.expected_results = list(init_request.expected_output_keys) - self.configuration = init_request.configuration - self.token = current.communication_token - - datachunk = init_request.payload - self.payload.extend(datachunk.data) - while not datachunk.data_complete: - current = next(self.request_iterator, None) - if current is None: - raise ValueError("Request stream ended unexpectedly") - if current.compute.WhichOneof("type") != "payload": - raise ValueError("Expected a Compute request type with Payload to continue the stream.") - datachunk = current.compute.payload - self.payload.extend(datachunk.data) - - while True: - current = next(self.request_iterator, None) - if current is None: - raise ValueError("Request stream ended unexpectedly") - if current.compute.WhichOneof("type") != "init_data": - raise ValueError("Expected a Compute request type with InitData to continue the stream.") - init_data = current.compute.init_data - if not (init_data.key is None or init_data.key == ""): - chunk = bytearray() - while True: - current = next(self.request_iterator, None) - if current is None: - raise ValueError("Request stream ended unexpectedly") - if current.compute.WhichOneof("type") != "data": - raise ValueError("Expected a Compute request type with Data to continue the stream.") - datachunk = current.compute.data - if datachunk.WhichOneof("type") == "data": - chunk.extend(datachunk.data) - elif datachunk.WhichOneof("type") is None or datachunk.WhichOneof("type") == "": - raise ValueError("Expected a Compute request type with Datachunk Payload to continue the stream.") - elif datachunk.WhichOneof("type") == "data_complete": - break - self.data_dependencies[init_data.key] = chunk - else: - break + self.session_id: str = request.session_id + self.task_id: str = request.task_id + self.task_options: TaskOptions = TaskOptions.from_message(request.task_options) + self.token: str = request.communication_token + self.expected_results: List[str] = list(request.expected_output_keys) + self.configuration: Configuration = request.configuration + self.payload_id: str = request.payload_id + self.data_folder: str = request.data_folder + + # TODO: Lazy load + with open(os.path.join(self.data_folder, self.payload_id), "rb") as f: + self.payload = f.read() + + # TODO: Lazy load + self.data_dependencies: Dict[str, bytes] = {} + for dd in request.data_dependencies: + with open(os.path.join(self.data_folder, dd), "rb") as f: + self.data_dependencies[dd] = f.read() def create_tasks(self, tasks: List[TaskDefinition], task_options: Optional[TaskOptions] = None) -> Tuple[List[Task], List[str]]: """Create new tasks for ArmoniK @@ -122,29 +74,13 @@ def send_result(self, key: str, data: Union[bytes, bytearray]) -> None: key: Result key data: Result data """ - def result_stream(): - res = Result(communication_token=self.token, init=InitKeyedDataStream(key=key)) - assert self.configuration is not None - yield res - start = 0 - data_len = len(data) - while start < data_len: - chunksize = min(self.configuration.data_chunk_max_size, data_len - start) - res = Result(communication_token=self.token, data=DataChunk(data=data[start:start + chunksize])) - yield res - start += chunksize - res = Result(communication_token=self.token, data=DataChunk(data_complete=True)) - yield res - res = Result(communication_token=self.token, init=InitKeyedDataStream(last_result=True)) - yield res - - result_reply = self._client.SendResult(result_stream()) - if result_reply.WhichOneof("type") == "error": - raise Exception(f"Cannot send result id={key}") - - def get_results_ids(self, names : List[str]) -> Dict[str, str]: - return {r.name : r.result_id for r in cast(CreateResultsMetaDataResponse, self._client.CreateResultsMetaData(CreateResultsMetaDataRequest(results=[CreateResultsMetaDataRequest.ResultCreate(name = n) for n in names], session_id=self.session_id, communication_token=self.token))).results} + with open(os.path.join(self.data_folder, key), "wb") as f: + f.write(data) + self._client.NotifyResultData(NotifyResultDataRequest(ids=[NotifyResultDataRequest.ResultIdentifier(session_id=self.session_id, result_id=key)], communication_token=self.token)) + + def get_results_ids(self, names: List[str]) -> Dict[str, str]: + return {r.name: r.result_id for r in cast(CreateResultsMetaDataResponse, self._client.CreateResultsMetaData(CreateResultsMetaDataRequest(results=[CreateResultsMetaDataRequest.ResultCreate(name=n) for n in names], session_id=self.session_id, communication_token=self.token))).results} def _to_request_stream_internal(request, communication_token, is_last, chunk_max_size): diff --git a/packages/python/src/armonik/worker/worker.py b/packages/python/src/armonik/worker/worker.py index 19db04f38..2de34aad1 100644 --- a/packages/python/src/armonik/worker/worker.py +++ b/packages/python/src/armonik/worker/worker.py @@ -9,7 +9,7 @@ from .seqlogger import ClefLogger from ..common import Output, HealthCheckStatus from ..protogen.common.objects_pb2 import Empty -from ..protogen.common.worker_common_pb2 import ProcessReply, HealthCheckReply +from ..protogen.common.worker_common_pb2 import ProcessReply, ProcessRequest, HealthCheckReply from ..protogen.worker.agent_service_pb2_grpc import AgentStub from ..protogen.worker.worker_service_pb2_grpc import WorkerServicer, add_WorkerServicer_to_server from .taskhandler import TaskHandler @@ -46,11 +46,11 @@ def start(self, endpoint: str): server.start() server.wait_for_termination() - def Process(self, request_iterator, context) -> Union[ProcessReply, None]: + def Process(self, request: ProcessRequest, context) -> Union[ProcessReply, None]: try: self._logger.debug("Received task") - task_handler = TaskHandler.create(request_iterator, self._client) - return ProcessReply(communication_token=task_handler.token, output=self.processing_function(task_handler).to_message()) + task_handler = TaskHandler(request, self._client) + return ProcessReply(output=self.processing_function(task_handler).to_message()) except Exception as e: self._logger.exception(f"Failed task {''.join(traceback.format_exception(type(e) ,e, e.__traceback__))}", exc_info=e) diff --git a/packages/python/tests/taskhandler_test.py b/packages/python/tests/taskhandler_test.py index 88cbc3809..e4f3c181c 100644 --- a/packages/python/tests/taskhandler_test.py +++ b/packages/python/tests/taskhandler_test.py @@ -1,18 +1,35 @@ #!/usr/bin/env python3 +import os + import pytest from typing import Iterator -from .common import DummyChannel + from armonik.common import TaskDefinition + +from .common import DummyChannel from armonik.worker import TaskHandler from armonik.protogen.worker.agent_service_pb2_grpc import AgentStub -from armonik.protogen.common.agent_common_pb2 import CreateTaskRequest, CreateTaskReply, Result, ResultReply +from armonik.protogen.common.agent_common_pb2 import CreateTaskRequest, CreateTaskReply, NotifyResultDataRequest, NotifyResultDataResponse from armonik.protogen.common.worker_common_pb2 import ProcessRequest -from armonik.protogen.common.objects_pb2 import Configuration, DataChunk +from armonik.protogen.common.objects_pb2 import Configuration import logging logging.basicConfig() logging.getLogger().setLevel(logging.INFO) +data_folder = os.getcwd() + + +@pytest.fixture(autouse=True, scope="session") +def setup_teardown(): + with open(os.path.join(data_folder, "payloadid"), "wb") as f: + f.write("payload".encode()) + with open(os.path.join(data_folder, "ddid"), "wb") as f: + f.write("dd".encode()) + yield + os.remove(os.path.join(data_folder, "payloadid")) + os.remove(os.path.join(data_folder, "ddid")) + class DummyAgent(AgentStub): @@ -29,94 +46,18 @@ def CreateTask(self, request_iterator: Iterator[CreateTaskRequest]) -> CreateTas task_info=CreateTaskReply.TaskInfo(task_id="TaskId", expected_output_keys=["EOK"], data_dependencies=["DD"]))])) - def SendResult(self, request_iterator: Iterator[Result]) -> ResultReply: - self.send_result_task_message = [r for r in request_iterator] - return ResultReply() - - -class Reqs: - InitData1 = ProcessRequest(communication_token="Token", - compute=ProcessRequest.ComputeRequest( - init_data=ProcessRequest.ComputeRequest.InitData(key="DataKey1"))) - InitData2 = ProcessRequest(communication_token="Token", - compute=ProcessRequest.ComputeRequest( - init_data=ProcessRequest.ComputeRequest.InitData(key="DataKey2"))) - LastDataTrue = ProcessRequest(communication_token="Token", - compute=ProcessRequest.ComputeRequest( - init_data=ProcessRequest.ComputeRequest.InitData(last_data=True))) - LastDataFalse = ProcessRequest(communication_token="Token", - compute=ProcessRequest.ComputeRequest( - init_data=ProcessRequest.ComputeRequest.InitData(last_data=False))) - InitRequestPayload = ProcessRequest(communication_token="Token", - compute=ProcessRequest.ComputeRequest( - init_request=ProcessRequest.ComputeRequest.InitRequest( - payload=DataChunk(data="test".encode("utf-8")), - configuration=Configuration(data_chunk_max_size=100), - expected_output_keys=["EOK"], session_id="SessionId", - task_id="TaskId"))) - InitRequestEmptyPayload = ProcessRequest(communication_token="Token", - compute=ProcessRequest.ComputeRequest( - init_request=ProcessRequest.ComputeRequest.InitRequest( - configuration=Configuration(data_chunk_max_size=100), - expected_output_keys=["EOK"], session_id="SessionId", - task_id="TaskId"))) - Payload1 = ProcessRequest(communication_token="Token", - compute=ProcessRequest.ComputeRequest( - payload=DataChunk(data="Payload1".encode("utf-8")))) - Payload2 = ProcessRequest(communication_token="Token", - compute=ProcessRequest.ComputeRequest( - payload=DataChunk(data="Payload2".encode("utf-8")))) - PayloadComplete = ProcessRequest(communication_token="Token", - compute=ProcessRequest.ComputeRequest(payload=DataChunk(data_complete=True))) - DataComplete = ProcessRequest(communication_token="Token", - compute=ProcessRequest.ComputeRequest(data=DataChunk(data_complete=True))) - Data1 = ProcessRequest(communication_token="Token", - compute=ProcessRequest.ComputeRequest( - data=DataChunk(data="Data1".encode("utf-8")))) - Data2 = ProcessRequest(communication_token="Token", - compute=ProcessRequest.ComputeRequest( - data=DataChunk(data="Data2".encode("utf-8")))) - - -should_throw_cases = [ - [], - [Reqs.InitData1], - [Reqs.InitData2], - [Reqs.LastDataTrue], - [Reqs.LastDataFalse], - [Reqs.InitRequestPayload], - [Reqs.DataComplete], - [Reqs.InitRequestEmptyPayload], - [Reqs.InitRequestPayload, Reqs.PayloadComplete, Reqs.InitData1, Reqs.Data1, Reqs.LastDataTrue], - [Reqs.InitRequestPayload, Reqs.InitData1, Reqs.Data1, Reqs.DataComplete, Reqs.LastDataTrue], - [Reqs.InitRequestPayload, Reqs.PayloadComplete, Reqs.Data1, Reqs.DataComplete, Reqs.LastDataTrue], -] - -should_succeed_cases = [ - [Reqs.InitRequestPayload, Reqs.Payload1, Reqs.Payload2, Reqs.PayloadComplete, Reqs.InitData1, Reqs.Data1, - Reqs.Data2, Reqs.DataComplete, Reqs.InitData2, Reqs.Data1, Reqs.Data2, Reqs.Data2, Reqs.Data2, - Reqs.DataComplete, Reqs.LastDataTrue], - [Reqs.InitRequestPayload, Reqs.Payload1, Reqs.PayloadComplete, Reqs.InitData1, Reqs.Data1, Reqs.DataComplete, - Reqs.LastDataTrue], - [Reqs.InitRequestPayload, Reqs.PayloadComplete, Reqs.InitData1, Reqs.Data1, Reqs.DataComplete, Reqs.LastDataTrue], -] - - -def get_cases(list_requests): - for r in list_requests: - yield iter(r) - - -@pytest.mark.parametrize("requests", get_cases(should_throw_cases)) -def test_taskhandler_create_should_throw(requests: Iterator[ProcessRequest]): - with pytest.raises(ValueError): - TaskHandler.create(requests, DummyAgent(DummyChannel())) - - -@pytest.mark.parametrize("requests", get_cases(should_succeed_cases)) -def test_taskhandler_create_should_succeed(requests: Iterator[ProcessRequest]): + def NotifyResultData(self, request: NotifyResultDataRequest) -> NotifyResultDataResponse: + self.send_result_task_message.append(request) + return NotifyResultDataResponse(result_ids=[i.result_id for i in request.ids]) + + +should_succeed_case = ProcessRequest(communication_token="token", session_id="sessionid", task_id="taskid", expected_output_keys=["resultid"], payload_id="payloadid", data_dependencies=["ddid"], data_folder=data_folder, configuration=Configuration(data_chunk_max_size=8000)) + + +@pytest.mark.parametrize("requests", [should_succeed_case]) +def test_taskhandler_create_should_succeed(requests: ProcessRequest): agent = DummyAgent(DummyChannel()) - task_handler = TaskHandler.create(requests, agent) + task_handler = TaskHandler(requests, agent) assert task_handler.token is not None and len(task_handler.token) > 0 assert len(task_handler.payload) > 0 assert task_handler.session_id is not None and len(task_handler.session_id) > 0 @@ -125,28 +66,8 @@ def test_taskhandler_create_should_succeed(requests: Iterator[ProcessRequest]): def test_taskhandler_data_are_correct(): agent = DummyAgent(DummyChannel()) - task_handler = TaskHandler.create(iter(should_succeed_cases[0]), agent) + task_handler = TaskHandler(should_succeed_case, agent) assert len(task_handler.payload) > 0 - assert task_handler.payload.decode('utf-8') == "testPayload1Payload2" - assert len(task_handler.data_dependencies) == 2 - assert task_handler.data_dependencies["DataKey1"].decode('utf-8') == "Data1Data2" - assert task_handler.data_dependencies["DataKey2"].decode('utf-8') == "Data1Data2Data2Data2" - assert task_handler.task_id == "TaskId" - assert task_handler.session_id == "SessionId" - assert task_handler.token == "Token" - - task_handler.send_result("test", "TestData".encode("utf-8")) - - results = agent.send_result_task_message - assert len(results) == 4 - assert results[0].WhichOneof("type") == "init" - assert results[0].init.key == "test" - assert results[1].WhichOneof("type") == "data" - assert results[1].data.data == "TestData".encode("utf-8") - assert results[2].WhichOneof("type") == "data" - assert results[2].data.data_complete - assert results[3].WhichOneof("type") == "init" - assert results[3].init.last_result task_handler.create_tasks([TaskDefinition("Payload".encode("utf-8"), ["EOK"], ["DD"])]) diff --git a/packages/python/tests/worker_test.py b/packages/python/tests/worker_test.py index 1e4604a31..032c406ee 100644 --- a/packages/python/tests/worker_test.py +++ b/packages/python/tests/worker_test.py @@ -1,9 +1,11 @@ #!/usr/bin/env python3 import logging - +import os +import pytest from armonik.worker import ArmoniKWorker, TaskHandler, ClefLogger from armonik.common import Output -from .taskhandler_test import should_succeed_cases +from .taskhandler_test import should_succeed_case, data_folder, DummyAgent +from .common import DummyChannel from armonik.protogen.common.objects_pb2 import Empty import grpc @@ -20,10 +22,22 @@ def return_error(_: TaskHandler) -> Output: return Output("TestError") +def return_and_send(th: TaskHandler) -> Output: + th.send_result(th.expected_results[0], b"result") + return Output() + + +@pytest.fixture(autouse=True, scope="function") +def remove_result(): + yield + if os.path.exists(os.path.join(data_folder, "resultid")): + os.remove(os.path.join(data_folder, "resultid")) + + def test_do_nothing_worker(): with grpc.insecure_channel("unix:///tmp/agent.sock") as agent_channel: worker = ArmoniKWorker(agent_channel, do_nothing, logger=ClefLogger("TestLogger", level=logging.CRITICAL)) - reply = worker.Process(iter(should_succeed_cases[0]), None) + reply = worker.Process(should_succeed_case, None) assert Output(reply.output.error.details if reply.output.WhichOneof("type") == "error" else None).success worker.HealthCheck(Empty(), None) @@ -31,15 +45,29 @@ def test_do_nothing_worker(): def test_worker_should_return_none(): with grpc.insecure_channel("unix:///tmp/agent.sock") as agent_channel: worker = ArmoniKWorker(agent_channel, throw_error, logger=ClefLogger("TestLogger", level=logging.CRITICAL)) - reply = worker.Process(iter(should_succeed_cases[0]), None) + reply = worker.Process(should_succeed_case, None) assert reply is None def test_worker_should_error(): with grpc.insecure_channel("unix:///tmp/agent.sock") as agent_channel: worker = ArmoniKWorker(agent_channel, return_error, logger=ClefLogger("TestLogger", level=logging.CRITICAL)) - reply = worker.Process(iter(should_succeed_cases[0]), None) + reply = worker.Process(should_succeed_case, None) output = Output(reply.output.error.details if reply.output.WhichOneof("type") == "error" else None) assert not output.success assert output.error == "TestError" + +def test_worker_should_write_result(): + with grpc.insecure_channel("unix:///tmp/agent.sock") as agent_channel: + worker = ArmoniKWorker(agent_channel, return_and_send, logger=ClefLogger("TestLogger", level=logging.DEBUG)) + worker._client = DummyAgent(DummyChannel()) + reply = worker.Process(should_succeed_case, None) + assert reply is not None + output = Output(reply.output.error.details if reply.output.WhichOneof("type") == "error" else None) + assert output.success + assert os.path.exists(os.path.join(data_folder, should_succeed_case.expected_output_keys[0])) + with open(os.path.join(data_folder, should_succeed_case.expected_output_keys[0]), "rb") as f: + value = f.read() + assert len(value) > 0 +