diff --git a/python/wedpr_ml_toolkit/job_exceuter/pws_client.py b/python/wedpr_ml_toolkit/job_exceuter/pws_client.py index 40df0238..08cb2e5d 100644 --- a/python/wedpr_ml_toolkit/job_exceuter/pws_client.py +++ b/python/wedpr_ml_toolkit/job_exceuter/pws_client.py @@ -1,3 +1,4 @@ +import json import random import time import requests @@ -40,8 +41,9 @@ def run(self, params): response = requests.request("POST", self.pws_url, json=payload, headers=headers) if response.status_code != 200: raise Exception(f"创建任务失败: {response.json()}") - return - # return self._poll_task_status(response.data, self.token) + print(response.text) + # self._poll_task_status(response.data, self.token) + return json.loads(response.text) def _poll_task_status(self, job_id, token): while True: diff --git a/python/wedpr_ml_toolkit/test/test_dev.py b/python/wedpr_ml_toolkit/test/test_dev.py index 9ccc864f..bfaf7a8d 100644 --- a/python/wedpr_ml_toolkit/test/test_dev.py +++ b/python/wedpr_ml_toolkit/test/test_dev.py @@ -13,16 +13,16 @@ # 从jupyter环境中获取project_id等信息 # create workspace # 相同项目/刷新专家模式project_id固定 -project_id = 'p-123' +project_id = '测试-xinyi' user = 'flyhuang1' -my_agency='sgd' +my_agency='SGD' pws_endpoint = 'http://139.159.202.235:8005' # http hdfs_endpoint = 'http://192.168.0.18:50070' # client token = 'abc...' # 自定义合作方机构 -partner_agency1='webank' +partner_agency1='WeBank' partner_agency2='TX' # 初始化project ctx 信息 @@ -45,7 +45,8 @@ dataset1.save_values(path='d-101') # './milestone2\\sgd\\flyhuang1\\share\\d-101' # hdfs_path -dataset2 = WedprData(ctx, dataset_path='/user/ppc/milestone2/webank/flyhuang/d-9606695119693829', agency=agency2) +ctx2 = BaseContext(project_id, 'flyhuang', pws_endpoint, hdfs_endpoint, token) +dataset2 = WedprData(ctx2, dataset_path='/user/ppc/milestone2/webank/flyhuang/d-9606695119693829', agency=agency2) dataset2.storage_client = None # dataset2.load_values() if dataset2.storage_client is None: diff --git a/python/wedpr_ml_toolkit/wedpr_session/wedpr_session.py b/python/wedpr_ml_toolkit/wedpr_session/wedpr_session.py index 85b81d52..2cdae065 100644 --- a/python/wedpr_ml_toolkit/wedpr_session/wedpr_session.py +++ b/python/wedpr_ml_toolkit/wedpr_session/wedpr_session.py @@ -32,7 +32,7 @@ def task(self, params: dict = {}): self.check_agencies() job_response = self.excute.run(params) - return job_response.job_id + return job_response['data'] def psi(self, dataset: DataContext = None, merge_filed: str = 'id'): @@ -44,16 +44,16 @@ def psi(self, dataset: DataContext = None, merge_filed: str = 'id'): # 构造参数 # params = {merge_filed: merge_filed} params = {'jobType': 'PSI', - 'projectName': 'jupyter', + 'projectName': self.dataset.ctx.project_id, 'param': json.dumps({'dataSetList': self.dataset_list}).replace('"', '\\"'), 'taskParties': self.task_parties, - 'datasetList': [None, None]} + 'datasetList': self.dataset_id_list} # 执行任务 job_id = self.task(params) # 结果处理 - psi_result = PSIResult(dataset, job_id) + psi_result = PSIResult(dataset, 'psi-' + job_id) return psi_result