From dfdce4e1ab3dd62f0b847ed380ffcde026c83c7b Mon Sep 17 00:00:00 2001 From: qdelamea Date: Mon, 18 Dec 2023 12:16:53 +0100 Subject: [PATCH 01/15] Python API update Task and Result service --- packages/python/src/armonik/client/results.py | 139 +++++++++++++++++- packages/python/src/armonik/client/tasks.py | 101 ++++++++++++- packages/python/src/armonik/common/objects.py | 4 +- packages/python/src/armonik/utils.py | 26 ++++ 4 files changed, 264 insertions(+), 6 deletions(-) create mode 100644 packages/python/src/armonik/utils.py diff --git a/packages/python/src/armonik/client/results.py b/packages/python/src/armonik/client/results.py index 942add9c1..9bb842e01 100644 --- a/packages/python/src/armonik/client/results.py +++ b/packages/python/src/armonik/client/results.py @@ -4,7 +4,7 @@ from typing import List, Dict, cast, Tuple from ..protogen.client.results_service_pb2_grpc import ResultsStub -from ..protogen.common.results_common_pb2 import CreateResultsMetaDataRequest, CreateResultsMetaDataResponse, ListResultsRequest, ListResultsResponse +from ..protogen.common.results_common_pb2 import CreateResultsMetaDataRequest, CreateResultsMetaDataResponse, ListResultsRequest, ListResultsResponse, GetOwnerTaskIdRequest, GetOwnerTaskIdResponse, CreateResultsMetaDataRequest, CreateResultsMetaDataResponse, CreateResultsRequest, CreateResultsResponse, ResultsServiceConfigurationResponse, DeleteResultsDataRequest, DeleteResultsDataResponse, UploadResultDataRequest, UploadResultDataResponse, DownloadResultDataRequest, DownloadResultDataResponse from ..protogen.common.results_filters_pb2 import Filters as rawFilters, FiltersAnd as rawFilterAnd, FilterField as rawFilterField, FilterStatus as rawFilterStatus from ..protogen.common.results_fields_pb2 import ResultField from ..common.filter import StringFilter, StatusFilter, DateFilter, NumberFilter, Filter @@ -12,6 +12,9 @@ from ..common import Direction , Result from ..protogen.common.results_fields_pb2 import ResultField, ResultRawField, ResultRawEnumField, RESULT_RAW_ENUM_FIELD_STATUS +from ..utils import batched + + class ResultFieldFilter: STATUS = StatusFilter(ResultField(result_raw_field=ResultRawField(field=RESULT_RAW_ENUM_FIELD_STATUS)), rawFilters, rawFilterAnd, rawFilterField, rawFilterStatus) @@ -49,3 +52,137 @@ def list_results(self, result_filter: Filter, page: int = 0, page_size: int = 10 ) list_response: ListResultsResponse = self._client.ListResults(request) return list_response.total, [Result.from_message(r) for r in list_response.results] + + def get_owner_task_id(self, result_ids: List[str], session_id: str, batch_size: int = 500) -> Dict[str, str]: + """Get the IDs of the tasks that should produce the results. + + Args: + result_ids: A list of results. + session_id: The ID of the session to which the results belongs. + batch_size: Batch size for querying. + + Return: + A dictionnary mapping results to owner task ID. + """ + results = {} + for result_ids_batch in batched(result_ids, batch_size): + request = GetOwnerTaskIdRequest(session_id=session_id, result_id=result_ids_batch) + response: GetOwnerTaskIdResponse = self._client.GetOwnerTaskId(request) + for result_task in response.result_task: + results[result_task.result_id] = result_task.task_id + return results + + def create_result_metadata(self, result_names: List[str], session_id: str, batch_size: int) -> List[Result]: + """Create the metadata of multiple results at once. + Data have to be uploaded separately. + + Args: + result_names: The list of the names of the results to create. + session_id: The ID of the session to which the results belongs. + batch_size: Batch size for querying. + + Return: + The list of created results. + """ + results = [] + for result_names_batch in batched(result_names): + request = CreateResultsMetaDataRequest( + results=[CreateResultsMetaDataRequest.ResultCreate(name=result_name) for result_name in result_names_batch], + session_id=session_id + ) + response: CreateResultsMetaDataResponse = self._client.CreateResultsMetaData(request) + results.extend([Result.from_message(result_message) for result_message in response.results]) + return results + + def create_results(self, results_data: Dict[str, bytes], session_id: str, batch_size: int = 1) -> List[Result]: + """Create one result with data included in the request. + + Args: + results_data: A dictionnary mapping the result names to their actual data. + session_id: The ID of the session to which the results belongs. + batch_size: Batch size for querying. + + Return: + The list of created results. + """ + results = [] + for results_data_batch in batched(results_data, batch_size): + request = CreateResultsRequest( + results=[CreateResultsRequest.ResultCreate(name=name, data=data) for name, data in results_data_batch.items()], + session_id=session_id + ) + response: CreateResultsResponse = self._client.CreateResults(request) + results.extend([Result.from_message(message) for message in response.results]) + return results + + def upload_result_data(self, result_id: str, result_data: bytes | bytearray, session_id: str) -> None: + """Upload data for an empty result already created. + + Args: + result_id: The ID of the result. + result_data: The result data. + session_id: The ID of the session. + """ + data_chunk_max_size = self.get_service_config() + + def upload_result_stream(): + request = UploadResultDataRequest( + id=UploadResultDataRequest.ResultIdentifier( + session_id=session_id, result_id=result_id + ) + ) + yield request + + start = 0 + data_len = len(result_data) + while start < data_len: + chunk_size = min(data_chunk_max_size, data_len - start) + request = UploadResultDataRequest( + data_chunk=result_data[start : start + chunksize] + ) + yield request + start += chunk_size + + self._client.UploadResultData(upload_result_stream()) + + def download_result_data(self, result_id: str, session_id: str) -> bytes: + """Retrieve data of a result. + + Args: + result_id: The ID of the result. + session_id: The session of the result. + + Return: + Result data. + """ + request = DownloadResultDataRequest( + result_id=result_id, + session=session_id + ) + streaming_call = self._client.DownloadResultData(request) + return b''.join(streaming_call) + + def delete_result_data(self, result_ids: List[str], session_id: str, batch_size: int) -> None: + """Delete data from multiple results + + Args: + result_ids: The IDs of the results which data must be deleted. + session_id: The ID of the session to which the results belongs. + batch_size: Batch size for querying. + """ + for result_ids_batch in batched(result_ids): + request = DeleteResultsDataRequest( + result_id=result_ids_batch, + session_id=session_id + ) + response: DeleteResultsDataResponse = self._client.DeleteResultsData(request) + assert sorted(result_ids_batch) == sorted(response.result_id) + + def get_service_config(self) -> int: + """Get the configuration of the service. + + Return: + Maximum size supported by a data chunk for the result service. + """ + response: ResultsServiceConfigurationResponse = self._client.GetServiceConfiguration() + return response.data_chunk_max_size diff --git a/packages/python/src/armonik/client/tasks.py b/packages/python/src/armonik/client/tasks.py index 18f7c3478..66a107fac 100644 --- a/packages/python/src/armonik/client/tasks.py +++ b/packages/python/src/armonik/client/tasks.py @@ -1,15 +1,16 @@ from __future__ import annotations from grpc import Channel -from typing import cast, Tuple, List +from typing import cast, Dict, Optional, Tuple, List -from ..common import Task, Direction +from ..common import Task, Direction, TaskDefinition, TaskOptions, TaskStatus from ..common.filter import StringFilter, StatusFilter, DateFilter, NumberFilter, Filter, DurationFilter +from ..protogen.common import objects_pb2 from ..protogen.client.tasks_service_pb2_grpc import TasksStub -from ..protogen.common.tasks_common_pb2 import GetTaskRequest, GetTaskResponse, ListTasksRequest, ListTasksDetailedResponse +from ..protogen.common.tasks_common_pb2 import GetTaskRequest, GetTaskResponse, ListTasksRequest, ListTasksDetailedResponse, CancelTasksRequest, CancelTasksResponse, GetResultIdsRequest, GetResultIdsResponse, SubmitTasksRequest, SubmitTasksResponse, CountTasksByStatusRequest, CountTasksByStatusResponse from ..protogen.common.tasks_filters_pb2 import Filters as rawFilters, FiltersAnd as rawFilterAnd, FilterField as rawFilterField, FilterStatus as rawFilterStatus from ..protogen.common.sort_direction_pb2 import SortDirection - from ..protogen.common.tasks_fields_pb2 import * +from ..utils import batched class TaskFieldFilter: @@ -102,3 +103,95 @@ def list_tasks(self, task_filter: Filter, with_errors: bool = False, page: int = with_errors=with_errors) list_response: ListTasksDetailedResponse = self._client.ListTasksDetailed(request) return list_response.total, [Task.from_message(t) for t in list_response.tasks] + + def cancel_tasks(self, task_ids: List[str], chunk_size: Optional[int] = 500) -> List[Task]: + """Cancel tasks. + + Args: + task_ids: IDs of the tasks. + chunk_size: Batch size for cancelling. + + Return: + The list of cancelled tasks. + """ + cancelled_tasks = [] + for task_id_batch in batched(task_ids, chunk_size): + request = CancelTasksRequest(task_ids=task_id_batch) + cancel_tasks_response: CancelTasksResponse = self._client.CancelTasks(request) + cancelled_tasks.extend([Task.from_message(t) for t in cancel_tasks_response.tasks]) + return cancelled_tasks + + def get_result_ids(self, task_ids: List[str], chunk_size: Optional[int] = 500) -> Dict[str, List[str]]: + """Get result IDs of a list of tasks. + + Args: + task_ids: The IDs of the tasks. + chunk_size: Batch size for retrieval. + + Return: + A dictionary mapping the ID of a task to the IDs of its results.. + """ + tasks_result_ids = {} + + for task_id_batch in batched(task_ids, chunk_size): + request = GetResultIdsRequest(task_ids=task_id_batch) + result_ids_response: GetResultIdsResponse = self._client.GetResultIds(request) + for t in result_ids_response.task_results: + tasks_result_ids[t.task_id] = [id for id in t.result_ids] + return tasks_result_ids + + def count_tasks_by_status(self, task_filter: List[Filter]) -> Dict[TaskStatus, int]: + """Get number of tasks by status. + + Args: + task_filter: Filter for the tasks to be listed + + Return: + A dictionnary mapping each status to the number of filtered tasks. + """ + request = CountTasksByStatusRequest(filters=cast(rawFilters, task_filter.to_disjunction().to_message())) + count_tasks_by_status_response: CountTasksByStatusResponse = self._client.CountTasksByStatus(request) + return {TaskStatus(status_count.status): status_count.count for status_count in count_tasks_by_status_response.status} + + def submit_tasks(self, session_id: str, tasks: List[TaskDefinition], default_task_options: Optional[TaskOptions] = None, chunk_size: Optional[int] = 100) -> List[Task]: + """Submit tasks to ArmoniK. + + Args: + session_id: Session Id + tasks: List of task definitions + default_task_options: Default Task Options used if a task has its options not set + chunk_size: Batch size for submission + + Returns: + Tuple containing the list of successfully sent tasks, and + the list of submission errors if any + """ + + tasks_submitted = [] + + for task_batch in batched(tasks, chunk_size): + request = SubmitTasksRequest(session_id=session_id, task_options=TaskOptions(objects_pb2.TaskOptions(**default_task_options.__dict__()))) + + task_creations = [] + + for t in task_batch: + task_creation = SubmitTasksRequest.TaskCreation( + expected_output_keys=t.expected_output_ids, + payload_id=t.payload_id, + data_dependencies=t.data_dependencies if t.data_dependencies else None, + task_options=objects_pb2.TaskOptions(**t.options.__dict__()) + ) + task_creations.append(task_creation) + + request.task_creations = task_creations + + submit_tasks_reponse:SubmitTasksResponse = self._client.SubmitTasks(request) + + for task_info in submit_tasks_reponse.task_infos: + tasks_submitted.append(Task(id=task_info.task_id, + session_id=session_id, + expected_output_ids=[k for k in task_info.expected_output_ids], + data_dependencies=[k for k in task_info.data_dependencies]), + payload_id=task_info.payload_id + ) + return tasks_submitted diff --git a/packages/python/src/armonik/common/objects.py b/packages/python/src/armonik/common/objects.py index 4f9527377..e83ce75f6 100644 --- a/packages/python/src/armonik/common/objects.py +++ b/packages/python/src/armonik/common/objects.py @@ -70,9 +70,10 @@ def to_message(self): @dataclass() class TaskDefinition: - payload: bytes + payload_id: str expected_output_ids: List[str] = field(default_factory=list) data_dependencies: List[str] = field(default_factory=list) + options: Optional[TaskOptions] = field(default_factory=TaskOptions) def __post_init__(self): if len(self.expected_output_ids) <= 0: @@ -89,6 +90,7 @@ class Task: expected_output_ids: List[str] = field(default_factory=list) retry_of_ids: List[str] = field(default_factory=list) status: RawTaskStatus = TaskStatus.UNSPECIFIED + payload_id: Optional[str] = None status_message: Optional[str] = None options: Optional[TaskOptions] = None created_at: Optional[datetime] = None diff --git a/packages/python/src/armonik/utils.py b/packages/python/src/armonik/utils.py new file mode 100644 index 000000000..7d79c9105 --- /dev/null +++ b/packages/python/src/armonik/utils.py @@ -0,0 +1,26 @@ +from typing import Iterable, List, TypeVar + +T = TypeVar('T') + +def batched(iterable: Iterable[T], n: int) -> Iterable[List[T]]: + """ + Batches elements from an iterable into lists of size at most 'n'. + + Args: + iterable : The input iterable. + n : The batch size. + + Yields: + A generator yielding batches of elements from the input iterable. + """ + it = iter(iterable) + while True: + batch = [] + try: + for i in range(n): + batch.append(next(it)) + except StopIteration: + if len(batch) > 0: + yield batch + break + yield batch From 62c86b983b977518d31efa3bfbfc841c7fd1e922 Mon Sep 17 00:00:00 2001 From: qdelamea Date: Wed, 20 Dec 2023 10:21:27 +0100 Subject: [PATCH 02/15] Python API update Results Tasks Events Partitions and Versions services --- .../python/src/armonik/client/__init__.py | 4 +- packages/python/src/armonik/client/events.py | 63 ++++++++++++++++++ .../python/src/armonik/client/partitions.py | 65 +++++++++++++++++++ packages/python/src/armonik/client/results.py | 63 +++++++++++------- .../python/src/armonik/client/sessions.py | 22 +++++-- packages/python/src/armonik/client/tasks.py | 9 ++- .../python/src/armonik/client/versions.py | 26 ++++++++ .../python/src/armonik/common/__init__.py | 7 +- .../python/src/armonik/common/enumwrapper.py | 14 ++++ packages/python/src/armonik/common/events.py | 49 ++++++++++++++ packages/python/src/armonik/common/helpers.py | 29 ++++++++- packages/python/src/armonik/common/objects.py | 23 +++++++ packages/python/src/armonik/utils.py | 26 -------- 13 files changed, 337 insertions(+), 63 deletions(-) create mode 100644 packages/python/src/armonik/client/events.py create mode 100644 packages/python/src/armonik/client/partitions.py create mode 100644 packages/python/src/armonik/client/versions.py create mode 100644 packages/python/src/armonik/common/events.py delete mode 100644 packages/python/src/armonik/utils.py diff --git a/packages/python/src/armonik/client/__init__.py b/packages/python/src/armonik/client/__init__.py index e94d7dde9..79391671e 100644 --- a/packages/python/src/armonik/client/__init__.py +++ b/packages/python/src/armonik/client/__init__.py @@ -1,3 +1,5 @@ +from .partitions import ArmoniKPartitions +from .sessions import ArmoniKSessions from .submitter import ArmoniKSubmitter from .tasks import ArmoniKTasks -from .results import ArmoniKResult +from .results import ArmoniKResults diff --git a/packages/python/src/armonik/client/events.py b/packages/python/src/armonik/client/events.py new file mode 100644 index 000000000..4fc96d0bc --- /dev/null +++ b/packages/python/src/armonik/client/events.py @@ -0,0 +1,63 @@ +from typing import Any, Callable, cast, List + +from grpc import Channel + +from ..common import EventTypes, Filter, NewTaskEvent, NewResultEvent, ResultOwnerUpdateEvent, ResultStatusUpdateEvent, TaskStatusUpdateEvent, ResultStatus +from .results import ResultFieldFilter +from ..protogen.client.events_service_pb2_grpc import EventsStub +from ..protogen.common.events_common_pb2 import EventSubscriptionRequest, EventSubscriptionResponse +from ..protogen.common.results_filters_pb2 import Filters as rawResultFilters +from ..protogen.common.tasks_filters_pb2 import Filters as rawTaskFilters + +class ArmoniKEvents: + + _events_obj_mapping = { + "new_result": NewResultEvent, + "new_task": NewTaskEvent, + "result_owner_update": ResultOwnerUpdateEvent, + "result_status_update": ResultStatusUpdateEvent, + "task_status_update": TaskStatusUpdateEvent + } + + def __init__(self, grpc_channel: Channel): + """Events service client + + Args: + grpc_channel: gRPC channel to use + """ + self._client = EventsStub(grpc_channel) + + def get_events(self, session_id: str, event_types: List[EventTypes], event_handlers: List[Callable[[str, Any], bool]], task_filter: Filter, result_filter: Filter) -> None: + """Get events that represents updates of result and tasks data. + + Args: + session_id: The ID of the session. + event_types: The list of the types of event to catch. + event_handlers: The list of handlers that process the events. Handlers are evaluated in he order they are provided. + An handler takes three positional arguments: the ID of the session, the type of event and the event as an object. + An handler returns a boolean, if True the process continues, otherwise the stream is closed and the service stops + listening to new events. + task_filter: A filter on tasks. + result_filter: A filter on results. + + """ + request = EventSubscriptionRequest( + session_id=session_id, + tasks_filters=cast(rawTaskFilters, task_filter.to_disjunction().to_message()), + results_filters=cast(rawResultFilters, result_filter.to_disjunction().to_message()), + returned_events=event_types + ) + streaming_call = self._client.GetEvents(request) + for message in streaming_call: + event_type = message.WhichOneof("update") + if any([event_handler(session_id, EventTypes.from_string(event_type), self._events_obj_mapping[event_type].from_raw_event(getattr(message, event_type))) for event_handler in event_handlers]): + break + + def wait_for_result_availability(self, result_id: str, session_id: str) -> bool: + def handler(session_id, event_type, event): + if not isinstance(event, ResultStatusUpdateEvent): + raise ValueError("Handler should receive event of type 'ResultStatusUpdateEvent'.") + if event.status == ResultStatus.COMPLETED: + return False + return True + self.get_events(session_id, [EventTypes.RESULT_STATUS_UPDATE], lambda _: False, result_filter=(ResultFieldFilter.RESULT_ID == result_id)) diff --git a/packages/python/src/armonik/client/partitions.py b/packages/python/src/armonik/client/partitions.py new file mode 100644 index 000000000..fb5e0421b --- /dev/null +++ b/packages/python/src/armonik/client/partitions.py @@ -0,0 +1,65 @@ +from typing import cast, List, Tuple + +from grpc import Channel + +from ..common import Direction, Partition +from ..common.filter import Filter, NumberFilter +from ..protogen.client.partitions_service_pb2_grpc import PartitionsStub +from ..protogen.common.partitions_common_pb2 import ListPartitionsRequest, ListPartitionsResponse, GetPartitionRequest, GetPartitionResponse +from ..protogen.common.partitions_fields_pb2 import PartitionField, PartitionRawField, PARTITION_RAW_ENUM_FIELD_PRIORITY +from ..protogen.common.partitions_filters_pb2 import Filters as rawFilters, FiltersAnd as rawFiltersAnd, FilterField as rawFilterField +from ..protogen.common.sort_direction_pb2 import SortDirection + + +class PartitionFieldFilter: + PRIORITY = NumberFilter( + PartitionField(partition_raw_field=PartitionRawField(field=PARTITION_RAW_ENUM_FIELD_PRIORITY)), + rawFilters, + rawFiltersAnd, + rawFilterField + ) + + +class ArmoniKPartitions: + def __init__(self, grpc_channel: Channel): + """ Result service client + + Args: + grpc_channel: gRPC channel to use + """ + self._client = PartitionsStub(grpc_channel) + + def list_partitions(self, partition_filter: Filter, page: int = 0, page_size: int = 1000, sort_field: Filter = PartitionFieldFilter.PRIORITY, sort_direction: SortDirection = Direction.ASC) -> Tuple[int, List[Partition]]: + """List partitions based on a filter. + + Args: + partition_filter: Filter to apply when listing partitions + page: page number to request, useful for pagination, defaults to 0 + page_size: size of a page, defaults to 1000 + sort_field: field to sort the resulting list by, defaults to the status + sort_direction: direction of the sort, defaults to ascending + + Returns: + A tuple containing : + - The total number of results for the given filter + - The obtained list of results + """ + request = ListPartitionsRequest( + page=page, + page_size=page_size, + filters=cast(rawFilters, partition_filter.to_disjunction().to_message()), + sort=ListPartitionsRequest.Sort(field=cast(PartitionField, sort_field.field), direction=sort_direction), + ) + response: ListPartitionsResponse = self._client.ListPartitions(request) + return response.total, [Partition.from_message(p) for p in response.partitions] + + def get_partition(self, partition_id: str) -> Partition: + """Get a partition by its ID. + + Args: + partition_id: The partition ID. + + Return: + The partition summary. + """ + return Partition.from_message(self._client.GetPartition(GetPartitionRequest(id=partition_id)).partition) diff --git a/packages/python/src/armonik/client/results.py b/packages/python/src/armonik/client/results.py index 9bb842e01..81e1d9554 100644 --- a/packages/python/src/armonik/client/results.py +++ b/packages/python/src/armonik/client/results.py @@ -4,21 +4,22 @@ from typing import List, Dict, cast, Tuple from ..protogen.client.results_service_pb2_grpc import ResultsStub -from ..protogen.common.results_common_pb2 import CreateResultsMetaDataRequest, CreateResultsMetaDataResponse, ListResultsRequest, ListResultsResponse, GetOwnerTaskIdRequest, GetOwnerTaskIdResponse, CreateResultsMetaDataRequest, CreateResultsMetaDataResponse, CreateResultsRequest, CreateResultsResponse, ResultsServiceConfigurationResponse, DeleteResultsDataRequest, DeleteResultsDataResponse, UploadResultDataRequest, UploadResultDataResponse, DownloadResultDataRequest, DownloadResultDataResponse -from ..protogen.common.results_filters_pb2 import Filters as rawFilters, FiltersAnd as rawFilterAnd, FilterField as rawFilterField, FilterStatus as rawFilterStatus +from ..protogen.common.results_common_pb2 import CreateResultsMetaDataRequest, CreateResultsMetaDataResponse, ListResultsRequest, ListResultsResponse, GetOwnerTaskIdRequest, GetOwnerTaskIdResponse, CreateResultsMetaDataRequest, CreateResultsMetaDataResponse, CreateResultsRequest, CreateResultsResponse, ResultsServiceConfigurationResponse, DeleteResultsDataRequest, DeleteResultsDataResponse, UploadResultDataRequest, UploadResultDataResponse, DownloadResultDataRequest, DownloadResultDataResponse, GetResultRequest, GetResultResponse +from ..protogen.common.results_filters_pb2 import Filters as rawFilters, FiltersAnd as rawFilterAnd, FilterField as rawFilterField, FilterStatus as rawFilterStatus, FilterString as rawFilterString from ..protogen.common.results_fields_pb2 import ResultField +from ..protogen.common.objects_pb2 import Empty from ..common.filter import StringFilter, StatusFilter, DateFilter, NumberFilter, Filter from ..protogen.common.sort_direction_pb2 import SortDirection from ..common import Direction , Result -from ..protogen.common.results_fields_pb2 import ResultField, ResultRawField, ResultRawEnumField, RESULT_RAW_ENUM_FIELD_STATUS - -from ..utils import batched +from ..protogen.common.results_fields_pb2 import ResultField, ResultRawField, ResultRawEnumField, RESULT_RAW_ENUM_FIELD_STATUS, RESULT_RAW_ENUM_FIELD_RESULT_ID +from ..common.helpers import batched class ResultFieldFilter: STATUS = StatusFilter(ResultField(result_raw_field=ResultRawField(field=RESULT_RAW_ENUM_FIELD_STATUS)), rawFilters, rawFilterAnd, rawFilterField, rawFilterStatus) + RESULT_ID = StringFilter(ResultField(result_raw_field=ResultRawField(field=RESULT_RAW_ENUM_FIELD_RESULT_ID)), rawFilters, rawFilterAnd, rawFilterField, rawFilterString) -class ArmoniKResult: +class ArmoniKResults: def __init__(self, grpc_channel: Channel): """ Result service client @@ -53,6 +54,19 @@ def list_results(self, result_filter: Filter, page: int = 0, page_size: int = 10 list_response: ListResultsResponse = self._client.ListResults(request) return list_response.total, [Result.from_message(r) for r in list_response.results] + def get_result(self, result_id: str) -> Result: + """Get a result by id. + + Args: + result_id: The ID of the result. + + Return: + The result summary. + """ + request = GetResultRequest(result_id=result_id) + response: GetResultResponse = self._client.GetResult(request) + return Result.from_message(response.result) + def get_owner_task_id(self, result_ids: List[str], session_id: str, batch_size: int = 500) -> Dict[str, str]: """Get the IDs of the tasks that should produce the results. @@ -72,7 +86,7 @@ def get_owner_task_id(self, result_ids: List[str], session_id: str, batch_size: results[result_task.result_id] = result_task.task_id return results - def create_result_metadata(self, result_names: List[str], session_id: str, batch_size: int) -> List[Result]: + def create_result_metadata(self, result_names: List[str], session_id: str, batch_size: int = 100) -> Dict[str, Result]: """Create the metadata of multiple results at once. Data have to be uploaded separately. @@ -82,19 +96,20 @@ def create_result_metadata(self, result_names: List[str], session_id: str, batch batch_size: Batch size for querying. Return: - The list of created results. + A dictionnary mapping each result name to its corresponding result summary. """ - results = [] - for result_names_batch in batched(result_names): + results = {} + for result_names_batch in batched(result_names, batch_size): request = CreateResultsMetaDataRequest( results=[CreateResultsMetaDataRequest.ResultCreate(name=result_name) for result_name in result_names_batch], session_id=session_id ) response: CreateResultsMetaDataResponse = self._client.CreateResultsMetaData(request) - results.extend([Result.from_message(result_message) for result_message in response.results]) + for result_message in response.results: + results[result_message.name] = Result.from_message(result_message) return results - def create_results(self, results_data: Dict[str, bytes], session_id: str, batch_size: int = 1) -> List[Result]: + def create_results(self, results_data: Dict[str, bytes], session_id: str, batch_size: int = 1) -> Dict[str, Result]: """Create one result with data included in the request. Args: @@ -103,19 +118,20 @@ def create_results(self, results_data: Dict[str, bytes], session_id: str, batch_ batch_size: Batch size for querying. Return: - The list of created results. + A dictionnary mappin each result name to its corresponding result summary. """ - results = [] + results = {} for results_data_batch in batched(results_data, batch_size): request = CreateResultsRequest( results=[CreateResultsRequest.ResultCreate(name=name, data=data) for name, data in results_data_batch.items()], session_id=session_id ) response: CreateResultsResponse = self._client.CreateResults(request) - results.extend([Result.from_message(message) for message in response.results]) + for message in response.results: + results[message.name] = Result.from_message(message) return results - def upload_result_data(self, result_id: str, result_data: bytes | bytearray, session_id: str) -> None: + def upload_result_data(self, result_id: str, session_id: str, result_data: bytes | bytearray) -> None: """Upload data for an empty result already created. Args: @@ -138,7 +154,7 @@ def upload_result_stream(): while start < data_len: chunk_size = min(data_chunk_max_size, data_len - start) request = UploadResultDataRequest( - data_chunk=result_data[start : start + chunksize] + data_chunk=result_data[start : start + chunk_size] ) yield request start += chunk_size @@ -157,12 +173,12 @@ def download_result_data(self, result_id: str, session_id: str) -> bytes: """ request = DownloadResultDataRequest( result_id=result_id, - session=session_id + session_id=session_id ) streaming_call = self._client.DownloadResultData(request) - return b''.join(streaming_call) + return b''.join([message.data_chunk for message in streaming_call]) - def delete_result_data(self, result_ids: List[str], session_id: str, batch_size: int) -> None: + def delete_result_data(self, result_ids: List[str], session_id: str, batch_size: int = 100) -> None: """Delete data from multiple results Args: @@ -170,7 +186,7 @@ def delete_result_data(self, result_ids: List[str], session_id: str, batch_size: session_id: The ID of the session to which the results belongs. batch_size: Batch size for querying. """ - for result_ids_batch in batched(result_ids): + for result_ids_batch in batched(result_ids, batch_size): request = DeleteResultsDataRequest( result_id=result_ids_batch, session_id=session_id @@ -184,5 +200,8 @@ def get_service_config(self) -> int: Return: Maximum size supported by a data chunk for the result service. """ - response: ResultsServiceConfigurationResponse = self._client.GetServiceConfiguration() + response: ResultsServiceConfigurationResponse = self._client.GetServiceConfiguration(Empty()) return response.data_chunk_max_size + + def watch_results(self): + raise NotImplementedError() diff --git a/packages/python/src/armonik/client/sessions.py b/packages/python/src/armonik/client/sessions.py index 8f676144d..b6795334b 100644 --- a/packages/python/src/armonik/client/sessions.py +++ b/packages/python/src/armonik/client/sessions.py @@ -60,7 +60,20 @@ def create_session(self, default_task_options: TaskOptions, partition_ids: Optio request.partition_ids.append(partition) return self._client.CreateSession(request).session_id - def list_sessions(self, task_filter: Filter, page: int = 0, page_size: int = 1000, sort_field: Filter = SessionFieldFilter.STATUS, sort_direction: SortDirection = Direction.ASC) -> Tuple[int, List[Session]]: + def get_session(self, session_id: str): + """Get a session by its ID. + + Args: + session_id: The ID of the session. + + Return: + The session summary. + """ + request = GetSessionRequest(session_id=session_id) + response: GetSessionResponse = self._client.GetSession(request) + return Session.from_message(response.session) + + def list_sessions(self, session_filter: Filter , page: int = 0, page_size: int = 1000, sort_field: Filter = SessionFieldFilter.STATUS, sort_direction: SortDirection = Direction.ASC) -> Tuple[int, List[Session]]: """ List sessions @@ -76,14 +89,14 @@ def list_sessions(self, task_filter: Filter, page: int = 0, page_size: int = 100 - The total number of sessions for the given filter - The obtained list of sessions """ - request : ListSessionsRequest = ListSessionsRequest( + request = ListSessionsRequest( page=page, page_size=page_size, - filters=cast(rawFilters, task_filter.to_disjunction().to_message()), + filters=cast(rawFilters, session_filter.to_disjunction().to_message()), sort=ListSessionsRequest.Sort(field=cast(SessionField, sort_field.field), direction=sort_direction), ) list_response : ListSessionsResponse = self._client.ListSessions(request) - return list_response.total, [Session.from_message(t) for t in list_response.sessions] + return list_response.total, [Session.from_message(s) for s in list_response.sessions] def cancel_session(self, session_id: str) -> None: """Cancel a session @@ -92,4 +105,3 @@ def cancel_session(self, session_id: str) -> None: session_id: Id of the session to b cancelled """ self._client.CancelSession(CancelSessionRequest(session_id=session_id)) - \ No newline at end of file diff --git a/packages/python/src/armonik/client/tasks.py b/packages/python/src/armonik/client/tasks.py index 66a107fac..2905c7477 100644 --- a/packages/python/src/armonik/client/tasks.py +++ b/packages/python/src/armonik/client/tasks.py @@ -4,7 +4,6 @@ from ..common import Task, Direction, TaskDefinition, TaskOptions, TaskStatus from ..common.filter import StringFilter, StatusFilter, DateFilter, NumberFilter, Filter, DurationFilter -from ..protogen.common import objects_pb2 from ..protogen.client.tasks_service_pb2_grpc import TasksStub from ..protogen.common.tasks_common_pb2 import GetTaskRequest, GetTaskResponse, ListTasksRequest, ListTasksDetailedResponse, CancelTasksRequest, CancelTasksResponse, GetResultIdsRequest, GetResultIdsResponse, SubmitTasksRequest, SubmitTasksResponse, CountTasksByStatusRequest, CountTasksByStatusResponse from ..protogen.common.tasks_filters_pb2 import Filters as rawFilters, FiltersAnd as rawFilterAnd, FilterField as rawFilterField, FilterStatus as rawFilterStatus @@ -169,17 +168,17 @@ def submit_tasks(self, session_id: str, tasks: List[TaskDefinition], default_tas tasks_submitted = [] - for task_batch in batched(tasks, chunk_size): - request = SubmitTasksRequest(session_id=session_id, task_options=TaskOptions(objects_pb2.TaskOptions(**default_task_options.__dict__()))) + for tasks_batch in batched(tasks, chunk_size): + request = SubmitTasksRequest(session_id=session_id, task_options=default_task_options.to_message()) task_creations = [] - for t in task_batch: + for t in tasks_batch: task_creation = SubmitTasksRequest.TaskCreation( expected_output_keys=t.expected_output_ids, payload_id=t.payload_id, data_dependencies=t.data_dependencies if t.data_dependencies else None, - task_options=objects_pb2.TaskOptions(**t.options.__dict__()) + task_options=t.options.to_message() ) task_creations.append(task_creation) diff --git a/packages/python/src/armonik/client/versions.py b/packages/python/src/armonik/client/versions.py new file mode 100644 index 000000000..db6f23a69 --- /dev/null +++ b/packages/python/src/armonik/client/versions.py @@ -0,0 +1,26 @@ +from typing import Dict + +from grpc import Channel + +from ..protogen.client.versions_service_pb2_grpc import VersionsStub +from ..protogen.common.versions_common_pb2 import ListVersionsRequest, ListVersionsResponse + + +class ArmoniKVersions: + def __init__(self, grpc_channel: Channel): + """ Result service client + + Args: + grpc_channel: gRPC channel to use + """ + self._client = VersionsStub(grpc_channel) + + def list_versions(self) -> Dict[str, str]: + """Get versions of ArmoniK components. + + Return: + A dictionnary mapping each component to its version. + """ + request = ListVersionsRequest() + response: ListVersionsResponse = self._client.ListVersions(request) + return {"core": response.core, "api": response.api} diff --git a/packages/python/src/armonik/common/__init__.py b/packages/python/src/armonik/common/__init__.py index 001721868..fd1c14e79 100644 --- a/packages/python/src/armonik/common/__init__.py +++ b/packages/python/src/armonik/common/__init__.py @@ -1,4 +1,5 @@ from .helpers import datetime_to_timestamp, timestamp_to_datetime, duration_to_timedelta, timedelta_to_duration, get_task_filter -from .objects import Task, TaskDefinition, TaskOptions, Output, ResultAvailability, Session, Result -from .enumwrapper import HealthCheckStatus, TaskStatus, Direction -from .filter import StringFilter, StatusFilter +from .objects import Task, TaskDefinition, TaskOptions, Output, ResultAvailability, Session, Result, Partition +from .enumwrapper import HealthCheckStatus, TaskStatus, Direction, SessionStatus, ResultStatus, EventTypes +from .events import * +from .filter import Filter, StringFilter, StatusFilter diff --git a/packages/python/src/armonik/common/enumwrapper.py b/packages/python/src/armonik/common/enumwrapper.py index 9c19a9a82..9aba6d2cd 100644 --- a/packages/python/src/armonik/common/enumwrapper.py +++ b/packages/python/src/armonik/common/enumwrapper.py @@ -1,6 +1,7 @@ from __future__ import annotations from ..protogen.common.task_status_pb2 import TaskStatus as RawStatus, _TASKSTATUS, TASK_STATUS_CANCELLED, TASK_STATUS_CANCELLING, TASK_STATUS_COMPLETED, TASK_STATUS_CREATING, TASK_STATUS_DISPATCHED, TASK_STATUS_ERROR, TASK_STATUS_PROCESSED, TASK_STATUS_PROCESSING, TASK_STATUS_SUBMITTED, TASK_STATUS_TIMEOUT, TASK_STATUS_UNSPECIFIED, TASK_STATUS_RETRIED +from ..protogen.common.events_common_pb2 import EventsEnum as rawEventsEnum, EVENTS_ENUM_UNSPECIFIED, EVENTS_ENUM_NEW_TASK, EVENTS_ENUM_TASK_STATUS_UPDATE, EVENTS_ENUM_NEW_RESULT, EVENTS_ENUM_RESULT_STATUS_UPDATE, EVENTS_ENUM_RESULT_OWNER_UPDATE from ..protogen.common.session_status_pb2 import SessionStatus as RawSessionStatus, _SESSIONSTATUS, SESSION_STATUS_UNSPECIFIED, SESSION_STATUS_CANCELLED, SESSION_STATUS_RUNNING from ..protogen.common.result_status_pb2 import ResultStatus as RawResultStatus, _RESULTSTATUS, RESULT_STATUS_UNSPECIFIED, RESULT_STATUS_CREATED, RESULT_STATUS_COMPLETED, RESULT_STATUS_ABORTED, RESULT_STATUS_NOTFOUND from ..protogen.common.worker_common_pb2 import HealthCheckReply @@ -58,3 +59,16 @@ def name_from_value(status: RawResultStatus) -> str: COMPLETED = RESULT_STATUS_COMPLETED ABORTED = RESULT_STATUS_ABORTED NOTFOUND = RESULT_STATUS_NOTFOUND + + +class EventTypes: + UNSPECIFIED = EVENTS_ENUM_UNSPECIFIED + NEW_TASK = EVENTS_ENUM_NEW_TASK + TASK_STATUS_UPDATE = EVENTS_ENUM_TASK_STATUS_UPDATE + NEW_RESULT = EVENTS_ENUM_NEW_RESULT + RESULT_STATUS_UPDATE = EVENTS_ENUM_RESULT_STATUS_UPDATE + RESULT_OWNER_UPDATE = EVENTS_ENUM_RESULT_OWNER_UPDATE + + @classmethod + def from_string(cls, name: str): + return getattr(cls, name.upper()) diff --git a/packages/python/src/armonik/common/events.py b/packages/python/src/armonik/common/events.py new file mode 100644 index 000000000..3fa9fe225 --- /dev/null +++ b/packages/python/src/armonik/common/events.py @@ -0,0 +1,49 @@ +from abc import ABC +from typing import List + +from dataclasses import dataclass, fields + +from .enumwrapper import TaskStatus, ResultStatus + + +@dataclass +class Event(ABC): + @classmethod + def from_raw_event(cls, raw_event): + values = {} + for raw_field in fields(cls): + values[raw_field] = getattr(raw_event, raw_field) + return cls(**values) + + +class TaskStatusUpdateEvent(Event): + task_id: str + status: TaskStatus + + +class ResultStatusUpdateEvent(Event): + result_id: str + status: ResultStatus + + +class ResultOwnerUpdateEvent(Event): + result_id: str + previous_owner_id: str + current_owner_id: str + + +class NewTaskEvent(Event): + task_id: str + payload_id: str + origin_task_id: str + status: TaskStatus + expected_output_keys: List[str] + data_dependencies: List[str] + retry_of_ids: List[str] + parent_task_ids: List[str] + + +class NewResultEvent(Event): + result_id: str + owner_id: str + status: ResultStatus diff --git a/packages/python/src/armonik/common/helpers.py b/packages/python/src/armonik/common/helpers.py index e174e2f42..22d389fca 100644 --- a/packages/python/src/armonik/common/helpers.py +++ b/packages/python/src/armonik/common/helpers.py @@ -1,6 +1,6 @@ from __future__ import annotations from datetime import timedelta, datetime, timezone -from typing import List, Optional +from typing import List, Optional, Iterable, TypeVar import google.protobuf.duration_pb2 as duration import google.protobuf.timestamp_pb2 as timestamp @@ -9,6 +9,9 @@ from .enumwrapper import TaskStatus +T = TypeVar('T') + + def get_task_filter(session_ids: Optional[List[str]] = None, task_ids: Optional[List[str]] = None, included_statuses: Optional[List[TaskStatus]] = None, excluded_statuses: Optional[List[TaskStatus]] = None) -> TaskFilter: @@ -96,3 +99,27 @@ def timedelta_to_duration(delta: timedelta) -> duration.Duration: d = duration.Duration() d.FromTimedelta(delta) return d + + +def batched(iterable: Iterable[T], n: int) -> Iterable[List[T]]: + """ + Batches elements from an iterable into lists of size at most 'n'. + + Args: + iterable : The input iterable. + n : The batch size. + + Yields: + A generator yielding batches of elements from the input iterable. + """ + it = iter(iterable) + while True: + batch = [] + try: + for i in range(n): + batch.append(next(it)) + except StopIteration: + if len(batch) > 0: + yield batch + break + yield batch diff --git a/packages/python/src/armonik/common/objects.py b/packages/python/src/armonik/common/objects.py index e83ce75f6..745ba7e4a 100644 --- a/packages/python/src/armonik/common/objects.py +++ b/packages/python/src/armonik/common/objects.py @@ -8,6 +8,7 @@ from ..protogen.common.objects_pb2 import Empty, Output as WorkerOutput, TaskOptions as RawTaskOptions from ..protogen.common.task_status_pb2 import TaskStatus as RawTaskStatus from .enumwrapper import TaskStatus, SessionStatus, ResultStatus +from ..protogen.common.partitions_common_pb2 import PartitionRaw from ..protogen.common.session_status_pb2 import SessionStatus as RawSessionStatus from ..protogen.common.sessions_common_pb2 import SessionRaw from ..protogen.common.result_status_pb2 import ResultStatus as RawResultStatus @@ -213,3 +214,25 @@ def from_message(cls, result_raw: ResultRaw) -> "Result": result_id=result_raw.result_id, size=result_raw.size ) + +@dataclass +class Partition: + id: str + parent_partition_ids: List[str] + pod_reserved: int + pod_max: int + pod_configuration: Dict[str, str] + preemption_percentage: int + priority: int + + @classmethod + def from_message(cls, partition_raw: PartitionRaw) -> "Partition": + return cls( + id=partition_raw.id, + parent_partition_ids=partition_raw.parent_partition_ids, + pod_reserved=partition_raw.pod_reserved, + pod_max=partition_raw.pod_max, + pod_configuration=partition_raw.pod_configuration, + preemption_percentage=partition_raw.preemption_percentage, + priority=partition_raw.priority + ) diff --git a/packages/python/src/armonik/utils.py b/packages/python/src/armonik/utils.py deleted file mode 100644 index 7d79c9105..000000000 --- a/packages/python/src/armonik/utils.py +++ /dev/null @@ -1,26 +0,0 @@ -from typing import Iterable, List, TypeVar - -T = TypeVar('T') - -def batched(iterable: Iterable[T], n: int) -> Iterable[List[T]]: - """ - Batches elements from an iterable into lists of size at most 'n'. - - Args: - iterable : The input iterable. - n : The batch size. - - Yields: - A generator yielding batches of elements from the input iterable. - """ - it = iter(iterable) - while True: - batch = [] - try: - for i in range(n): - batch.append(next(it)) - except StopIteration: - if len(batch) > 0: - yield batch - break - yield batch From f89b0ac4464aef957b1579c3fd101125dd999faf Mon Sep 17 00:00:00 2001 From: qdelamea Date: Wed, 20 Dec 2023 11:32:23 +0100 Subject: [PATCH 03/15] Python API update task handler --- .../python/src/armonik/worker/taskhandler.py | 99 ++++++++++++++++--- 1 file changed, 88 insertions(+), 11 deletions(-) diff --git a/packages/python/src/armonik/worker/taskhandler.py b/packages/python/src/armonik/worker/taskhandler.py index 49eeb8ff3..b321f651b 100644 --- a/packages/python/src/armonik/worker/taskhandler.py +++ b/packages/python/src/armonik/worker/taskhandler.py @@ -2,11 +2,12 @@ import os from typing import Optional, Dict, List, Tuple, Union, cast -from ..common import TaskOptions, TaskDefinition, Task -from ..protogen.common.agent_common_pb2 import CreateTaskRequest, CreateResultsMetaDataRequest, CreateResultsMetaDataResponse, NotifyResultDataRequest +from ..common import TaskOptions, TaskDefinition, Task, Result +from ..protogen.common.agent_common_pb2 import CreateTaskRequest, CreateResultsMetaDataRequest, CreateResultsMetaDataResponse, NotifyResultDataRequest, CreateResultsRequest, CreateResultsResponse, SubmitTasksRequest, SubmitTasksResponse from ..protogen.common.objects_pb2 import TaskRequest, DataChunk, InitTaskRequest, TaskRequestHeader, Configuration from ..protogen.worker.agent_service_pb2_grpc import AgentStub from ..protogen.common.worker_common_pb2 import ProcessRequest +from ..common.helpers import batched class TaskHandler: @@ -67,21 +68,97 @@ def create_tasks(self, tasks: List[TaskDefinition], task_options: Optional[TaskO raise Exception("Unknown value") return tasks_created, tasks_creation_failed - def send_result(self, key: str, data: Union[bytes, bytearray]) -> None: - """ Send task result + def submit_tasks(self, tasks: List[TaskDefinition], default_task_options: Optional[TaskOptions] = None, batch_size: Optional[int] = 100) -> None: + """Submit tasks to the agent. Args: - key: Result key - data: Result data + tasks: List of task definitions + default_task_options: Default Task Options used if a task has its options not set + batch_size: Batch size for submission """ - with open(os.path.join(self.data_folder, key), "wb") as f: - f.write(data) + for tasks_batch in batched(tasks, batch_size): + request = SubmitTasksRequest( + session_id=self.session_id, + task_options=default_task_options.to_message(), + communication_token=self.token + ) + + task_creations = [] + + for t in tasks_batch: + task_creation = SubmitTasksRequest.TaskCreation( + expected_output_keys=t.expected_output_ids, + payload_id=t.payload_id, + data_dependencies=t.data_dependencies if t.data_dependencies else None, + task_options=t.options.to_message() + ) + task_creations.append(task_creation) + + request.task_creations = task_creations - self._client.NotifyResultData(NotifyResultDataRequest(ids=[NotifyResultDataRequest.ResultIdentifier(session_id=self.session_id, result_id=key)], communication_token=self.token)) + self._client.SubmitTasks(request) - def get_results_ids(self, names: List[str]) -> Dict[str, str]: - return {r.name: r.result_id for r in cast(CreateResultsMetaDataResponse, self._client.CreateResultsMetaData(CreateResultsMetaDataRequest(results=[CreateResultsMetaDataRequest.ResultCreate(name=n) for n in names], session_id=self.session_id, communication_token=self.token))).results} + def send_results(self, result_data: Dict[str, bytes | bytearray]) -> None: + """Send results. + + Args: + result_data: A dictionnary mapping each result ID to its data. + """ + for result_id, result_data in result_data.items(): + with open(os.path.join(self.data_folder, result_id), "wb") as f: + f.write(result_data) + + request = NotifyResultDataRequest( + ids=[NotifyResultDataRequest.ResultIdentifier(session_id=self.session_id, result_id=result_id) for result_id in result_data.keys()], + communication_token=self.token + ) + self._client.NotifyResultData(request) + def create_result_metadata(self, result_names: List[str], batch_size: int = 100) -> Dict[str, List[Result]]: + """ + Create the metadata of multiple results at once. + Data have to be uploaded separately. + + Args: + result_names: The names of the results to create. + batch_size: Batch size for querying. + + Return: + A dictionnary mapping each result name to its result summary. + """ + results = {} + for result_names_batch in batched(result_names, batch_size): + request = CreateResultsMetaDataRequest( + results=[CreateResultsMetaDataRequest.ResultCreate(name=result_name) for result_name in result_names], + session_id=self.session_id, + communication_token=self.token + ) + response: CreateResultsMetaDataResponse = self._client.CreateResultsMetaData(request) + for result_message in response.results: + results[result_message.name] = Result.from_message(result_message) + return results + + def create_results(self, results_data: Dict[str, bytes], batch_size: int = 1) -> Dict[str, Result]: + """Create one result with data included in the request. + + Args: + results_data: A dictionnary mapping the result names to their actual data. + batch_size: Batch size for querying. + + Return: + A dictionnary mappin each result name to its corresponding result summary. + """ + results = {} + for results_data_batch in batched(results_data, batch_size): + request = CreateResultsRequest( + results=[CreateResultsRequest.ResultCreate(name=name, data=data) for name, data in results_data_batch.items()], + session_id=self.session_id, + communication_token=self.token + ) + response: CreateResultsResponse = self._client.CreateResults(request) + for message in response.results: + results[message.name] = Result.from_message(message) + return results def _to_request_stream_internal(request, communication_token, is_last, chunk_max_size): req = CreateTaskRequest( From 450cd6868a663ab16ae47dbdd041d4aa18b2d6f4 Mon Sep 17 00:00:00 2001 From: qdelamea Date: Fri, 22 Dec 2023 16:02:48 +0100 Subject: [PATCH 04/15] Add tests for Tasks Sessions and Versions services --- packages/python/.gitignore | 1 + packages/python/proto2python.sh | 2 +- .../python/src/armonik/client/__init__.py | 10 ++ packages/python/src/armonik/client/events.py | 28 +++- packages/python/src/armonik/client/results.py | 4 +- .../python/src/armonik/client/sessions.py | 9 +- packages/python/src/armonik/client/tasks.py | 76 +++++----- .../python/src/armonik/common/__init__.py | 37 ++++- packages/python/src/armonik/common/helpers.py | 20 +-- packages/python/src/armonik/common/objects.py | 2 +- .../python/src/armonik/worker/__init__.py | 6 + .../python/src/armonik/worker/taskhandler.py | 2 + packages/python/tests/conftest.py | 143 ++++++++++++++++++ packages/python/tests/test_sessions.py | 63 ++++++++ packages/python/tests/test_tasks.py | 96 ++++++++++++ packages/python/tests/test_versions.py | 39 +++++ 16 files changed, 477 insertions(+), 61 deletions(-) create mode 100644 packages/python/tests/conftest.py create mode 100644 packages/python/tests/test_sessions.py create mode 100644 packages/python/tests/test_tasks.py create mode 100644 packages/python/tests/test_versions.py diff --git a/packages/python/.gitignore b/packages/python/.gitignore index 53df0df6c..6fe9f8e9f 100644 --- a/packages/python/.gitignore +++ b/packages/python/.gitignore @@ -4,3 +4,4 @@ build/ *.egg-info **/_version.py **/.pytest_cache +.ruff_cache diff --git a/packages/python/proto2python.sh b/packages/python/proto2python.sh index c2fe45244..6fe79c0e7 100755 --- a/packages/python/proto2python.sh +++ b/packages/python/proto2python.sh @@ -33,7 +33,7 @@ python -m pip install --upgrade pip python -m venv $PYTHON_VENV source $PYTHON_VENV/bin/activate # We need to fix grpc to 1.56 until this bug is solved : https://github.com/grpc/grpc/issues/34305 -python -m pip install build grpcio==1.56.2 grpcio-tools==1.56.2 click pytest setuptools_scm[toml] +python -m pip install build grpcio==1.56.2 grpcio-tools==1.56.2 click pytest setuptools_scm[toml] ruff requests unset proto_files for proto in ${armonik_worker_files[@]}; do diff --git a/packages/python/src/armonik/client/__init__.py b/packages/python/src/armonik/client/__init__.py index 79391671e..8d1c923bc 100644 --- a/packages/python/src/armonik/client/__init__.py +++ b/packages/python/src/armonik/client/__init__.py @@ -3,3 +3,13 @@ from .submitter import ArmoniKSubmitter from .tasks import ArmoniKTasks from .results import ArmoniKResults +from .versions import ArmoniKVersions + +__all__ = [ + 'ArmoniKPartitions', + 'ArmoniKSessions', + 'ArmoniKSubmitter', + 'ArmoniKTasks', + 'ArmoniKResults', + "ArmoniKVersions", +] diff --git a/packages/python/src/armonik/client/events.py b/packages/python/src/armonik/client/events.py index 4fc96d0bc..440f6a382 100644 --- a/packages/python/src/armonik/client/events.py +++ b/packages/python/src/armonik/client/events.py @@ -2,6 +2,7 @@ from grpc import Channel +from .results import ArmoniKResults from ..common import EventTypes, Filter, NewTaskEvent, NewResultEvent, ResultOwnerUpdateEvent, ResultStatusUpdateEvent, TaskStatusUpdateEvent, ResultStatus from .results import ResultFieldFilter from ..protogen.client.events_service_pb2_grpc import EventsStub @@ -26,8 +27,9 @@ def __init__(self, grpc_channel: Channel): grpc_channel: gRPC channel to use """ self._client = EventsStub(grpc_channel) + self._results_client = ArmoniKResults(grpc_channel) - def get_events(self, session_id: str, event_types: List[EventTypes], event_handlers: List[Callable[[str, Any], bool]], task_filter: Filter, result_filter: Filter) -> None: + def get_events(self, session_id: str, event_types: List[EventTypes], event_handlers: List[Callable[[str, Any], bool]], task_filter: Filter | None = None, result_filter: Filter | None = None) -> None: """Get events that represents updates of result and tasks data. Args: @@ -43,7 +45,7 @@ def get_events(self, session_id: str, event_types: List[EventTypes], event_handl """ request = EventSubscriptionRequest( session_id=session_id, - tasks_filters=cast(rawTaskFilters, task_filter.to_disjunction().to_message()), +tasks_filters=cast(rawTaskFilters, task_filter.to_disjunction().to_message()), results_filters=cast(rawResultFilters, result_filter.to_disjunction().to_message()), returned_events=event_types ) @@ -53,11 +55,29 @@ def get_events(self, session_id: str, event_types: List[EventTypes], event_handl if any([event_handler(session_id, EventTypes.from_string(event_type), self._events_obj_mapping[event_type].from_raw_event(getattr(message, event_type))) for event_handler in event_handlers]): break - def wait_for_result_availability(self, result_id: str, session_id: str) -> bool: + def wait_for_result_availability(self, result_id: str, session_id: str) -> None: + """Wait until a result is ready i.e its status updates to COMPLETED. + + Args: + result_id: The ID of the result. + session_id: The ID of the session. + + Raises: + RuntimeError: If the result status is ABORTED. + """ def handler(session_id, event_type, event): if not isinstance(event, ResultStatusUpdateEvent): raise ValueError("Handler should receive event of type 'ResultStatusUpdateEvent'.") if event.status == ResultStatus.COMPLETED: return False + elif event.status == ResultStatus.ABORTED: + raise RuntimeError(f"Result {result.name} with ID {result_id} is aborted.") return True - self.get_events(session_id, [EventTypes.RESULT_STATUS_UPDATE], lambda _: False, result_filter=(ResultFieldFilter.RESULT_ID == result_id)) + + result = self._results_client.get_result(result_id) + if result.status == ResultStatus.COMPLETED: + return + elif result.status == ResultStatus.ABORTED: + raise RuntimeError(f"Result {result.name} with ID {result_id} is aborted.") + + self.get_events(session_id, [EventTypes.RESULT_STATUS_UPDATE], [handler], result_filter=(ResultFieldFilter.RESULT_ID == result_id)) diff --git a/packages/python/src/armonik/client/results.py b/packages/python/src/armonik/client/results.py index 81e1d9554..74c5f4c6e 100644 --- a/packages/python/src/armonik/client/results.py +++ b/packages/python/src/armonik/client/results.py @@ -5,7 +5,7 @@ from ..protogen.client.results_service_pb2_grpc import ResultsStub from ..protogen.common.results_common_pb2 import CreateResultsMetaDataRequest, CreateResultsMetaDataResponse, ListResultsRequest, ListResultsResponse, GetOwnerTaskIdRequest, GetOwnerTaskIdResponse, CreateResultsMetaDataRequest, CreateResultsMetaDataResponse, CreateResultsRequest, CreateResultsResponse, ResultsServiceConfigurationResponse, DeleteResultsDataRequest, DeleteResultsDataResponse, UploadResultDataRequest, UploadResultDataResponse, DownloadResultDataRequest, DownloadResultDataResponse, GetResultRequest, GetResultResponse -from ..protogen.common.results_filters_pb2 import Filters as rawFilters, FiltersAnd as rawFilterAnd, FilterField as rawFilterField, FilterStatus as rawFilterStatus, FilterString as rawFilterString +from ..protogen.common.results_filters_pb2 import Filters as rawFilters, FiltersAnd as rawFilterAnd, FilterField as rawFilterField, FilterStatus as rawFilterStatus from ..protogen.common.results_fields_pb2 import ResultField from ..protogen.common.objects_pb2 import Empty from ..common.filter import StringFilter, StatusFilter, DateFilter, NumberFilter, Filter @@ -17,7 +17,7 @@ class ResultFieldFilter: STATUS = StatusFilter(ResultField(result_raw_field=ResultRawField(field=RESULT_RAW_ENUM_FIELD_STATUS)), rawFilters, rawFilterAnd, rawFilterField, rawFilterStatus) - RESULT_ID = StringFilter(ResultField(result_raw_field=ResultRawField(field=RESULT_RAW_ENUM_FIELD_RESULT_ID)), rawFilters, rawFilterAnd, rawFilterField, rawFilterString) + RESULT_ID = StringFilter(ResultField(result_raw_field=ResultRawField(field=RESULT_RAW_ENUM_FIELD_RESULT_ID)), rawFilters, rawFilterAnd, rawFilterField) class ArmoniKResults: def __init__(self, grpc_channel: Channel): diff --git a/packages/python/src/armonik/client/sessions.py b/packages/python/src/armonik/client/sessions.py index b6795334b..a15fddb01 100644 --- a/packages/python/src/armonik/client/sessions.py +++ b/packages/python/src/armonik/client/sessions.py @@ -73,7 +73,7 @@ def get_session(self, session_id: str): response: GetSessionResponse = self._client.GetSession(request) return Session.from_message(response.session) - def list_sessions(self, session_filter: Filter , page: int = 0, page_size: int = 1000, sort_field: Filter = SessionFieldFilter.STATUS, sort_direction: SortDirection = Direction.ASC) -> Tuple[int, List[Session]]: + def list_sessions(self, session_filter: Filter | None = None, page: int = 0, page_size: int = 1000, sort_field: Filter = SessionFieldFilter.STATUS, sort_direction: SortDirection = Direction.ASC) -> Tuple[int, List[Session]]: """ List sessions @@ -92,11 +92,12 @@ def list_sessions(self, session_filter: Filter , page: int = 0, page_size: int = request = ListSessionsRequest( page=page, page_size=page_size, - filters=cast(rawFilters, session_filter.to_disjunction().to_message()), sort=ListSessionsRequest.Sort(field=cast(SessionField, sort_field.field), direction=sort_direction), ) - list_response : ListSessionsResponse = self._client.ListSessions(request) - return list_response.total, [Session.from_message(s) for s in list_response.sessions] + if session_filter: + request.filters = cast(rawFilters, session_filter.to_disjunction().to_message()), + response : ListSessionsResponse = self._client.ListSessions(request) + return response.total, [Session.from_message(s) for s in response.sessions] def cancel_session(self, session_id: str) -> None: """Cancel a session diff --git a/packages/python/src/armonik/client/tasks.py b/packages/python/src/armonik/client/tasks.py index 2905c7477..26879991e 100644 --- a/packages/python/src/armonik/client/tasks.py +++ b/packages/python/src/armonik/client/tasks.py @@ -5,11 +5,11 @@ from ..common import Task, Direction, TaskDefinition, TaskOptions, TaskStatus from ..common.filter import StringFilter, StatusFilter, DateFilter, NumberFilter, Filter, DurationFilter from ..protogen.client.tasks_service_pb2_grpc import TasksStub -from ..protogen.common.tasks_common_pb2 import GetTaskRequest, GetTaskResponse, ListTasksRequest, ListTasksDetailedResponse, CancelTasksRequest, CancelTasksResponse, GetResultIdsRequest, GetResultIdsResponse, SubmitTasksRequest, SubmitTasksResponse, CountTasksByStatusRequest, CountTasksByStatusResponse +from ..protogen.common.tasks_common_pb2 import GetTaskRequest, GetTaskResponse, ListTasksRequest, ListTasksDetailedResponse, CancelTasksRequest, CancelTasksResponse, GetResultIdsRequest, GetResultIdsResponse, SubmitTasksRequest, SubmitTasksResponse, CountTasksByStatusRequest, CountTasksByStatusResponse, ListTasksResponse from ..protogen.common.tasks_filters_pb2 import Filters as rawFilters, FiltersAnd as rawFilterAnd, FilterField as rawFilterField, FilterStatus as rawFilterStatus from ..protogen.common.sort_direction_pb2 import SortDirection from ..protogen.common.tasks_fields_pb2 import * -from ..utils import batched +from ..common.helpers import batched class TaskFieldFilter: @@ -77,7 +77,7 @@ def get_task(self, task_id: str) -> Task: task_response: GetTaskResponse = self._client.GetTask(GetTaskRequest(task_id=task_id)) return Task.from_message(task_response.task) - def list_tasks(self, task_filter: Filter, with_errors: bool = False, page: int = 0, page_size: int = 1000, sort_field: Filter = TaskFieldFilter.TASK_ID, sort_direction: SortDirection = Direction.ASC) -> Tuple[int, List[Task]]: + def list_tasks(self, task_filter: Filter | None = None, with_errors: bool = False, page: int = 0, page_size: int = 1000, sort_field: Filter = TaskFieldFilter.TASK_ID, sort_direction: SortDirection = Direction.ASC, detailed: bool = True) -> Tuple[int, List[Task]]: """List tasks If the total returned exceeds the requested page size, you may want to use this function again and ask for subsequent pages. @@ -89,6 +89,7 @@ def list_tasks(self, task_filter: Filter, with_errors: bool = False, page: int = page_size: size of a page, defaults to 1000 sort_field: field on which to sort the resulting list, defaults to the task_id sort_direction: direction of the sort, defaults to ascending + detailed: Wether to retrieve the detailed description of the task. Returns: A tuple containing : @@ -96,14 +97,19 @@ def list_tasks(self, task_filter: Filter, with_errors: bool = False, page: int = - The obtained list of tasks """ request = ListTasksRequest(page=page, - page_size=page_size, - filters=cast(rawFilters, task_filter.to_disjunction().to_message()), - sort=ListTasksRequest.Sort(field=cast(TaskField, sort_field.field), direction=sort_direction), - with_errors=with_errors) - list_response: ListTasksDetailedResponse = self._client.ListTasksDetailed(request) - return list_response.total, [Task.from_message(t) for t in list_response.tasks] - - def cancel_tasks(self, task_ids: List[str], chunk_size: Optional[int] = 500) -> List[Task]: + page_size=page_size, + sort=ListTasksRequest.Sort(field=cast(TaskField, sort_field.field), direction=sort_direction), + with_errors=with_errors + ) + if task_filter: + request.filters = cast(rawFilters, task_filter.to_disjunction().to_message()) + if detailed: + response: ListTasksDetailedResponse = self._client.ListTasksDetailed(request) + return response.total, [Task.from_message(t) for t in response.tasks] + response: ListTasksResponse = self._client.ListTasks(request) + return response.total, [Task.from_message(t) for t in response.tasks] + + def cancel_tasks(self, task_ids: List[str], chunk_size: Optional[int] = 500): """Cancel tasks. Args: @@ -113,12 +119,9 @@ def cancel_tasks(self, task_ids: List[str], chunk_size: Optional[int] = 500) -> Return: The list of cancelled tasks. """ - cancelled_tasks = [] for task_id_batch in batched(task_ids, chunk_size): request = CancelTasksRequest(task_ids=task_id_batch) - cancel_tasks_response: CancelTasksResponse = self._client.CancelTasks(request) - cancelled_tasks.extend([Task.from_message(t) for t in cancel_tasks_response.tasks]) - return cancelled_tasks + self._client.CancelTasks(request) def get_result_ids(self, task_ids: List[str], chunk_size: Optional[int] = 500) -> Dict[str, List[str]]: """Get result IDs of a list of tasks. @@ -132,14 +135,14 @@ def get_result_ids(self, task_ids: List[str], chunk_size: Optional[int] = 500) - """ tasks_result_ids = {} - for task_id_batch in batched(task_ids, chunk_size): - request = GetResultIdsRequest(task_ids=task_id_batch) + for task_ids_batch in batched(task_ids, chunk_size): + request = GetResultIdsRequest(task_id=task_ids_batch) result_ids_response: GetResultIdsResponse = self._client.GetResultIds(request) for t in result_ids_response.task_results: - tasks_result_ids[t.task_id] = [id for id in t.result_ids] + tasks_result_ids[t.task_id] = list(t.result_ids) return tasks_result_ids - def count_tasks_by_status(self, task_filter: List[Filter]) -> Dict[TaskStatus, int]: + def count_tasks_by_status(self, task_filter: Filter | None = None) -> Dict[TaskStatus, int]: """Get number of tasks by status. Args: @@ -148,11 +151,14 @@ def count_tasks_by_status(self, task_filter: List[Filter]) -> Dict[TaskStatus, i Return: A dictionnary mapping each status to the number of filtered tasks. """ - request = CountTasksByStatusRequest(filters=cast(rawFilters, task_filter.to_disjunction().to_message())) + if task_filter: + request = CountTasksByStatusRequest(filters=cast(rawFilters, task_filter.to_disjunction().to_message())) + else: + request = CountTasksByStatusRequest() count_tasks_by_status_response: CountTasksByStatusResponse = self._client.CountTasksByStatus(request) return {TaskStatus(status_count.status): status_count.count for status_count in count_tasks_by_status_response.status} - def submit_tasks(self, session_id: str, tasks: List[TaskDefinition], default_task_options: Optional[TaskOptions] = None, chunk_size: Optional[int] = 100) -> List[Task]: + def submit_tasks(self, session_id: str, tasks: List[TaskDefinition], default_task_options: Optional[TaskOptions | None] = None, chunk_size: Optional[int] = 100) -> List[Task]: """Submit tasks to ArmoniK. Args: @@ -165,32 +171,24 @@ def submit_tasks(self, session_id: str, tasks: List[TaskDefinition], default_tas Tuple containing the list of successfully sent tasks, and the list of submission errors if any """ - - tasks_submitted = [] - for tasks_batch in batched(tasks, chunk_size): - request = SubmitTasksRequest(session_id=session_id, task_options=default_task_options.to_message()) - task_creations = [] for t in tasks_batch: task_creation = SubmitTasksRequest.TaskCreation( expected_output_keys=t.expected_output_ids, payload_id=t.payload_id, - data_dependencies=t.data_dependencies if t.data_dependencies else None, - task_options=t.options.to_message() + data_dependencies=t.data_dependencies, ) + if t.options: + task_creation.task_options = t.options.to_message() task_creations.append(task_creation) - request.task_creations = task_creations - - submit_tasks_reponse:SubmitTasksResponse = self._client.SubmitTasks(request) + request = SubmitTasksRequest( + session_id=session_id, + task_creations=task_creations + ) + if default_task_options: + request.task_options = default_task_options.to_message() - for task_info in submit_tasks_reponse.task_infos: - tasks_submitted.append(Task(id=task_info.task_id, - session_id=session_id, - expected_output_ids=[k for k in task_info.expected_output_ids], - data_dependencies=[k for k in task_info.data_dependencies]), - payload_id=task_info.payload_id - ) - return tasks_submitted + self._client.SubmitTasks(request) diff --git a/packages/python/src/armonik/common/__init__.py b/packages/python/src/armonik/common/__init__.py index fd1c14e79..836872785 100644 --- a/packages/python/src/armonik/common/__init__.py +++ b/packages/python/src/armonik/common/__init__.py @@ -1,5 +1,40 @@ -from .helpers import datetime_to_timestamp, timestamp_to_datetime, duration_to_timedelta, timedelta_to_duration, get_task_filter +from .helpers import ( + datetime_to_timestamp, + timestamp_to_datetime, + duration_to_timedelta, + timedelta_to_duration, + get_task_filter, + batched +) from .objects import Task, TaskDefinition, TaskOptions, Output, ResultAvailability, Session, Result, Partition from .enumwrapper import HealthCheckStatus, TaskStatus, Direction, SessionStatus, ResultStatus, EventTypes from .events import * from .filter import Filter, StringFilter, StatusFilter + +__all__ = [ + 'datetime_to_timestamp', + 'timestamp_to_datetime', + 'duration_to_timedelta', + 'timedelta_to_duration', + 'get_task_filter', + 'batched', + 'Task', + 'TaskDefinition', + 'TaskOptions', + 'Output', + 'ResultAvailability', + 'Session', + 'Result', + 'Partition', + 'HealthCheckStatus', + 'TaskStatus', + 'Direction', + 'SessionStatus', + 'ResultStatus', + 'EventTypes', + # Include all names from events module + # Add names from filter module + 'Filter', + 'StringFilter', + 'StatusFilter', +] diff --git a/packages/python/src/armonik/common/helpers.py b/packages/python/src/armonik/common/helpers.py index 22d389fca..3a9cb8324 100644 --- a/packages/python/src/armonik/common/helpers.py +++ b/packages/python/src/armonik/common/helpers.py @@ -113,13 +113,15 @@ def batched(iterable: Iterable[T], n: int) -> Iterable[List[T]]: A generator yielding batches of elements from the input iterable. """ it = iter(iterable) - while True: - batch = [] - try: - for i in range(n): - batch.append(next(it)) - except StopIteration: - if len(batch) > 0: - yield batch - break + + sentinel = object() + batch = [] + c = next(it, sentinel) + while c is not sentinel: + batch.append(c) + if len(batch) == n: + yield batch + batch.clear() + c = next(it, sentinel) + if len(batch) > 0: yield batch diff --git a/packages/python/src/armonik/common/objects.py b/packages/python/src/armonik/common/objects.py index 745ba7e4a..8e580dbdc 100644 --- a/packages/python/src/armonik/common/objects.py +++ b/packages/python/src/armonik/common/objects.py @@ -74,7 +74,7 @@ class TaskDefinition: payload_id: str expected_output_ids: List[str] = field(default_factory=list) data_dependencies: List[str] = field(default_factory=list) - options: Optional[TaskOptions] = field(default_factory=TaskOptions) + options: Optional[TaskOptions] = None def __post_init__(self): if len(self.expected_output_ids) <= 0: diff --git a/packages/python/src/armonik/worker/__init__.py b/packages/python/src/armonik/worker/__init__.py index 508d49ae5..78a61174c 100644 --- a/packages/python/src/armonik/worker/__init__.py +++ b/packages/python/src/armonik/worker/__init__.py @@ -1,3 +1,9 @@ from .worker import ArmoniKWorker from .taskhandler import TaskHandler from .seqlogger import ClefLogger + +__all__ = [ + 'ArmoniKWorker', + 'TaskHandler', + 'ClefLogger', +] diff --git a/packages/python/src/armonik/worker/taskhandler.py b/packages/python/src/armonik/worker/taskhandler.py index b321f651b..ae5d79c2c 100644 --- a/packages/python/src/armonik/worker/taskhandler.py +++ b/packages/python/src/armonik/worker/taskhandler.py @@ -1,5 +1,6 @@ from __future__ import annotations import os +from deprecation import deprecated from typing import Optional, Dict, List, Tuple, Union, cast from ..common import TaskOptions, TaskDefinition, Task, Result @@ -32,6 +33,7 @@ def __init__(self, request: ProcessRequest, agent_client: AgentStub): with open(os.path.join(self.data_folder, dd), "rb") as f: self.data_dependencies[dd] = f.read() + @deprecated(deprecated_in="3.15.0", details="Use submit_tasks and instead and create the payload using create_result_metadata and send_result") def create_tasks(self, tasks: List[TaskDefinition], task_options: Optional[TaskOptions] = None) -> Tuple[List[Task], List[str]]: """Create new tasks for ArmoniK diff --git a/packages/python/tests/conftest.py b/packages/python/tests/conftest.py new file mode 100644 index 000000000..dc163d0ea --- /dev/null +++ b/packages/python/tests/conftest.py @@ -0,0 +1,143 @@ +import grpc +import pytest +import requests + +from armonik.client import ArmoniKResults, ArmoniKSubmitter, ArmoniKTasks, ArmoniKSessions, ArmoniKPartitions, ArmoniKVersions +from typing import Union + + +# Mock server endpoints used for the tests. +grpc_endpoint = "localhost:5001" +calls_endpoint = "http://localhost:5000/calls.json" +reset_endpoint = "http://localhost:5000/reset" + + +@pytest.fixture(scope="session", autouse=True) +def clean_up(request): + """ + This fixture runs at the session scope and is automatically used before and after + running all the tests. It resets the mocking gRPC server counters to maintain a + clean testing environment. + + Yields: + None: This fixture is used as a context manager, and the test code runs between + the 'yield' statement and the cleanup code. + + Raises: + requests.exceptions.HTTPError: If an error occurs when attempting to reset + the mocking gRPC server counters. + """ + # Run all the tests + yield + + # Teardown code + try: + response = requests.post(reset_endpoint) + response.raise_for_status() + print("\nMock server resetted.") + except requests.exceptions.HTTPError as e: + print("An error occurred when resetting the server: " + str(e)) + + +def get_client(client_name: str, endpoint: str = grpc_endpoint) -> [ArmoniKResults, ArmoniKSubmitter, ArmoniKTasks, ArmoniKSessions, ArmoniKPartitions, ArmoniKVersions]: + """ + Get the ArmoniK client instance based on the specified service name. + + Args: + client_name (str): The name of the ArmoniK client to retrieve. + endpoint (str, optional): The gRPC server endpoint. Defaults to grpc_endpoint. + + Returns: + Union[ArmoniKResults, ArmoniKSubmitter, ArmoniKTasks, ArmoniKSessions, ARmoniKPartitions]: + An instance of the specified ArmoniK client. + + Raises: + ValueError: If the specified service name is not recognized. + + Example: + >>> result_service = get_service("Results") + >>> submitter_service = get_service("Submitter", "custom_endpoint") + """ + channel = grpc.insecure_channel(endpoint).__enter__() + match client_name: + case "Results": + return ArmoniKResults(channel) + case "Submitter": + return ArmoniKSubmitter(channel) + case "Tasks": + return ArmoniKTasks(channel) + case "Sessions": + return ArmoniKSessions(channel) + case "Partitions": + return ARmoniKPartitions(channel) + case "Versions": + return ArmoniKVersions(channel) + case _: + raise ValueError("Unknown service name: " + str(service_name)) + + +def rpc_called(service_name: str, rpc_name: str, n_calls: int = 1, endpoint: str = calls_endpoint) -> bool: + """Check if a remote procedure call (RPC) has been made a specified number of times. + This function uses ArmoniK.Api.Mock. It just gets the '/calls.json' endpoint. + + Args: + service_name (str): The name of the service providing the RPC. + rpc_name (str): The name of the specific RPC to check for the number of calls. + n_calls (int, optional): The expected number of times the RPC should have been called. Default is 1. + endpoint (str, optional): The URL of the remote service providing RPC information. Default to + calls_endpoint. + + Returns: + bool: True if the specified RPC has been called the expected number of times, False otherwise. + + Raises: + requests.exceptions.RequestException: If an error occurs when requesting ArmoniK.Api.Mock. + + Example: + >>> rpc_called('http://localhost:5000/calls.json', 'Versions', 'ListVersionss', 0) + True + """ + response = requests.get(endpoint) + response.raise_for_status() + data = response.json() + + # Check if the RPC has been called n_calls times + if data[service_name][rpc_name] == n_calls: + return True + return False + + +def all_rpc_called(service_name: str, endpoint: str = calls_endpoint) -> bool: + """ + Check if all remote procedure calls (RPCs) in a service have been made at least once. + This function uses ArmoniK.Api.Mock. It just gets the '/calls.json' endpoint. + + Args: + service_name (str): The name of the service containing the RPC information in the response. + endpoint (str, optional): The URL of the remote service providing RPC information. Default is + the value of calls_endpoint. + + Returns: + bool: True if all RPCs in the specified service have been called at least once, False otherwise. + + Raises: + requests.exceptions.RequestException: If an error occurs when requesting ArmoniK.Api.Mock. + + Example: + >>> all_rpc_called('http://localhost:5000/calls.json', 'Versions') + False + """ + response = requests.get(endpoint) + response.raise_for_status() + data = response.json() + + missing_rpcs = [] + + # Check if all RPCs in the service have been called at least once + for rpc_name, rpc_num_calls in data[service_name].items(): + if rpc_num_calls == 0: + missing_rpcs.append(rpc_name) + if missing_rpcs: + print(f"RPCs not implemented in {service_name} service: {missing_rpcs}.") + return False + return True diff --git a/packages/python/tests/test_sessions.py b/packages/python/tests/test_sessions.py new file mode 100644 index 000000000..c12e5d9e0 --- /dev/null +++ b/packages/python/tests/test_sessions.py @@ -0,0 +1,63 @@ +import datetime + +from .conftest import all_rpc_called, rpc_called, get_client +from armonik.client import ArmoniKSessions +from armonik.common import Session, SessionStatus, TaskOptions + + +class TestArmoniKSessions: + + def test_create_session(self): + sessions_client: ArmoniKSessions = get_client("Sessions") + default_task_options = TaskOptions( + max_duration=datetime.timedelta(seconds=1), + priority=1, + max_retries=1 + ) + session_id = sessions_client.create_session(default_task_options) + + assert rpc_called("Sessions", "CreateSession") + assert session_id == "session-id" + + def test_get_session(self): + sessions_client: ArmoniKSessions = get_client("Sessions") + session = sessions_client.get_session("session-id") + + assert rpc_called("Sessions", "GetSession") + assert isinstance(session, Session) + assert session.session_id == 'session-id' + assert session.status == SessionStatus.CANCELLED + assert session.partition_ids == [] + assert session.options == TaskOptions( + max_duration=datetime.timedelta(0), + priority=0, + max_retries=0, + partition_id='', + application_name='', + application_version='', + application_namespace='', + application_service='', + engine_type='', + options={} + ) + assert session.created_at == datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc) + assert session.cancelled_at == datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc) + assert session.duration == datetime.timedelta(0) + + def test_list_session_no_filter(self): + sessions_client: ArmoniKSessions = get_client("Sessions") + num, sessions = sessions_client.list_sessions() + + assert rpc_called("Sessions", "ListSessions") + # TODO: Mock must be updated to return something and so that changes the following assertions + assert num == 0 + assert sessions == [] + + def test_cancel_session(self): + sessions_client: ArmoniKSessions = get_client("Sessions") + sessions_client.cancel_session("session-id") + + assert rpc_called("Sessions", "CancelSession") + + def test_service_fully_implemented(self): + assert all_rpc_called("Sessions") diff --git a/packages/python/tests/test_tasks.py b/packages/python/tests/test_tasks.py new file mode 100644 index 000000000..a48e66714 --- /dev/null +++ b/packages/python/tests/test_tasks.py @@ -0,0 +1,96 @@ +import datetime + +from .conftest import all_rpc_called, rpc_called, get_client +from armonik.client import ArmoniKTasks +from armonik.common import Task, TaskDefinition, TaskOptions, TaskStatus, Output + + +class TestArmoniKTasks: + + def test_get_task(self): + tasks_client: ArmoniKTasks = get_client("Tasks") + task = tasks_client.get_task("task-id") + + assert rpc_called("Tasks", "GetTask") + assert isinstance(task, Task) + assert task.id == 'task-id' + assert task.session_id == 'session-id' + assert task.data_dependencies == [] + assert task.expected_output_ids == [] + assert task.retry_of_ids == [] + assert task.status == TaskStatus.COMPLETED + assert task.payload_id is None + assert task.status_message == '' + assert task.options == TaskOptions( + max_duration=datetime.timedelta(seconds=1), + priority=1, + max_retries=1, + partition_id='partition-id', + application_name='application-name', + application_version='application-version', + application_namespace='application-namespace', + application_service='application-service', + engine_type='engine-type', + options={} + ) + assert task.created_at == datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc) + assert task.submitted_at == datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc) + assert task.started_at == datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc) + assert task.ended_at == datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc) + assert task.pod_ttl == datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc) + assert task.output == Output(error='') + assert task.pod_hostname == '' + assert task.received_at == datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc) + assert task.acquired_at == datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc) + + def test_list_tasks_detailed_no_filter(self): + tasks_client: ArmoniKTasks = get_client("Tasks") + num, tasks = tasks_client.list_tasks() + assert rpc_called("Tasks", "ListTasksDetailed") + # TODO: Mock must be updated to return something and so that changes the following assertions + assert num == 0 + assert tasks == [] + + def test_list_tasks_no_filter(self): + tasks_client: ArmoniKTasks = get_client("Tasks") + num, tasks = tasks_client.list_tasks(detailed=False) + assert rpc_called("Tasks", "ListTasks") + # TODO: Mock must be updated to return something and so that changes the following assertions + assert num == 0 + assert tasks == [] + + def test_cancel_tasks(self): + tasks_client: ArmoniKTasks = get_client("Tasks") + tasks = tasks_client.cancel_tasks(["task-id-1", "task-id-2"]) + + assert rpc_called("Tasks", "CancelTasks") + assert tasks is None + + def test_get_result_ids(self): + tasks_client: ArmoniKTasks = get_client("Tasks") + tasks_results = tasks_client.get_result_ids(["task-id-1", "task-id-2"]) + assert rpc_called("Tasks", "GetResultIds") + # TODO: Mock must be updated to return something and so that changes the following assertions + assert tasks_results == {} + + def test_count_tasks_by_status_no_filter(self): + tasks_client: ArmoniKTasks = get_client("Tasks") + count = tasks_client.count_tasks_by_status() + assert rpc_called("Tasks", "CountTasksByStatus") + # TODO: Mock must be updated to return something and so that changes the following assertions + assert count == {} + + def test_submit_tasks(self): + tasks_client: ArmoniKTasks = get_client("Tasks") + tasks = tasks_client.submit_tasks( + "session-id", + [TaskDefinition(payload_id="payload-id", + expected_output_ids=["result-id"], + data_dependencies=[])] + ) + assert rpc_called("Tasks", "SubmitTasks") + # TODO: Mock must be updated to return something and so that changes the following assertions + assert tasks is None + + def test_service_fully_implemented(self): + assert all_rpc_called("Tasks") diff --git a/packages/python/tests/test_versions.py b/packages/python/tests/test_versions.py new file mode 100644 index 000000000..e6e30db5d --- /dev/null +++ b/packages/python/tests/test_versions.py @@ -0,0 +1,39 @@ +import pytest + +from .conftest import all_rpc_called, rpc_called, get_client +from armonik.client import ArmoniKVersions + + +class TestArmoniKVersions: + + def test_list_versions(self): + """ + Test the list_versions method of ArmoniKVersions client. + + Args: + grpc_endpoint (str): The gRPC endpoint for the service mock. + calls_recap_endpoint (str): The endpoint for tracking RPC calls. + + Assertions: + Ensures that the RPC 'ListVersions' is called on the service 'Versions'. + Asserts that the 'core' version is returned with correct value. + Asserts that the 'api' version is returned with correct value. + """ + versions_client: ArmoniKVersions = get_client("Versions") + versions = versions_client.list_versions() + + assert rpc_called("Versions", "ListVersions") + assert versions["core"] == "Unknown" + assert versions["api"] == "1.0.0.0" + + def test_service_fully_implemented(self): + """ + Test if all RPCs in the 'Versions' service have been called at least once. + + Args: + calls_recap_endpoint (str): The endpoint for tracking RPC calls. + + Assertions: + Ensures that all RPCs in the 'Versions' service have been called at least once. + """ + assert all_rpc_called("Versions") From 91d2076c2053806746fe10bf445367f1e3134490 Mon Sep 17 00:00:00 2001 From: qdelamea Date: Fri, 22 Dec 2023 16:13:51 +0100 Subject: [PATCH 05/15] Add tests for Partitions service --- .../python/src/armonik/client/partitions.py | 5 +-- packages/python/tests/test_partitions.py | 31 +++++++++++++++++++ 2 files changed, 34 insertions(+), 2 deletions(-) create mode 100644 packages/python/tests/test_partitions.py diff --git a/packages/python/src/armonik/client/partitions.py b/packages/python/src/armonik/client/partitions.py index fb5e0421b..e0b0b0ada 100644 --- a/packages/python/src/armonik/client/partitions.py +++ b/packages/python/src/armonik/client/partitions.py @@ -29,7 +29,7 @@ def __init__(self, grpc_channel: Channel): """ self._client = PartitionsStub(grpc_channel) - def list_partitions(self, partition_filter: Filter, page: int = 0, page_size: int = 1000, sort_field: Filter = PartitionFieldFilter.PRIORITY, sort_direction: SortDirection = Direction.ASC) -> Tuple[int, List[Partition]]: + def list_partitions(self, partition_filter: Filter | None = None, page: int = 0, page_size: int = 1000, sort_field: Filter = PartitionFieldFilter.PRIORITY, sort_direction: SortDirection = Direction.ASC) -> Tuple[int, List[Partition]]: """List partitions based on a filter. Args: @@ -47,9 +47,10 @@ def list_partitions(self, partition_filter: Filter, page: int = 0, page_size: in request = ListPartitionsRequest( page=page, page_size=page_size, - filters=cast(rawFilters, partition_filter.to_disjunction().to_message()), sort=ListPartitionsRequest.Sort(field=cast(PartitionField, sort_field.field), direction=sort_direction), ) + if partition_filter: + request.filters = cast(rawFilters, partition_filter.to_disjunction().to_message()), response: ListPartitionsResponse = self._client.ListPartitions(request) return response.total, [Partition.from_message(p) for p in response.partitions] diff --git a/packages/python/tests/test_partitions.py b/packages/python/tests/test_partitions.py new file mode 100644 index 000000000..984bcd270 --- /dev/null +++ b/packages/python/tests/test_partitions.py @@ -0,0 +1,31 @@ +from .conftest import all_rpc_called, rpc_called, get_client +from armonik.client import ArmoniKPartitions +from armonik.common import Partition + + +class TestArmoniKPartitions: + + def test_get_partitions(self): + partitions_client: ArmoniKPartitions = get_client("Partitions") + partition = partitions_client.get_partition("partition-id") + + assert rpc_called("Partitions", "GetPartition") + assert isinstance(partition, Partition) + assert partition.id == 'partition-id' + assert partition.parent_partition_ids == [] + assert partition.pod_reserved == 1 + assert partition.pod_max == 1 + assert partition.pod_configuration == {} + assert partition.preemption_percentage == 0 + assert partition.priority == 1 + + def test_list_partitions_no_filter(self): + partitions_client: ArmoniKPartitions = get_client("Partitions") + num, partitions = partitions_client.list_partitions() + + assert rpc_called("Partitions", "GetPartition") + assert num == 0 + assert partitions == [] + + def test_service_fully_implemented(self): + assert all_rpc_called("Partitions") From 43eba44af642f3c9730f25fedf9434dc0d845eb2 Mon Sep 17 00:00:00 2001 From: qdelamea Date: Tue, 26 Dec 2023 11:08:46 +0100 Subject: [PATCH 06/15] Add tests for Results service --- packages/python/src/armonik/client/results.py | 16 ++-- packages/python/tests/test_partitions.py | 1 + packages/python/tests/test_results.py | 94 +++++++++++++++++++ 3 files changed, 104 insertions(+), 7 deletions(-) create mode 100644 packages/python/tests/test_results.py diff --git a/packages/python/src/armonik/client/results.py b/packages/python/src/armonik/client/results.py index 74c5f4c6e..026b77095 100644 --- a/packages/python/src/armonik/client/results.py +++ b/packages/python/src/armonik/client/results.py @@ -1,5 +1,6 @@ from __future__ import annotations from grpc import Channel +from deprecation import deprecated from typing import List, Dict, cast, Tuple @@ -28,10 +29,11 @@ def __init__(self, grpc_channel: Channel): """ self._client = ResultsStub(grpc_channel) + @deprecated(deprecated_in="3.15.0", details="Use create_result_metadata or create_result insted.") def get_results_ids(self, session_id: str, names: List[str]) -> Dict[str, str]: return {r.name : r.result_id for r in cast(CreateResultsMetaDataResponse, self._client.CreateResultsMetaData(CreateResultsMetaDataRequest(results=[CreateResultsMetaDataRequest.ResultCreate(name = n) for n in names], session_id=session_id))).results} - def list_results(self, result_filter: Filter, page: int = 0, page_size: int = 1000, sort_field: Filter = ResultFieldFilter.STATUS,sort_direction: SortDirection = Direction.ASC ) -> Tuple[int, List[Result]]: + def list_results(self, result_filter: Filter | None = None, page: int = 0, page_size: int = 1000, sort_field: Filter = ResultFieldFilter.STATUS,sort_direction: SortDirection = Direction.ASC ) -> Tuple[int, List[Result]]: """List results based on a filter. Args: @@ -48,9 +50,10 @@ def list_results(self, result_filter: Filter, page: int = 0, page_size: int = 10 request: ListResultsRequest = ListResultsRequest( page=page, page_size=page_size, - filters=cast(rawFilters, result_filter.to_disjunction().to_message()), sort=ListResultsRequest.Sort(field=cast(ResultField, sort_field.field), direction=sort_direction), ) + if result_filter: + request.filters = cast(rawFilters, result_filter.to_disjunction().to_message()), list_response: ListResultsResponse = self._client.ListResults(request) return list_response.total, [Result.from_message(r) for r in list_response.results] @@ -86,7 +89,7 @@ def get_owner_task_id(self, result_ids: List[str], session_id: str, batch_size: results[result_task.result_id] = result_task.task_id return results - def create_result_metadata(self, result_names: List[str], session_id: str, batch_size: int = 100) -> Dict[str, Result]: + def create_results_metadata(self, result_names: List[str], session_id: str, batch_size: int = 100) -> Dict[str, Result]: """Create the metadata of multiple results at once. Data have to be uploaded separately. @@ -121,9 +124,9 @@ def create_results(self, results_data: Dict[str, bytes], session_id: str, batch_ A dictionnary mappin each result name to its corresponding result summary. """ results = {} - for results_data_batch in batched(results_data, batch_size): + for results_names_batch in batched(results_data.keys(), batch_size): request = CreateResultsRequest( - results=[CreateResultsRequest.ResultCreate(name=name, data=data) for name, data in results_data_batch.items()], + results=[CreateResultsRequest.ResultCreate(name=name, data=results_data[name]) for name in results_names_batch], session_id=session_id ) response: CreateResultsResponse = self._client.CreateResults(request) @@ -191,8 +194,7 @@ def delete_result_data(self, result_ids: List[str], session_id: str, batch_size: result_id=result_ids_batch, session_id=session_id ) - response: DeleteResultsDataResponse = self._client.DeleteResultsData(request) - assert sorted(result_ids_batch) == sorted(response.result_id) + self._client.DeleteResultsData(request) def get_service_config(self) -> int: """Get the configuration of the service. diff --git a/packages/python/tests/test_partitions.py b/packages/python/tests/test_partitions.py index 984bcd270..b2ae6ecce 100644 --- a/packages/python/tests/test_partitions.py +++ b/packages/python/tests/test_partitions.py @@ -24,6 +24,7 @@ def test_list_partitions_no_filter(self): num, partitions = partitions_client.list_partitions() assert rpc_called("Partitions", "GetPartition") + # TODO: Mock must be updated to return something and so that changes the following assertions assert num == 0 assert partitions == [] diff --git a/packages/python/tests/test_results.py b/packages/python/tests/test_results.py new file mode 100644 index 000000000..05789e8a6 --- /dev/null +++ b/packages/python/tests/test_results.py @@ -0,0 +1,94 @@ +import datetime +import pytest + +from .conftest import all_rpc_called, rpc_called, get_client +from armonik.client import ArmoniKResults +from armonik.common import Result, ResultStatus + + +class TestArmoniKResults: + + def test_get_result(self): + results_client: ArmoniKResults = get_client("Results") + result = results_client.get_result("result-name") + + assert rpc_called("Results", "GetResult") + assert isinstance(result, Result) + assert result.session_id == 'session-id' + assert result.name == 'result-name' + assert result.owner_task_id == 'owner-task-id' + assert result.status == 2 + assert result.created_at == datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc) + assert result.completed_at == datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc) + assert result.result_id == 'result-id' + assert result.size == 0 + + def test_get_owner_task_id(self): + results_client: ArmoniKResults = get_client("Results") + results_tasks = results_client.get_owner_task_id(["result-id"], "session-id") + + assert rpc_called("Results", "GetOwnerTaskId") + # TODO: Mock must be updated to return something and so that changes the following assertions + assert results_tasks == {} + + def test_list_results_no_filter(self): + results_client: ArmoniKResults = get_client("Results") + num, results = results_client.list_results() + + assert rpc_called("Results", "ListResults") + # TODO: Mock must be updated to return something and so that changes the following assertions + assert num == 0 + assert results == [] + + def test_create_results_metadata(self): + results_client: ArmoniKResults = get_client("Results") + results = results_client.create_results_metadata(["result-name"], "session-id") + + assert rpc_called("Results", "CreateResultsMetaData") + # TODO: Mock must be updated to return something and so that changes the following assertions + assert results == {} + + def test_create_results(self): + results_client: ArmoniKResults = get_client("Results") + results = results_client.create_results({"result-name": b"test data"}, "session-id") + + assert rpc_called("Results", "CreateResults") + assert results == {} + + def test_upload_result_data(self): + results_client: ArmoniKResults = get_client("Results") + result = results_client.upload_result_data("result-name", "session-id", b"test data") + + assert rpc_called("Results", "UploadResultData") + assert result is None + + def test_download_result_data(self): + results_client: ArmoniKResults = get_client("Results") + data = results_client.download_result_data("result-name", "session-id") + + assert rpc_called("Results", "DownloadResultData") + assert data == b"" + + def test_delete_result_data(self): + results_client: ArmoniKResults = get_client("Results") + result = results_client.delete_result_data(["result-name"], "session-id") + + assert rpc_called("Results", "DeleteResultsData") + assert result is None + + def test_get_service_config(self): + results_client: ArmoniKResults = get_client("Results") + chunk_size = results_client.get_service_config() + + assert rpc_called("Results", "GetServiceConfiguration", 2) + assert isinstance(chunk_size, int) + assert chunk_size == 81920 + + def test_watch_results(self): + results_client: ArmoniKResults = get_client("Results") + with pytest.raises(NotImplementedError, match=""): + results_client.watch_results() + assert rpc_called("Results", "WatchResults", 0) + + def test_service_fully_implemented(self): + assert all_rpc_called("Results", missings=["WatchResults"]) From 9cb5cf317df2da52874b74bff471f2801e201df4 Mon Sep 17 00:00:00 2001 From: qdelamea Date: Tue, 26 Dec 2023 13:53:53 +0100 Subject: [PATCH 07/15] Add tests for TaskHandler --- packages/python/src/armonik/common/objects.py | 3 +- .../python/src/armonik/worker/taskhandler.py | 32 +++--- packages/python/tests/conftest.py | 36 ++++-- packages/python/tests/test_results.py | 16 +-- packages/python/tests/test_taskhandler.py | 108 ++++++++++++++++++ 5 files changed, 163 insertions(+), 32 deletions(-) create mode 100644 packages/python/tests/test_taskhandler.py diff --git a/packages/python/src/armonik/common/objects.py b/packages/python/src/armonik/common/objects.py index 8e580dbdc..1b5801f7a 100644 --- a/packages/python/src/armonik/common/objects.py +++ b/packages/python/src/armonik/common/objects.py @@ -71,7 +71,8 @@ def to_message(self): @dataclass() class TaskDefinition: - payload_id: str + payload_id: str = field(default_factory=str) + payload: bytes = field(default_factory=bytes) expected_output_ids: List[str] = field(default_factory=list) data_dependencies: List[str] = field(default_factory=list) options: Optional[TaskOptions] = None diff --git a/packages/python/src/armonik/worker/taskhandler.py b/packages/python/src/armonik/worker/taskhandler.py index ae5d79c2c..7d18ef7db 100644 --- a/packages/python/src/armonik/worker/taskhandler.py +++ b/packages/python/src/armonik/worker/taskhandler.py @@ -79,44 +79,46 @@ def submit_tasks(self, tasks: List[TaskDefinition], default_task_options: Option batch_size: Batch size for submission """ for tasks_batch in batched(tasks, batch_size): - request = SubmitTasksRequest( - session_id=self.session_id, - task_options=default_task_options.to_message(), - communication_token=self.token - ) - task_creations = [] for t in tasks_batch: task_creation = SubmitTasksRequest.TaskCreation( expected_output_keys=t.expected_output_ids, payload_id=t.payload_id, - data_dependencies=t.data_dependencies if t.data_dependencies else None, - task_options=t.options.to_message() + data_dependencies=t.data_dependencies ) + if t.options: + task_creation.task_options=t.options.to_message() task_creations.append(task_creation) - request.task_creations = task_creations + request = SubmitTasksRequest( + session_id=self.session_id, + communication_token=self.token, + task_creations=task_creations + ) + + if default_task_options: + request.task_options=default_task_options.to_message(), self._client.SubmitTasks(request) - def send_results(self, result_data: Dict[str, bytes | bytearray]) -> None: + def send_results(self, results_data: Dict[str, bytes | bytearray]) -> None: """Send results. Args: result_data: A dictionnary mapping each result ID to its data. """ - for result_id, result_data in result_data.items(): + for result_id, result_data in results_data.items(): with open(os.path.join(self.data_folder, result_id), "wb") as f: f.write(result_data) request = NotifyResultDataRequest( - ids=[NotifyResultDataRequest.ResultIdentifier(session_id=self.session_id, result_id=result_id) for result_id in result_data.keys()], + ids=[NotifyResultDataRequest.ResultIdentifier(session_id=self.session_id, result_id=result_id) for result_id in results_data.keys()], communication_token=self.token ) self._client.NotifyResultData(request) - def create_result_metadata(self, result_names: List[str], batch_size: int = 100) -> Dict[str, List[Result]]: + def create_results_metadata(self, result_names: List[str], batch_size: int = 100) -> Dict[str, List[Result]]: """ Create the metadata of multiple results at once. Data have to be uploaded separately. @@ -151,9 +153,9 @@ def create_results(self, results_data: Dict[str, bytes], batch_size: int = 1) -> A dictionnary mappin each result name to its corresponding result summary. """ results = {} - for results_data_batch in batched(results_data, batch_size): + for results_ids_batch in batched(results_data.keys(), batch_size): request = CreateResultsRequest( - results=[CreateResultsRequest.ResultCreate(name=name, data=data) for name, data in results_data_batch.items()], + results=[CreateResultsRequest.ResultCreate(name=name, data=results_data[name]) for name in results_ids_batch], session_id=self.session_id, communication_token=self.token ) diff --git a/packages/python/tests/conftest.py b/packages/python/tests/conftest.py index dc163d0ea..060c45ed6 100644 --- a/packages/python/tests/conftest.py +++ b/packages/python/tests/conftest.py @@ -1,23 +1,28 @@ import grpc +import os import pytest import requests from armonik.client import ArmoniKResults, ArmoniKSubmitter, ArmoniKTasks, ArmoniKSessions, ArmoniKPartitions, ArmoniKVersions -from typing import Union +from armonik.protogen.worker.agent_service_pb2_grpc import AgentStub +from typing import List # Mock server endpoints used for the tests. grpc_endpoint = "localhost:5001" calls_endpoint = "http://localhost:5000/calls.json" reset_endpoint = "http://localhost:5000/reset" +data_folder = os.getcwd() @pytest.fixture(scope="session", autouse=True) def clean_up(request): """ This fixture runs at the session scope and is automatically used before and after - running all the tests. It resets the mocking gRPC server counters to maintain a - clean testing environment. + running all the tests. It set up and teardown the testing environments by: + - creating dummy files before testing begins; + - clear files after testing; + - resets the mocking gRPC server counters to maintain a clean testing environment. Yields: None: This fixture is used as a context manager, and the test code runs between @@ -27,10 +32,20 @@ def clean_up(request): requests.exceptions.HTTPError: If an error occurs when attempting to reset the mocking gRPC server counters. """ + # Write dumm payload and data dependency to files for testing purposes + with open(os.path.join(data_folder, "payload-id"), "wb") as f: + f.write("payload".encode()) + with open(os.path.join(data_folder, "dd-id"), "wb") as f: + f.write("dd".encode()) + # Run all the tests yield - # Teardown code + # Remove the temporary files created for testing + os.remove(os.path.join(data_folder, "payload-id")) + os.remove(os.path.join(data_folder, "dd-id")) + + # Reset the mock server counters try: response = requests.post(reset_endpoint) response.raise_for_status() @@ -39,7 +54,7 @@ def clean_up(request): print("An error occurred when resetting the server: " + str(e)) -def get_client(client_name: str, endpoint: str = grpc_endpoint) -> [ArmoniKResults, ArmoniKSubmitter, ArmoniKTasks, ArmoniKSessions, ArmoniKPartitions, ArmoniKVersions]: +def get_client(client_name: str, endpoint: str = grpc_endpoint) -> [ArmoniKResults, ArmoniKSubmitter, ArmoniKTasks, ArmoniKSessions, ArmoniKPartitions, ArmoniKVersions, AgentStub]: """ Get the ArmoniK client instance based on the specified service name. @@ -48,7 +63,7 @@ def get_client(client_name: str, endpoint: str = grpc_endpoint) -> [ArmoniKResul endpoint (str, optional): The gRPC server endpoint. Defaults to grpc_endpoint. Returns: - Union[ArmoniKResults, ArmoniKSubmitter, ArmoniKTasks, ArmoniKSessions, ARmoniKPartitions]: + Union[ArmoniKResults, ArmoniKSubmitter, ArmoniKTasks, ArmoniKSessions, ARmoniKPartitions, AgentStub]: An instance of the specified ArmoniK client. Raises: @@ -69,9 +84,11 @@ def get_client(client_name: str, endpoint: str = grpc_endpoint) -> [ArmoniKResul case "Sessions": return ArmoniKSessions(channel) case "Partitions": - return ARmoniKPartitions(channel) + return ArmoniKPartitions(channel) case "Versions": return ArmoniKVersions(channel) + case "Agent": + return AgentStub(channel) case _: raise ValueError("Unknown service name: " + str(service_name)) @@ -107,7 +124,7 @@ def rpc_called(service_name: str, rpc_name: str, n_calls: int = 1, endpoint: str return False -def all_rpc_called(service_name: str, endpoint: str = calls_endpoint) -> bool: +def all_rpc_called(service_name: str, missings: List[str] = [], endpoint: str = calls_endpoint) -> bool: """ Check if all remote procedure calls (RPCs) in a service have been made at least once. This function uses ArmoniK.Api.Mock. It just gets the '/calls.json' endpoint. @@ -116,6 +133,7 @@ def all_rpc_called(service_name: str, endpoint: str = calls_endpoint) -> bool: service_name (str): The name of the service containing the RPC information in the response. endpoint (str, optional): The URL of the remote service providing RPC information. Default is the value of calls_endpoint. + missings (List[str], optional): A list of RPCs known to be not implemented. Default is an empty list. Returns: bool: True if all RPCs in the specified service have been called at least once, False otherwise. @@ -138,6 +156,8 @@ def all_rpc_called(service_name: str, endpoint: str = calls_endpoint) -> bool: if rpc_num_calls == 0: missing_rpcs.append(rpc_name) if missing_rpcs: + if missings == missing_rpcs: + return True print(f"RPCs not implemented in {service_name} service: {missing_rpcs}.") return False return True diff --git a/packages/python/tests/test_results.py b/packages/python/tests/test_results.py index 05789e8a6..064cb7ba0 100644 --- a/packages/python/tests/test_results.py +++ b/packages/python/tests/test_results.py @@ -55,6 +55,14 @@ def test_create_results(self): assert rpc_called("Results", "CreateResults") assert results == {} + def test_get_service_config(self): + results_client: ArmoniKResults = get_client("Results") + chunk_size = results_client.get_service_config() + + assert rpc_called("Results", "GetServiceConfiguration") + assert isinstance(chunk_size, int) + assert chunk_size == 81920 + def test_upload_result_data(self): results_client: ArmoniKResults = get_client("Results") result = results_client.upload_result_data("result-name", "session-id", b"test data") @@ -76,14 +84,6 @@ def test_delete_result_data(self): assert rpc_called("Results", "DeleteResultsData") assert result is None - def test_get_service_config(self): - results_client: ArmoniKResults = get_client("Results") - chunk_size = results_client.get_service_config() - - assert rpc_called("Results", "GetServiceConfiguration", 2) - assert isinstance(chunk_size, int) - assert chunk_size == 81920 - def test_watch_results(self): results_client: ArmoniKResults = get_client("Results") with pytest.raises(NotImplementedError, match=""): diff --git a/packages/python/tests/test_taskhandler.py b/packages/python/tests/test_taskhandler.py new file mode 100644 index 000000000..f5e8634eb --- /dev/null +++ b/packages/python/tests/test_taskhandler.py @@ -0,0 +1,108 @@ +import datetime +import logging +import warnings + +from .conftest import all_rpc_called, rpc_called, get_client, data_folder +from armonik.common import TaskDefinition, TaskOptions +from armonik.worker import TaskHandler +from armonik.protogen.worker.agent_service_pb2_grpc import AgentStub +from armonik.protogen.common.worker_common_pb2 import ProcessRequest +from armonik.protogen.common.objects_pb2 import Configuration + + +logging.basicConfig() +logging.getLogger().setLevel(logging.INFO) + + +class TestTaskHandler: + + request =ProcessRequest( + communication_token="token", + session_id="session-id", + task_id="task-id", + expected_output_keys=["result-id"], + payload_id="payload-id", + data_dependencies=["dd-id"], + data_folder=data_folder, + configuration=Configuration(data_chunk_max_size=8000), + task_options=TaskOptions( + max_duration=datetime.timedelta(seconds=1), + priority=1, + max_retries=1 + ).to_message() + ) + + def test_taskhandler_init(self): + task_handler = TaskHandler(self.request, get_client("Agent")) + + assert task_handler.session_id == "session-id" + assert task_handler.task_id == "task-id" + assert task_handler.task_options == TaskOptions( + max_duration=datetime.timedelta(seconds=1), + priority=1, + max_retries=1, + partition_id='', + application_name='', + application_version='', + application_namespace='', + application_service='', + engine_type='', + options={} + ) + assert task_handler.token == "token" + assert task_handler.expected_results == ["result-id"] + assert task_handler.configuration == Configuration(data_chunk_max_size=8000) + assert task_handler.payload_id == "payload-id" + assert task_handler.data_folder == data_folder + assert task_handler.payload == "payload".encode() + assert task_handler.data_dependencies == {"dd-id": "dd".encode()} + + def test_create_task(self): + with warnings.catch_warnings(record=True) as w: + # Cause all warnings to always be triggered. + warnings.simplefilter("always") + + task_handler = TaskHandler(self.request, get_client("Agent")) + tasks, errors = task_handler.create_tasks([TaskDefinition( + payload=b"payload", + expected_output_ids=["result-id"], + data_dependencies=[])]) + + assert issubclass(w[-1].category, DeprecationWarning) + assert rpc_called("Agent", "CreateTask") + assert tasks == [] + assert errors == [] + + def test_submit_tasks(self): + task_handler = TaskHandler(self.request, get_client("Agent")) + tasks = task_handler.submit_tasks([TaskDefinition(payload_id="payload-id", + expected_output_ids=["result-id"], + data_dependencies=[])] + ) + + assert rpc_called("Agent", "SubmitTasks") + assert tasks is None + + def test_send_results(self): + task_handler = TaskHandler(self.request, get_client("Agent")) + resuls = task_handler.send_results({"result-id": b"result data"}) + assert rpc_called("Agent", "NotifyResultData") + assert resuls is None + + def test_create_result_metadata(self): + task_handler = TaskHandler(self.request, get_client("Agent")) + results = task_handler.create_results_metadata(["result-name"]) + + assert rpc_called("Agent", "CreateResultsMetaData") + # TODO: Mock must be updated to return something and so that changes the following assertions + assert results == {} + + def test_create_results(self): + task_handler = TaskHandler(self.request, get_client("Agent")) + results = task_handler.create_results({"result-name": b"test data"}) + + assert rpc_called("Agent", "CreateResults") + assert results == {} + + def test_service_fully_implemented(self): + assert all_rpc_called("Agent", missings=["GetCommonData", "GetDirectData", "GetResourceData"]) From e4ba3bc89d0ecc0b7f09f5f888c7f509003ca8c4 Mon Sep 17 00:00:00 2001 From: qdelamea Date: Tue, 26 Dec 2023 16:06:19 +0100 Subject: [PATCH 08/15] Add test for Events service --- .../python/src/armonik/client/__init__.py | 2 ++ packages/python/src/armonik/client/events.py | 11 ++++++---- packages/python/src/armonik/common/events.py | 8 +++++-- packages/python/tests/test_events.py | 22 +++++++++++++++++++ 4 files changed, 37 insertions(+), 6 deletions(-) create mode 100644 packages/python/tests/test_events.py diff --git a/packages/python/src/armonik/client/__init__.py b/packages/python/src/armonik/client/__init__.py index 8d1c923bc..32ad8b563 100644 --- a/packages/python/src/armonik/client/__init__.py +++ b/packages/python/src/armonik/client/__init__.py @@ -4,6 +4,7 @@ from .tasks import ArmoniKTasks from .results import ArmoniKResults from .versions import ArmoniKVersions +from .events import ArmoniKEvents __all__ = [ 'ArmoniKPartitions', @@ -12,4 +13,5 @@ 'ArmoniKTasks', 'ArmoniKResults', "ArmoniKVersions", + "ArmoniKEvents" ] diff --git a/packages/python/src/armonik/client/events.py b/packages/python/src/armonik/client/events.py index 440f6a382..88695c31f 100644 --- a/packages/python/src/armonik/client/events.py +++ b/packages/python/src/armonik/client/events.py @@ -3,7 +3,7 @@ from grpc import Channel from .results import ArmoniKResults -from ..common import EventTypes, Filter, NewTaskEvent, NewResultEvent, ResultOwnerUpdateEvent, ResultStatusUpdateEvent, TaskStatusUpdateEvent, ResultStatus +from ..common import EventTypes, Filter, NewTaskEvent, NewResultEvent, ResultOwnerUpdateEvent, ResultStatusUpdateEvent, TaskStatusUpdateEvent, ResultStatus, Event from .results import ResultFieldFilter from ..protogen.client.events_service_pb2_grpc import EventsStub from ..protogen.common.events_common_pb2 import EventSubscriptionRequest, EventSubscriptionResponse @@ -29,7 +29,7 @@ def __init__(self, grpc_channel: Channel): self._client = EventsStub(grpc_channel) self._results_client = ArmoniKResults(grpc_channel) - def get_events(self, session_id: str, event_types: List[EventTypes], event_handlers: List[Callable[[str, Any], bool]], task_filter: Filter | None = None, result_filter: Filter | None = None) -> None: + def get_events(self, session_id: str, event_types: List[EventTypes], event_handlers: List[Callable[[str, EventTypes, Event], bool]], task_filter: Filter | None = None, result_filter: Filter | None = None) -> None: """Get events that represents updates of result and tasks data. Args: @@ -45,10 +45,13 @@ def get_events(self, session_id: str, event_types: List[EventTypes], event_handl """ request = EventSubscriptionRequest( session_id=session_id, -tasks_filters=cast(rawTaskFilters, task_filter.to_disjunction().to_message()), - results_filters=cast(rawResultFilters, result_filter.to_disjunction().to_message()), returned_events=event_types ) + if task_filter: + request.tasks_filters=cast(rawTaskFilters, task_filter.to_disjunction().to_message()), + if result_filter: + request.results_filters=cast(rawResultFilters, result_filter.to_disjunction().to_message()), + streaming_call = self._client.GetEvents(request) for message in streaming_call: event_type = message.WhichOneof("update") diff --git a/packages/python/src/armonik/common/events.py b/packages/python/src/armonik/common/events.py index 3fa9fe225..34acbd2c0 100644 --- a/packages/python/src/armonik/common/events.py +++ b/packages/python/src/armonik/common/events.py @@ -6,32 +6,35 @@ from .enumwrapper import TaskStatus, ResultStatus -@dataclass class Event(ABC): @classmethod def from_raw_event(cls, raw_event): values = {} - for raw_field in fields(cls): + for raw_field in cls.__annotations__.keys(): values[raw_field] = getattr(raw_event, raw_field) return cls(**values) +@dataclass class TaskStatusUpdateEvent(Event): task_id: str status: TaskStatus +@dataclass class ResultStatusUpdateEvent(Event): result_id: str status: ResultStatus +@dataclass class ResultOwnerUpdateEvent(Event): result_id: str previous_owner_id: str current_owner_id: str +@dataclass class NewTaskEvent(Event): task_id: str payload_id: str @@ -43,6 +46,7 @@ class NewTaskEvent(Event): parent_task_ids: List[str] +@dataclass class NewResultEvent(Event): result_id: str owner_id: str diff --git a/packages/python/tests/test_events.py b/packages/python/tests/test_events.py new file mode 100644 index 000000000..86ca9e42b --- /dev/null +++ b/packages/python/tests/test_events.py @@ -0,0 +1,22 @@ +from .conftest import all_rpc_called, rpc_called, get_client +from armonik.client import ArmoniKEvents +from armonik.common import EventTypes, NewResultEvent, ResultStatus + + +class TestArmoniKEvents: + def test_get_events_no_filter(self): + def test_handler(session_id, event_type, event): + assert session_id == "session-id" + assert event_type == EventTypes.NEW_RESULT + assert isinstance(event, NewResultEvent) + assert event.result_id == "result-id" + assert event.owner_id == "owner-id" + assert event.status == ResultStatus.CREATED + + tasks_client: ArmoniKEvents = get_client("Events") + tasks_client.get_events("session-id", [EventTypes.TASK_STATUS_UPDATE], [test_handler]) + + assert rpc_called("Events", "GetEvents") + + def test_service_fully_implemented(self): + assert all_rpc_called("Events") From f6f1824e326a6527b496b23355265c1e008a6f4f Mon Sep 17 00:00:00 2001 From: qdelamea Date: Tue, 26 Dec 2023 16:08:00 +0100 Subject: [PATCH 09/15] Refactor tests --- packages/python/tests/common.py | 40 --- packages/python/tests/conftest.py | 6 +- packages/python/tests/taskhandler_test.py | 89 ------ packages/python/tests/tasks_test.py | 281 ------------------ .../{filters_test.py => test_filters.py} | 0 .../{helpers_test.py => test_helpers.py} | 15 +- packages/python/tests/test_submitter.py | 0 packages/python/tests/test_worker.py | 80 +++++ packages/python/tests/worker_test.py | 73 ----- 9 files changed, 98 insertions(+), 486 deletions(-) delete mode 100644 packages/python/tests/common.py delete mode 100644 packages/python/tests/taskhandler_test.py delete mode 100644 packages/python/tests/tasks_test.py rename packages/python/tests/{filters_test.py => test_filters.py} (100%) rename packages/python/tests/{helpers_test.py => test_helpers.py} (76%) create mode 100644 packages/python/tests/test_submitter.py create mode 100644 packages/python/tests/test_worker.py delete mode 100644 packages/python/tests/worker_test.py diff --git a/packages/python/tests/common.py b/packages/python/tests/common.py deleted file mode 100644 index db2feed4a..000000000 --- a/packages/python/tests/common.py +++ /dev/null @@ -1,40 +0,0 @@ -from grpc import Channel - - -class DummyChannel(Channel): - def __init__(self): - self.method_dict = {} - - def stream_unary(self, *args, **kwargs): - return self.get_method(args[0]) - - def unary_stream(self, *args, **kwargs): - return self.get_method(args[0]) - - def unary_unary(self, *args, **kwargs): - return self.get_method(args[0]) - - def stream_stream(self, *args, **kwargs): - return self.get_method(args[0]) - - def set_instance(self, instance): - self.method_dict = {func: getattr(instance, func) for func in dir(type(instance)) if callable(getattr(type(instance), func)) and not func.startswith("__")} - - def get_method(self, name: str): - return self.method_dict.get(name.split("/")[-1], None) - - def subscribe(self, callback, try_to_connect=False): - pass - - def unsubscribe(self, callback): - pass - - def close(self): - pass - - def __enter__(self): - pass - - def __exit__(self, exc_type, exc_val, exc_tb): - pass - diff --git a/packages/python/tests/conftest.py b/packages/python/tests/conftest.py index 060c45ed6..a5d47658c 100644 --- a/packages/python/tests/conftest.py +++ b/packages/python/tests/conftest.py @@ -3,7 +3,7 @@ import pytest import requests -from armonik.client import ArmoniKResults, ArmoniKSubmitter, ArmoniKTasks, ArmoniKSessions, ArmoniKPartitions, ArmoniKVersions +from armonik.client import ArmoniKResults, ArmoniKSubmitter, ArmoniKTasks, ArmoniKSessions, ArmoniKPartitions, ArmoniKVersions, ArmoniKEvents from armonik.protogen.worker.agent_service_pb2_grpc import AgentStub from typing import List @@ -54,7 +54,7 @@ def clean_up(request): print("An error occurred when resetting the server: " + str(e)) -def get_client(client_name: str, endpoint: str = grpc_endpoint) -> [ArmoniKResults, ArmoniKSubmitter, ArmoniKTasks, ArmoniKSessions, ArmoniKPartitions, ArmoniKVersions, AgentStub]: +def get_client(client_name: str, endpoint: str = grpc_endpoint) -> [ArmoniKResults, ArmoniKSubmitter, ArmoniKTasks, ArmoniKSessions, ArmoniKPartitions, ArmoniKVersions, AgentStub, ArmoniKEvents]: """ Get the ArmoniK client instance based on the specified service name. @@ -89,6 +89,8 @@ def get_client(client_name: str, endpoint: str = grpc_endpoint) -> [ArmoniKResul return ArmoniKVersions(channel) case "Agent": return AgentStub(channel) + case "Events": + return ArmoniKEvents(channel) case _: raise ValueError("Unknown service name: " + str(service_name)) diff --git a/packages/python/tests/taskhandler_test.py b/packages/python/tests/taskhandler_test.py deleted file mode 100644 index e4f3c181c..000000000 --- a/packages/python/tests/taskhandler_test.py +++ /dev/null @@ -1,89 +0,0 @@ -#!/usr/bin/env python3 -import os - -import pytest -from typing import Iterator - -from armonik.common import TaskDefinition - -from .common import DummyChannel -from armonik.worker import TaskHandler -from armonik.protogen.worker.agent_service_pb2_grpc import AgentStub -from armonik.protogen.common.agent_common_pb2 import CreateTaskRequest, CreateTaskReply, NotifyResultDataRequest, NotifyResultDataResponse -from armonik.protogen.common.worker_common_pb2 import ProcessRequest -from armonik.protogen.common.objects_pb2 import Configuration -import logging - -logging.basicConfig() -logging.getLogger().setLevel(logging.INFO) - -data_folder = os.getcwd() - - -@pytest.fixture(autouse=True, scope="session") -def setup_teardown(): - with open(os.path.join(data_folder, "payloadid"), "wb") as f: - f.write("payload".encode()) - with open(os.path.join(data_folder, "ddid"), "wb") as f: - f.write("dd".encode()) - yield - os.remove(os.path.join(data_folder, "payloadid")) - os.remove(os.path.join(data_folder, "ddid")) - - -class DummyAgent(AgentStub): - - def __init__(self, channel: DummyChannel) -> None: - channel.set_instance(self) - super(DummyAgent, self).__init__(channel) - self.create_task_messages = [] - self.send_result_task_message = [] - - def CreateTask(self, request_iterator: Iterator[CreateTaskRequest]) -> CreateTaskReply: - self.create_task_messages = [r for r in request_iterator] - return CreateTaskReply(creation_status_list=CreateTaskReply.CreationStatusList(creation_statuses=[ - CreateTaskReply.CreationStatus( - task_info=CreateTaskReply.TaskInfo(task_id="TaskId", expected_output_keys=["EOK"], - data_dependencies=["DD"]))])) - - def NotifyResultData(self, request: NotifyResultDataRequest) -> NotifyResultDataResponse: - self.send_result_task_message.append(request) - return NotifyResultDataResponse(result_ids=[i.result_id for i in request.ids]) - - -should_succeed_case = ProcessRequest(communication_token="token", session_id="sessionid", task_id="taskid", expected_output_keys=["resultid"], payload_id="payloadid", data_dependencies=["ddid"], data_folder=data_folder, configuration=Configuration(data_chunk_max_size=8000)) - - -@pytest.mark.parametrize("requests", [should_succeed_case]) -def test_taskhandler_create_should_succeed(requests: ProcessRequest): - agent = DummyAgent(DummyChannel()) - task_handler = TaskHandler(requests, agent) - assert task_handler.token is not None and len(task_handler.token) > 0 - assert len(task_handler.payload) > 0 - assert task_handler.session_id is not None and len(task_handler.session_id) > 0 - assert task_handler.task_id is not None and len(task_handler.task_id) > 0 - - -def test_taskhandler_data_are_correct(): - agent = DummyAgent(DummyChannel()) - task_handler = TaskHandler(should_succeed_case, agent) - assert len(task_handler.payload) > 0 - - task_handler.create_tasks([TaskDefinition("Payload".encode("utf-8"), ["EOK"], ["DD"])]) - - tasks = agent.create_task_messages - assert len(tasks) == 5 - assert tasks[0].WhichOneof("type") == "init_request" - assert tasks[1].WhichOneof("type") == "init_task" - assert len(tasks[1].init_task.header.data_dependencies) == 1 \ - and tasks[1].init_task.header.data_dependencies[0] == "DD" - assert len(tasks[1].init_task.header.expected_output_keys) == 1 \ - and tasks[1].init_task.header.expected_output_keys[0] == "EOK" - assert tasks[2].WhichOneof("type") == "task_payload" - assert tasks[2].task_payload.data == "Payload".encode("utf-8") - assert tasks[3].WhichOneof("type") == "task_payload" - assert tasks[3].task_payload.data_complete - assert tasks[4].WhichOneof("type") == "init_task" - assert tasks[4].init_task.last_task - - diff --git a/packages/python/tests/tasks_test.py b/packages/python/tests/tasks_test.py deleted file mode 100644 index 752b2ac63..000000000 --- a/packages/python/tests/tasks_test.py +++ /dev/null @@ -1,281 +0,0 @@ -#!/usr/bin/env python3 -import dataclasses -from typing import Optional, List, Any, Union, Dict, Collection -from google.protobuf.timestamp_pb2 import Timestamp - -from datetime import datetime - -import pytest - -from .common import DummyChannel -from armonik.client import ArmoniKTasks -from armonik.client.tasks import TaskFieldFilter -from armonik.common import TaskStatus, datetime_to_timestamp, Task -from armonik.common.filter import StringFilter, Filter -from armonik.protogen.client.tasks_service_pb2_grpc import TasksStub -from armonik.protogen.common.tasks_common_pb2 import GetTaskRequest, GetTaskResponse, TaskDetailed -from armonik.protogen.common.tasks_filters_pb2 import Filters, FilterField -from armonik.protogen.common.filters_common_pb2 import * -from armonik.protogen.common.tasks_fields_pb2 import * -from .submitter_test import default_task_option - - -class DummyTasksService(TasksStub): - def __init__(self, channel: DummyChannel): - channel.set_instance(self) - super().__init__(channel) - self.task_request: Optional[GetTaskRequest] = None - - def GetTask(self, request: GetTaskRequest) -> GetTaskResponse: - self.task_request = request - raw = TaskDetailed(id="TaskId", session_id="SessionId", owner_pod_id="PodId", parent_task_ids=["ParentTaskId"], - data_dependencies=["DD"], expected_output_ids=["EOK"], retry_of_ids=["RetryId"], - status=TaskStatus.COMPLETED, status_message="Message", - options=default_task_option.to_message(), - created_at=datetime_to_timestamp(datetime.now()), - started_at=datetime_to_timestamp(datetime.now()), - submitted_at=datetime_to_timestamp(datetime.now()), - ended_at=datetime_to_timestamp(datetime.now()), pod_ttl=datetime_to_timestamp(datetime.now()), - output=TaskDetailed.Output(success=True), pod_hostname="Hostname", received_at=datetime_to_timestamp(datetime.now()), - acquired_at=datetime_to_timestamp(datetime.now()) - ) - return GetTaskResponse(task=raw) - - -def test_tasks_get_task_should_succeed(): - channel = DummyChannel() - inner = DummyTasksService(channel) - tasks = ArmoniKTasks(channel) - task = tasks.get_task("TaskId") - assert task is not None - assert inner.task_request is not None - assert inner.task_request.task_id == "TaskId" - assert task.id == "TaskId" - assert task.session_id == "SessionId" - assert task.parent_task_ids == ["ParentTaskId"] - assert task.output - assert task.output.success - - -def test_task_refresh(): - channel = DummyChannel() - inner = DummyTasksService(channel) - tasks = ArmoniKTasks(channel) - current = Task(id="TaskId") - current.refresh(tasks) - assert current is not None - assert inner.task_request is not None - assert inner.task_request.task_id == "TaskId" - assert current.id == "TaskId" - assert current.session_id == "SessionId" - assert current.parent_task_ids == ["ParentTaskId"] - assert current.output - assert current.output.success - - -def test_task_filters(): - filt: StringFilter = TaskFieldFilter.TASK_ID == "TaskId" - message = filt.to_message() - assert isinstance(message, FilterField) - assert message.field.WhichOneof("field") == "task_summary_field" - assert message.field.task_summary_field.field == TASK_SUMMARY_ENUM_FIELD_TASK_ID - assert message.filter_string.value == "TaskId" - assert message.filter_string.operator == FILTER_STRING_OPERATOR_EQUAL - - filt: StringFilter = TaskFieldFilter.TASK_ID != "TaskId" - message = filt.to_message() - assert isinstance(message, FilterField) - assert message.field.WhichOneof("field") == "task_summary_field" - assert message.field.task_summary_field.field == TASK_SUMMARY_ENUM_FIELD_TASK_ID - assert message.filter_string.value == "TaskId" - assert message.filter_string.operator == FILTER_STRING_OPERATOR_NOT_EQUAL - - -@dataclasses.dataclass -class SimpleFieldFilter: - field: Any - value: Any - operator: Any - - -@pytest.mark.parametrize("filt,n_or,n_and,filters", [ - ( - (TaskFieldFilter.INITIAL_TASK_ID == "TestId"), - 1, [1], - [ - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_INITIAL_TASK_ID, "TestId", FILTER_STRING_OPERATOR_EQUAL) - ] - ), - ( - (TaskFieldFilter.APPLICATION_NAME.contains("TestName") & (TaskFieldFilter.CREATED_AT > Timestamp(seconds=1000, nanos=500))), - 1, [2], - [ - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_APPLICATION_NAME, "TestName", FILTER_STRING_OPERATOR_CONTAINS), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_CREATED_AT, Timestamp(seconds=1000, nanos=500), FILTER_DATE_OPERATOR_AFTER) - ] - ), - ( - (((TaskFieldFilter.MAX_RETRIES <= 3) & ~(TaskFieldFilter.SESSION_ID == "SessionId")) | (TaskFieldFilter.task_options_key("MyKey").startswith("Start"))), - 2, [1, 2], - [ - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_MAX_RETRIES, 3, FILTER_NUMBER_OPERATOR_LESS_THAN_OR_EQUAL), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_SESSION_ID, "SessionId", FILTER_STRING_OPERATOR_NOT_EQUAL), - SimpleFieldFilter("MyKey", "Start", FILTER_STRING_OPERATOR_STARTS_WITH) - ] - ), - ( - (((TaskFieldFilter.PRIORITY > 3) & ~(TaskFieldFilter.STATUS == TaskStatus.COMPLETED) & TaskFieldFilter.APPLICATION_VERSION.contains("1.0")) | (TaskFieldFilter.ENGINE_TYPE.endswith("Test") & (TaskFieldFilter.ENDED_AT <= Timestamp(seconds=1000, nanos=500)))), - 2, [2, 3], - [ - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_PRIORITY, 3, FILTER_NUMBER_OPERATOR_GREATER_THAN), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_STATUS, TaskStatus.COMPLETED, FILTER_STATUS_OPERATOR_NOT_EQUAL), - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_APPLICATION_VERSION, "1.0", FILTER_STRING_OPERATOR_CONTAINS), - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_ENGINE_TYPE, "Test", FILTER_STRING_OPERATOR_ENDS_WITH), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_ENDED_AT, Timestamp(seconds=1000, nanos=500), FILTER_DATE_OPERATOR_BEFORE_OR_EQUAL), - ] - ), - ( - (((TaskFieldFilter.PRIORITY >= 3) * -(TaskFieldFilter.STATUS != TaskStatus.COMPLETED) * -TaskFieldFilter.APPLICATION_VERSION.contains("1.0")) + (TaskFieldFilter.ENGINE_TYPE.endswith("Test") * (TaskFieldFilter.ENDED_AT <= Timestamp(seconds=1000, nanos=500)))), - 2, [2, 3], - [ - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_PRIORITY, 3, FILTER_NUMBER_OPERATOR_GREATER_THAN_OR_EQUAL), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_STATUS, TaskStatus.COMPLETED, FILTER_STATUS_OPERATOR_EQUAL), - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_APPLICATION_VERSION, "1.0", FILTER_STRING_OPERATOR_NOT_CONTAINS), - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_ENGINE_TYPE, "Test", FILTER_STRING_OPERATOR_ENDS_WITH), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_ENDED_AT, Timestamp(seconds=1000, nanos=500), FILTER_DATE_OPERATOR_BEFORE_OR_EQUAL), - ] - ) -]) -def test_filter_combination(filt: Filter, n_or: int, n_and: List[int], filters: List[SimpleFieldFilter]): - filt = filt.to_disjunction() - assert len(filt._filters) == n_or - sorted_n_and = sorted(n_and) - sorted_actual = sorted([len(f) for f in filt._filters]) - assert len(sorted_n_and) == len(sorted_actual) - assert all((sorted_n_and[i] == sorted_actual[i] for i in range(len(sorted_actual)))) - for f in filt._filters: - for ff in f: - field_value = getattr(ff.field, ff.field.WhichOneof("field")).field - for i, expected in enumerate(filters): - if expected.field == field_value and expected.value == ff.value and expected.operator == ff.operator: - filters.pop(i) - break - else: - print(f"Could not find {str(ff)}") - assert False - assert len(filters) == 0 - - -def test_name_from_value(): - assert TaskStatus.name_from_value(TaskStatus.COMPLETED) == "TASK_STATUS_COMPLETED" - - -class BasicFilterAnd: - - def __setattr__(self, key, value): - self.__dict__[key] = value - - def __getattr__(self, item): - return self.__dict__[item] - - -@pytest.mark.parametrize("filt,n_or,n_and,filters,expected_type", [ - ( - (TaskFieldFilter.INITIAL_TASK_ID == "TestId"), - 1, [1], - [ - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_INITIAL_TASK_ID, "TestId", FILTER_STRING_OPERATOR_EQUAL) - ], - 0 - ), - ( - (TaskFieldFilter.APPLICATION_NAME.contains("TestName") & (TaskFieldFilter.CREATED_AT > Timestamp(seconds=1000, nanos=500))), - 1, [2], - [ - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_APPLICATION_NAME, "TestName", FILTER_STRING_OPERATOR_CONTAINS), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_CREATED_AT, Timestamp(seconds=1000, nanos=500), FILTER_DATE_OPERATOR_AFTER) - ], - 1 - ), - ( - (((TaskFieldFilter.MAX_RETRIES <= 3) & ~(TaskFieldFilter.SESSION_ID == "SessionId")) | (TaskFieldFilter.task_options_key("MyKey").startswith("Start"))), - 2, [1, 2], - [ - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_MAX_RETRIES, 3, FILTER_NUMBER_OPERATOR_LESS_THAN_OR_EQUAL), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_SESSION_ID, "SessionId", FILTER_STRING_OPERATOR_NOT_EQUAL), - SimpleFieldFilter("MyKey", "Start", FILTER_STRING_OPERATOR_STARTS_WITH) - ], - 2 - ), - ( - (((TaskFieldFilter.PRIORITY > 3) & ~(TaskFieldFilter.STATUS == TaskStatus.COMPLETED) & TaskFieldFilter.APPLICATION_VERSION.contains("1.0")) | (TaskFieldFilter.ENGINE_TYPE.endswith("Test") & (TaskFieldFilter.ENDED_AT <= Timestamp(seconds=1000, nanos=500)))), - 2, [2, 3], - [ - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_PRIORITY, 3, FILTER_NUMBER_OPERATOR_GREATER_THAN), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_STATUS, TaskStatus.COMPLETED, FILTER_STATUS_OPERATOR_NOT_EQUAL), - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_APPLICATION_VERSION, "1.0", FILTER_STRING_OPERATOR_CONTAINS), - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_ENGINE_TYPE, "Test", FILTER_STRING_OPERATOR_ENDS_WITH), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_ENDED_AT, Timestamp(seconds=1000, nanos=500), FILTER_DATE_OPERATOR_BEFORE_OR_EQUAL), - ], - 2 - ), - ( - (((TaskFieldFilter.PRIORITY >= 3) * -(TaskFieldFilter.STATUS != TaskStatus.COMPLETED) * -TaskFieldFilter.APPLICATION_VERSION.contains("1.0")) + (TaskFieldFilter.ENGINE_TYPE.endswith("Test") * (TaskFieldFilter.ENDED_AT <= Timestamp(seconds=1000, nanos=500)))), - 2, [2, 3], - [ - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_PRIORITY, 3, FILTER_NUMBER_OPERATOR_GREATER_THAN_OR_EQUAL), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_STATUS, TaskStatus.COMPLETED, FILTER_STATUS_OPERATOR_EQUAL), - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_APPLICATION_VERSION, "1.0", FILTER_STRING_OPERATOR_NOT_CONTAINS), - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_ENGINE_TYPE, "Test", FILTER_STRING_OPERATOR_ENDS_WITH), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_ENDED_AT, Timestamp(seconds=1000, nanos=500), FILTER_DATE_OPERATOR_BEFORE_OR_EQUAL), - ], - 2 - ), - ( - (((TaskFieldFilter.PRIORITY >= 3) * -(TaskFieldFilter.STATUS != TaskStatus.COMPLETED) * -TaskFieldFilter.APPLICATION_VERSION.contains("1.0")) + (TaskFieldFilter.ENGINE_TYPE.endswith("Test") * (TaskFieldFilter.ENDED_AT <= Timestamp(seconds=1000, nanos=500)))) + (((TaskFieldFilter.MAX_RETRIES <= 3) & ~(TaskFieldFilter.SESSION_ID == "SessionId")) | (TaskFieldFilter.task_options_key("MyKey").startswith("Start"))), - 4, [2, 3, 2, 1], - [ - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_PRIORITY, 3, FILTER_NUMBER_OPERATOR_GREATER_THAN_OR_EQUAL), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_STATUS, TaskStatus.COMPLETED, FILTER_STATUS_OPERATOR_EQUAL), - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_APPLICATION_VERSION, "1.0", FILTER_STRING_OPERATOR_NOT_CONTAINS), - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_ENGINE_TYPE, "Test", FILTER_STRING_OPERATOR_ENDS_WITH), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_ENDED_AT, Timestamp(seconds=1000, nanos=500), FILTER_DATE_OPERATOR_BEFORE_OR_EQUAL), - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_MAX_RETRIES, 3, FILTER_NUMBER_OPERATOR_LESS_THAN_OR_EQUAL), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_SESSION_ID, "SessionId", FILTER_STRING_OPERATOR_NOT_EQUAL), - SimpleFieldFilter("MyKey", "Start", FILTER_STRING_OPERATOR_STARTS_WITH) - ], - 2 - ) -]) -def test_taskfilter_to_message(filt: Filter, n_or: int, n_and: List[int], filters: List[SimpleFieldFilter], expected_type: int): - print(filt) - message = filt.to_message() - conjs: Collection = [] - if expected_type == 2: # Disjunction - conjs: Collection = getattr(message, "or") - assert len(conjs) == n_or - sorted_n_and = sorted(n_and) - sorted_actual = sorted([len(getattr(f, "and")) for f in conjs]) - assert len(sorted_n_and) == len(sorted_actual) - assert all((sorted_n_and[i] == sorted_actual[i] for i in range(len(sorted_actual)))) - - if expected_type == 1: # Conjunction - conjs: Collection = [message] - - if expected_type == 0: # Simple filter - m = BasicFilterAnd() - setattr(m, "and", [message]) - conjs: Collection = [m] - - for conj in conjs: - basics = getattr(conj, "and") - for f in basics: - field_value = getattr(f.field, f.field.WhichOneof("field")).field - for i, expected in enumerate(filters): - if expected.field == field_value and expected.value == getattr(f, f.WhichOneof("value_condition")).value and expected.operator == getattr(f, f.WhichOneof("value_condition")).operator: - filters.pop(i) - break - else: - print(f"Could not find {str(f)}") - assert False - assert len(filters) == 0 diff --git a/packages/python/tests/filters_test.py b/packages/python/tests/test_filters.py similarity index 100% rename from packages/python/tests/filters_test.py rename to packages/python/tests/test_filters.py diff --git a/packages/python/tests/helpers_test.py b/packages/python/tests/test_helpers.py similarity index 76% rename from packages/python/tests/helpers_test.py rename to packages/python/tests/test_helpers.py index 20e07d922..1c0cf517b 100644 --- a/packages/python/tests/helpers_test.py +++ b/packages/python/tests/test_helpers.py @@ -5,7 +5,9 @@ from google.protobuf.duration_pb2 import Duration from dataclasses import dataclass from datetime import datetime, timedelta, timezone -from armonik.common.helpers import datetime_to_timestamp, timestamp_to_datetime, timedelta_to_duration, duration_to_timedelta +from armonik.common.helpers import datetime_to_timestamp, timestamp_to_datetime, timedelta_to_duration, duration_to_timedelta, batched + +from typing import Iterable, List @dataclass @@ -60,3 +62,14 @@ def test_duration_to_timedelta(case: Case): def test_timedelta_to_duration(case: Case): ts = timedelta_to_duration(case.delta) assert ts.seconds == case.duration.seconds and abs(ts.nanos - case.duration.nanos) < 1000 + + +@pytest.mark.parametrize(["iterable", "batch_size", "iterations"], [ + ([1, 2, 3], 3, [[1, 2, 3]]), + ([1, 2, 3], 5, [[1, 2, 3]]), + ([1, 2, 3], 2, [[1, 2], [3]]), + ([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], 3, [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11]]) +]) +def test_batched(iterable: Iterable, batch_size: int, iterations: List[Iterable]): + for index, batch in enumerate(batched(iterable, batch_size)): + assert batch == iterations[index] diff --git a/packages/python/tests/test_submitter.py b/packages/python/tests/test_submitter.py new file mode 100644 index 000000000..e69de29bb diff --git a/packages/python/tests/test_worker.py b/packages/python/tests/test_worker.py new file mode 100644 index 000000000..42198bcdf --- /dev/null +++ b/packages/python/tests/test_worker.py @@ -0,0 +1,80 @@ +import datetime +import grpc +import logging +import os +import pytest + +from .conftest import data_folder, grpc_endpoint +from armonik.worker import ArmoniKWorker, TaskHandler, ClefLogger +from armonik.common import Output, TaskOptions +from armonik.protogen.common.objects_pb2 import Empty, Configuration +from armonik.protogen.common.worker_common_pb2 import ProcessRequest + + +def do_nothing(_: TaskHandler) -> Output: + return Output() + + +def throw_error(_: TaskHandler) -> Output: + raise ValueError("TestError") + + +def return_error(_: TaskHandler) -> Output: + return Output("TestError") + + +def return_and_send(th: TaskHandler) -> Output: + th.send_results({th.expected_results[0]: b"result"}) + return Output() + + +class TestWorker: + + request = ProcessRequest( + communication_token="token", + session_id="session-id", + task_id="task-id", + expected_output_keys=["result-id"], + payload_id="payload-id", + data_dependencies=["dd-id"], + data_folder=data_folder, + configuration=Configuration(data_chunk_max_size=8000), + task_options=TaskOptions( + max_duration=datetime.timedelta(seconds=1), + priority=1, + max_retries=1 + ).to_message() + ) + + def test_do_nothing(self): + with grpc.insecure_channel(grpc_endpoint) as agent_channel: + worker = ArmoniKWorker(agent_channel, do_nothing, logger=ClefLogger("TestLogger", level=logging.CRITICAL)) + reply = worker.Process(self.request, None) + assert Output(reply.output.error.details if reply.output.WhichOneof("type") == "error" else None).success + worker.HealthCheck(Empty(), None) + + def test_should_return_none(self): + with grpc.insecure_channel(grpc_endpoint) as agent_channel: + worker = ArmoniKWorker(agent_channel, throw_error, logger=ClefLogger("TestLogger", level=logging.CRITICAL)) + reply = worker.Process(self.request, None) + assert reply is None + + def test_should_error(self): + with grpc.insecure_channel(grpc_endpoint) as agent_channel: + worker = ArmoniKWorker(agent_channel, return_error, logger=ClefLogger("TestLogger", level=logging.CRITICAL)) + reply = worker.Process(self.request, None) + output = Output(reply.output.error.details if reply.output.WhichOneof("type") == "error" else None) + assert not output.success + assert output.error == "TestError" + + def test_should_write_result(self): + with grpc.insecure_channel(grpc_endpoint) as agent_channel: + worker = ArmoniKWorker(agent_channel, return_and_send, logger=ClefLogger("TestLogger", level=logging.DEBUG)) + reply = worker.Process(self.request, None) + assert reply is not None + output = Output(reply.output.error.details if reply.output.WhichOneof("type") == "error" else None) + assert output.success + assert os.path.exists(os.path.join(data_folder, self.request.expected_output_keys[0])) + with open(os.path.join(data_folder, self.request.expected_output_keys[0]), "rb") as f: + value = f.read() + assert len(value) > 0 diff --git a/packages/python/tests/worker_test.py b/packages/python/tests/worker_test.py deleted file mode 100644 index 032c406ee..000000000 --- a/packages/python/tests/worker_test.py +++ /dev/null @@ -1,73 +0,0 @@ -#!/usr/bin/env python3 -import logging -import os -import pytest -from armonik.worker import ArmoniKWorker, TaskHandler, ClefLogger -from armonik.common import Output -from .taskhandler_test import should_succeed_case, data_folder, DummyAgent -from .common import DummyChannel -from armonik.protogen.common.objects_pb2 import Empty -import grpc - - -def do_nothing(_: TaskHandler) -> Output: - return Output() - - -def throw_error(_: TaskHandler) -> Output: - raise ValueError("TestError") - - -def return_error(_: TaskHandler) -> Output: - return Output("TestError") - - -def return_and_send(th: TaskHandler) -> Output: - th.send_result(th.expected_results[0], b"result") - return Output() - - -@pytest.fixture(autouse=True, scope="function") -def remove_result(): - yield - if os.path.exists(os.path.join(data_folder, "resultid")): - os.remove(os.path.join(data_folder, "resultid")) - - -def test_do_nothing_worker(): - with grpc.insecure_channel("unix:///tmp/agent.sock") as agent_channel: - worker = ArmoniKWorker(agent_channel, do_nothing, logger=ClefLogger("TestLogger", level=logging.CRITICAL)) - reply = worker.Process(should_succeed_case, None) - assert Output(reply.output.error.details if reply.output.WhichOneof("type") == "error" else None).success - worker.HealthCheck(Empty(), None) - - -def test_worker_should_return_none(): - with grpc.insecure_channel("unix:///tmp/agent.sock") as agent_channel: - worker = ArmoniKWorker(agent_channel, throw_error, logger=ClefLogger("TestLogger", level=logging.CRITICAL)) - reply = worker.Process(should_succeed_case, None) - assert reply is None - - -def test_worker_should_error(): - with grpc.insecure_channel("unix:///tmp/agent.sock") as agent_channel: - worker = ArmoniKWorker(agent_channel, return_error, logger=ClefLogger("TestLogger", level=logging.CRITICAL)) - reply = worker.Process(should_succeed_case, None) - output = Output(reply.output.error.details if reply.output.WhichOneof("type") == "error" else None) - assert not output.success - assert output.error == "TestError" - - -def test_worker_should_write_result(): - with grpc.insecure_channel("unix:///tmp/agent.sock") as agent_channel: - worker = ArmoniKWorker(agent_channel, return_and_send, logger=ClefLogger("TestLogger", level=logging.DEBUG)) - worker._client = DummyAgent(DummyChannel()) - reply = worker.Process(should_succeed_case, None) - assert reply is not None - output = Output(reply.output.error.details if reply.output.WhichOneof("type") == "error" else None) - assert output.success - assert os.path.exists(os.path.join(data_folder, should_succeed_case.expected_output_keys[0])) - with open(os.path.join(data_folder, should_succeed_case.expected_output_keys[0]), "rb") as f: - value = f.read() - assert len(value) > 0 - From 26d0b9d4cdff7e518e3bd0220f99c22d542704c7 Mon Sep 17 00:00:00 2001 From: qdelamea Date: Wed, 27 Dec 2023 11:36:30 +0100 Subject: [PATCH 10/15] Add HealthChecks service to Python API with tests --- .../python/src/armonik/client/__init__.py | 4 +++- .../python/src/armonik/client/health_check.py | 21 +++++++++++++++++++ .../python/src/armonik/common/__init__.py | 3 ++- .../python/src/armonik/common/enumwrapper.py | 8 +++++++ packages/python/tests/conftest.py | 7 +++++-- packages/python/tests/test_healthcheck.py | 18 ++++++++++++++++ 6 files changed, 57 insertions(+), 4 deletions(-) create mode 100644 packages/python/src/armonik/client/health_check.py create mode 100644 packages/python/tests/test_healthcheck.py diff --git a/packages/python/src/armonik/client/__init__.py b/packages/python/src/armonik/client/__init__.py index 32ad8b563..6510b4663 100644 --- a/packages/python/src/armonik/client/__init__.py +++ b/packages/python/src/armonik/client/__init__.py @@ -5,6 +5,7 @@ from .results import ArmoniKResults from .versions import ArmoniKVersions from .events import ArmoniKEvents +from .health_check import ArmoniKHealthChecks __all__ = [ 'ArmoniKPartitions', @@ -13,5 +14,6 @@ 'ArmoniKTasks', 'ArmoniKResults', "ArmoniKVersions", - "ArmoniKEvents" + "ArmoniKEvents", + "ArmoniKHealthChecks" ] diff --git a/packages/python/src/armonik/client/health_check.py b/packages/python/src/armonik/client/health_check.py new file mode 100644 index 000000000..76c04e251 --- /dev/null +++ b/packages/python/src/armonik/client/health_check.py @@ -0,0 +1,21 @@ +from typing import cast, List, Tuple + +from grpc import Channel + +from ..common import HealthCheckStatus +from ..protogen.client.health_checks_service_pb2_grpc import HealthChecksServiceStub +from ..protogen.common.health_checks_common_pb2 import CheckHealthRequest, CheckHealthResponse + + +class ArmoniKHealthChecks: + def __init__(self, grpc_channel: Channel): + """ Result service client + + Args: + grpc_channel: gRPC channel to use + """ + self._client = HealthChecksServiceStub(grpc_channel) + + def check_health(self): + response: CheckHealthResponse = self._client.CheckHealth(CheckHealthRequest()) + return {service.name: {"message": service.message, "status": service.healthy} for service in response.services} diff --git a/packages/python/src/armonik/common/__init__.py b/packages/python/src/armonik/common/__init__.py index 836872785..04105d3d4 100644 --- a/packages/python/src/armonik/common/__init__.py +++ b/packages/python/src/armonik/common/__init__.py @@ -7,7 +7,7 @@ batched ) from .objects import Task, TaskDefinition, TaskOptions, Output, ResultAvailability, Session, Result, Partition -from .enumwrapper import HealthCheckStatus, TaskStatus, Direction, SessionStatus, ResultStatus, EventTypes +from .enumwrapper import HealthCheckStatus, TaskStatus, Direction, SessionStatus, ResultStatus, EventTypes, ServiceHealthCheckStatus from .events import * from .filter import Filter, StringFilter, StatusFilter @@ -37,4 +37,5 @@ 'Filter', 'StringFilter', 'StatusFilter', + 'ServiceHealthCheckStatus' ] diff --git a/packages/python/src/armonik/common/enumwrapper.py b/packages/python/src/armonik/common/enumwrapper.py index 9aba6d2cd..d6fe134cb 100644 --- a/packages/python/src/armonik/common/enumwrapper.py +++ b/packages/python/src/armonik/common/enumwrapper.py @@ -4,6 +4,7 @@ from ..protogen.common.events_common_pb2 import EventsEnum as rawEventsEnum, EVENTS_ENUM_UNSPECIFIED, EVENTS_ENUM_NEW_TASK, EVENTS_ENUM_TASK_STATUS_UPDATE, EVENTS_ENUM_NEW_RESULT, EVENTS_ENUM_RESULT_STATUS_UPDATE, EVENTS_ENUM_RESULT_OWNER_UPDATE from ..protogen.common.session_status_pb2 import SessionStatus as RawSessionStatus, _SESSIONSTATUS, SESSION_STATUS_UNSPECIFIED, SESSION_STATUS_CANCELLED, SESSION_STATUS_RUNNING from ..protogen.common.result_status_pb2 import ResultStatus as RawResultStatus, _RESULTSTATUS, RESULT_STATUS_UNSPECIFIED, RESULT_STATUS_CREATED, RESULT_STATUS_COMPLETED, RESULT_STATUS_ABORTED, RESULT_STATUS_NOTFOUND +from ..protogen.common.health_checks_common_pb2 import HEALTH_STATUS_ENUM_UNSPECIFIED, HEALTH_STATUS_ENUM_HEALTHY, HEALTH_STATUS_ENUM_DEGRADED, HEALTH_STATUS_ENUM_UNHEALTHY from ..protogen.common.worker_common_pb2 import HealthCheckReply from ..protogen.common.sort_direction_pb2 import SORT_DIRECTION_ASC, SORT_DIRECTION_DESC @@ -72,3 +73,10 @@ class EventTypes: @classmethod def from_string(cls, name: str): return getattr(cls, name.upper()) + + +class ServiceHealthCheckStatus: + UNSPECIFIED = HEALTH_STATUS_ENUM_UNSPECIFIED + HEALTHY = HEALTH_STATUS_ENUM_HEALTHY + DEGRADED = HEALTH_STATUS_ENUM_DEGRADED + UNHEALTHY = HEALTH_STATUS_ENUM_UNHEALTHY diff --git a/packages/python/tests/conftest.py b/packages/python/tests/conftest.py index a5d47658c..1b5949a7e 100644 --- a/packages/python/tests/conftest.py +++ b/packages/python/tests/conftest.py @@ -3,7 +3,7 @@ import pytest import requests -from armonik.client import ArmoniKResults, ArmoniKSubmitter, ArmoniKTasks, ArmoniKSessions, ArmoniKPartitions, ArmoniKVersions, ArmoniKEvents +from armonik.client import ArmoniKResults, ArmoniKSubmitter, ArmoniKTasks, ArmoniKSessions, ArmoniKPartitions, ArmoniKVersions, ArmoniKEvents, ArmoniKHealthChecks from armonik.protogen.worker.agent_service_pb2_grpc import AgentStub from typing import List @@ -44,6 +44,7 @@ def clean_up(request): # Remove the temporary files created for testing os.remove(os.path.join(data_folder, "payload-id")) os.remove(os.path.join(data_folder, "dd-id")) + os.remove(os.path.join(data_folder, "result-id")) # Reset the mock server counters try: @@ -54,7 +55,7 @@ def clean_up(request): print("An error occurred when resetting the server: " + str(e)) -def get_client(client_name: str, endpoint: str = grpc_endpoint) -> [ArmoniKResults, ArmoniKSubmitter, ArmoniKTasks, ArmoniKSessions, ArmoniKPartitions, ArmoniKVersions, AgentStub, ArmoniKEvents]: +def get_client(client_name: str, endpoint: str = grpc_endpoint) -> [ArmoniKResults, ArmoniKSubmitter, ArmoniKTasks, ArmoniKSessions, ArmoniKPartitions, ArmoniKVersions, AgentStub, ArmoniKEvents, ArmoniKHealthChecks]: """ Get the ArmoniK client instance based on the specified service name. @@ -91,6 +92,8 @@ def get_client(client_name: str, endpoint: str = grpc_endpoint) -> [ArmoniKResul return AgentStub(channel) case "Events": return ArmoniKEvents(channel) + case "HealthChecks": + return ArmoniKHealthChecks(channel) case _: raise ValueError("Unknown service name: " + str(service_name)) diff --git a/packages/python/tests/test_healthcheck.py b/packages/python/tests/test_healthcheck.py new file mode 100644 index 000000000..8062aaab6 --- /dev/null +++ b/packages/python/tests/test_healthcheck.py @@ -0,0 +1,18 @@ +import datetime + +from .conftest import all_rpc_called, rpc_called, get_client +from armonik.client import ArmoniKHealthChecks +from armonik.common import ServiceHealthCheckStatus + + +class TestArmoniKHealthChecks: + + def test_check_health(self): + health_checks_client: ArmoniKHealthChecks = get_client("HealthChecks") + services_health = health_checks_client.check_health() + + assert rpc_called("HealthChecks", "CheckHealth") + assert services_health == {'mock': {'message': 'Mock is healthy', 'status': ServiceHealthCheckStatus.HEALTHY}} + + def test_service_fully_implemented(self): + assert all_rpc_called("HealthChecks") From b900412f913f74a553a3e61143a8accd02b25e25 Mon Sep 17 00:00:00 2001 From: qdelamea Date: Wed, 27 Dec 2023 11:47:18 +0100 Subject: [PATCH 11/15] Correct CI for Python API testing --- .github/workflows/ci.yml | 10 ++++++++++ packages/python/pyproject.toml | 3 ++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b91f05761..350688f02 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -231,6 +231,16 @@ jobs: - name: Install dependencies run: pip install "$(echo pkg/armonik*.whl)[tests]" + - name: Install .NET Core + uses: actions/setup-dotnet@3447fd6a9f9e57506b15f895c5b76d3b197dc7c2 # v3 + with: + dotnet-version: 6.x + + - name: Start Mock server + run: | + cd ../csharp/ArmoniK.Api.Mock + dotnet run + - name: Run tests run: python -m pytest tests --cov=armonik --cov-config=.coveragerc --cov-report=term-missing --cov-report xml:coverage.xml --cov-report html:coverage_report diff --git a/packages/python/pyproject.toml b/packages/python/pyproject.toml index 42a8a65d8..e646ecf39 100644 --- a/packages/python/pyproject.toml +++ b/packages/python/pyproject.toml @@ -41,9 +41,10 @@ tests = [ 'pytest', 'pytest-cov', 'pytest-benchmark[histogram]', + 'requests' ] [tool.pytest.ini_options] addopts = [ "--import-mode=importlib", -] \ No newline at end of file +] From 67db9cda771b7edfbdb7341807dcec29bd76f6b3 Mon Sep 17 00:00:00 2001 From: qdelamea Date: Wed, 27 Dec 2023 11:58:10 +0100 Subject: [PATCH 12/15] Fix ci Python test server start --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 350688f02..fe5174b1d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -239,7 +239,7 @@ jobs: - name: Start Mock server run: | cd ../csharp/ArmoniK.Api.Mock - dotnet run + nohup dotnet run > /dev/null 2>&1 & - name: Run tests run: python -m pytest tests --cov=armonik --cov-config=.coveragerc --cov-report=term-missing --cov-report xml:coverage.xml --cov-report html:coverage_report From 5eee0b278f47ebf34997a462c3065c1799feae9f Mon Sep 17 00:00:00 2001 From: qdelamea Date: Wed, 27 Dec 2023 11:59:08 +0100 Subject: [PATCH 13/15] Python API delete old test --- packages/python/tests/submitter_test.py | 313 ------------------------ 1 file changed, 313 deletions(-) delete mode 100644 packages/python/tests/submitter_test.py diff --git a/packages/python/tests/submitter_test.py b/packages/python/tests/submitter_test.py deleted file mode 100644 index c849efdc0..000000000 --- a/packages/python/tests/submitter_test.py +++ /dev/null @@ -1,313 +0,0 @@ -#!/usr/bin/env python3 -import datetime -import logging -import pytest -from armonik.client import ArmoniKSubmitter -from typing import Iterator, Optional, List -from .common import DummyChannel -from armonik.common import TaskOptions, TaskDefinition, TaskStatus, timedelta_to_duration -from armonik.protogen.client.submitter_service_pb2_grpc import SubmitterStub -from armonik.protogen.common.objects_pb2 import Empty, Configuration, Session, TaskIdList, ResultRequest, TaskError, Error, \ - Count, StatusCount, DataChunk -from armonik.protogen.common.submitter_common_pb2 import CreateSessionRequest, CreateSessionReply, CreateLargeTaskRequest, \ - CreateTaskReply, TaskFilter, ResultReply, AvailabilityReply, WaitRequest, GetTaskStatusRequest, GetTaskStatusReply - -logging.basicConfig() -logging.getLogger().setLevel(logging.INFO) - - -class DummySubmitter(SubmitterStub): - def __init__(self, channel: DummyChannel, max_chunk_size=300): - channel.set_instance(self) - super().__init__(channel) - self.max_chunk_size = max_chunk_size - self.large_tasks_requests: List[CreateLargeTaskRequest] = [] - self.task_filter: Optional[TaskFilter] = None - self.create_session: Optional[CreateSessionRequest] = None - self.session: Optional[Session] = None - self.result_stream: List[ResultReply] = [] - self.result_request: Optional[ResultRequest] = None - self.is_available = True - self.wait_request: Optional[WaitRequest] = None - self.get_status_request: Optional[GetTaskStatusRequest] = None - - def GetServiceConfiguration(self, _: Empty) -> Configuration: - return Configuration(data_chunk_max_size=self.max_chunk_size) - - def CreateSession(self, request: CreateSessionRequest) -> CreateSessionReply: - self.create_session = request - return CreateSessionReply(session_id="SessionId") - - def CancelSession(self, request: Session) -> Empty: - self.session = request - return Empty() - - def CreateLargeTasks(self, request: Iterator[CreateLargeTaskRequest]) -> CreateTaskReply: - self.large_tasks_requests = [r for r in request] - return CreateTaskReply(creation_status_list=CreateTaskReply.CreationStatusList(creation_statuses=[ - CreateTaskReply.CreationStatus( - task_info=CreateTaskReply.TaskInfo(task_id="TaskId", expected_output_keys=["EOK"], - data_dependencies=["DD"])), - CreateTaskReply.CreationStatus(error="TestError")])) - - def ListTasks(self, request: TaskFilter) -> TaskIdList: - self.task_filter = request - return TaskIdList(task_ids=["TaskId"]) - - def TryGetResultStream(self, request: ResultRequest) -> Iterator[ResultReply]: - self.result_request = request - for r in self.result_stream: - yield r - - def WaitForAvailability(self, request: ResultRequest) -> AvailabilityReply: - from armonik.protogen.common.task_status_pb2 import TASK_STATUS_ERROR - self.result_request = request - return AvailabilityReply(ok=Empty()) if self.is_available else AvailabilityReply( - error=TaskError(task_id="TaskId", errors=[Error(task_status=TASK_STATUS_ERROR, detail="TestError")])) - - def WaitForCompletion(self, request: WaitRequest) -> Count: - from armonik.protogen.common.task_status_pb2 import TASK_STATUS_COMPLETED - self.wait_request = request - return Count(values=[StatusCount(status=TASK_STATUS_COMPLETED, count=1)]) - - def GetTaskStatus(self, request: GetTaskStatusRequest) -> GetTaskStatusReply: - from armonik.protogen.common.task_status_pb2 import TASK_STATUS_COMPLETED - self.get_status_request = request - return GetTaskStatusReply( - id_statuses=[GetTaskStatusReply.IdStatus(task_id="TaskId", status=TASK_STATUS_COMPLETED)]) - - -default_task_option = TaskOptions(datetime.timedelta(seconds=300), priority=1, max_retries=5) - - -@pytest.mark.parametrize("task_options,partitions", [(default_task_option, None), (default_task_option, ["default"])]) -def test_armonik_submitter_should_create_session(task_options, partitions): - channel = DummyChannel() - inner = DummySubmitter(channel) - submitter = ArmoniKSubmitter(channel) - session_id = submitter.create_session(default_task_options=task_options, partition_ids=partitions) - assert session_id == "SessionId" - assert inner.create_session - assert inner.create_session.default_task_option.priority == task_options.priority - assert len(inner.create_session.partition_ids) == 0 if partitions is None else list(inner.create_session.partition_ids) == partitions - assert len(inner.create_session.default_task_option.options) == len(task_options.options) - assert inner.create_session.default_task_option.max_duration == timedelta_to_duration(task_options.max_duration) - assert inner.create_session.default_task_option.max_retries == task_options.max_retries - - -def test_armonik_submitter_should_cancel_session(): - channel = DummyChannel() - inner = DummySubmitter(channel) - submitter = ArmoniKSubmitter(channel) - submitter.cancel_session("SessionId") - assert inner.session is not None - assert inner.session.id == "SessionId" - - -def test_armonik_submitter_should_get_config(): - channel = DummyChannel() - inner = DummySubmitter(channel) - submitter = ArmoniKSubmitter(channel) - config = submitter.get_service_configuration() - assert config is not None - assert config.data_chunk_max_size == 300 - - -should_submit = [ - [TaskDefinition("Payload1".encode('utf-8'), expected_output_ids=["EOK"], data_dependencies=["DD"]), - TaskDefinition("Payload2".encode('utf-8'), expected_output_ids=["EOK"], data_dependencies=["DD"])], - [TaskDefinition("Payload1".encode('utf-8'), expected_output_ids=["EOK"]), - TaskDefinition("Payload2".encode('utf-8'), expected_output_ids=["EOK"])], - [TaskDefinition("".encode('utf-8'), expected_output_ids=["EOK"]), - TaskDefinition("".encode('utf-8'), expected_output_ids=["EOK"])] -] - - -@pytest.mark.parametrize("task_list,task_options", - [(t, default_task_option if i else None) for t in should_submit for i in [True, False]]) -def test_armonik_submitter_should_submit(task_list, task_options): - channel = DummyChannel() - inner = DummySubmitter(channel, max_chunk_size=5) - submitter = ArmoniKSubmitter(channel) - successes, errors = submitter.submit("SessionId", tasks=task_list, task_options=task_options) - # The dummy submitter has been set to submit one successful task and one submission error - assert len(successes) == 1 - assert len(errors) == 1 - assert successes[0].id == "TaskId" - assert successes[0].session_id == "SessionId" - assert errors[0] == "TestError" - - reqs = inner.large_tasks_requests - assert len(reqs) > 0 - offset = 0 - assert reqs[0 + offset].WhichOneof("type") == "init_request" - assert reqs[0 + offset].init_request.session_id == "SessionId" - assert reqs[1 + offset].WhichOneof("type") == "init_task" - assert reqs[1 + offset].init_task.header.expected_output_keys[0] == "EOK" - assert reqs[1 + offset].init_task.header.data_dependencies[0] == "DD" if len( - task_list[0].data_dependencies) > 0 else len(reqs[1 + offset].init_task.header.data_dependencies) == 0 - assert reqs[2 + offset].WhichOneof("type") == "task_payload" - assert reqs[2 + offset].task_payload.data == "".encode("utf-8") if len(task_list[0].payload) == 0 \ - else reqs[2 + offset].task_payload.data == task_list[0].payload[:5] - if len(task_list[0].payload) > 0: - offset += 1 - assert reqs[2 + offset].WhichOneof("type") == "task_payload" - assert reqs[2 + offset].task_payload.data == task_list[0].payload[5:] - assert reqs[3 + offset].WhichOneof("type") == "task_payload" - assert reqs[3 + offset].task_payload.data_complete - assert reqs[4 + offset].WhichOneof("type") == "init_task" - assert reqs[4 + offset].init_task.header.expected_output_keys[0] == "EOK" - assert reqs[4 + offset].init_task.header.data_dependencies[0] == "DD" if len( - task_list[1].data_dependencies) > 0 else len(reqs[4 + offset].init_task.header.data_dependencies) == 0 - assert reqs[5 + offset].WhichOneof("type") == "task_payload" - assert reqs[5 + offset].task_payload.data == "".encode("utf-8") if len(task_list[1].payload) == 0 \ - else reqs[5 + offset].task_payload.data == task_list[1].payload[:5] - if len(task_list[1].payload) > 0: - offset += 1 - assert reqs[5 + offset].WhichOneof("type") == "task_payload" - assert reqs[5 + offset].task_payload.data == task_list[1].payload[5:] - assert reqs[6 + offset].WhichOneof("type") == "task_payload" - assert reqs[6 + offset].task_payload.data_complete - assert reqs[7 + offset].WhichOneof("type") == "init_task" - assert reqs[7 + offset].init_task.last_task - - -filters_params = [(session_ids, task_ids, included_statuses, excluded_statuses, - (session_ids is None or task_ids is None) and ( - included_statuses is None or excluded_statuses is None)) - for session_ids in [["SessionId"], None] - for task_ids in [["TaskId"], None] - for included_statuses in [[TaskStatus.COMPLETED], None] - for excluded_statuses in [[TaskStatus.COMPLETED], None]] - - -@pytest.mark.parametrize("session_ids,task_ids,included_statuses,excluded_statuses,should_succeed", filters_params) -def test_armonik_submitter_should_list_tasks(session_ids, task_ids, included_statuses, excluded_statuses, - should_succeed): - channel = DummyChannel() - inner = DummySubmitter(channel) - submitter = ArmoniKSubmitter(channel) - if should_succeed: - tasks = submitter.list_tasks(session_ids=session_ids, task_ids=task_ids, included_statuses=included_statuses, - excluded_statuses=excluded_statuses) - assert len(tasks) > 0 - assert tasks[0] == "TaskId" - assert inner.task_filter is not None - assert all(map(lambda x: x[1] == session_ids[x[0]], enumerate(inner.task_filter.session.ids))) - assert all(map(lambda x: x[1] == task_ids[x[0]], enumerate(inner.task_filter.task.ids))) - assert all(map(lambda x: x[1] == included_statuses[x[0]], enumerate(inner.task_filter.included.statuses))) - assert all(map(lambda x: x[1] == excluded_statuses[x[0]], enumerate(inner.task_filter.excluded.statuses))) - else: - with pytest.raises(ValueError): - _ = submitter.list_tasks(session_ids=session_ids, task_ids=task_ids, included_statuses=included_statuses, - excluded_statuses=excluded_statuses) - - -def test_armonik_submitter_should_get_status(): - channel = DummyChannel() - inner = DummySubmitter(channel) - submitter = ArmoniKSubmitter(channel) - - statuses = submitter.get_task_status(["TaskId"]) - assert len(statuses) > 0 - assert "TaskId" in statuses - assert statuses["TaskId"] == TaskStatus.COMPLETED - assert inner.get_status_request is not None - assert len(inner.get_status_request.task_ids) == 1 - assert inner.get_status_request.task_ids[0] == "TaskId" - - -get_result_should_throw = [ - [], - [ResultReply(result=DataChunk(data="payload".encode("utf-8")))], - [ResultReply(result=DataChunk(data="payload".encode("utf-8"))), ResultReply(result=DataChunk(data_complete=True)), - ResultReply(result=DataChunk(data="payload".encode("utf-8")))], - [ResultReply( - error=TaskError(task_id="TaskId", errors=[Error(task_status=TaskStatus.ERROR, detail="TestError")]))], -] - -get_result_should_succeed = [ - [ResultReply(result=DataChunk(data="payload".encode("utf-8"))), ResultReply(result=DataChunk(data_complete=True))] -] - -get_result_should_none = [ - [ResultReply(not_completed_task="NotCompleted")] -] - - -@pytest.mark.parametrize("stream", [iter(x) for x in get_result_should_succeed]) -def test_armonik_submitter_should_get_result(stream): - channel = DummyChannel() - inner = DummySubmitter(channel) - inner.result_stream = stream - submitter = ArmoniKSubmitter(channel) - result = submitter.get_result("SessionId", "ResultId") - assert result is not None - assert len(result) > 0 - assert inner.result_request - assert inner.result_request.result_id == "ResultId" - assert inner.result_request.session == "SessionId" - - -@pytest.mark.parametrize("stream", [iter(x) for x in get_result_should_throw]) -def test_armonik_submitter_get_result_should_throw(stream): - channel = DummyChannel() - inner = DummySubmitter(channel) - inner.result_stream = stream - submitter = ArmoniKSubmitter(channel) - with pytest.raises(Exception): - _ = submitter.get_result("SessionId", "ResultId") - - -@pytest.mark.parametrize("stream", [iter(x) for x in get_result_should_none]) -def test_armonik_submitter_get_result_should_none(stream): - channel = DummyChannel() - inner = DummySubmitter(channel) - inner.result_stream = stream - submitter = ArmoniKSubmitter(channel) - result = submitter.get_result("SessionId", "ResultId") - assert result is None - assert inner.result_request - assert inner.result_request.result_id == "ResultId" - assert inner.result_request.session == "SessionId" - - -@pytest.mark.parametrize("available", [True, False]) -def test_armonik_submitter_wait_availability(available): - channel = DummyChannel() - inner = DummySubmitter(channel) - inner.is_available = available - submitter = ArmoniKSubmitter(channel) - reply = submitter.wait_for_availability("SessionId", "ResultId") - assert reply is not None - assert reply.is_available() == available - assert len(reply.errors) == 0 if available else reply.errors[0] == "TestError" - - -@pytest.mark.parametrize("session_ids,task_ids,included_statuses,excluded_statuses,should_succeed", filters_params) -def test_armonik_submitter_wait_completion(session_ids, task_ids, included_statuses, excluded_statuses, should_succeed): - channel = DummyChannel() - inner = DummySubmitter(channel) - submitter = ArmoniKSubmitter(channel) - - if should_succeed: - counts = submitter.wait_for_completion(session_ids=session_ids, task_ids=task_ids, - included_statuses=included_statuses, - excluded_statuses=excluded_statuses) - assert len(counts) > 0 - assert TaskStatus.COMPLETED in counts - assert counts[TaskStatus.COMPLETED] == 1 - assert inner.wait_request is not None - assert all(map(lambda x: x[1] == session_ids[x[0]], enumerate(inner.wait_request.filter.session.ids))) - assert all(map(lambda x: x[1] == task_ids[x[0]], enumerate(inner.wait_request.filter.task.ids))) - assert all(map(lambda x: x[1] == included_statuses[x[0]], - enumerate(inner.wait_request.filter.included.statuses))) - assert all(map(lambda x: x[1] == excluded_statuses[x[0]], - enumerate(inner.wait_request.filter.excluded.statuses))) - assert not inner.wait_request.stop_on_first_task_error - assert not inner.wait_request.stop_on_first_task_cancellation - else: - with pytest.raises(ValueError): - _ = submitter.wait_for_completion(session_ids=session_ids, task_ids=task_ids, - included_statuses=included_statuses, - excluded_statuses=excluded_statuses) From cbea27219cdb00ed75aef384a6e8deee87887f32 Mon Sep 17 00:00:00 2001 From: qdelamea Date: Wed, 27 Dec 2023 12:27:22 +0100 Subject: [PATCH 14/15] Fix Python test CI to wait for server start up --- .github/workflows/ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fe5174b1d..8a8889353 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -240,6 +240,7 @@ jobs: run: | cd ../csharp/ArmoniK.Api.Mock nohup dotnet run > /dev/null 2>&1 & + sleep 10 - name: Run tests run: python -m pytest tests --cov=armonik --cov-config=.coveragerc --cov-report=term-missing --cov-report xml:coverage.xml --cov-report html:coverage_report From b0f427096bc233648a8bc53fb1f94e2920d1f30c Mon Sep 17 00:00:00 2001 From: qdelamea Date: Wed, 3 Jan 2024 09:59:11 +0100 Subject: [PATCH 15/15] Extend waiting time for mock server to start --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8a8889353..19579998f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -240,7 +240,7 @@ jobs: run: | cd ../csharp/ArmoniK.Api.Mock nohup dotnet run > /dev/null 2>&1 & - sleep 10 + sleep 60 - name: Run tests run: python -m pytest tests --cov=armonik --cov-config=.coveragerc --cov-report=term-missing --cov-report xml:coverage.xml --cov-report html:coverage_report