diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b91f05761..19579998f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -231,6 +231,17 @@ jobs: - name: Install dependencies run: pip install "$(echo pkg/armonik*.whl)[tests]" + - name: Install .NET Core + uses: actions/setup-dotnet@3447fd6a9f9e57506b15f895c5b76d3b197dc7c2 # v3 + with: + dotnet-version: 6.x + + - name: Start Mock server + run: | + cd ../csharp/ArmoniK.Api.Mock + nohup dotnet run > /dev/null 2>&1 & + sleep 60 + - name: Run tests run: python -m pytest tests --cov=armonik --cov-config=.coveragerc --cov-report=term-missing --cov-report xml:coverage.xml --cov-report html:coverage_report diff --git a/packages/python/.gitignore b/packages/python/.gitignore index 53df0df6c..6fe9f8e9f 100644 --- a/packages/python/.gitignore +++ b/packages/python/.gitignore @@ -4,3 +4,4 @@ build/ *.egg-info **/_version.py **/.pytest_cache +.ruff_cache diff --git a/packages/python/proto2python.sh b/packages/python/proto2python.sh index c2fe45244..6fe79c0e7 100755 --- a/packages/python/proto2python.sh +++ b/packages/python/proto2python.sh @@ -33,7 +33,7 @@ python -m pip install --upgrade pip python -m venv $PYTHON_VENV source $PYTHON_VENV/bin/activate # We need to fix grpc to 1.56 until this bug is solved : https://github.com/grpc/grpc/issues/34305 -python -m pip install build grpcio==1.56.2 grpcio-tools==1.56.2 click pytest setuptools_scm[toml] +python -m pip install build grpcio==1.56.2 grpcio-tools==1.56.2 click pytest setuptools_scm[toml] ruff requests unset proto_files for proto in ${armonik_worker_files[@]}; do diff --git a/packages/python/pyproject.toml b/packages/python/pyproject.toml index 42a8a65d8..e646ecf39 100644 --- a/packages/python/pyproject.toml +++ b/packages/python/pyproject.toml @@ -41,9 +41,10 @@ tests = [ 'pytest', 'pytest-cov', 'pytest-benchmark[histogram]', + 'requests' ] [tool.pytest.ini_options] addopts = [ "--import-mode=importlib", -] \ No newline at end of file +] diff --git a/packages/python/src/armonik/client/__init__.py b/packages/python/src/armonik/client/__init__.py index e94d7dde9..6510b4663 100644 --- a/packages/python/src/armonik/client/__init__.py +++ b/packages/python/src/armonik/client/__init__.py @@ -1,3 +1,19 @@ +from .partitions import ArmoniKPartitions +from .sessions import ArmoniKSessions from .submitter import ArmoniKSubmitter from .tasks import ArmoniKTasks -from .results import ArmoniKResult +from .results import ArmoniKResults +from .versions import ArmoniKVersions +from .events import ArmoniKEvents +from .health_check import ArmoniKHealthChecks + +__all__ = [ + 'ArmoniKPartitions', + 'ArmoniKSessions', + 'ArmoniKSubmitter', + 'ArmoniKTasks', + 'ArmoniKResults', + "ArmoniKVersions", + "ArmoniKEvents", + "ArmoniKHealthChecks" +] diff --git a/packages/python/src/armonik/client/events.py b/packages/python/src/armonik/client/events.py new file mode 100644 index 000000000..88695c31f --- /dev/null +++ b/packages/python/src/armonik/client/events.py @@ -0,0 +1,86 @@ +from typing import Any, Callable, cast, List + +from grpc import Channel + +from .results import ArmoniKResults +from ..common import EventTypes, Filter, NewTaskEvent, NewResultEvent, ResultOwnerUpdateEvent, ResultStatusUpdateEvent, TaskStatusUpdateEvent, ResultStatus, Event +from .results import ResultFieldFilter +from ..protogen.client.events_service_pb2_grpc import EventsStub +from ..protogen.common.events_common_pb2 import EventSubscriptionRequest, EventSubscriptionResponse +from ..protogen.common.results_filters_pb2 import Filters as rawResultFilters +from ..protogen.common.tasks_filters_pb2 import Filters as rawTaskFilters + +class ArmoniKEvents: + + _events_obj_mapping = { + "new_result": NewResultEvent, + "new_task": NewTaskEvent, + "result_owner_update": ResultOwnerUpdateEvent, + "result_status_update": ResultStatusUpdateEvent, + "task_status_update": TaskStatusUpdateEvent + } + + def __init__(self, grpc_channel: Channel): + """Events service client + + Args: + grpc_channel: gRPC channel to use + """ + self._client = EventsStub(grpc_channel) + self._results_client = ArmoniKResults(grpc_channel) + + def get_events(self, session_id: str, event_types: List[EventTypes], event_handlers: List[Callable[[str, EventTypes, Event], bool]], task_filter: Filter | None = None, result_filter: Filter | None = None) -> None: + """Get events that represents updates of result and tasks data. + + Args: + session_id: The ID of the session. + event_types: The list of the types of event to catch. + event_handlers: The list of handlers that process the events. Handlers are evaluated in he order they are provided. + An handler takes three positional arguments: the ID of the session, the type of event and the event as an object. + An handler returns a boolean, if True the process continues, otherwise the stream is closed and the service stops + listening to new events. + task_filter: A filter on tasks. + result_filter: A filter on results. + + """ + request = EventSubscriptionRequest( + session_id=session_id, + returned_events=event_types + ) + if task_filter: + request.tasks_filters=cast(rawTaskFilters, task_filter.to_disjunction().to_message()), + if result_filter: + request.results_filters=cast(rawResultFilters, result_filter.to_disjunction().to_message()), + + streaming_call = self._client.GetEvents(request) + for message in streaming_call: + event_type = message.WhichOneof("update") + if any([event_handler(session_id, EventTypes.from_string(event_type), self._events_obj_mapping[event_type].from_raw_event(getattr(message, event_type))) for event_handler in event_handlers]): + break + + def wait_for_result_availability(self, result_id: str, session_id: str) -> None: + """Wait until a result is ready i.e its status updates to COMPLETED. + + Args: + result_id: The ID of the result. + session_id: The ID of the session. + + Raises: + RuntimeError: If the result status is ABORTED. + """ + def handler(session_id, event_type, event): + if not isinstance(event, ResultStatusUpdateEvent): + raise ValueError("Handler should receive event of type 'ResultStatusUpdateEvent'.") + if event.status == ResultStatus.COMPLETED: + return False + elif event.status == ResultStatus.ABORTED: + raise RuntimeError(f"Result {result.name} with ID {result_id} is aborted.") + return True + + result = self._results_client.get_result(result_id) + if result.status == ResultStatus.COMPLETED: + return + elif result.status == ResultStatus.ABORTED: + raise RuntimeError(f"Result {result.name} with ID {result_id} is aborted.") + + self.get_events(session_id, [EventTypes.RESULT_STATUS_UPDATE], [handler], result_filter=(ResultFieldFilter.RESULT_ID == result_id)) diff --git a/packages/python/src/armonik/client/health_check.py b/packages/python/src/armonik/client/health_check.py new file mode 100644 index 000000000..76c04e251 --- /dev/null +++ b/packages/python/src/armonik/client/health_check.py @@ -0,0 +1,21 @@ +from typing import cast, List, Tuple + +from grpc import Channel + +from ..common import HealthCheckStatus +from ..protogen.client.health_checks_service_pb2_grpc import HealthChecksServiceStub +from ..protogen.common.health_checks_common_pb2 import CheckHealthRequest, CheckHealthResponse + + +class ArmoniKHealthChecks: + def __init__(self, grpc_channel: Channel): + """ Result service client + + Args: + grpc_channel: gRPC channel to use + """ + self._client = HealthChecksServiceStub(grpc_channel) + + def check_health(self): + response: CheckHealthResponse = self._client.CheckHealth(CheckHealthRequest()) + return {service.name: {"message": service.message, "status": service.healthy} for service in response.services} diff --git a/packages/python/src/armonik/client/partitions.py b/packages/python/src/armonik/client/partitions.py new file mode 100644 index 000000000..e0b0b0ada --- /dev/null +++ b/packages/python/src/armonik/client/partitions.py @@ -0,0 +1,66 @@ +from typing import cast, List, Tuple + +from grpc import Channel + +from ..common import Direction, Partition +from ..common.filter import Filter, NumberFilter +from ..protogen.client.partitions_service_pb2_grpc import PartitionsStub +from ..protogen.common.partitions_common_pb2 import ListPartitionsRequest, ListPartitionsResponse, GetPartitionRequest, GetPartitionResponse +from ..protogen.common.partitions_fields_pb2 import PartitionField, PartitionRawField, PARTITION_RAW_ENUM_FIELD_PRIORITY +from ..protogen.common.partitions_filters_pb2 import Filters as rawFilters, FiltersAnd as rawFiltersAnd, FilterField as rawFilterField +from ..protogen.common.sort_direction_pb2 import SortDirection + + +class PartitionFieldFilter: + PRIORITY = NumberFilter( + PartitionField(partition_raw_field=PartitionRawField(field=PARTITION_RAW_ENUM_FIELD_PRIORITY)), + rawFilters, + rawFiltersAnd, + rawFilterField + ) + + +class ArmoniKPartitions: + def __init__(self, grpc_channel: Channel): + """ Result service client + + Args: + grpc_channel: gRPC channel to use + """ + self._client = PartitionsStub(grpc_channel) + + def list_partitions(self, partition_filter: Filter | None = None, page: int = 0, page_size: int = 1000, sort_field: Filter = PartitionFieldFilter.PRIORITY, sort_direction: SortDirection = Direction.ASC) -> Tuple[int, List[Partition]]: + """List partitions based on a filter. + + Args: + partition_filter: Filter to apply when listing partitions + page: page number to request, useful for pagination, defaults to 0 + page_size: size of a page, defaults to 1000 + sort_field: field to sort the resulting list by, defaults to the status + sort_direction: direction of the sort, defaults to ascending + + Returns: + A tuple containing : + - The total number of results for the given filter + - The obtained list of results + """ + request = ListPartitionsRequest( + page=page, + page_size=page_size, + sort=ListPartitionsRequest.Sort(field=cast(PartitionField, sort_field.field), direction=sort_direction), + ) + if partition_filter: + request.filters = cast(rawFilters, partition_filter.to_disjunction().to_message()), + response: ListPartitionsResponse = self._client.ListPartitions(request) + return response.total, [Partition.from_message(p) for p in response.partitions] + + def get_partition(self, partition_id: str) -> Partition: + """Get a partition by its ID. + + Args: + partition_id: The partition ID. + + Return: + The partition summary. + """ + return Partition.from_message(self._client.GetPartition(GetPartitionRequest(id=partition_id)).partition) diff --git a/packages/python/src/armonik/client/results.py b/packages/python/src/armonik/client/results.py index 942add9c1..026b77095 100644 --- a/packages/python/src/armonik/client/results.py +++ b/packages/python/src/armonik/client/results.py @@ -1,21 +1,26 @@ from __future__ import annotations from grpc import Channel +from deprecation import deprecated from typing import List, Dict, cast, Tuple from ..protogen.client.results_service_pb2_grpc import ResultsStub -from ..protogen.common.results_common_pb2 import CreateResultsMetaDataRequest, CreateResultsMetaDataResponse, ListResultsRequest, ListResultsResponse +from ..protogen.common.results_common_pb2 import CreateResultsMetaDataRequest, CreateResultsMetaDataResponse, ListResultsRequest, ListResultsResponse, GetOwnerTaskIdRequest, GetOwnerTaskIdResponse, CreateResultsMetaDataRequest, CreateResultsMetaDataResponse, CreateResultsRequest, CreateResultsResponse, ResultsServiceConfigurationResponse, DeleteResultsDataRequest, DeleteResultsDataResponse, UploadResultDataRequest, UploadResultDataResponse, DownloadResultDataRequest, DownloadResultDataResponse, GetResultRequest, GetResultResponse from ..protogen.common.results_filters_pb2 import Filters as rawFilters, FiltersAnd as rawFilterAnd, FilterField as rawFilterField, FilterStatus as rawFilterStatus from ..protogen.common.results_fields_pb2 import ResultField +from ..protogen.common.objects_pb2 import Empty from ..common.filter import StringFilter, StatusFilter, DateFilter, NumberFilter, Filter from ..protogen.common.sort_direction_pb2 import SortDirection from ..common import Direction , Result -from ..protogen.common.results_fields_pb2 import ResultField, ResultRawField, ResultRawEnumField, RESULT_RAW_ENUM_FIELD_STATUS +from ..protogen.common.results_fields_pb2 import ResultField, ResultRawField, ResultRawEnumField, RESULT_RAW_ENUM_FIELD_STATUS, RESULT_RAW_ENUM_FIELD_RESULT_ID +from ..common.helpers import batched + class ResultFieldFilter: STATUS = StatusFilter(ResultField(result_raw_field=ResultRawField(field=RESULT_RAW_ENUM_FIELD_STATUS)), rawFilters, rawFilterAnd, rawFilterField, rawFilterStatus) + RESULT_ID = StringFilter(ResultField(result_raw_field=ResultRawField(field=RESULT_RAW_ENUM_FIELD_RESULT_ID)), rawFilters, rawFilterAnd, rawFilterField) -class ArmoniKResult: +class ArmoniKResults: def __init__(self, grpc_channel: Channel): """ Result service client @@ -24,10 +29,11 @@ def __init__(self, grpc_channel: Channel): """ self._client = ResultsStub(grpc_channel) + @deprecated(deprecated_in="3.15.0", details="Use create_result_metadata or create_result insted.") def get_results_ids(self, session_id: str, names: List[str]) -> Dict[str, str]: return {r.name : r.result_id for r in cast(CreateResultsMetaDataResponse, self._client.CreateResultsMetaData(CreateResultsMetaDataRequest(results=[CreateResultsMetaDataRequest.ResultCreate(name = n) for n in names], session_id=session_id))).results} - def list_results(self, result_filter: Filter, page: int = 0, page_size: int = 1000, sort_field: Filter = ResultFieldFilter.STATUS,sort_direction: SortDirection = Direction.ASC ) -> Tuple[int, List[Result]]: + def list_results(self, result_filter: Filter | None = None, page: int = 0, page_size: int = 1000, sort_field: Filter = ResultFieldFilter.STATUS,sort_direction: SortDirection = Direction.ASC ) -> Tuple[int, List[Result]]: """List results based on a filter. Args: @@ -44,8 +50,160 @@ def list_results(self, result_filter: Filter, page: int = 0, page_size: int = 10 request: ListResultsRequest = ListResultsRequest( page=page, page_size=page_size, - filters=cast(rawFilters, result_filter.to_disjunction().to_message()), sort=ListResultsRequest.Sort(field=cast(ResultField, sort_field.field), direction=sort_direction), ) + if result_filter: + request.filters = cast(rawFilters, result_filter.to_disjunction().to_message()), list_response: ListResultsResponse = self._client.ListResults(request) return list_response.total, [Result.from_message(r) for r in list_response.results] + + def get_result(self, result_id: str) -> Result: + """Get a result by id. + + Args: + result_id: The ID of the result. + + Return: + The result summary. + """ + request = GetResultRequest(result_id=result_id) + response: GetResultResponse = self._client.GetResult(request) + return Result.from_message(response.result) + + def get_owner_task_id(self, result_ids: List[str], session_id: str, batch_size: int = 500) -> Dict[str, str]: + """Get the IDs of the tasks that should produce the results. + + Args: + result_ids: A list of results. + session_id: The ID of the session to which the results belongs. + batch_size: Batch size for querying. + + Return: + A dictionnary mapping results to owner task ID. + """ + results = {} + for result_ids_batch in batched(result_ids, batch_size): + request = GetOwnerTaskIdRequest(session_id=session_id, result_id=result_ids_batch) + response: GetOwnerTaskIdResponse = self._client.GetOwnerTaskId(request) + for result_task in response.result_task: + results[result_task.result_id] = result_task.task_id + return results + + def create_results_metadata(self, result_names: List[str], session_id: str, batch_size: int = 100) -> Dict[str, Result]: + """Create the metadata of multiple results at once. + Data have to be uploaded separately. + + Args: + result_names: The list of the names of the results to create. + session_id: The ID of the session to which the results belongs. + batch_size: Batch size for querying. + + Return: + A dictionnary mapping each result name to its corresponding result summary. + """ + results = {} + for result_names_batch in batched(result_names, batch_size): + request = CreateResultsMetaDataRequest( + results=[CreateResultsMetaDataRequest.ResultCreate(name=result_name) for result_name in result_names_batch], + session_id=session_id + ) + response: CreateResultsMetaDataResponse = self._client.CreateResultsMetaData(request) + for result_message in response.results: + results[result_message.name] = Result.from_message(result_message) + return results + + def create_results(self, results_data: Dict[str, bytes], session_id: str, batch_size: int = 1) -> Dict[str, Result]: + """Create one result with data included in the request. + + Args: + results_data: A dictionnary mapping the result names to their actual data. + session_id: The ID of the session to which the results belongs. + batch_size: Batch size for querying. + + Return: + A dictionnary mappin each result name to its corresponding result summary. + """ + results = {} + for results_names_batch in batched(results_data.keys(), batch_size): + request = CreateResultsRequest( + results=[CreateResultsRequest.ResultCreate(name=name, data=results_data[name]) for name in results_names_batch], + session_id=session_id + ) + response: CreateResultsResponse = self._client.CreateResults(request) + for message in response.results: + results[message.name] = Result.from_message(message) + return results + + def upload_result_data(self, result_id: str, session_id: str, result_data: bytes | bytearray) -> None: + """Upload data for an empty result already created. + + Args: + result_id: The ID of the result. + result_data: The result data. + session_id: The ID of the session. + """ + data_chunk_max_size = self.get_service_config() + + def upload_result_stream(): + request = UploadResultDataRequest( + id=UploadResultDataRequest.ResultIdentifier( + session_id=session_id, result_id=result_id + ) + ) + yield request + + start = 0 + data_len = len(result_data) + while start < data_len: + chunk_size = min(data_chunk_max_size, data_len - start) + request = UploadResultDataRequest( + data_chunk=result_data[start : start + chunk_size] + ) + yield request + start += chunk_size + + self._client.UploadResultData(upload_result_stream()) + + def download_result_data(self, result_id: str, session_id: str) -> bytes: + """Retrieve data of a result. + + Args: + result_id: The ID of the result. + session_id: The session of the result. + + Return: + Result data. + """ + request = DownloadResultDataRequest( + result_id=result_id, + session_id=session_id + ) + streaming_call = self._client.DownloadResultData(request) + return b''.join([message.data_chunk for message in streaming_call]) + + def delete_result_data(self, result_ids: List[str], session_id: str, batch_size: int = 100) -> None: + """Delete data from multiple results + + Args: + result_ids: The IDs of the results which data must be deleted. + session_id: The ID of the session to which the results belongs. + batch_size: Batch size for querying. + """ + for result_ids_batch in batched(result_ids, batch_size): + request = DeleteResultsDataRequest( + result_id=result_ids_batch, + session_id=session_id + ) + self._client.DeleteResultsData(request) + + def get_service_config(self) -> int: + """Get the configuration of the service. + + Return: + Maximum size supported by a data chunk for the result service. + """ + response: ResultsServiceConfigurationResponse = self._client.GetServiceConfiguration(Empty()) + return response.data_chunk_max_size + + def watch_results(self): + raise NotImplementedError() diff --git a/packages/python/src/armonik/client/sessions.py b/packages/python/src/armonik/client/sessions.py index 8f676144d..a15fddb01 100644 --- a/packages/python/src/armonik/client/sessions.py +++ b/packages/python/src/armonik/client/sessions.py @@ -60,7 +60,20 @@ def create_session(self, default_task_options: TaskOptions, partition_ids: Optio request.partition_ids.append(partition) return self._client.CreateSession(request).session_id - def list_sessions(self, task_filter: Filter, page: int = 0, page_size: int = 1000, sort_field: Filter = SessionFieldFilter.STATUS, sort_direction: SortDirection = Direction.ASC) -> Tuple[int, List[Session]]: + def get_session(self, session_id: str): + """Get a session by its ID. + + Args: + session_id: The ID of the session. + + Return: + The session summary. + """ + request = GetSessionRequest(session_id=session_id) + response: GetSessionResponse = self._client.GetSession(request) + return Session.from_message(response.session) + + def list_sessions(self, session_filter: Filter | None = None, page: int = 0, page_size: int = 1000, sort_field: Filter = SessionFieldFilter.STATUS, sort_direction: SortDirection = Direction.ASC) -> Tuple[int, List[Session]]: """ List sessions @@ -76,14 +89,15 @@ def list_sessions(self, task_filter: Filter, page: int = 0, page_size: int = 100 - The total number of sessions for the given filter - The obtained list of sessions """ - request : ListSessionsRequest = ListSessionsRequest( + request = ListSessionsRequest( page=page, page_size=page_size, - filters=cast(rawFilters, task_filter.to_disjunction().to_message()), sort=ListSessionsRequest.Sort(field=cast(SessionField, sort_field.field), direction=sort_direction), ) - list_response : ListSessionsResponse = self._client.ListSessions(request) - return list_response.total, [Session.from_message(t) for t in list_response.sessions] + if session_filter: + request.filters = cast(rawFilters, session_filter.to_disjunction().to_message()), + response : ListSessionsResponse = self._client.ListSessions(request) + return response.total, [Session.from_message(s) for s in response.sessions] def cancel_session(self, session_id: str) -> None: """Cancel a session @@ -92,4 +106,3 @@ def cancel_session(self, session_id: str) -> None: session_id: Id of the session to b cancelled """ self._client.CancelSession(CancelSessionRequest(session_id=session_id)) - \ No newline at end of file diff --git a/packages/python/src/armonik/client/tasks.py b/packages/python/src/armonik/client/tasks.py index 18f7c3478..26879991e 100644 --- a/packages/python/src/armonik/client/tasks.py +++ b/packages/python/src/armonik/client/tasks.py @@ -1,15 +1,15 @@ from __future__ import annotations from grpc import Channel -from typing import cast, Tuple, List +from typing import cast, Dict, Optional, Tuple, List -from ..common import Task, Direction +from ..common import Task, Direction, TaskDefinition, TaskOptions, TaskStatus from ..common.filter import StringFilter, StatusFilter, DateFilter, NumberFilter, Filter, DurationFilter from ..protogen.client.tasks_service_pb2_grpc import TasksStub -from ..protogen.common.tasks_common_pb2 import GetTaskRequest, GetTaskResponse, ListTasksRequest, ListTasksDetailedResponse +from ..protogen.common.tasks_common_pb2 import GetTaskRequest, GetTaskResponse, ListTasksRequest, ListTasksDetailedResponse, CancelTasksRequest, CancelTasksResponse, GetResultIdsRequest, GetResultIdsResponse, SubmitTasksRequest, SubmitTasksResponse, CountTasksByStatusRequest, CountTasksByStatusResponse, ListTasksResponse from ..protogen.common.tasks_filters_pb2 import Filters as rawFilters, FiltersAnd as rawFilterAnd, FilterField as rawFilterField, FilterStatus as rawFilterStatus from ..protogen.common.sort_direction_pb2 import SortDirection - from ..protogen.common.tasks_fields_pb2 import * +from ..common.helpers import batched class TaskFieldFilter: @@ -77,7 +77,7 @@ def get_task(self, task_id: str) -> Task: task_response: GetTaskResponse = self._client.GetTask(GetTaskRequest(task_id=task_id)) return Task.from_message(task_response.task) - def list_tasks(self, task_filter: Filter, with_errors: bool = False, page: int = 0, page_size: int = 1000, sort_field: Filter = TaskFieldFilter.TASK_ID, sort_direction: SortDirection = Direction.ASC) -> Tuple[int, List[Task]]: + def list_tasks(self, task_filter: Filter | None = None, with_errors: bool = False, page: int = 0, page_size: int = 1000, sort_field: Filter = TaskFieldFilter.TASK_ID, sort_direction: SortDirection = Direction.ASC, detailed: bool = True) -> Tuple[int, List[Task]]: """List tasks If the total returned exceeds the requested page size, you may want to use this function again and ask for subsequent pages. @@ -89,6 +89,7 @@ def list_tasks(self, task_filter: Filter, with_errors: bool = False, page: int = page_size: size of a page, defaults to 1000 sort_field: field on which to sort the resulting list, defaults to the task_id sort_direction: direction of the sort, defaults to ascending + detailed: Wether to retrieve the detailed description of the task. Returns: A tuple containing : @@ -96,9 +97,98 @@ def list_tasks(self, task_filter: Filter, with_errors: bool = False, page: int = - The obtained list of tasks """ request = ListTasksRequest(page=page, - page_size=page_size, - filters=cast(rawFilters, task_filter.to_disjunction().to_message()), - sort=ListTasksRequest.Sort(field=cast(TaskField, sort_field.field), direction=sort_direction), - with_errors=with_errors) - list_response: ListTasksDetailedResponse = self._client.ListTasksDetailed(request) - return list_response.total, [Task.from_message(t) for t in list_response.tasks] + page_size=page_size, + sort=ListTasksRequest.Sort(field=cast(TaskField, sort_field.field), direction=sort_direction), + with_errors=with_errors + ) + if task_filter: + request.filters = cast(rawFilters, task_filter.to_disjunction().to_message()) + if detailed: + response: ListTasksDetailedResponse = self._client.ListTasksDetailed(request) + return response.total, [Task.from_message(t) for t in response.tasks] + response: ListTasksResponse = self._client.ListTasks(request) + return response.total, [Task.from_message(t) for t in response.tasks] + + def cancel_tasks(self, task_ids: List[str], chunk_size: Optional[int] = 500): + """Cancel tasks. + + Args: + task_ids: IDs of the tasks. + chunk_size: Batch size for cancelling. + + Return: + The list of cancelled tasks. + """ + for task_id_batch in batched(task_ids, chunk_size): + request = CancelTasksRequest(task_ids=task_id_batch) + self._client.CancelTasks(request) + + def get_result_ids(self, task_ids: List[str], chunk_size: Optional[int] = 500) -> Dict[str, List[str]]: + """Get result IDs of a list of tasks. + + Args: + task_ids: The IDs of the tasks. + chunk_size: Batch size for retrieval. + + Return: + A dictionary mapping the ID of a task to the IDs of its results.. + """ + tasks_result_ids = {} + + for task_ids_batch in batched(task_ids, chunk_size): + request = GetResultIdsRequest(task_id=task_ids_batch) + result_ids_response: GetResultIdsResponse = self._client.GetResultIds(request) + for t in result_ids_response.task_results: + tasks_result_ids[t.task_id] = list(t.result_ids) + return tasks_result_ids + + def count_tasks_by_status(self, task_filter: Filter | None = None) -> Dict[TaskStatus, int]: + """Get number of tasks by status. + + Args: + task_filter: Filter for the tasks to be listed + + Return: + A dictionnary mapping each status to the number of filtered tasks. + """ + if task_filter: + request = CountTasksByStatusRequest(filters=cast(rawFilters, task_filter.to_disjunction().to_message())) + else: + request = CountTasksByStatusRequest() + count_tasks_by_status_response: CountTasksByStatusResponse = self._client.CountTasksByStatus(request) + return {TaskStatus(status_count.status): status_count.count for status_count in count_tasks_by_status_response.status} + + def submit_tasks(self, session_id: str, tasks: List[TaskDefinition], default_task_options: Optional[TaskOptions | None] = None, chunk_size: Optional[int] = 100) -> List[Task]: + """Submit tasks to ArmoniK. + + Args: + session_id: Session Id + tasks: List of task definitions + default_task_options: Default Task Options used if a task has its options not set + chunk_size: Batch size for submission + + Returns: + Tuple containing the list of successfully sent tasks, and + the list of submission errors if any + """ + for tasks_batch in batched(tasks, chunk_size): + task_creations = [] + + for t in tasks_batch: + task_creation = SubmitTasksRequest.TaskCreation( + expected_output_keys=t.expected_output_ids, + payload_id=t.payload_id, + data_dependencies=t.data_dependencies, + ) + if t.options: + task_creation.task_options = t.options.to_message() + task_creations.append(task_creation) + + request = SubmitTasksRequest( + session_id=session_id, + task_creations=task_creations + ) + if default_task_options: + request.task_options = default_task_options.to_message() + + self._client.SubmitTasks(request) diff --git a/packages/python/src/armonik/client/versions.py b/packages/python/src/armonik/client/versions.py new file mode 100644 index 000000000..db6f23a69 --- /dev/null +++ b/packages/python/src/armonik/client/versions.py @@ -0,0 +1,26 @@ +from typing import Dict + +from grpc import Channel + +from ..protogen.client.versions_service_pb2_grpc import VersionsStub +from ..protogen.common.versions_common_pb2 import ListVersionsRequest, ListVersionsResponse + + +class ArmoniKVersions: + def __init__(self, grpc_channel: Channel): + """ Result service client + + Args: + grpc_channel: gRPC channel to use + """ + self._client = VersionsStub(grpc_channel) + + def list_versions(self) -> Dict[str, str]: + """Get versions of ArmoniK components. + + Return: + A dictionnary mapping each component to its version. + """ + request = ListVersionsRequest() + response: ListVersionsResponse = self._client.ListVersions(request) + return {"core": response.core, "api": response.api} diff --git a/packages/python/src/armonik/common/__init__.py b/packages/python/src/armonik/common/__init__.py index 001721868..04105d3d4 100644 --- a/packages/python/src/armonik/common/__init__.py +++ b/packages/python/src/armonik/common/__init__.py @@ -1,4 +1,41 @@ -from .helpers import datetime_to_timestamp, timestamp_to_datetime, duration_to_timedelta, timedelta_to_duration, get_task_filter -from .objects import Task, TaskDefinition, TaskOptions, Output, ResultAvailability, Session, Result -from .enumwrapper import HealthCheckStatus, TaskStatus, Direction -from .filter import StringFilter, StatusFilter +from .helpers import ( + datetime_to_timestamp, + timestamp_to_datetime, + duration_to_timedelta, + timedelta_to_duration, + get_task_filter, + batched +) +from .objects import Task, TaskDefinition, TaskOptions, Output, ResultAvailability, Session, Result, Partition +from .enumwrapper import HealthCheckStatus, TaskStatus, Direction, SessionStatus, ResultStatus, EventTypes, ServiceHealthCheckStatus +from .events import * +from .filter import Filter, StringFilter, StatusFilter + +__all__ = [ + 'datetime_to_timestamp', + 'timestamp_to_datetime', + 'duration_to_timedelta', + 'timedelta_to_duration', + 'get_task_filter', + 'batched', + 'Task', + 'TaskDefinition', + 'TaskOptions', + 'Output', + 'ResultAvailability', + 'Session', + 'Result', + 'Partition', + 'HealthCheckStatus', + 'TaskStatus', + 'Direction', + 'SessionStatus', + 'ResultStatus', + 'EventTypes', + # Include all names from events module + # Add names from filter module + 'Filter', + 'StringFilter', + 'StatusFilter', + 'ServiceHealthCheckStatus' +] diff --git a/packages/python/src/armonik/common/enumwrapper.py b/packages/python/src/armonik/common/enumwrapper.py index 9c19a9a82..d6fe134cb 100644 --- a/packages/python/src/armonik/common/enumwrapper.py +++ b/packages/python/src/armonik/common/enumwrapper.py @@ -1,8 +1,10 @@ from __future__ import annotations from ..protogen.common.task_status_pb2 import TaskStatus as RawStatus, _TASKSTATUS, TASK_STATUS_CANCELLED, TASK_STATUS_CANCELLING, TASK_STATUS_COMPLETED, TASK_STATUS_CREATING, TASK_STATUS_DISPATCHED, TASK_STATUS_ERROR, TASK_STATUS_PROCESSED, TASK_STATUS_PROCESSING, TASK_STATUS_SUBMITTED, TASK_STATUS_TIMEOUT, TASK_STATUS_UNSPECIFIED, TASK_STATUS_RETRIED +from ..protogen.common.events_common_pb2 import EventsEnum as rawEventsEnum, EVENTS_ENUM_UNSPECIFIED, EVENTS_ENUM_NEW_TASK, EVENTS_ENUM_TASK_STATUS_UPDATE, EVENTS_ENUM_NEW_RESULT, EVENTS_ENUM_RESULT_STATUS_UPDATE, EVENTS_ENUM_RESULT_OWNER_UPDATE from ..protogen.common.session_status_pb2 import SessionStatus as RawSessionStatus, _SESSIONSTATUS, SESSION_STATUS_UNSPECIFIED, SESSION_STATUS_CANCELLED, SESSION_STATUS_RUNNING from ..protogen.common.result_status_pb2 import ResultStatus as RawResultStatus, _RESULTSTATUS, RESULT_STATUS_UNSPECIFIED, RESULT_STATUS_CREATED, RESULT_STATUS_COMPLETED, RESULT_STATUS_ABORTED, RESULT_STATUS_NOTFOUND +from ..protogen.common.health_checks_common_pb2 import HEALTH_STATUS_ENUM_UNSPECIFIED, HEALTH_STATUS_ENUM_HEALTHY, HEALTH_STATUS_ENUM_DEGRADED, HEALTH_STATUS_ENUM_UNHEALTHY from ..protogen.common.worker_common_pb2 import HealthCheckReply from ..protogen.common.sort_direction_pb2 import SORT_DIRECTION_ASC, SORT_DIRECTION_DESC @@ -58,3 +60,23 @@ def name_from_value(status: RawResultStatus) -> str: COMPLETED = RESULT_STATUS_COMPLETED ABORTED = RESULT_STATUS_ABORTED NOTFOUND = RESULT_STATUS_NOTFOUND + + +class EventTypes: + UNSPECIFIED = EVENTS_ENUM_UNSPECIFIED + NEW_TASK = EVENTS_ENUM_NEW_TASK + TASK_STATUS_UPDATE = EVENTS_ENUM_TASK_STATUS_UPDATE + NEW_RESULT = EVENTS_ENUM_NEW_RESULT + RESULT_STATUS_UPDATE = EVENTS_ENUM_RESULT_STATUS_UPDATE + RESULT_OWNER_UPDATE = EVENTS_ENUM_RESULT_OWNER_UPDATE + + @classmethod + def from_string(cls, name: str): + return getattr(cls, name.upper()) + + +class ServiceHealthCheckStatus: + UNSPECIFIED = HEALTH_STATUS_ENUM_UNSPECIFIED + HEALTHY = HEALTH_STATUS_ENUM_HEALTHY + DEGRADED = HEALTH_STATUS_ENUM_DEGRADED + UNHEALTHY = HEALTH_STATUS_ENUM_UNHEALTHY diff --git a/packages/python/src/armonik/common/events.py b/packages/python/src/armonik/common/events.py new file mode 100644 index 000000000..34acbd2c0 --- /dev/null +++ b/packages/python/src/armonik/common/events.py @@ -0,0 +1,53 @@ +from abc import ABC +from typing import List + +from dataclasses import dataclass, fields + +from .enumwrapper import TaskStatus, ResultStatus + + +class Event(ABC): + @classmethod + def from_raw_event(cls, raw_event): + values = {} + for raw_field in cls.__annotations__.keys(): + values[raw_field] = getattr(raw_event, raw_field) + return cls(**values) + + +@dataclass +class TaskStatusUpdateEvent(Event): + task_id: str + status: TaskStatus + + +@dataclass +class ResultStatusUpdateEvent(Event): + result_id: str + status: ResultStatus + + +@dataclass +class ResultOwnerUpdateEvent(Event): + result_id: str + previous_owner_id: str + current_owner_id: str + + +@dataclass +class NewTaskEvent(Event): + task_id: str + payload_id: str + origin_task_id: str + status: TaskStatus + expected_output_keys: List[str] + data_dependencies: List[str] + retry_of_ids: List[str] + parent_task_ids: List[str] + + +@dataclass +class NewResultEvent(Event): + result_id: str + owner_id: str + status: ResultStatus diff --git a/packages/python/src/armonik/common/helpers.py b/packages/python/src/armonik/common/helpers.py index e174e2f42..3a9cb8324 100644 --- a/packages/python/src/armonik/common/helpers.py +++ b/packages/python/src/armonik/common/helpers.py @@ -1,6 +1,6 @@ from __future__ import annotations from datetime import timedelta, datetime, timezone -from typing import List, Optional +from typing import List, Optional, Iterable, TypeVar import google.protobuf.duration_pb2 as duration import google.protobuf.timestamp_pb2 as timestamp @@ -9,6 +9,9 @@ from .enumwrapper import TaskStatus +T = TypeVar('T') + + def get_task_filter(session_ids: Optional[List[str]] = None, task_ids: Optional[List[str]] = None, included_statuses: Optional[List[TaskStatus]] = None, excluded_statuses: Optional[List[TaskStatus]] = None) -> TaskFilter: @@ -96,3 +99,29 @@ def timedelta_to_duration(delta: timedelta) -> duration.Duration: d = duration.Duration() d.FromTimedelta(delta) return d + + +def batched(iterable: Iterable[T], n: int) -> Iterable[List[T]]: + """ + Batches elements from an iterable into lists of size at most 'n'. + + Args: + iterable : The input iterable. + n : The batch size. + + Yields: + A generator yielding batches of elements from the input iterable. + """ + it = iter(iterable) + + sentinel = object() + batch = [] + c = next(it, sentinel) + while c is not sentinel: + batch.append(c) + if len(batch) == n: + yield batch + batch.clear() + c = next(it, sentinel) + if len(batch) > 0: + yield batch diff --git a/packages/python/src/armonik/common/objects.py b/packages/python/src/armonik/common/objects.py index 4f9527377..1b5801f7a 100644 --- a/packages/python/src/armonik/common/objects.py +++ b/packages/python/src/armonik/common/objects.py @@ -8,6 +8,7 @@ from ..protogen.common.objects_pb2 import Empty, Output as WorkerOutput, TaskOptions as RawTaskOptions from ..protogen.common.task_status_pb2 import TaskStatus as RawTaskStatus from .enumwrapper import TaskStatus, SessionStatus, ResultStatus +from ..protogen.common.partitions_common_pb2 import PartitionRaw from ..protogen.common.session_status_pb2 import SessionStatus as RawSessionStatus from ..protogen.common.sessions_common_pb2 import SessionRaw from ..protogen.common.result_status_pb2 import ResultStatus as RawResultStatus @@ -70,9 +71,11 @@ def to_message(self): @dataclass() class TaskDefinition: - payload: bytes + payload_id: str = field(default_factory=str) + payload: bytes = field(default_factory=bytes) expected_output_ids: List[str] = field(default_factory=list) data_dependencies: List[str] = field(default_factory=list) + options: Optional[TaskOptions] = None def __post_init__(self): if len(self.expected_output_ids) <= 0: @@ -89,6 +92,7 @@ class Task: expected_output_ids: List[str] = field(default_factory=list) retry_of_ids: List[str] = field(default_factory=list) status: RawTaskStatus = TaskStatus.UNSPECIFIED + payload_id: Optional[str] = None status_message: Optional[str] = None options: Optional[TaskOptions] = None created_at: Optional[datetime] = None @@ -211,3 +215,25 @@ def from_message(cls, result_raw: ResultRaw) -> "Result": result_id=result_raw.result_id, size=result_raw.size ) + +@dataclass +class Partition: + id: str + parent_partition_ids: List[str] + pod_reserved: int + pod_max: int + pod_configuration: Dict[str, str] + preemption_percentage: int + priority: int + + @classmethod + def from_message(cls, partition_raw: PartitionRaw) -> "Partition": + return cls( + id=partition_raw.id, + parent_partition_ids=partition_raw.parent_partition_ids, + pod_reserved=partition_raw.pod_reserved, + pod_max=partition_raw.pod_max, + pod_configuration=partition_raw.pod_configuration, + preemption_percentage=partition_raw.preemption_percentage, + priority=partition_raw.priority + ) diff --git a/packages/python/src/armonik/worker/__init__.py b/packages/python/src/armonik/worker/__init__.py index 508d49ae5..78a61174c 100644 --- a/packages/python/src/armonik/worker/__init__.py +++ b/packages/python/src/armonik/worker/__init__.py @@ -1,3 +1,9 @@ from .worker import ArmoniKWorker from .taskhandler import TaskHandler from .seqlogger import ClefLogger + +__all__ = [ + 'ArmoniKWorker', + 'TaskHandler', + 'ClefLogger', +] diff --git a/packages/python/src/armonik/worker/taskhandler.py b/packages/python/src/armonik/worker/taskhandler.py index 49eeb8ff3..7d18ef7db 100644 --- a/packages/python/src/armonik/worker/taskhandler.py +++ b/packages/python/src/armonik/worker/taskhandler.py @@ -1,12 +1,14 @@ from __future__ import annotations import os +from deprecation import deprecated from typing import Optional, Dict, List, Tuple, Union, cast -from ..common import TaskOptions, TaskDefinition, Task -from ..protogen.common.agent_common_pb2 import CreateTaskRequest, CreateResultsMetaDataRequest, CreateResultsMetaDataResponse, NotifyResultDataRequest +from ..common import TaskOptions, TaskDefinition, Task, Result +from ..protogen.common.agent_common_pb2 import CreateTaskRequest, CreateResultsMetaDataRequest, CreateResultsMetaDataResponse, NotifyResultDataRequest, CreateResultsRequest, CreateResultsResponse, SubmitTasksRequest, SubmitTasksResponse from ..protogen.common.objects_pb2 import TaskRequest, DataChunk, InitTaskRequest, TaskRequestHeader, Configuration from ..protogen.worker.agent_service_pb2_grpc import AgentStub from ..protogen.common.worker_common_pb2 import ProcessRequest +from ..common.helpers import batched class TaskHandler: @@ -31,6 +33,7 @@ def __init__(self, request: ProcessRequest, agent_client: AgentStub): with open(os.path.join(self.data_folder, dd), "rb") as f: self.data_dependencies[dd] = f.read() + @deprecated(deprecated_in="3.15.0", details="Use submit_tasks and instead and create the payload using create_result_metadata and send_result") def create_tasks(self, tasks: List[TaskDefinition], task_options: Optional[TaskOptions] = None) -> Tuple[List[Task], List[str]]: """Create new tasks for ArmoniK @@ -67,21 +70,99 @@ def create_tasks(self, tasks: List[TaskDefinition], task_options: Optional[TaskO raise Exception("Unknown value") return tasks_created, tasks_creation_failed - def send_result(self, key: str, data: Union[bytes, bytearray]) -> None: - """ Send task result + def submit_tasks(self, tasks: List[TaskDefinition], default_task_options: Optional[TaskOptions] = None, batch_size: Optional[int] = 100) -> None: + """Submit tasks to the agent. Args: - key: Result key - data: Result data + tasks: List of task definitions + default_task_options: Default Task Options used if a task has its options not set + batch_size: Batch size for submission """ - with open(os.path.join(self.data_folder, key), "wb") as f: - f.write(data) + for tasks_batch in batched(tasks, batch_size): + task_creations = [] + + for t in tasks_batch: + task_creation = SubmitTasksRequest.TaskCreation( + expected_output_keys=t.expected_output_ids, + payload_id=t.payload_id, + data_dependencies=t.data_dependencies + ) + if t.options: + task_creation.task_options=t.options.to_message() + task_creations.append(task_creation) + + request = SubmitTasksRequest( + session_id=self.session_id, + communication_token=self.token, + task_creations=task_creations + ) + + if default_task_options: + request.task_options=default_task_options.to_message(), - self._client.NotifyResultData(NotifyResultDataRequest(ids=[NotifyResultDataRequest.ResultIdentifier(session_id=self.session_id, result_id=key)], communication_token=self.token)) + self._client.SubmitTasks(request) - def get_results_ids(self, names: List[str]) -> Dict[str, str]: - return {r.name: r.result_id for r in cast(CreateResultsMetaDataResponse, self._client.CreateResultsMetaData(CreateResultsMetaDataRequest(results=[CreateResultsMetaDataRequest.ResultCreate(name=n) for n in names], session_id=self.session_id, communication_token=self.token))).results} + def send_results(self, results_data: Dict[str, bytes | bytearray]) -> None: + """Send results. + Args: + result_data: A dictionnary mapping each result ID to its data. + """ + for result_id, result_data in results_data.items(): + with open(os.path.join(self.data_folder, result_id), "wb") as f: + f.write(result_data) + + request = NotifyResultDataRequest( + ids=[NotifyResultDataRequest.ResultIdentifier(session_id=self.session_id, result_id=result_id) for result_id in results_data.keys()], + communication_token=self.token + ) + self._client.NotifyResultData(request) + + def create_results_metadata(self, result_names: List[str], batch_size: int = 100) -> Dict[str, List[Result]]: + """ + Create the metadata of multiple results at once. + Data have to be uploaded separately. + + Args: + result_names: The names of the results to create. + batch_size: Batch size for querying. + + Return: + A dictionnary mapping each result name to its result summary. + """ + results = {} + for result_names_batch in batched(result_names, batch_size): + request = CreateResultsMetaDataRequest( + results=[CreateResultsMetaDataRequest.ResultCreate(name=result_name) for result_name in result_names], + session_id=self.session_id, + communication_token=self.token + ) + response: CreateResultsMetaDataResponse = self._client.CreateResultsMetaData(request) + for result_message in response.results: + results[result_message.name] = Result.from_message(result_message) + return results + + def create_results(self, results_data: Dict[str, bytes], batch_size: int = 1) -> Dict[str, Result]: + """Create one result with data included in the request. + + Args: + results_data: A dictionnary mapping the result names to their actual data. + batch_size: Batch size for querying. + + Return: + A dictionnary mappin each result name to its corresponding result summary. + """ + results = {} + for results_ids_batch in batched(results_data.keys(), batch_size): + request = CreateResultsRequest( + results=[CreateResultsRequest.ResultCreate(name=name, data=results_data[name]) for name in results_ids_batch], + session_id=self.session_id, + communication_token=self.token + ) + response: CreateResultsResponse = self._client.CreateResults(request) + for message in response.results: + results[message.name] = Result.from_message(message) + return results def _to_request_stream_internal(request, communication_token, is_last, chunk_max_size): req = CreateTaskRequest( diff --git a/packages/python/tests/common.py b/packages/python/tests/common.py deleted file mode 100644 index db2feed4a..000000000 --- a/packages/python/tests/common.py +++ /dev/null @@ -1,40 +0,0 @@ -from grpc import Channel - - -class DummyChannel(Channel): - def __init__(self): - self.method_dict = {} - - def stream_unary(self, *args, **kwargs): - return self.get_method(args[0]) - - def unary_stream(self, *args, **kwargs): - return self.get_method(args[0]) - - def unary_unary(self, *args, **kwargs): - return self.get_method(args[0]) - - def stream_stream(self, *args, **kwargs): - return self.get_method(args[0]) - - def set_instance(self, instance): - self.method_dict = {func: getattr(instance, func) for func in dir(type(instance)) if callable(getattr(type(instance), func)) and not func.startswith("__")} - - def get_method(self, name: str): - return self.method_dict.get(name.split("/")[-1], None) - - def subscribe(self, callback, try_to_connect=False): - pass - - def unsubscribe(self, callback): - pass - - def close(self): - pass - - def __enter__(self): - pass - - def __exit__(self, exc_type, exc_val, exc_tb): - pass - diff --git a/packages/python/tests/conftest.py b/packages/python/tests/conftest.py new file mode 100644 index 000000000..1b5949a7e --- /dev/null +++ b/packages/python/tests/conftest.py @@ -0,0 +1,168 @@ +import grpc +import os +import pytest +import requests + +from armonik.client import ArmoniKResults, ArmoniKSubmitter, ArmoniKTasks, ArmoniKSessions, ArmoniKPartitions, ArmoniKVersions, ArmoniKEvents, ArmoniKHealthChecks +from armonik.protogen.worker.agent_service_pb2_grpc import AgentStub +from typing import List + + +# Mock server endpoints used for the tests. +grpc_endpoint = "localhost:5001" +calls_endpoint = "http://localhost:5000/calls.json" +reset_endpoint = "http://localhost:5000/reset" +data_folder = os.getcwd() + + +@pytest.fixture(scope="session", autouse=True) +def clean_up(request): + """ + This fixture runs at the session scope and is automatically used before and after + running all the tests. It set up and teardown the testing environments by: + - creating dummy files before testing begins; + - clear files after testing; + - resets the mocking gRPC server counters to maintain a clean testing environment. + + Yields: + None: This fixture is used as a context manager, and the test code runs between + the 'yield' statement and the cleanup code. + + Raises: + requests.exceptions.HTTPError: If an error occurs when attempting to reset + the mocking gRPC server counters. + """ + # Write dumm payload and data dependency to files for testing purposes + with open(os.path.join(data_folder, "payload-id"), "wb") as f: + f.write("payload".encode()) + with open(os.path.join(data_folder, "dd-id"), "wb") as f: + f.write("dd".encode()) + + # Run all the tests + yield + + # Remove the temporary files created for testing + os.remove(os.path.join(data_folder, "payload-id")) + os.remove(os.path.join(data_folder, "dd-id")) + os.remove(os.path.join(data_folder, "result-id")) + + # Reset the mock server counters + try: + response = requests.post(reset_endpoint) + response.raise_for_status() + print("\nMock server resetted.") + except requests.exceptions.HTTPError as e: + print("An error occurred when resetting the server: " + str(e)) + + +def get_client(client_name: str, endpoint: str = grpc_endpoint) -> [ArmoniKResults, ArmoniKSubmitter, ArmoniKTasks, ArmoniKSessions, ArmoniKPartitions, ArmoniKVersions, AgentStub, ArmoniKEvents, ArmoniKHealthChecks]: + """ + Get the ArmoniK client instance based on the specified service name. + + Args: + client_name (str): The name of the ArmoniK client to retrieve. + endpoint (str, optional): The gRPC server endpoint. Defaults to grpc_endpoint. + + Returns: + Union[ArmoniKResults, ArmoniKSubmitter, ArmoniKTasks, ArmoniKSessions, ARmoniKPartitions, AgentStub]: + An instance of the specified ArmoniK client. + + Raises: + ValueError: If the specified service name is not recognized. + + Example: + >>> result_service = get_service("Results") + >>> submitter_service = get_service("Submitter", "custom_endpoint") + """ + channel = grpc.insecure_channel(endpoint).__enter__() + match client_name: + case "Results": + return ArmoniKResults(channel) + case "Submitter": + return ArmoniKSubmitter(channel) + case "Tasks": + return ArmoniKTasks(channel) + case "Sessions": + return ArmoniKSessions(channel) + case "Partitions": + return ArmoniKPartitions(channel) + case "Versions": + return ArmoniKVersions(channel) + case "Agent": + return AgentStub(channel) + case "Events": + return ArmoniKEvents(channel) + case "HealthChecks": + return ArmoniKHealthChecks(channel) + case _: + raise ValueError("Unknown service name: " + str(service_name)) + + +def rpc_called(service_name: str, rpc_name: str, n_calls: int = 1, endpoint: str = calls_endpoint) -> bool: + """Check if a remote procedure call (RPC) has been made a specified number of times. + This function uses ArmoniK.Api.Mock. It just gets the '/calls.json' endpoint. + + Args: + service_name (str): The name of the service providing the RPC. + rpc_name (str): The name of the specific RPC to check for the number of calls. + n_calls (int, optional): The expected number of times the RPC should have been called. Default is 1. + endpoint (str, optional): The URL of the remote service providing RPC information. Default to + calls_endpoint. + + Returns: + bool: True if the specified RPC has been called the expected number of times, False otherwise. + + Raises: + requests.exceptions.RequestException: If an error occurs when requesting ArmoniK.Api.Mock. + + Example: + >>> rpc_called('http://localhost:5000/calls.json', 'Versions', 'ListVersionss', 0) + True + """ + response = requests.get(endpoint) + response.raise_for_status() + data = response.json() + + # Check if the RPC has been called n_calls times + if data[service_name][rpc_name] == n_calls: + return True + return False + + +def all_rpc_called(service_name: str, missings: List[str] = [], endpoint: str = calls_endpoint) -> bool: + """ + Check if all remote procedure calls (RPCs) in a service have been made at least once. + This function uses ArmoniK.Api.Mock. It just gets the '/calls.json' endpoint. + + Args: + service_name (str): The name of the service containing the RPC information in the response. + endpoint (str, optional): The URL of the remote service providing RPC information. Default is + the value of calls_endpoint. + missings (List[str], optional): A list of RPCs known to be not implemented. Default is an empty list. + + Returns: + bool: True if all RPCs in the specified service have been called at least once, False otherwise. + + Raises: + requests.exceptions.RequestException: If an error occurs when requesting ArmoniK.Api.Mock. + + Example: + >>> all_rpc_called('http://localhost:5000/calls.json', 'Versions') + False + """ + response = requests.get(endpoint) + response.raise_for_status() + data = response.json() + + missing_rpcs = [] + + # Check if all RPCs in the service have been called at least once + for rpc_name, rpc_num_calls in data[service_name].items(): + if rpc_num_calls == 0: + missing_rpcs.append(rpc_name) + if missing_rpcs: + if missings == missing_rpcs: + return True + print(f"RPCs not implemented in {service_name} service: {missing_rpcs}.") + return False + return True diff --git a/packages/python/tests/submitter_test.py b/packages/python/tests/submitter_test.py deleted file mode 100644 index c849efdc0..000000000 --- a/packages/python/tests/submitter_test.py +++ /dev/null @@ -1,313 +0,0 @@ -#!/usr/bin/env python3 -import datetime -import logging -import pytest -from armonik.client import ArmoniKSubmitter -from typing import Iterator, Optional, List -from .common import DummyChannel -from armonik.common import TaskOptions, TaskDefinition, TaskStatus, timedelta_to_duration -from armonik.protogen.client.submitter_service_pb2_grpc import SubmitterStub -from armonik.protogen.common.objects_pb2 import Empty, Configuration, Session, TaskIdList, ResultRequest, TaskError, Error, \ - Count, StatusCount, DataChunk -from armonik.protogen.common.submitter_common_pb2 import CreateSessionRequest, CreateSessionReply, CreateLargeTaskRequest, \ - CreateTaskReply, TaskFilter, ResultReply, AvailabilityReply, WaitRequest, GetTaskStatusRequest, GetTaskStatusReply - -logging.basicConfig() -logging.getLogger().setLevel(logging.INFO) - - -class DummySubmitter(SubmitterStub): - def __init__(self, channel: DummyChannel, max_chunk_size=300): - channel.set_instance(self) - super().__init__(channel) - self.max_chunk_size = max_chunk_size - self.large_tasks_requests: List[CreateLargeTaskRequest] = [] - self.task_filter: Optional[TaskFilter] = None - self.create_session: Optional[CreateSessionRequest] = None - self.session: Optional[Session] = None - self.result_stream: List[ResultReply] = [] - self.result_request: Optional[ResultRequest] = None - self.is_available = True - self.wait_request: Optional[WaitRequest] = None - self.get_status_request: Optional[GetTaskStatusRequest] = None - - def GetServiceConfiguration(self, _: Empty) -> Configuration: - return Configuration(data_chunk_max_size=self.max_chunk_size) - - def CreateSession(self, request: CreateSessionRequest) -> CreateSessionReply: - self.create_session = request - return CreateSessionReply(session_id="SessionId") - - def CancelSession(self, request: Session) -> Empty: - self.session = request - return Empty() - - def CreateLargeTasks(self, request: Iterator[CreateLargeTaskRequest]) -> CreateTaskReply: - self.large_tasks_requests = [r for r in request] - return CreateTaskReply(creation_status_list=CreateTaskReply.CreationStatusList(creation_statuses=[ - CreateTaskReply.CreationStatus( - task_info=CreateTaskReply.TaskInfo(task_id="TaskId", expected_output_keys=["EOK"], - data_dependencies=["DD"])), - CreateTaskReply.CreationStatus(error="TestError")])) - - def ListTasks(self, request: TaskFilter) -> TaskIdList: - self.task_filter = request - return TaskIdList(task_ids=["TaskId"]) - - def TryGetResultStream(self, request: ResultRequest) -> Iterator[ResultReply]: - self.result_request = request - for r in self.result_stream: - yield r - - def WaitForAvailability(self, request: ResultRequest) -> AvailabilityReply: - from armonik.protogen.common.task_status_pb2 import TASK_STATUS_ERROR - self.result_request = request - return AvailabilityReply(ok=Empty()) if self.is_available else AvailabilityReply( - error=TaskError(task_id="TaskId", errors=[Error(task_status=TASK_STATUS_ERROR, detail="TestError")])) - - def WaitForCompletion(self, request: WaitRequest) -> Count: - from armonik.protogen.common.task_status_pb2 import TASK_STATUS_COMPLETED - self.wait_request = request - return Count(values=[StatusCount(status=TASK_STATUS_COMPLETED, count=1)]) - - def GetTaskStatus(self, request: GetTaskStatusRequest) -> GetTaskStatusReply: - from armonik.protogen.common.task_status_pb2 import TASK_STATUS_COMPLETED - self.get_status_request = request - return GetTaskStatusReply( - id_statuses=[GetTaskStatusReply.IdStatus(task_id="TaskId", status=TASK_STATUS_COMPLETED)]) - - -default_task_option = TaskOptions(datetime.timedelta(seconds=300), priority=1, max_retries=5) - - -@pytest.mark.parametrize("task_options,partitions", [(default_task_option, None), (default_task_option, ["default"])]) -def test_armonik_submitter_should_create_session(task_options, partitions): - channel = DummyChannel() - inner = DummySubmitter(channel) - submitter = ArmoniKSubmitter(channel) - session_id = submitter.create_session(default_task_options=task_options, partition_ids=partitions) - assert session_id == "SessionId" - assert inner.create_session - assert inner.create_session.default_task_option.priority == task_options.priority - assert len(inner.create_session.partition_ids) == 0 if partitions is None else list(inner.create_session.partition_ids) == partitions - assert len(inner.create_session.default_task_option.options) == len(task_options.options) - assert inner.create_session.default_task_option.max_duration == timedelta_to_duration(task_options.max_duration) - assert inner.create_session.default_task_option.max_retries == task_options.max_retries - - -def test_armonik_submitter_should_cancel_session(): - channel = DummyChannel() - inner = DummySubmitter(channel) - submitter = ArmoniKSubmitter(channel) - submitter.cancel_session("SessionId") - assert inner.session is not None - assert inner.session.id == "SessionId" - - -def test_armonik_submitter_should_get_config(): - channel = DummyChannel() - inner = DummySubmitter(channel) - submitter = ArmoniKSubmitter(channel) - config = submitter.get_service_configuration() - assert config is not None - assert config.data_chunk_max_size == 300 - - -should_submit = [ - [TaskDefinition("Payload1".encode('utf-8'), expected_output_ids=["EOK"], data_dependencies=["DD"]), - TaskDefinition("Payload2".encode('utf-8'), expected_output_ids=["EOK"], data_dependencies=["DD"])], - [TaskDefinition("Payload1".encode('utf-8'), expected_output_ids=["EOK"]), - TaskDefinition("Payload2".encode('utf-8'), expected_output_ids=["EOK"])], - [TaskDefinition("".encode('utf-8'), expected_output_ids=["EOK"]), - TaskDefinition("".encode('utf-8'), expected_output_ids=["EOK"])] -] - - -@pytest.mark.parametrize("task_list,task_options", - [(t, default_task_option if i else None) for t in should_submit for i in [True, False]]) -def test_armonik_submitter_should_submit(task_list, task_options): - channel = DummyChannel() - inner = DummySubmitter(channel, max_chunk_size=5) - submitter = ArmoniKSubmitter(channel) - successes, errors = submitter.submit("SessionId", tasks=task_list, task_options=task_options) - # The dummy submitter has been set to submit one successful task and one submission error - assert len(successes) == 1 - assert len(errors) == 1 - assert successes[0].id == "TaskId" - assert successes[0].session_id == "SessionId" - assert errors[0] == "TestError" - - reqs = inner.large_tasks_requests - assert len(reqs) > 0 - offset = 0 - assert reqs[0 + offset].WhichOneof("type") == "init_request" - assert reqs[0 + offset].init_request.session_id == "SessionId" - assert reqs[1 + offset].WhichOneof("type") == "init_task" - assert reqs[1 + offset].init_task.header.expected_output_keys[0] == "EOK" - assert reqs[1 + offset].init_task.header.data_dependencies[0] == "DD" if len( - task_list[0].data_dependencies) > 0 else len(reqs[1 + offset].init_task.header.data_dependencies) == 0 - assert reqs[2 + offset].WhichOneof("type") == "task_payload" - assert reqs[2 + offset].task_payload.data == "".encode("utf-8") if len(task_list[0].payload) == 0 \ - else reqs[2 + offset].task_payload.data == task_list[0].payload[:5] - if len(task_list[0].payload) > 0: - offset += 1 - assert reqs[2 + offset].WhichOneof("type") == "task_payload" - assert reqs[2 + offset].task_payload.data == task_list[0].payload[5:] - assert reqs[3 + offset].WhichOneof("type") == "task_payload" - assert reqs[3 + offset].task_payload.data_complete - assert reqs[4 + offset].WhichOneof("type") == "init_task" - assert reqs[4 + offset].init_task.header.expected_output_keys[0] == "EOK" - assert reqs[4 + offset].init_task.header.data_dependencies[0] == "DD" if len( - task_list[1].data_dependencies) > 0 else len(reqs[4 + offset].init_task.header.data_dependencies) == 0 - assert reqs[5 + offset].WhichOneof("type") == "task_payload" - assert reqs[5 + offset].task_payload.data == "".encode("utf-8") if len(task_list[1].payload) == 0 \ - else reqs[5 + offset].task_payload.data == task_list[1].payload[:5] - if len(task_list[1].payload) > 0: - offset += 1 - assert reqs[5 + offset].WhichOneof("type") == "task_payload" - assert reqs[5 + offset].task_payload.data == task_list[1].payload[5:] - assert reqs[6 + offset].WhichOneof("type") == "task_payload" - assert reqs[6 + offset].task_payload.data_complete - assert reqs[7 + offset].WhichOneof("type") == "init_task" - assert reqs[7 + offset].init_task.last_task - - -filters_params = [(session_ids, task_ids, included_statuses, excluded_statuses, - (session_ids is None or task_ids is None) and ( - included_statuses is None or excluded_statuses is None)) - for session_ids in [["SessionId"], None] - for task_ids in [["TaskId"], None] - for included_statuses in [[TaskStatus.COMPLETED], None] - for excluded_statuses in [[TaskStatus.COMPLETED], None]] - - -@pytest.mark.parametrize("session_ids,task_ids,included_statuses,excluded_statuses,should_succeed", filters_params) -def test_armonik_submitter_should_list_tasks(session_ids, task_ids, included_statuses, excluded_statuses, - should_succeed): - channel = DummyChannel() - inner = DummySubmitter(channel) - submitter = ArmoniKSubmitter(channel) - if should_succeed: - tasks = submitter.list_tasks(session_ids=session_ids, task_ids=task_ids, included_statuses=included_statuses, - excluded_statuses=excluded_statuses) - assert len(tasks) > 0 - assert tasks[0] == "TaskId" - assert inner.task_filter is not None - assert all(map(lambda x: x[1] == session_ids[x[0]], enumerate(inner.task_filter.session.ids))) - assert all(map(lambda x: x[1] == task_ids[x[0]], enumerate(inner.task_filter.task.ids))) - assert all(map(lambda x: x[1] == included_statuses[x[0]], enumerate(inner.task_filter.included.statuses))) - assert all(map(lambda x: x[1] == excluded_statuses[x[0]], enumerate(inner.task_filter.excluded.statuses))) - else: - with pytest.raises(ValueError): - _ = submitter.list_tasks(session_ids=session_ids, task_ids=task_ids, included_statuses=included_statuses, - excluded_statuses=excluded_statuses) - - -def test_armonik_submitter_should_get_status(): - channel = DummyChannel() - inner = DummySubmitter(channel) - submitter = ArmoniKSubmitter(channel) - - statuses = submitter.get_task_status(["TaskId"]) - assert len(statuses) > 0 - assert "TaskId" in statuses - assert statuses["TaskId"] == TaskStatus.COMPLETED - assert inner.get_status_request is not None - assert len(inner.get_status_request.task_ids) == 1 - assert inner.get_status_request.task_ids[0] == "TaskId" - - -get_result_should_throw = [ - [], - [ResultReply(result=DataChunk(data="payload".encode("utf-8")))], - [ResultReply(result=DataChunk(data="payload".encode("utf-8"))), ResultReply(result=DataChunk(data_complete=True)), - ResultReply(result=DataChunk(data="payload".encode("utf-8")))], - [ResultReply( - error=TaskError(task_id="TaskId", errors=[Error(task_status=TaskStatus.ERROR, detail="TestError")]))], -] - -get_result_should_succeed = [ - [ResultReply(result=DataChunk(data="payload".encode("utf-8"))), ResultReply(result=DataChunk(data_complete=True))] -] - -get_result_should_none = [ - [ResultReply(not_completed_task="NotCompleted")] -] - - -@pytest.mark.parametrize("stream", [iter(x) for x in get_result_should_succeed]) -def test_armonik_submitter_should_get_result(stream): - channel = DummyChannel() - inner = DummySubmitter(channel) - inner.result_stream = stream - submitter = ArmoniKSubmitter(channel) - result = submitter.get_result("SessionId", "ResultId") - assert result is not None - assert len(result) > 0 - assert inner.result_request - assert inner.result_request.result_id == "ResultId" - assert inner.result_request.session == "SessionId" - - -@pytest.mark.parametrize("stream", [iter(x) for x in get_result_should_throw]) -def test_armonik_submitter_get_result_should_throw(stream): - channel = DummyChannel() - inner = DummySubmitter(channel) - inner.result_stream = stream - submitter = ArmoniKSubmitter(channel) - with pytest.raises(Exception): - _ = submitter.get_result("SessionId", "ResultId") - - -@pytest.mark.parametrize("stream", [iter(x) for x in get_result_should_none]) -def test_armonik_submitter_get_result_should_none(stream): - channel = DummyChannel() - inner = DummySubmitter(channel) - inner.result_stream = stream - submitter = ArmoniKSubmitter(channel) - result = submitter.get_result("SessionId", "ResultId") - assert result is None - assert inner.result_request - assert inner.result_request.result_id == "ResultId" - assert inner.result_request.session == "SessionId" - - -@pytest.mark.parametrize("available", [True, False]) -def test_armonik_submitter_wait_availability(available): - channel = DummyChannel() - inner = DummySubmitter(channel) - inner.is_available = available - submitter = ArmoniKSubmitter(channel) - reply = submitter.wait_for_availability("SessionId", "ResultId") - assert reply is not None - assert reply.is_available() == available - assert len(reply.errors) == 0 if available else reply.errors[0] == "TestError" - - -@pytest.mark.parametrize("session_ids,task_ids,included_statuses,excluded_statuses,should_succeed", filters_params) -def test_armonik_submitter_wait_completion(session_ids, task_ids, included_statuses, excluded_statuses, should_succeed): - channel = DummyChannel() - inner = DummySubmitter(channel) - submitter = ArmoniKSubmitter(channel) - - if should_succeed: - counts = submitter.wait_for_completion(session_ids=session_ids, task_ids=task_ids, - included_statuses=included_statuses, - excluded_statuses=excluded_statuses) - assert len(counts) > 0 - assert TaskStatus.COMPLETED in counts - assert counts[TaskStatus.COMPLETED] == 1 - assert inner.wait_request is not None - assert all(map(lambda x: x[1] == session_ids[x[0]], enumerate(inner.wait_request.filter.session.ids))) - assert all(map(lambda x: x[1] == task_ids[x[0]], enumerate(inner.wait_request.filter.task.ids))) - assert all(map(lambda x: x[1] == included_statuses[x[0]], - enumerate(inner.wait_request.filter.included.statuses))) - assert all(map(lambda x: x[1] == excluded_statuses[x[0]], - enumerate(inner.wait_request.filter.excluded.statuses))) - assert not inner.wait_request.stop_on_first_task_error - assert not inner.wait_request.stop_on_first_task_cancellation - else: - with pytest.raises(ValueError): - _ = submitter.wait_for_completion(session_ids=session_ids, task_ids=task_ids, - included_statuses=included_statuses, - excluded_statuses=excluded_statuses) diff --git a/packages/python/tests/taskhandler_test.py b/packages/python/tests/taskhandler_test.py deleted file mode 100644 index e4f3c181c..000000000 --- a/packages/python/tests/taskhandler_test.py +++ /dev/null @@ -1,89 +0,0 @@ -#!/usr/bin/env python3 -import os - -import pytest -from typing import Iterator - -from armonik.common import TaskDefinition - -from .common import DummyChannel -from armonik.worker import TaskHandler -from armonik.protogen.worker.agent_service_pb2_grpc import AgentStub -from armonik.protogen.common.agent_common_pb2 import CreateTaskRequest, CreateTaskReply, NotifyResultDataRequest, NotifyResultDataResponse -from armonik.protogen.common.worker_common_pb2 import ProcessRequest -from armonik.protogen.common.objects_pb2 import Configuration -import logging - -logging.basicConfig() -logging.getLogger().setLevel(logging.INFO) - -data_folder = os.getcwd() - - -@pytest.fixture(autouse=True, scope="session") -def setup_teardown(): - with open(os.path.join(data_folder, "payloadid"), "wb") as f: - f.write("payload".encode()) - with open(os.path.join(data_folder, "ddid"), "wb") as f: - f.write("dd".encode()) - yield - os.remove(os.path.join(data_folder, "payloadid")) - os.remove(os.path.join(data_folder, "ddid")) - - -class DummyAgent(AgentStub): - - def __init__(self, channel: DummyChannel) -> None: - channel.set_instance(self) - super(DummyAgent, self).__init__(channel) - self.create_task_messages = [] - self.send_result_task_message = [] - - def CreateTask(self, request_iterator: Iterator[CreateTaskRequest]) -> CreateTaskReply: - self.create_task_messages = [r for r in request_iterator] - return CreateTaskReply(creation_status_list=CreateTaskReply.CreationStatusList(creation_statuses=[ - CreateTaskReply.CreationStatus( - task_info=CreateTaskReply.TaskInfo(task_id="TaskId", expected_output_keys=["EOK"], - data_dependencies=["DD"]))])) - - def NotifyResultData(self, request: NotifyResultDataRequest) -> NotifyResultDataResponse: - self.send_result_task_message.append(request) - return NotifyResultDataResponse(result_ids=[i.result_id for i in request.ids]) - - -should_succeed_case = ProcessRequest(communication_token="token", session_id="sessionid", task_id="taskid", expected_output_keys=["resultid"], payload_id="payloadid", data_dependencies=["ddid"], data_folder=data_folder, configuration=Configuration(data_chunk_max_size=8000)) - - -@pytest.mark.parametrize("requests", [should_succeed_case]) -def test_taskhandler_create_should_succeed(requests: ProcessRequest): - agent = DummyAgent(DummyChannel()) - task_handler = TaskHandler(requests, agent) - assert task_handler.token is not None and len(task_handler.token) > 0 - assert len(task_handler.payload) > 0 - assert task_handler.session_id is not None and len(task_handler.session_id) > 0 - assert task_handler.task_id is not None and len(task_handler.task_id) > 0 - - -def test_taskhandler_data_are_correct(): - agent = DummyAgent(DummyChannel()) - task_handler = TaskHandler(should_succeed_case, agent) - assert len(task_handler.payload) > 0 - - task_handler.create_tasks([TaskDefinition("Payload".encode("utf-8"), ["EOK"], ["DD"])]) - - tasks = agent.create_task_messages - assert len(tasks) == 5 - assert tasks[0].WhichOneof("type") == "init_request" - assert tasks[1].WhichOneof("type") == "init_task" - assert len(tasks[1].init_task.header.data_dependencies) == 1 \ - and tasks[1].init_task.header.data_dependencies[0] == "DD" - assert len(tasks[1].init_task.header.expected_output_keys) == 1 \ - and tasks[1].init_task.header.expected_output_keys[0] == "EOK" - assert tasks[2].WhichOneof("type") == "task_payload" - assert tasks[2].task_payload.data == "Payload".encode("utf-8") - assert tasks[3].WhichOneof("type") == "task_payload" - assert tasks[3].task_payload.data_complete - assert tasks[4].WhichOneof("type") == "init_task" - assert tasks[4].init_task.last_task - - diff --git a/packages/python/tests/tasks_test.py b/packages/python/tests/tasks_test.py deleted file mode 100644 index 752b2ac63..000000000 --- a/packages/python/tests/tasks_test.py +++ /dev/null @@ -1,281 +0,0 @@ -#!/usr/bin/env python3 -import dataclasses -from typing import Optional, List, Any, Union, Dict, Collection -from google.protobuf.timestamp_pb2 import Timestamp - -from datetime import datetime - -import pytest - -from .common import DummyChannel -from armonik.client import ArmoniKTasks -from armonik.client.tasks import TaskFieldFilter -from armonik.common import TaskStatus, datetime_to_timestamp, Task -from armonik.common.filter import StringFilter, Filter -from armonik.protogen.client.tasks_service_pb2_grpc import TasksStub -from armonik.protogen.common.tasks_common_pb2 import GetTaskRequest, GetTaskResponse, TaskDetailed -from armonik.protogen.common.tasks_filters_pb2 import Filters, FilterField -from armonik.protogen.common.filters_common_pb2 import * -from armonik.protogen.common.tasks_fields_pb2 import * -from .submitter_test import default_task_option - - -class DummyTasksService(TasksStub): - def __init__(self, channel: DummyChannel): - channel.set_instance(self) - super().__init__(channel) - self.task_request: Optional[GetTaskRequest] = None - - def GetTask(self, request: GetTaskRequest) -> GetTaskResponse: - self.task_request = request - raw = TaskDetailed(id="TaskId", session_id="SessionId", owner_pod_id="PodId", parent_task_ids=["ParentTaskId"], - data_dependencies=["DD"], expected_output_ids=["EOK"], retry_of_ids=["RetryId"], - status=TaskStatus.COMPLETED, status_message="Message", - options=default_task_option.to_message(), - created_at=datetime_to_timestamp(datetime.now()), - started_at=datetime_to_timestamp(datetime.now()), - submitted_at=datetime_to_timestamp(datetime.now()), - ended_at=datetime_to_timestamp(datetime.now()), pod_ttl=datetime_to_timestamp(datetime.now()), - output=TaskDetailed.Output(success=True), pod_hostname="Hostname", received_at=datetime_to_timestamp(datetime.now()), - acquired_at=datetime_to_timestamp(datetime.now()) - ) - return GetTaskResponse(task=raw) - - -def test_tasks_get_task_should_succeed(): - channel = DummyChannel() - inner = DummyTasksService(channel) - tasks = ArmoniKTasks(channel) - task = tasks.get_task("TaskId") - assert task is not None - assert inner.task_request is not None - assert inner.task_request.task_id == "TaskId" - assert task.id == "TaskId" - assert task.session_id == "SessionId" - assert task.parent_task_ids == ["ParentTaskId"] - assert task.output - assert task.output.success - - -def test_task_refresh(): - channel = DummyChannel() - inner = DummyTasksService(channel) - tasks = ArmoniKTasks(channel) - current = Task(id="TaskId") - current.refresh(tasks) - assert current is not None - assert inner.task_request is not None - assert inner.task_request.task_id == "TaskId" - assert current.id == "TaskId" - assert current.session_id == "SessionId" - assert current.parent_task_ids == ["ParentTaskId"] - assert current.output - assert current.output.success - - -def test_task_filters(): - filt: StringFilter = TaskFieldFilter.TASK_ID == "TaskId" - message = filt.to_message() - assert isinstance(message, FilterField) - assert message.field.WhichOneof("field") == "task_summary_field" - assert message.field.task_summary_field.field == TASK_SUMMARY_ENUM_FIELD_TASK_ID - assert message.filter_string.value == "TaskId" - assert message.filter_string.operator == FILTER_STRING_OPERATOR_EQUAL - - filt: StringFilter = TaskFieldFilter.TASK_ID != "TaskId" - message = filt.to_message() - assert isinstance(message, FilterField) - assert message.field.WhichOneof("field") == "task_summary_field" - assert message.field.task_summary_field.field == TASK_SUMMARY_ENUM_FIELD_TASK_ID - assert message.filter_string.value == "TaskId" - assert message.filter_string.operator == FILTER_STRING_OPERATOR_NOT_EQUAL - - -@dataclasses.dataclass -class SimpleFieldFilter: - field: Any - value: Any - operator: Any - - -@pytest.mark.parametrize("filt,n_or,n_and,filters", [ - ( - (TaskFieldFilter.INITIAL_TASK_ID == "TestId"), - 1, [1], - [ - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_INITIAL_TASK_ID, "TestId", FILTER_STRING_OPERATOR_EQUAL) - ] - ), - ( - (TaskFieldFilter.APPLICATION_NAME.contains("TestName") & (TaskFieldFilter.CREATED_AT > Timestamp(seconds=1000, nanos=500))), - 1, [2], - [ - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_APPLICATION_NAME, "TestName", FILTER_STRING_OPERATOR_CONTAINS), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_CREATED_AT, Timestamp(seconds=1000, nanos=500), FILTER_DATE_OPERATOR_AFTER) - ] - ), - ( - (((TaskFieldFilter.MAX_RETRIES <= 3) & ~(TaskFieldFilter.SESSION_ID == "SessionId")) | (TaskFieldFilter.task_options_key("MyKey").startswith("Start"))), - 2, [1, 2], - [ - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_MAX_RETRIES, 3, FILTER_NUMBER_OPERATOR_LESS_THAN_OR_EQUAL), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_SESSION_ID, "SessionId", FILTER_STRING_OPERATOR_NOT_EQUAL), - SimpleFieldFilter("MyKey", "Start", FILTER_STRING_OPERATOR_STARTS_WITH) - ] - ), - ( - (((TaskFieldFilter.PRIORITY > 3) & ~(TaskFieldFilter.STATUS == TaskStatus.COMPLETED) & TaskFieldFilter.APPLICATION_VERSION.contains("1.0")) | (TaskFieldFilter.ENGINE_TYPE.endswith("Test") & (TaskFieldFilter.ENDED_AT <= Timestamp(seconds=1000, nanos=500)))), - 2, [2, 3], - [ - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_PRIORITY, 3, FILTER_NUMBER_OPERATOR_GREATER_THAN), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_STATUS, TaskStatus.COMPLETED, FILTER_STATUS_OPERATOR_NOT_EQUAL), - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_APPLICATION_VERSION, "1.0", FILTER_STRING_OPERATOR_CONTAINS), - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_ENGINE_TYPE, "Test", FILTER_STRING_OPERATOR_ENDS_WITH), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_ENDED_AT, Timestamp(seconds=1000, nanos=500), FILTER_DATE_OPERATOR_BEFORE_OR_EQUAL), - ] - ), - ( - (((TaskFieldFilter.PRIORITY >= 3) * -(TaskFieldFilter.STATUS != TaskStatus.COMPLETED) * -TaskFieldFilter.APPLICATION_VERSION.contains("1.0")) + (TaskFieldFilter.ENGINE_TYPE.endswith("Test") * (TaskFieldFilter.ENDED_AT <= Timestamp(seconds=1000, nanos=500)))), - 2, [2, 3], - [ - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_PRIORITY, 3, FILTER_NUMBER_OPERATOR_GREATER_THAN_OR_EQUAL), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_STATUS, TaskStatus.COMPLETED, FILTER_STATUS_OPERATOR_EQUAL), - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_APPLICATION_VERSION, "1.0", FILTER_STRING_OPERATOR_NOT_CONTAINS), - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_ENGINE_TYPE, "Test", FILTER_STRING_OPERATOR_ENDS_WITH), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_ENDED_AT, Timestamp(seconds=1000, nanos=500), FILTER_DATE_OPERATOR_BEFORE_OR_EQUAL), - ] - ) -]) -def test_filter_combination(filt: Filter, n_or: int, n_and: List[int], filters: List[SimpleFieldFilter]): - filt = filt.to_disjunction() - assert len(filt._filters) == n_or - sorted_n_and = sorted(n_and) - sorted_actual = sorted([len(f) for f in filt._filters]) - assert len(sorted_n_and) == len(sorted_actual) - assert all((sorted_n_and[i] == sorted_actual[i] for i in range(len(sorted_actual)))) - for f in filt._filters: - for ff in f: - field_value = getattr(ff.field, ff.field.WhichOneof("field")).field - for i, expected in enumerate(filters): - if expected.field == field_value and expected.value == ff.value and expected.operator == ff.operator: - filters.pop(i) - break - else: - print(f"Could not find {str(ff)}") - assert False - assert len(filters) == 0 - - -def test_name_from_value(): - assert TaskStatus.name_from_value(TaskStatus.COMPLETED) == "TASK_STATUS_COMPLETED" - - -class BasicFilterAnd: - - def __setattr__(self, key, value): - self.__dict__[key] = value - - def __getattr__(self, item): - return self.__dict__[item] - - -@pytest.mark.parametrize("filt,n_or,n_and,filters,expected_type", [ - ( - (TaskFieldFilter.INITIAL_TASK_ID == "TestId"), - 1, [1], - [ - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_INITIAL_TASK_ID, "TestId", FILTER_STRING_OPERATOR_EQUAL) - ], - 0 - ), - ( - (TaskFieldFilter.APPLICATION_NAME.contains("TestName") & (TaskFieldFilter.CREATED_AT > Timestamp(seconds=1000, nanos=500))), - 1, [2], - [ - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_APPLICATION_NAME, "TestName", FILTER_STRING_OPERATOR_CONTAINS), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_CREATED_AT, Timestamp(seconds=1000, nanos=500), FILTER_DATE_OPERATOR_AFTER) - ], - 1 - ), - ( - (((TaskFieldFilter.MAX_RETRIES <= 3) & ~(TaskFieldFilter.SESSION_ID == "SessionId")) | (TaskFieldFilter.task_options_key("MyKey").startswith("Start"))), - 2, [1, 2], - [ - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_MAX_RETRIES, 3, FILTER_NUMBER_OPERATOR_LESS_THAN_OR_EQUAL), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_SESSION_ID, "SessionId", FILTER_STRING_OPERATOR_NOT_EQUAL), - SimpleFieldFilter("MyKey", "Start", FILTER_STRING_OPERATOR_STARTS_WITH) - ], - 2 - ), - ( - (((TaskFieldFilter.PRIORITY > 3) & ~(TaskFieldFilter.STATUS == TaskStatus.COMPLETED) & TaskFieldFilter.APPLICATION_VERSION.contains("1.0")) | (TaskFieldFilter.ENGINE_TYPE.endswith("Test") & (TaskFieldFilter.ENDED_AT <= Timestamp(seconds=1000, nanos=500)))), - 2, [2, 3], - [ - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_PRIORITY, 3, FILTER_NUMBER_OPERATOR_GREATER_THAN), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_STATUS, TaskStatus.COMPLETED, FILTER_STATUS_OPERATOR_NOT_EQUAL), - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_APPLICATION_VERSION, "1.0", FILTER_STRING_OPERATOR_CONTAINS), - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_ENGINE_TYPE, "Test", FILTER_STRING_OPERATOR_ENDS_WITH), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_ENDED_AT, Timestamp(seconds=1000, nanos=500), FILTER_DATE_OPERATOR_BEFORE_OR_EQUAL), - ], - 2 - ), - ( - (((TaskFieldFilter.PRIORITY >= 3) * -(TaskFieldFilter.STATUS != TaskStatus.COMPLETED) * -TaskFieldFilter.APPLICATION_VERSION.contains("1.0")) + (TaskFieldFilter.ENGINE_TYPE.endswith("Test") * (TaskFieldFilter.ENDED_AT <= Timestamp(seconds=1000, nanos=500)))), - 2, [2, 3], - [ - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_PRIORITY, 3, FILTER_NUMBER_OPERATOR_GREATER_THAN_OR_EQUAL), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_STATUS, TaskStatus.COMPLETED, FILTER_STATUS_OPERATOR_EQUAL), - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_APPLICATION_VERSION, "1.0", FILTER_STRING_OPERATOR_NOT_CONTAINS), - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_ENGINE_TYPE, "Test", FILTER_STRING_OPERATOR_ENDS_WITH), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_ENDED_AT, Timestamp(seconds=1000, nanos=500), FILTER_DATE_OPERATOR_BEFORE_OR_EQUAL), - ], - 2 - ), - ( - (((TaskFieldFilter.PRIORITY >= 3) * -(TaskFieldFilter.STATUS != TaskStatus.COMPLETED) * -TaskFieldFilter.APPLICATION_VERSION.contains("1.0")) + (TaskFieldFilter.ENGINE_TYPE.endswith("Test") * (TaskFieldFilter.ENDED_AT <= Timestamp(seconds=1000, nanos=500)))) + (((TaskFieldFilter.MAX_RETRIES <= 3) & ~(TaskFieldFilter.SESSION_ID == "SessionId")) | (TaskFieldFilter.task_options_key("MyKey").startswith("Start"))), - 4, [2, 3, 2, 1], - [ - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_PRIORITY, 3, FILTER_NUMBER_OPERATOR_GREATER_THAN_OR_EQUAL), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_STATUS, TaskStatus.COMPLETED, FILTER_STATUS_OPERATOR_EQUAL), - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_APPLICATION_VERSION, "1.0", FILTER_STRING_OPERATOR_NOT_CONTAINS), - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_ENGINE_TYPE, "Test", FILTER_STRING_OPERATOR_ENDS_WITH), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_ENDED_AT, Timestamp(seconds=1000, nanos=500), FILTER_DATE_OPERATOR_BEFORE_OR_EQUAL), - SimpleFieldFilter(TASK_OPTION_ENUM_FIELD_MAX_RETRIES, 3, FILTER_NUMBER_OPERATOR_LESS_THAN_OR_EQUAL), - SimpleFieldFilter(TASK_SUMMARY_ENUM_FIELD_SESSION_ID, "SessionId", FILTER_STRING_OPERATOR_NOT_EQUAL), - SimpleFieldFilter("MyKey", "Start", FILTER_STRING_OPERATOR_STARTS_WITH) - ], - 2 - ) -]) -def test_taskfilter_to_message(filt: Filter, n_or: int, n_and: List[int], filters: List[SimpleFieldFilter], expected_type: int): - print(filt) - message = filt.to_message() - conjs: Collection = [] - if expected_type == 2: # Disjunction - conjs: Collection = getattr(message, "or") - assert len(conjs) == n_or - sorted_n_and = sorted(n_and) - sorted_actual = sorted([len(getattr(f, "and")) for f in conjs]) - assert len(sorted_n_and) == len(sorted_actual) - assert all((sorted_n_and[i] == sorted_actual[i] for i in range(len(sorted_actual)))) - - if expected_type == 1: # Conjunction - conjs: Collection = [message] - - if expected_type == 0: # Simple filter - m = BasicFilterAnd() - setattr(m, "and", [message]) - conjs: Collection = [m] - - for conj in conjs: - basics = getattr(conj, "and") - for f in basics: - field_value = getattr(f.field, f.field.WhichOneof("field")).field - for i, expected in enumerate(filters): - if expected.field == field_value and expected.value == getattr(f, f.WhichOneof("value_condition")).value and expected.operator == getattr(f, f.WhichOneof("value_condition")).operator: - filters.pop(i) - break - else: - print(f"Could not find {str(f)}") - assert False - assert len(filters) == 0 diff --git a/packages/python/tests/test_events.py b/packages/python/tests/test_events.py new file mode 100644 index 000000000..86ca9e42b --- /dev/null +++ b/packages/python/tests/test_events.py @@ -0,0 +1,22 @@ +from .conftest import all_rpc_called, rpc_called, get_client +from armonik.client import ArmoniKEvents +from armonik.common import EventTypes, NewResultEvent, ResultStatus + + +class TestArmoniKEvents: + def test_get_events_no_filter(self): + def test_handler(session_id, event_type, event): + assert session_id == "session-id" + assert event_type == EventTypes.NEW_RESULT + assert isinstance(event, NewResultEvent) + assert event.result_id == "result-id" + assert event.owner_id == "owner-id" + assert event.status == ResultStatus.CREATED + + tasks_client: ArmoniKEvents = get_client("Events") + tasks_client.get_events("session-id", [EventTypes.TASK_STATUS_UPDATE], [test_handler]) + + assert rpc_called("Events", "GetEvents") + + def test_service_fully_implemented(self): + assert all_rpc_called("Events") diff --git a/packages/python/tests/filters_test.py b/packages/python/tests/test_filters.py similarity index 100% rename from packages/python/tests/filters_test.py rename to packages/python/tests/test_filters.py diff --git a/packages/python/tests/test_healthcheck.py b/packages/python/tests/test_healthcheck.py new file mode 100644 index 000000000..8062aaab6 --- /dev/null +++ b/packages/python/tests/test_healthcheck.py @@ -0,0 +1,18 @@ +import datetime + +from .conftest import all_rpc_called, rpc_called, get_client +from armonik.client import ArmoniKHealthChecks +from armonik.common import ServiceHealthCheckStatus + + +class TestArmoniKHealthChecks: + + def test_check_health(self): + health_checks_client: ArmoniKHealthChecks = get_client("HealthChecks") + services_health = health_checks_client.check_health() + + assert rpc_called("HealthChecks", "CheckHealth") + assert services_health == {'mock': {'message': 'Mock is healthy', 'status': ServiceHealthCheckStatus.HEALTHY}} + + def test_service_fully_implemented(self): + assert all_rpc_called("HealthChecks") diff --git a/packages/python/tests/helpers_test.py b/packages/python/tests/test_helpers.py similarity index 76% rename from packages/python/tests/helpers_test.py rename to packages/python/tests/test_helpers.py index 20e07d922..1c0cf517b 100644 --- a/packages/python/tests/helpers_test.py +++ b/packages/python/tests/test_helpers.py @@ -5,7 +5,9 @@ from google.protobuf.duration_pb2 import Duration from dataclasses import dataclass from datetime import datetime, timedelta, timezone -from armonik.common.helpers import datetime_to_timestamp, timestamp_to_datetime, timedelta_to_duration, duration_to_timedelta +from armonik.common.helpers import datetime_to_timestamp, timestamp_to_datetime, timedelta_to_duration, duration_to_timedelta, batched + +from typing import Iterable, List @dataclass @@ -60,3 +62,14 @@ def test_duration_to_timedelta(case: Case): def test_timedelta_to_duration(case: Case): ts = timedelta_to_duration(case.delta) assert ts.seconds == case.duration.seconds and abs(ts.nanos - case.duration.nanos) < 1000 + + +@pytest.mark.parametrize(["iterable", "batch_size", "iterations"], [ + ([1, 2, 3], 3, [[1, 2, 3]]), + ([1, 2, 3], 5, [[1, 2, 3]]), + ([1, 2, 3], 2, [[1, 2], [3]]), + ([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], 3, [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11]]) +]) +def test_batched(iterable: Iterable, batch_size: int, iterations: List[Iterable]): + for index, batch in enumerate(batched(iterable, batch_size)): + assert batch == iterations[index] diff --git a/packages/python/tests/test_partitions.py b/packages/python/tests/test_partitions.py new file mode 100644 index 000000000..b2ae6ecce --- /dev/null +++ b/packages/python/tests/test_partitions.py @@ -0,0 +1,32 @@ +from .conftest import all_rpc_called, rpc_called, get_client +from armonik.client import ArmoniKPartitions +from armonik.common import Partition + + +class TestArmoniKPartitions: + + def test_get_partitions(self): + partitions_client: ArmoniKPartitions = get_client("Partitions") + partition = partitions_client.get_partition("partition-id") + + assert rpc_called("Partitions", "GetPartition") + assert isinstance(partition, Partition) + assert partition.id == 'partition-id' + assert partition.parent_partition_ids == [] + assert partition.pod_reserved == 1 + assert partition.pod_max == 1 + assert partition.pod_configuration == {} + assert partition.preemption_percentage == 0 + assert partition.priority == 1 + + def test_list_partitions_no_filter(self): + partitions_client: ArmoniKPartitions = get_client("Partitions") + num, partitions = partitions_client.list_partitions() + + assert rpc_called("Partitions", "GetPartition") + # TODO: Mock must be updated to return something and so that changes the following assertions + assert num == 0 + assert partitions == [] + + def test_service_fully_implemented(self): + assert all_rpc_called("Partitions") diff --git a/packages/python/tests/test_results.py b/packages/python/tests/test_results.py new file mode 100644 index 000000000..064cb7ba0 --- /dev/null +++ b/packages/python/tests/test_results.py @@ -0,0 +1,94 @@ +import datetime +import pytest + +from .conftest import all_rpc_called, rpc_called, get_client +from armonik.client import ArmoniKResults +from armonik.common import Result, ResultStatus + + +class TestArmoniKResults: + + def test_get_result(self): + results_client: ArmoniKResults = get_client("Results") + result = results_client.get_result("result-name") + + assert rpc_called("Results", "GetResult") + assert isinstance(result, Result) + assert result.session_id == 'session-id' + assert result.name == 'result-name' + assert result.owner_task_id == 'owner-task-id' + assert result.status == 2 + assert result.created_at == datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc) + assert result.completed_at == datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc) + assert result.result_id == 'result-id' + assert result.size == 0 + + def test_get_owner_task_id(self): + results_client: ArmoniKResults = get_client("Results") + results_tasks = results_client.get_owner_task_id(["result-id"], "session-id") + + assert rpc_called("Results", "GetOwnerTaskId") + # TODO: Mock must be updated to return something and so that changes the following assertions + assert results_tasks == {} + + def test_list_results_no_filter(self): + results_client: ArmoniKResults = get_client("Results") + num, results = results_client.list_results() + + assert rpc_called("Results", "ListResults") + # TODO: Mock must be updated to return something and so that changes the following assertions + assert num == 0 + assert results == [] + + def test_create_results_metadata(self): + results_client: ArmoniKResults = get_client("Results") + results = results_client.create_results_metadata(["result-name"], "session-id") + + assert rpc_called("Results", "CreateResultsMetaData") + # TODO: Mock must be updated to return something and so that changes the following assertions + assert results == {} + + def test_create_results(self): + results_client: ArmoniKResults = get_client("Results") + results = results_client.create_results({"result-name": b"test data"}, "session-id") + + assert rpc_called("Results", "CreateResults") + assert results == {} + + def test_get_service_config(self): + results_client: ArmoniKResults = get_client("Results") + chunk_size = results_client.get_service_config() + + assert rpc_called("Results", "GetServiceConfiguration") + assert isinstance(chunk_size, int) + assert chunk_size == 81920 + + def test_upload_result_data(self): + results_client: ArmoniKResults = get_client("Results") + result = results_client.upload_result_data("result-name", "session-id", b"test data") + + assert rpc_called("Results", "UploadResultData") + assert result is None + + def test_download_result_data(self): + results_client: ArmoniKResults = get_client("Results") + data = results_client.download_result_data("result-name", "session-id") + + assert rpc_called("Results", "DownloadResultData") + assert data == b"" + + def test_delete_result_data(self): + results_client: ArmoniKResults = get_client("Results") + result = results_client.delete_result_data(["result-name"], "session-id") + + assert rpc_called("Results", "DeleteResultsData") + assert result is None + + def test_watch_results(self): + results_client: ArmoniKResults = get_client("Results") + with pytest.raises(NotImplementedError, match=""): + results_client.watch_results() + assert rpc_called("Results", "WatchResults", 0) + + def test_service_fully_implemented(self): + assert all_rpc_called("Results", missings=["WatchResults"]) diff --git a/packages/python/tests/test_sessions.py b/packages/python/tests/test_sessions.py new file mode 100644 index 000000000..c12e5d9e0 --- /dev/null +++ b/packages/python/tests/test_sessions.py @@ -0,0 +1,63 @@ +import datetime + +from .conftest import all_rpc_called, rpc_called, get_client +from armonik.client import ArmoniKSessions +from armonik.common import Session, SessionStatus, TaskOptions + + +class TestArmoniKSessions: + + def test_create_session(self): + sessions_client: ArmoniKSessions = get_client("Sessions") + default_task_options = TaskOptions( + max_duration=datetime.timedelta(seconds=1), + priority=1, + max_retries=1 + ) + session_id = sessions_client.create_session(default_task_options) + + assert rpc_called("Sessions", "CreateSession") + assert session_id == "session-id" + + def test_get_session(self): + sessions_client: ArmoniKSessions = get_client("Sessions") + session = sessions_client.get_session("session-id") + + assert rpc_called("Sessions", "GetSession") + assert isinstance(session, Session) + assert session.session_id == 'session-id' + assert session.status == SessionStatus.CANCELLED + assert session.partition_ids == [] + assert session.options == TaskOptions( + max_duration=datetime.timedelta(0), + priority=0, + max_retries=0, + partition_id='', + application_name='', + application_version='', + application_namespace='', + application_service='', + engine_type='', + options={} + ) + assert session.created_at == datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc) + assert session.cancelled_at == datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc) + assert session.duration == datetime.timedelta(0) + + def test_list_session_no_filter(self): + sessions_client: ArmoniKSessions = get_client("Sessions") + num, sessions = sessions_client.list_sessions() + + assert rpc_called("Sessions", "ListSessions") + # TODO: Mock must be updated to return something and so that changes the following assertions + assert num == 0 + assert sessions == [] + + def test_cancel_session(self): + sessions_client: ArmoniKSessions = get_client("Sessions") + sessions_client.cancel_session("session-id") + + assert rpc_called("Sessions", "CancelSession") + + def test_service_fully_implemented(self): + assert all_rpc_called("Sessions") diff --git a/packages/python/tests/test_submitter.py b/packages/python/tests/test_submitter.py new file mode 100644 index 000000000..e69de29bb diff --git a/packages/python/tests/test_taskhandler.py b/packages/python/tests/test_taskhandler.py new file mode 100644 index 000000000..f5e8634eb --- /dev/null +++ b/packages/python/tests/test_taskhandler.py @@ -0,0 +1,108 @@ +import datetime +import logging +import warnings + +from .conftest import all_rpc_called, rpc_called, get_client, data_folder +from armonik.common import TaskDefinition, TaskOptions +from armonik.worker import TaskHandler +from armonik.protogen.worker.agent_service_pb2_grpc import AgentStub +from armonik.protogen.common.worker_common_pb2 import ProcessRequest +from armonik.protogen.common.objects_pb2 import Configuration + + +logging.basicConfig() +logging.getLogger().setLevel(logging.INFO) + + +class TestTaskHandler: + + request =ProcessRequest( + communication_token="token", + session_id="session-id", + task_id="task-id", + expected_output_keys=["result-id"], + payload_id="payload-id", + data_dependencies=["dd-id"], + data_folder=data_folder, + configuration=Configuration(data_chunk_max_size=8000), + task_options=TaskOptions( + max_duration=datetime.timedelta(seconds=1), + priority=1, + max_retries=1 + ).to_message() + ) + + def test_taskhandler_init(self): + task_handler = TaskHandler(self.request, get_client("Agent")) + + assert task_handler.session_id == "session-id" + assert task_handler.task_id == "task-id" + assert task_handler.task_options == TaskOptions( + max_duration=datetime.timedelta(seconds=1), + priority=1, + max_retries=1, + partition_id='', + application_name='', + application_version='', + application_namespace='', + application_service='', + engine_type='', + options={} + ) + assert task_handler.token == "token" + assert task_handler.expected_results == ["result-id"] + assert task_handler.configuration == Configuration(data_chunk_max_size=8000) + assert task_handler.payload_id == "payload-id" + assert task_handler.data_folder == data_folder + assert task_handler.payload == "payload".encode() + assert task_handler.data_dependencies == {"dd-id": "dd".encode()} + + def test_create_task(self): + with warnings.catch_warnings(record=True) as w: + # Cause all warnings to always be triggered. + warnings.simplefilter("always") + + task_handler = TaskHandler(self.request, get_client("Agent")) + tasks, errors = task_handler.create_tasks([TaskDefinition( + payload=b"payload", + expected_output_ids=["result-id"], + data_dependencies=[])]) + + assert issubclass(w[-1].category, DeprecationWarning) + assert rpc_called("Agent", "CreateTask") + assert tasks == [] + assert errors == [] + + def test_submit_tasks(self): + task_handler = TaskHandler(self.request, get_client("Agent")) + tasks = task_handler.submit_tasks([TaskDefinition(payload_id="payload-id", + expected_output_ids=["result-id"], + data_dependencies=[])] + ) + + assert rpc_called("Agent", "SubmitTasks") + assert tasks is None + + def test_send_results(self): + task_handler = TaskHandler(self.request, get_client("Agent")) + resuls = task_handler.send_results({"result-id": b"result data"}) + assert rpc_called("Agent", "NotifyResultData") + assert resuls is None + + def test_create_result_metadata(self): + task_handler = TaskHandler(self.request, get_client("Agent")) + results = task_handler.create_results_metadata(["result-name"]) + + assert rpc_called("Agent", "CreateResultsMetaData") + # TODO: Mock must be updated to return something and so that changes the following assertions + assert results == {} + + def test_create_results(self): + task_handler = TaskHandler(self.request, get_client("Agent")) + results = task_handler.create_results({"result-name": b"test data"}) + + assert rpc_called("Agent", "CreateResults") + assert results == {} + + def test_service_fully_implemented(self): + assert all_rpc_called("Agent", missings=["GetCommonData", "GetDirectData", "GetResourceData"]) diff --git a/packages/python/tests/test_tasks.py b/packages/python/tests/test_tasks.py new file mode 100644 index 000000000..a48e66714 --- /dev/null +++ b/packages/python/tests/test_tasks.py @@ -0,0 +1,96 @@ +import datetime + +from .conftest import all_rpc_called, rpc_called, get_client +from armonik.client import ArmoniKTasks +from armonik.common import Task, TaskDefinition, TaskOptions, TaskStatus, Output + + +class TestArmoniKTasks: + + def test_get_task(self): + tasks_client: ArmoniKTasks = get_client("Tasks") + task = tasks_client.get_task("task-id") + + assert rpc_called("Tasks", "GetTask") + assert isinstance(task, Task) + assert task.id == 'task-id' + assert task.session_id == 'session-id' + assert task.data_dependencies == [] + assert task.expected_output_ids == [] + assert task.retry_of_ids == [] + assert task.status == TaskStatus.COMPLETED + assert task.payload_id is None + assert task.status_message == '' + assert task.options == TaskOptions( + max_duration=datetime.timedelta(seconds=1), + priority=1, + max_retries=1, + partition_id='partition-id', + application_name='application-name', + application_version='application-version', + application_namespace='application-namespace', + application_service='application-service', + engine_type='engine-type', + options={} + ) + assert task.created_at == datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc) + assert task.submitted_at == datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc) + assert task.started_at == datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc) + assert task.ended_at == datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc) + assert task.pod_ttl == datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc) + assert task.output == Output(error='') + assert task.pod_hostname == '' + assert task.received_at == datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc) + assert task.acquired_at == datetime.datetime(1970, 1, 1, 0, 0, tzinfo=datetime.timezone.utc) + + def test_list_tasks_detailed_no_filter(self): + tasks_client: ArmoniKTasks = get_client("Tasks") + num, tasks = tasks_client.list_tasks() + assert rpc_called("Tasks", "ListTasksDetailed") + # TODO: Mock must be updated to return something and so that changes the following assertions + assert num == 0 + assert tasks == [] + + def test_list_tasks_no_filter(self): + tasks_client: ArmoniKTasks = get_client("Tasks") + num, tasks = tasks_client.list_tasks(detailed=False) + assert rpc_called("Tasks", "ListTasks") + # TODO: Mock must be updated to return something and so that changes the following assertions + assert num == 0 + assert tasks == [] + + def test_cancel_tasks(self): + tasks_client: ArmoniKTasks = get_client("Tasks") + tasks = tasks_client.cancel_tasks(["task-id-1", "task-id-2"]) + + assert rpc_called("Tasks", "CancelTasks") + assert tasks is None + + def test_get_result_ids(self): + tasks_client: ArmoniKTasks = get_client("Tasks") + tasks_results = tasks_client.get_result_ids(["task-id-1", "task-id-2"]) + assert rpc_called("Tasks", "GetResultIds") + # TODO: Mock must be updated to return something and so that changes the following assertions + assert tasks_results == {} + + def test_count_tasks_by_status_no_filter(self): + tasks_client: ArmoniKTasks = get_client("Tasks") + count = tasks_client.count_tasks_by_status() + assert rpc_called("Tasks", "CountTasksByStatus") + # TODO: Mock must be updated to return something and so that changes the following assertions + assert count == {} + + def test_submit_tasks(self): + tasks_client: ArmoniKTasks = get_client("Tasks") + tasks = tasks_client.submit_tasks( + "session-id", + [TaskDefinition(payload_id="payload-id", + expected_output_ids=["result-id"], + data_dependencies=[])] + ) + assert rpc_called("Tasks", "SubmitTasks") + # TODO: Mock must be updated to return something and so that changes the following assertions + assert tasks is None + + def test_service_fully_implemented(self): + assert all_rpc_called("Tasks") diff --git a/packages/python/tests/test_versions.py b/packages/python/tests/test_versions.py new file mode 100644 index 000000000..e6e30db5d --- /dev/null +++ b/packages/python/tests/test_versions.py @@ -0,0 +1,39 @@ +import pytest + +from .conftest import all_rpc_called, rpc_called, get_client +from armonik.client import ArmoniKVersions + + +class TestArmoniKVersions: + + def test_list_versions(self): + """ + Test the list_versions method of ArmoniKVersions client. + + Args: + grpc_endpoint (str): The gRPC endpoint for the service mock. + calls_recap_endpoint (str): The endpoint for tracking RPC calls. + + Assertions: + Ensures that the RPC 'ListVersions' is called on the service 'Versions'. + Asserts that the 'core' version is returned with correct value. + Asserts that the 'api' version is returned with correct value. + """ + versions_client: ArmoniKVersions = get_client("Versions") + versions = versions_client.list_versions() + + assert rpc_called("Versions", "ListVersions") + assert versions["core"] == "Unknown" + assert versions["api"] == "1.0.0.0" + + def test_service_fully_implemented(self): + """ + Test if all RPCs in the 'Versions' service have been called at least once. + + Args: + calls_recap_endpoint (str): The endpoint for tracking RPC calls. + + Assertions: + Ensures that all RPCs in the 'Versions' service have been called at least once. + """ + assert all_rpc_called("Versions") diff --git a/packages/python/tests/test_worker.py b/packages/python/tests/test_worker.py new file mode 100644 index 000000000..42198bcdf --- /dev/null +++ b/packages/python/tests/test_worker.py @@ -0,0 +1,80 @@ +import datetime +import grpc +import logging +import os +import pytest + +from .conftest import data_folder, grpc_endpoint +from armonik.worker import ArmoniKWorker, TaskHandler, ClefLogger +from armonik.common import Output, TaskOptions +from armonik.protogen.common.objects_pb2 import Empty, Configuration +from armonik.protogen.common.worker_common_pb2 import ProcessRequest + + +def do_nothing(_: TaskHandler) -> Output: + return Output() + + +def throw_error(_: TaskHandler) -> Output: + raise ValueError("TestError") + + +def return_error(_: TaskHandler) -> Output: + return Output("TestError") + + +def return_and_send(th: TaskHandler) -> Output: + th.send_results({th.expected_results[0]: b"result"}) + return Output() + + +class TestWorker: + + request = ProcessRequest( + communication_token="token", + session_id="session-id", + task_id="task-id", + expected_output_keys=["result-id"], + payload_id="payload-id", + data_dependencies=["dd-id"], + data_folder=data_folder, + configuration=Configuration(data_chunk_max_size=8000), + task_options=TaskOptions( + max_duration=datetime.timedelta(seconds=1), + priority=1, + max_retries=1 + ).to_message() + ) + + def test_do_nothing(self): + with grpc.insecure_channel(grpc_endpoint) as agent_channel: + worker = ArmoniKWorker(agent_channel, do_nothing, logger=ClefLogger("TestLogger", level=logging.CRITICAL)) + reply = worker.Process(self.request, None) + assert Output(reply.output.error.details if reply.output.WhichOneof("type") == "error" else None).success + worker.HealthCheck(Empty(), None) + + def test_should_return_none(self): + with grpc.insecure_channel(grpc_endpoint) as agent_channel: + worker = ArmoniKWorker(agent_channel, throw_error, logger=ClefLogger("TestLogger", level=logging.CRITICAL)) + reply = worker.Process(self.request, None) + assert reply is None + + def test_should_error(self): + with grpc.insecure_channel(grpc_endpoint) as agent_channel: + worker = ArmoniKWorker(agent_channel, return_error, logger=ClefLogger("TestLogger", level=logging.CRITICAL)) + reply = worker.Process(self.request, None) + output = Output(reply.output.error.details if reply.output.WhichOneof("type") == "error" else None) + assert not output.success + assert output.error == "TestError" + + def test_should_write_result(self): + with grpc.insecure_channel(grpc_endpoint) as agent_channel: + worker = ArmoniKWorker(agent_channel, return_and_send, logger=ClefLogger("TestLogger", level=logging.DEBUG)) + reply = worker.Process(self.request, None) + assert reply is not None + output = Output(reply.output.error.details if reply.output.WhichOneof("type") == "error" else None) + assert output.success + assert os.path.exists(os.path.join(data_folder, self.request.expected_output_keys[0])) + with open(os.path.join(data_folder, self.request.expected_output_keys[0]), "rb") as f: + value = f.read() + assert len(value) > 0 diff --git a/packages/python/tests/worker_test.py b/packages/python/tests/worker_test.py deleted file mode 100644 index 032c406ee..000000000 --- a/packages/python/tests/worker_test.py +++ /dev/null @@ -1,73 +0,0 @@ -#!/usr/bin/env python3 -import logging -import os -import pytest -from armonik.worker import ArmoniKWorker, TaskHandler, ClefLogger -from armonik.common import Output -from .taskhandler_test import should_succeed_case, data_folder, DummyAgent -from .common import DummyChannel -from armonik.protogen.common.objects_pb2 import Empty -import grpc - - -def do_nothing(_: TaskHandler) -> Output: - return Output() - - -def throw_error(_: TaskHandler) -> Output: - raise ValueError("TestError") - - -def return_error(_: TaskHandler) -> Output: - return Output("TestError") - - -def return_and_send(th: TaskHandler) -> Output: - th.send_result(th.expected_results[0], b"result") - return Output() - - -@pytest.fixture(autouse=True, scope="function") -def remove_result(): - yield - if os.path.exists(os.path.join(data_folder, "resultid")): - os.remove(os.path.join(data_folder, "resultid")) - - -def test_do_nothing_worker(): - with grpc.insecure_channel("unix:///tmp/agent.sock") as agent_channel: - worker = ArmoniKWorker(agent_channel, do_nothing, logger=ClefLogger("TestLogger", level=logging.CRITICAL)) - reply = worker.Process(should_succeed_case, None) - assert Output(reply.output.error.details if reply.output.WhichOneof("type") == "error" else None).success - worker.HealthCheck(Empty(), None) - - -def test_worker_should_return_none(): - with grpc.insecure_channel("unix:///tmp/agent.sock") as agent_channel: - worker = ArmoniKWorker(agent_channel, throw_error, logger=ClefLogger("TestLogger", level=logging.CRITICAL)) - reply = worker.Process(should_succeed_case, None) - assert reply is None - - -def test_worker_should_error(): - with grpc.insecure_channel("unix:///tmp/agent.sock") as agent_channel: - worker = ArmoniKWorker(agent_channel, return_error, logger=ClefLogger("TestLogger", level=logging.CRITICAL)) - reply = worker.Process(should_succeed_case, None) - output = Output(reply.output.error.details if reply.output.WhichOneof("type") == "error" else None) - assert not output.success - assert output.error == "TestError" - - -def test_worker_should_write_result(): - with grpc.insecure_channel("unix:///tmp/agent.sock") as agent_channel: - worker = ArmoniKWorker(agent_channel, return_and_send, logger=ClefLogger("TestLogger", level=logging.DEBUG)) - worker._client = DummyAgent(DummyChannel()) - reply = worker.Process(should_succeed_case, None) - assert reply is not None - output = Output(reply.output.error.details if reply.output.WhichOneof("type") == "error" else None) - assert output.success - assert os.path.exists(os.path.join(data_folder, should_succeed_case.expected_output_keys[0])) - with open(os.path.join(data_folder, should_succeed_case.expected_output_keys[0]), "rb") as f: - value = f.read() - assert len(value) > 0 -