Skip to content

Commit

Permalink
fix model client bug
Browse files Browse the repository at this point in the history
  • Loading branch information
ywy2090 committed Sep 27, 2024
1 parent cd7f122 commit 83d2fe8
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _cleanup_finished_threads(self):
for target_id in finished_threads:
with self.lock:
del self.threads[target_id]
self.logger.info(f"cleanup finished thread {target_id}")
self.logger.info(f"Cleanup finished thread {target_id}")

def __del__(self):
self.kill_all()
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import time

from ppc_common.ppc_utils import http_utils
Expand All @@ -8,40 +9,47 @@


class ModelClient:
def __init__(self, logger, endpoint, polling_interval_s: int = 5, max_retries: int = 5, retry_delay_s: int = 5):
def __init__(self, logger, endpoint, token, polling_interval_s: int = 5, max_retries: int = 5, retry_delay_s: int = 5):
self.logger = logger
self.endpoint = endpoint
self.token = token
self.polling_interval_s = polling_interval_s
self.max_retries = max_retries
self.retry_delay_s = retry_delay_s
self._completed_status = 'COMPLETED'
self._failed_status = 'FAILED'

def run(self, args):
task_id = args['task_id']
def run(self, *args):

params = args[0]
if type(params) == str:
params = json.loads(params)

task_id = params['task_id']

try:
self.logger.info(f"ModelApi: begin to run model task {task_id}")
self.logger.info(f"model client begin to run model task {task_id}")
response = self._send_request_with_retry(http_utils.send_post_request,
endpoint=self.endpoint,
uri=RUN_MODEL_API_PREFIX + task_id,
params=args)
params=params)
check_response(response)
return self._poll_task_status(task_id)
except Exception as e:
self.logger.error(f"ModelApi: run model task error, task: {task_id}, error: {e}")
self.logger.error(f"model client run model task exception, task: {task_id}, e: {e}")
raise e

def kill(self, job_id):
def kill(self, task_id):
try:
self.logger.info(f"ModelApi: begin to kill model task {job_id}")
self.logger.info(f"model client begin to kill model task {task_id}")
response = self._send_request_with_retry(http_utils.send_delete_request,
endpoint=self.endpoint,
uri=RUN_MODEL_API_PREFIX + job_id)
uri=RUN_MODEL_API_PREFIX + task_id)
check_response(response)
self.logger.info(f"ModelApi: model task {job_id} was killed")
self.logger.info(f"model client task {task_id} was killed")
return response
except Exception as e:
self.logger.warn(f"ModelApi: kill model task {job_id} failed, error: {e}")
self.logger.warn(f"model client kill task {task_id} exception, e: {e}")
raise e

def _poll_task_status(self, task_id):
Expand All @@ -51,18 +59,18 @@ def _poll_task_status(self, task_id):
uri=RUN_MODEL_API_PREFIX + task_id)
check_response(response)
if response['data']['status'] == self._completed_status:
self.logger.info(f"task {task_id} completed, response: {response['data']}")
self.logger.info(f"model client task {task_id} completed, response: {response['data']}")
return response
elif response['data']['status'] == self._failed_status:
self.logger.warn(f"task {task_id} failed, response: {response['data']}")
self.logger.warn(f"model client task {task_id} failed, response: {response['data']}")
raise PpcException(PpcErrorCode.CALL_SCS_ERROR.get_code(), response['data'])
else:
time.sleep(self.polling_interval_s)

def get_remote_log(self, job_id):
def get_remote_log(self, remote_id):
response = self._send_request_with_retry(http_utils.send_get_request,
endpoint=self.endpoint,
uri=GET_MODEL_LOG_API_PREFIX + job_id)
uri=GET_MODEL_LOG_API_PREFIX + remote_id)
check_response(response)
return response['data']

Expand Down
6 changes: 4 additions & 2 deletions python/ppc_scheduler/node/computing_node_client/psi_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def run(self, *args):
if type(params) == str:
params = json.loads(params)

