From 25e376dc3915701975131010bc5917f9886252fb Mon Sep 17 00:00:00 2001 From: Xinyi YAN <41045439+yanxinyi620@users.noreply.github.com> Date: Thu, 17 Oct 2024 10:43:16 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0jupyter=20task=20=E5=8F=82?= =?UTF-8?q?=E6=95=B0=20(#58)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * update jupyter task --- python/wedpr_ml_toolkit/job_exceuter/pws_client.py | 6 ++++-- python/wedpr_ml_toolkit/test/test_dev.py | 9 +++++---- python/wedpr_ml_toolkit/wedpr_session/wedpr_session.py | 8 ++++---- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/python/wedpr_ml_toolkit/job_exceuter/pws_client.py b/python/wedpr_ml_toolkit/job_exceuter/pws_client.py index 40df0238..08cb2e5d 100644 --- a/python/wedpr_ml_toolkit/job_exceuter/pws_client.py +++ b/python/wedpr_ml_toolkit/job_exceuter/pws_client.py @@ -1,3 +1,4 @@ +import json import random import time import requests @@ -40,8 +41,9 @@ def run(self, params): response = requests.request("POST", self.pws_url, json=payload, headers=headers) if response.status_code != 200: raise Exception(f"创建任务失败: {response.json()}") - return - # return self._poll_task_status(response.data, self.token) + print(response.text) + # self._poll_task_status(response.data, self.token) + return json.loads(response.text) def _poll_task_status(self, job_id, token): while True: diff --git a/python/wedpr_ml_toolkit/test/test_dev.py b/python/wedpr_ml_toolkit/test/test_dev.py index 9ccc864f..bfaf7a8d 100644 --- a/python/wedpr_ml_toolkit/test/test_dev.py +++ b/python/wedpr_ml_toolkit/test/test_dev.py @@ -13,16 +13,16 @@ # 从jupyter环境中获取project_id等信息 # create workspace # 相同项目/刷新专家模式project_id固定 -project_id = 'p-123' +project_id = '测试-xinyi' user = 'flyhuang1' -my_agency='sgd' +my_agency='SGD' pws_endpoint = 'http://139.159.202.235:8005' # http hdfs_endpoint = 'http://192.168.0.18:50070' # client token = 'abc...' # 自定义合作方机构 -partner_agency1='webank' +partner_agency1='WeBank' partner_agency2='TX' # 初始化project ctx 信息 @@ -45,7 +45,8 @@ dataset1.save_values(path='d-101') # './milestone2\\sgd\\flyhuang1\\share\\d-101' # hdfs_path -dataset2 = WedprData(ctx, dataset_path='/user/ppc/milestone2/webank/flyhuang/d-9606695119693829', agency=agency2) +ctx2 = BaseContext(project_id, 'flyhuang', pws_endpoint, hdfs_endpoint, token) +dataset2 = WedprData(ctx2, dataset_path='/user/ppc/milestone2/webank/flyhuang/d-9606695119693829', agency=agency2) dataset2.storage_client = None # dataset2.load_values() if dataset2.storage_client is None: diff --git a/python/wedpr_ml_toolkit/wedpr_session/wedpr_session.py b/python/wedpr_ml_toolkit/wedpr_session/wedpr_session.py index 85b81d52..2cdae065 100644 --- a/python/wedpr_ml_toolkit/wedpr_session/wedpr_session.py +++ b/python/wedpr_ml_toolkit/wedpr_session/wedpr_session.py @@ -32,7 +32,7 @@ def task(self, params: dict = {}): self.check_agencies() job_response = self.excute.run(params) - return job_response.job_id + return job_response['data'] def psi(self, dataset: DataContext = None, merge_filed: str = 'id'): @@ -44,16 +44,16 @@ def psi(self, dataset: DataContext = None, merge_filed: str = 'id'): # 构造参数 # params = {merge_filed: merge_filed} params = {'jobType': 'PSI', - 'projectName': 'jupyter', + 'projectName': self.dataset.ctx.project_id, 'param': json.dumps({'dataSetList': self.dataset_list}).replace('"', '\\"'), 'taskParties': self.task_parties, - 'datasetList': [None, None]} + 'datasetList': self.dataset_id_list} # 执行任务 job_id = self.task(params) # 结果处理 - psi_result = PSIResult(dataset, job_id) + psi_result = PSIResult(dataset, 'psi-' + job_id) return psi_result