diff --git a/packages/python/src/armonik/client/sessions.py b/packages/python/src/armonik/client/sessions.py index 02b5dcf7e..bfff9c66f 100644 --- a/packages/python/src/armonik/client/sessions.py +++ b/packages/python/src/armonik/client/sessions.py @@ -9,9 +9,20 @@ from ..protogen.client.sessions_service_pb2_grpc import SessionsStub from ..protogen.common.sessions_common_pb2 import ( CancelSessionRequest, + CancelSessionResponse, CreateSessionRequest, + DeleteSessionRequest, + DeleteSessionResponse, GetSessionRequest, GetSessionResponse, + PauseSessionRequest, + PauseSessionResponse, + PurgeSessionRequest, + PurgeSessionResponse, + ResumeSessionRequest, + ResumeSessionResponse, + StopSubmissionRequest, + StopSubmissionResponse, ListSessionsRequest, ListSessionsResponse, ) @@ -148,10 +159,79 @@ def list_sessions( response: ListSessionsResponse = self._client.ListSessions(request) return response.total, [Session.from_message(s) for s in response.sessions] - def cancel_session(self, session_id: str) -> None: + def cancel_session(self, session_id: str) -> Session: """Cancel a session Args: - session_id: Id of the session to b cancelled + session_id: Id of the session to be cancelled """ - self._client.CancelSession(CancelSessionRequest(session_id=session_id)) + request = CancelSessionRequest(session_id=session_id) + response: CancelSessionResponse = self._client.CancelSession(request) + return Session.from_message(response.session) + + def pause_session(self, session_id: str) -> Session: + """Pause a session by its id. + + Args: + session_id: Id of the session to be paused. + + Returns: + session metadata + """ + request = PauseSessionRequest(session_id=session_id) + response: PauseSessionResponse = self._client.PauseSession(request) + return Session.from_message(response.session) + + def resume_session(self, session_id: str) -> Session: + """Resume a session by its id. + + Args: + session_id: Id of the session to be resumed. + + Returns: + session metadata + """ + request = ResumeSessionRequest(session_id=session_id) + response: ResumeSessionResponse = self._client.ResumeSession(request) + return Session.from_message(response.session) + + def purge_session(self, session_id: str) -> Session: + """Purge a session by its id. + + Args: + session_id: Id of the session to be purged. + + Returns: + session metadata + """ + request = PurgeSessionRequest(session_id=session_id) + response: PurgeSessionResponse = self._client.PurgeSession(request) + return Session.from_message(response.session) + + def delete_session(self, session_id: str) -> Session: + """Delete a session by its id. + + Args: + session_id: Id of the session to be deleted. + + Returns: + session metadata + """ + request = DeleteSessionRequest(session_id=session_id) + response: DeleteSessionResponse = self._client.DeleteSession(request) + return Session.from_message(response.session) + + def stop_submission_session(self, session_id: str, client: bool, worker:bool) -> Session: + """Stops clients and/or workers from submitting new tasks in the given session. + + Args: + session_id: Id of the session. + client: Stops clients from submitting new tasks in the given session. + worker: Stops workers from submitting new tasks in the given session. + + Returns: + session metadata + """ + request = StopSubmissionRequest(session_id=session_id, client=client, worker=worker) + response: StopSubmissionResponse = self._client.StopSubmission(request) + return Session.from_message(response.session) \ No newline at end of file diff --git a/packages/python/src/armonik/common/objects.py b/packages/python/src/armonik/common/objects.py index e34596c82..2e4452503 100644 --- a/packages/python/src/armonik/common/objects.py +++ b/packages/python/src/armonik/common/objects.py @@ -181,10 +181,14 @@ def is_available(self) -> bool: class Session: session_id: Optional[str] = None status: RawSessionStatus = SessionStatus.UNSPECIFIED + client_submission: Optional[bool] = None + worker_submission: Optional[bool] = None partition_ids: List[str] = field(default_factory=list) options: Optional[TaskOptions] = None created_at: Optional[datetime] = None cancelled_at: Optional[datetime] = None + purged_at: Optional[datetime] = None + deleted_at: Optional[datetime] = None duration: Optional[timedelta] = None @classmethod @@ -192,10 +196,14 @@ def from_message(cls, session_raw: SessionRaw) -> "Session": return cls( session_id=session_raw.session_id, status=session_raw.status, + client_submission=session_raw.client_submission, + worker_submission=session_raw.worker_submission, partition_ids=list(session_raw.partition_ids), options=TaskOptions.from_message(session_raw.options), created_at=timestamp_to_datetime(session_raw.created_at), cancelled_at=timestamp_to_datetime(session_raw.cancelled_at), + purged_at=timestamp_to_datetime(session_raw.purged_at), + deleted_at=timestamp_to_datetime(session_raw.deleted_at), duration=duration_to_timedelta(session_raw.duration), ) diff --git a/packages/python/tests/test_sessions.py b/packages/python/tests/test_sessions.py index 108971b0c..fde5240e3 100644 --- a/packages/python/tests/test_sessions.py +++ b/packages/python/tests/test_sessions.py @@ -71,5 +71,35 @@ def test_cancel_session(self): assert rpc_called("Sessions", "CancelSession") + def test_pause_session(self): + session_client: ArmoniKSessions = get_client("Sessions") + session_client.pause_session("session-id") + + assert rpc_called("Sessions", "PauseSession") + + def test_resume_session(self): + session_client: ArmoniKSessions = get_client("Sessions") + session_client.resume_session("session-id") + + assert rpc_called("Sessions", "ResumeSession") + + def test_purge_session(self): + session_client: ArmoniKSessions = get_client("Sessions") + session_client.purge_session("session-id") + + assert rpc_called("Sessions", "PurgeSession") + + def test_delete_session(self): + session_client: ArmoniKSessions = get_client("Sessions") + session_client.delete_session("session-id") + + assert rpc_called("Sessions", "DeleteSession") + + def test_stop_submission_session(self): + session_client: ArmoniKSessions = get_client("Sessions") + session_client.stop_submission_session("session-id", True, True) + + assert rpc_called("Sessions", "StopSubmission") + def test_service_fully_implemented(self): assert all_rpc_called("Sessions")