Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

更新jupyter task 参数 #58

Merged
Merged
6 changes: 4 additions & 2 deletions python/wedpr_ml_toolkit/job_exceuter/pws_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import random
import time
import requests
Expand Down Expand Up @@ -40,8 +41,9 @@ def run(self, params):
response = requests.request("POST", self.pws_url, json=payload, headers=headers)
if response.status_code != 200:
raise Exception(f"创建任务失败: {response.json()}")
return
# return self._poll_task_status(response.data, self.token)
print(response.text)
# self._poll_task_status(response.data, self.token)
return json.loads(response.text)

def _poll_task_status(self, job_id, token):
while True:
Expand Down
9 changes: 5 additions & 4 deletions python/wedpr_ml_toolkit/test/test_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@
# 从jupyter环境中获取project_id等信息
# create workspace
# 相同项目/刷新专家模式project_id固定
project_id = 'p-123'
project_id = '测试-xinyi'
user = 'flyhuang1'
my_agency='sgd'
my_agency='SGD'
pws_endpoint = 'http://139.159.202.235:8005' # http
hdfs_endpoint = 'http://192.168.0.18:50070' # client
token = 'abc...'


# 自定义合作方机构
partner_agency1='webank'
partner_agency1='WeBank'
partner_agency2='TX'

# 初始化project ctx 信息
Expand All @@ -45,7 +45,8 @@
dataset1.save_values(path='d-101') # './milestone2\\sgd\\flyhuang1\\share\\d-101'

# hdfs_path
dataset2 = WedprData(ctx, dataset_path='/user/ppc/milestone2/webank/flyhuang/d-9606695119693829', agency=agency2)
ctx2 = BaseContext(project_id, 'flyhuang', pws_endpoint, hdfs_endpoint, token)
dataset2 = WedprData(ctx2, dataset_path='/user/ppc/milestone2/webank/flyhuang/d-9606695119693829', agency=agency2)
dataset2.storage_client = None
# dataset2.load_values()
if dataset2.storage_client is None:
Expand Down
8 changes: 4 additions & 4 deletions python/wedpr_ml_toolkit/wedpr_session/wedpr_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def task(self, params: dict = {}):
self.check_agencies()
job_response = self.excute.run(params)

return job_response.job_id
return job_response['data']

def psi(self, dataset: DataContext = None, merge_filed: str = 'id'):

Expand All @@ -44,16 +44,16 @@ def psi(self, dataset: DataContext = None, merge_filed: str = 'id'):
# 构造参数
# params = {merge_filed: merge_filed}
params = {'jobType': 'PSI',
'projectName': 'jupyter',
'projectName': self.dataset.ctx.project_id,
'param': json.dumps({'dataSetList': self.dataset_list}).replace('"', '\\"'),
'taskParties': self.task_parties,
'datasetList': [None, None]}
'datasetList': self.dataset_id_list}

# 执行任务
job_id = self.task(params)

# 结果处理
psi_result = PSIResult(dataset, job_id)
psi_result = PSIResult(dataset, 'psi-' + job_id)

return psi_result

Expand Down
Loading