task_id = params['taskID']

json_rpc_request = {
'jsonrpc': '1',
'method': 'asyncRunTask',
Expand All @@ -34,9 +36,9 @@ def run(self, *args):
}
response = self._send_request_with_retry(http_utils.send_post_request, self.endpoint, None, json_rpc_request)
check_privacy_service_response(response)
return self._poll_task_status(params['taskID'])
return self._poll_task_status(task_id)

def _poll_task_status(self, task_id):
def _poll_task_status(self, task_id: str):
while True:
params = {
'jsonrpc': '1',
Expand Down
1 change: 1 addition & 0 deletions python/ppc_scheduler/node/node_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
class ComputingNodeManager:
type_map = {
WorkerType.T_PSI: 'PSI',
WorkerType.T_ML_PSI: 'PSI',
WorkerType.T_MPC: 'MPC',
WorkerType.T_PREPROCESSING: 'MODEL',
WorkerType.T_FEATURE_ENGINEERING: 'MODEL',
Expand Down
1 change: 1 addition & 0 deletions python/ppc_scheduler/workflow/common/worker_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ class WorkerType:

# specific job worker
T_PSI = 'PSI'
T_ML_PSI = 'ML_PSI'
T_MPC = 'MPC'
T_PREPROCESSING = 'PREPROCESSING'
T_FEATURE_ENGINEERING = 'FEATURE_ENGINEERING'
Expand Down
3 changes: 1 addition & 2 deletions python/ppc_scheduler/workflow/worker/engine/model_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import time

from ppc_scheduler.workflow.common.job_context import JobContext
Expand Down Expand Up @@ -29,7 +28,7 @@ def run(self, *args) -> list:
self.logger.info(f"## model engine run begin, job_id={job_id}, worker_id={self.worker_id}, args: {args}")

# send job request to model node and wait for the job to finish
# self.psi_client.run(*args)
self.model_client.run(*args)

time_costs = time.time() - start_time
self.logger.info(f"## model engine run finished, job_id={job_id}, timecost: {time_costs}s")
Expand Down
6 changes: 3 additions & 3 deletions python/ppc_scheduler/workflow/worker/model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ def __init__(self, components, job_context, worker_id, worker_type, worker_args,
super().__init__(components, job_context, worker_id, worker_type, worker_args, *args, **kwargs)

def engine_run(self, worker_inputs):
node_endpoint = self.node_manager.get_node(self.worker_type)
model_client = ModelClient(self.components.logger(), node_endpoint)
model_client_node = self.node_manager.get_node(self.worker_type)
model_client = ModelClient(self.components.logger(), model_client_node[0], model_client_node[1])
model_engine = ModelWorkerEngine(model_client, self.worker_type, self.worker_id, self.components, self.job_context)
try:
outputs = model_engine.run(*self.worker_args)
return outputs
finally:
self.node_manager.release_node(node_endpoint, self.worker_type)
self.node_manager.release_node(model_client_node, self.worker_type)
3 changes: 2 additions & 1 deletion python/ppc_scheduler/workflow/worker/worker_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def build_worker(job_context, worker_id, worker_type, worker_args, *args, **kwar
return PythonWorker(components, job_context, worker_id, worker_type, worker_args, *args, *kwargs)
elif worker_type == WorkerType.T_SHELL:
return ShellWorker(components, job_context, worker_id, worker_type, worker_args, *args, **kwargs)
elif worker_type == WorkerType.T_PSI:
elif worker_type == WorkerType.T_PSI or \
worker_type == WorkerType.T_ML_PSI:
return PsiWorker(components, job_context, worker_id, worker_type, worker_args, *args, **kwargs)
elif worker_type == WorkerType.T_MPC:
return MpcWorker(components, job_context, worker_id, worker_type, worker_args, *args, **kwargs)
Expand Down

0 comments on commit 83d2fe8

Please sign in to comment.