Skip to content

Commit

Permalink
fix queryJobDetail
Browse files Browse the repository at this point in the history
  • Loading branch information
cyjseagull committed Dec 10, 2024
1 parent 0686c02 commit 4b3b031
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 43 deletions.
74 changes: 37 additions & 37 deletions python/wedpr_ml_toolkit/jupyter-demo/test_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# In[1]:


import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, roc_auc_score, precision_recall_curve, accuracy_score, f1_score, precision_score, recall_score
import sys
import numpy as np
import pandas as pd
Expand All @@ -22,7 +24,8 @@


# 读取配置文件
wedpr_config = WeDPRMlConfigBuilder.build_from_properties_file('config.properties')
wedpr_config = WeDPRMlConfigBuilder.build_from_properties_file(
'config.properties')
wedpr_ml_toolkit = WeDPRMlToolkit(wedpr_config)


Expand All @@ -32,7 +35,7 @@
# dataset1
dataset1 = DatasetContext(storage_entrypoint=wedpr_ml_toolkit.get_storage_entry_point(),
dataset_client=wedpr_ml_toolkit.dataset_client,
dataset_id = 'd-9743660607744005',
dataset_id='d-9743660607744005',
is_label_holder=True)
print(f"* load dataset1: {dataset1}")
(values, cols, shapes) = dataset1.load_values()
Expand All @@ -45,8 +48,8 @@

# dataset2
dataset2 = DatasetContext(storage_entrypoint=wedpr_ml_toolkit.get_storage_entry_point(),
dataset_client = wedpr_ml_toolkit.dataset_client,
dataset_id = "d-9743674298214405")
dataset_client=wedpr_ml_toolkit.dataset_client,
dataset_id="d-9743674298214405")
print(f"* dataset2: {dataset2}")

# 构建 dataset context
Expand All @@ -60,11 +63,10 @@
model_setting = ModelSetting()
model_setting.use_psi = True
xgb_job_context = wedpr_ml_toolkit.build_job_context(
job_type = JobType.XGB_TRAINING,
project_id = project_id,
dataset = dataset,
model_setting = model_setting,
id_fields = "id")
job_type=JobType.XGB_TRAINING,
project_id=project_id,
dataset=dataset,
model_setting=model_setting)
print(f"* build xgb job context: {xgb_job_context}")


Expand All @@ -81,18 +83,20 @@

# 获取xgb任务结果
print(xgb_job_id)
#xgb_job_id = "9868279583877126"
# xgb_job_id = "9868279583877126"
xgb_result_detail = xgb_job_context.fetch_job_result(xgb_job_id, True)
# load the result context
xgb_result_context = wedpr_ml_toolkit.build_result_context(job_context=xgb_job_context,
job_result_detail=xgb_result_detail)
print(f"* xgb job result ctx: {xgb_result_context}")

xgb_test_dataset = xgb_result_context.test_result_dataset
print(f"* xgb_test_dataset: {xgb_test_dataset}, file_path: {xgb_test_dataset.dataset_meta.file_path}")
print(
f"* xgb_test_dataset: {xgb_test_dataset}, file_path: {xgb_test_dataset.dataset_meta.file_path}")

(data, cols, shapes) = xgb_test_dataset.load_values()
print(f"* test dataset detail, columns: {cols}, shape: {shapes}, value: {data}")
print(
f"* test dataset detail, columns: {cols}, shape: {shapes}, value: {data}")


# In[8]:
Expand All @@ -102,17 +106,19 @@
result_context: TrainResultContext = xgb_result_context
evaluation_result_dataset = result_context.evaluation_dataset
(eval_data, cols, shape) = evaluation_result_dataset.load_values(header=0)
print(f"* evaluation detail, col: {cols}, shape: {shape}, eval_data: {eval_data}")
print(
f"* evaluation detail, col: {cols}, shape: {shape}, eval_data: {eval_data}")


# In[9]:


# feature importance
# feature importance
feature_importance_dataset = result_context.feature_importance_dataset
(feature_importance_data, cols, shape) = feature_importance_dataset.load_values()

print(f"* feature_importance detail, col: {cols}, shape: {shape}, feature_importance_data: {feature_importance_data}")
print(
f"* feature_importance detail, col: {cols}, shape: {shape}, feature_importance_data: {feature_importance_data}")


# In[10]:
Expand All @@ -122,7 +128,8 @@
preprocessing_dataset = result_context.preprocessing_dataset
(preprocessing_data, cols, shape) = preprocessing_dataset.load_values()

print(f"* preprocessing detail, col: {cols}, shape: {shape}, preprocessing_data: {preprocessing_data}")
print(
f"* preprocessing detail, col: {cols}, shape: {shape}, preprocessing_data: {preprocessing_data}")


