From 83d2fe884326d16e91814a3f92f6a50e6a65128b Mon Sep 17 00:00:00 2001 From: octopus <912554887@qq.com> Date: Fri, 27 Sep 2024 09:46:39 +0800 Subject: [PATCH] fix model client bug --- .../async_thread_executor.py | 2 +- .../model_node_client.py | 38 +++++++++++-------- .../node/computing_node_client/psi_client.py | 6 ++- python/ppc_scheduler/node/node_manager.py | 1 + .../workflow/common/worker_type.py | 1 + .../workflow/worker/engine/model_engine.py | 3 +- .../workflow/worker/model_worker.py | 6 +-- .../workflow/worker/worker_factory.py | 3 +- 8 files changed, 36 insertions(+), 24 deletions(-) diff --git a/python/ppc_common/ppc_async_executor/async_thread_executor.py b/python/ppc_common/ppc_async_executor/async_thread_executor.py index fe012878..15b1f5ca 100644 --- a/python/ppc_common/ppc_async_executor/async_thread_executor.py +++ b/python/ppc_common/ppc_async_executor/async_thread_executor.py @@ -69,7 +69,7 @@ def _cleanup_finished_threads(self): for target_id in finished_threads: with self.lock: del self.threads[target_id] - self.logger.info(f"cleanup finished thread {target_id}") + self.logger.info(f"Cleanup finished thread {target_id}") def __del__(self): self.kill_all() diff --git a/python/ppc_scheduler/node/computing_node_client/model_node_client.py b/python/ppc_scheduler/node/computing_node_client/model_node_client.py index 64a14120..0bb9d9fe 100644 --- a/python/ppc_scheduler/node/computing_node_client/model_node_client.py +++ b/python/ppc_scheduler/node/computing_node_client/model_node_client.py @@ -1,3 +1,4 @@ +import json import time from ppc_common.ppc_utils import http_utils @@ -8,40 +9,47 @@ class ModelClient: - def __init__(self, logger, endpoint, polling_interval_s: int = 5, max_retries: int = 5, retry_delay_s: int = 5): + def __init__(self, logger, endpoint, token, polling_interval_s: int = 5, max_retries: int = 5, retry_delay_s: int = 5): self.logger = logger self.endpoint = endpoint + self.token = token self.polling_interval_s = polling_interval_s self.max_retries = max_retries self.retry_delay_s = retry_delay_s self._completed_status = 'COMPLETED' self._failed_status = 'FAILED' - def run(self, args): - task_id = args['task_id'] + def run(self, *args): + + params = args[0] + if type(params) == str: + params = json.loads(params) + + task_id = params['task_id'] + try: - self.logger.info(f"ModelApi: begin to run model task {task_id}") + self.logger.info(f"model client begin to run model task {task_id}") response = self._send_request_with_retry(http_utils.send_post_request, endpoint=self.endpoint, uri=RUN_MODEL_API_PREFIX + task_id, - params=args) + params=params) check_response(response) return self._poll_task_status(task_id) except Exception as e: - self.logger.error(f"ModelApi: run model task error, task: {task_id}, error: {e}") + self.logger.error(f"model client run model task exception, task: {task_id}, e: {e}") raise e - def kill(self, job_id): + def kill(self, task_id): try: - self.logger.info(f"ModelApi: begin to kill model task {job_id}") + self.logger.info(f"model client begin to kill model task {task_id}") response = self._send_request_with_retry(http_utils.send_delete_request, endpoint=self.endpoint, - uri=RUN_MODEL_API_PREFIX + job_id) + uri=RUN_MODEL_API_PREFIX + task_id) check_response(response) - self.logger.info(f"ModelApi: model task {job_id} was killed") + self.logger.info(f"model client task {task_id} was killed") return response except Exception as e: - self.logger.warn(f"ModelApi: kill model task {job_id} failed, error: {e}") + self.logger.warn(f"model client kill task {task_id} exception, e: {e}") raise e def _poll_task_status(self, task_id): @@ -51,18 +59,18 @@ def _poll_task_status(self, task_id): uri=RUN_MODEL_API_PREFIX + task_id) check_response(response) if response['data']['status'] == self._completed_status: - self.logger.info(f"task {task_id} completed, response: {response['data']}") + self.logger.info(f"model client task {task_id} completed, response: {response['data']}") return response elif response['data']['status'] == self._failed_status: - self.logger.warn(f"task {task_id} failed, response: {response['data']}") + self.logger.warn(f"model client task {task_id} failed, response: {response['data']}") raise PpcException(PpcErrorCode.CALL_SCS_ERROR.get_code(), response['data']) else: time.sleep(self.polling_interval_s) - def get_remote_log(self, job_id): + def get_remote_log(self, remote_id): response = self._send_request_with_retry(http_utils.send_get_request, endpoint=self.endpoint, - uri=GET_MODEL_LOG_API_PREFIX + job_id) + uri=GET_MODEL_LOG_API_PREFIX + remote_id) check_response(response) return response['data'] diff --git a/python/ppc_scheduler/node/computing_node_client/psi_client.py b/python/ppc_scheduler/node/computing_node_client/psi_client.py index 43fbd5f1..7dcd5ba3 100644 --- a/python/ppc_scheduler/node/computing_node_client/psi_client.py +++ b/python/ppc_scheduler/node/computing_node_client/psi_client.py @@ -25,6 +25,8 @@ def run(self, *args): if type(params) == str: params = json.loads(params) + task_id = params['taskID'] + json_rpc_request = { 'jsonrpc': '1', 'method': 'asyncRunTask', @@ -34,9 +36,9 @@ def run(self, *args): } response = self._send_request_with_retry(http_utils.send_post_request, self.endpoint, None, json_rpc_request) check_privacy_service_response(response) - return self._poll_task_status(params['taskID']) + return self._poll_task_status(task_id) - def _poll_task_status(self, task_id): + def _poll_task_status(self, task_id: str): while True: params = { 'jsonrpc': '1', diff --git a/python/ppc_scheduler/node/node_manager.py b/python/ppc_scheduler/node/node_manager.py index fa2b2247..cadf3e29 100644 --- a/python/ppc_scheduler/node/node_manager.py +++ b/python/ppc_scheduler/node/node_manager.py @@ -5,6 +5,7 @@ class ComputingNodeManager: type_map = { WorkerType.T_PSI: 'PSI', + WorkerType.T_ML_PSI: 'PSI', WorkerType.T_MPC: 'MPC', WorkerType.T_PREPROCESSING: 'MODEL', WorkerType.T_FEATURE_ENGINEERING: 'MODEL', diff --git a/python/ppc_scheduler/workflow/common/worker_type.py b/python/ppc_scheduler/workflow/common/worker_type.py index 54ccd245..6f430978 100644 --- a/python/ppc_scheduler/workflow/common/worker_type.py +++ b/python/ppc_scheduler/workflow/common/worker_type.py @@ -6,6 +6,7 @@ class WorkerType: # specific job worker T_PSI = 'PSI' + T_ML_PSI = 'ML_PSI' T_MPC = 'MPC' T_PREPROCESSING = 'PREPROCESSING' T_FEATURE_ENGINEERING = 'FEATURE_ENGINEERING' diff --git a/python/ppc_scheduler/workflow/worker/engine/model_engine.py b/python/ppc_scheduler/workflow/worker/engine/model_engine.py index 68084783..8f2b24ba 100644 --- a/python/ppc_scheduler/workflow/worker/engine/model_engine.py +++ b/python/ppc_scheduler/workflow/worker/engine/model_engine.py @@ -1,4 +1,3 @@ -import os import time from ppc_scheduler.workflow.common.job_context import JobContext @@ -29,7 +28,7 @@ def run(self, *args) -> list: self.logger.info(f"## model engine run begin, job_id={job_id}, worker_id={self.worker_id}, args: {args}") # send job request to model node and wait for the job to finish - # self.psi_client.run(*args) + self.model_client.run(*args) time_costs = time.time() - start_time self.logger.info(f"## model engine run finished, job_id={job_id}, timecost: {time_costs}s") diff --git a/python/ppc_scheduler/workflow/worker/model_worker.py b/python/ppc_scheduler/workflow/worker/model_worker.py index 6be85fe4..0268e705 100644 --- a/python/ppc_scheduler/workflow/worker/model_worker.py +++ b/python/ppc_scheduler/workflow/worker/model_worker.py @@ -9,11 +9,11 @@ def __init__(self, components, job_context, worker_id, worker_type, worker_args, super().__init__(components, job_context, worker_id, worker_type, worker_args, *args, **kwargs) def engine_run(self, worker_inputs): - node_endpoint = self.node_manager.get_node(self.worker_type) - model_client = ModelClient(self.components.logger(), node_endpoint) + model_client_node = self.node_manager.get_node(self.worker_type) + model_client = ModelClient(self.components.logger(), model_client_node[0], model_client_node[1]) model_engine = ModelWorkerEngine(model_client, self.worker_type, self.worker_id, self.components, self.job_context) try: outputs = model_engine.run(*self.worker_args) return outputs finally: - self.node_manager.release_node(node_endpoint, self.worker_type) + self.node_manager.release_node(model_client_node, self.worker_type) diff --git a/python/ppc_scheduler/workflow/worker/worker_factory.py b/python/ppc_scheduler/workflow/worker/worker_factory.py index 2518a89b..9b15db5d 100644 --- a/python/ppc_scheduler/workflow/worker/worker_factory.py +++ b/python/ppc_scheduler/workflow/worker/worker_factory.py @@ -20,7 +20,8 @@ def build_worker(job_context, worker_id, worker_type, worker_args, *args, **kwar return PythonWorker(components, job_context, worker_id, worker_type, worker_args, *args, *kwargs) elif worker_type == WorkerType.T_SHELL: return ShellWorker(components, job_context, worker_id, worker_type, worker_args, *args, **kwargs) - elif worker_type == WorkerType.T_PSI: + elif worker_type == WorkerType.T_PSI or \ + worker_type == WorkerType.T_ML_PSI: return PsiWorker(components, job_context, worker_id, worker_type, worker_args, *args, **kwargs) elif worker_type == WorkerType.T_MPC: return MpcWorker(components, job_context, worker_id, worker_type, worker_args, *args, **kwargs)