# In[11]:
Expand All @@ -132,15 +139,14 @@
model_result_dataset = result_context.model_result_dataset
(model_result, cols, shape) = model_result_dataset.load_values()

print(f"* model_result detail, col: {cols}, shape: {shape}, model_result: {model_result}")
print(
f"* model_result detail, col: {cols}, shape: {shape}, model_result: {model_result}")


# In[12]:


# 明文处理预测结果
from sklearn.metrics import roc_curve, roc_auc_score, precision_recall_curve, accuracy_score, f1_score, precision_score, recall_score
import matplotlib.pyplot as plt

# 提取真实标签和预测概率
y_true = data['class_label']
Expand Down Expand Up @@ -191,15 +197,14 @@
# 构造xgb预测任务配置
predict_setting = ModelSetting()
predict_setting.use_psi = True
#model_predict_algorithm = {}
#model_predict_algorithm.update({"setting": xgb_result_context.job_result_detail.model})
# model_predict_algorithm = {}
# model_predict_algorithm.update({"setting": xgb_result_context.job_result_detail.model})
predict_xgb_job_context = wedpr_ml_toolkit.build_job_context(
job_type=JobType.XGB_PREDICTING,
project_id = project_id,
dataset= dataset,
model_setting= predict_setting,
id_fields = "id",
predict_algorithm = xgb_result_context.job_result_detail.model_predict_algorithm)
job_type=JobType.XGB_PREDICTING,
project_id=project_id,
dataset=dataset,
model_setting=predict_setting,
predict_algorithm=xgb_result_context.job_result_detail.model_predict_algorithm)
print(f"* predict_xgb_job_context: {predict_xgb_job_context}")


Expand All @@ -218,25 +223,24 @@
# query the job detail
print(f"* xgb_predict_job_id: {xgb_predict_job_id}")

predict_xgb_job_result = predict_xgb_job_context.fetch_job_result(xgb_predict_job_id, True)
predict_xgb_job_result = predict_xgb_job_context.fetch_job_result(
xgb_predict_job_id, True)

# generate the result context
result_context = wedpr_ml_toolkit.build_result_context(job_context=predict_xgb_job_context,
result_context = wedpr_ml_toolkit.build_result_context(job_context=predict_xgb_job_context,
job_result_detail=predict_xgb_job_result)

xgb_predict_result_context : PredictResultContext = result_context
xgb_predict_result_context: PredictResultContext = result_context
print(f"* result_context is {xgb_predict_result_context}")


# In[16]:


# 明文处理预测结果
from sklearn.metrics import roc_curve, roc_auc_score, precision_recall_curve, accuracy_score, f1_score, precision_score, recall_score
import matplotlib.pyplot as plt


(data, cols, shapes) = xgb_predict_result_context.model_result_dataset.load_values(header = 0)
(data, cols, shapes) = xgb_predict_result_context.model_result_dataset.load_values(header=0)

# 提取真实标签和预测概率
y_true = data['class_label']
Expand Down Expand Up @@ -282,7 +286,3 @@


# In[ ]:




Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,17 @@ def __repr__(self):
f"resultFileInfo: {self.resultFileInfo}, model: {self.model}"


class JobDetailRequest(BaseObject):
def __init__(self, job_id=None,
fetch_job_detail=True,
fetch_job_result=True,
fetch_log=False):
self.jobID = job_id
self.fetchJobDetail = fetch_job_detail
self.fetchJobResult = fetch_job_result
self.fetchLog = fetch_log


class WeDPRRemoteJobClient(WeDPREntryPoint, BaseObject):
def __init__(self, http_config: HttpConfig, auth_config: AuthConfig, job_config: JobConfig):
if auth_config is None:
Expand Down Expand Up @@ -189,19 +200,19 @@ def query_job_detail(self, job_id, block_until_finish) -> JobDetailResponse:
or (not job_result.job_status.run_success()):
return JobDetailResponse(job_info=job_result, params=None)
# success case, query the job detail
params = {}
params["jobID"] = job_id
job_detail_requests = JobDetailRequest(job_id)
response_dict = self.execute_with_retry(self.send_request,
self.job_config.max_retries,
self.job_config.retry_delay_s,
False,
True,
self.job_config.query_job_detail_uri,
params,
None, None)
None,
None, json.dumps(job_detail_requests.as_dict()))
wedpr_response = WeDPRResponse(**response_dict)
if not wedpr_response.success():
raise Exception(
f"query_job_detail exception, job: {job_id}, code: {wedpr_response.code}, msg: {wedpr_response.msg}")
f"query_job_detail exception, job: {job_id}, "
f"code: {wedpr_response.code}, msg: {wedpr_response.msg}")
return JobDetailResponse(**(wedpr_response.data))

def poll_job_result(self, job_id, block_until_finish) -> JobInfo:
Expand Down

0 comments on commit 4b3b031

Please sign in to comment.