From 55f7765ab7d3fccb214647589901a1e0c26e116d Mon Sep 17 00:00:00 2001 From: cyjseagull Date: Thu, 17 Oct 2024 15:22:51 +0800 Subject: [PATCH] refactor wedpr_ml_toolkit --- .../wedpr_ml_toolkit/common/base_context.py | 13 -- python/wedpr_ml_toolkit/common/base_result.py | 8 - .../utils}/__init__.py | 0 .../wedpr_ml_toolkit/common/utils/constant.py | 14 ++ python/wedpr_ml_toolkit/common/utils/utils.py | 26 +++ .../{result => context}/__init__.py | 0 .../{wedpr_data => context}/data_context.py | 11 +- .../wedpr_ml_toolkit/context/job_context.py | 139 +++++++++++++++ .../{utils => context/result}/__init__.py | 0 .../context/result/fe_result_context.py | 22 +++ .../context/result/model_result_context.py | 51 ++++++ .../context/result/psi_result_context.py | 24 +++ .../context/result/result_context.py | 14 ++ .../job_exceuter/pws_client.py | 79 --------- python/wedpr_ml_toolkit/result/fe_result.py | 27 --- .../wedpr_ml_toolkit/result/model_result.py | 58 ------- python/wedpr_ml_toolkit/result/psi_result.py | 27 --- python/wedpr_ml_toolkit/test/test_dev.py | 28 ++-- .../{wedpr_data => transport}/__init__.py | 0 .../transport/credential_generator.py | 55 ++++++ .../storage_entrypoint.py} | 13 +- .../transport/wedpr_entrypoint.py | 41 +++++ .../transport/wedpr_remote_job_client.py | 91 ++++++++++ python/wedpr_ml_toolkit/utils/agency.py | 5 - python/wedpr_ml_toolkit/utils/utils.py | 12 -- .../wedpr_ml_toolkit/wedpr_data/wedpr_data.py | 68 -------- python/wedpr_ml_toolkit/wedpr_ml_toolkit.py | 26 +++ .../wedpr_session/__init__.py | 0 .../wedpr_session/wedpr_session.py | 158 ------------------ 29 files changed, 532 insertions(+), 478 deletions(-) delete mode 100644 python/wedpr_ml_toolkit/common/base_context.py delete mode 100644 python/wedpr_ml_toolkit/common/base_result.py rename python/wedpr_ml_toolkit/{job_exceuter => common/utils}/__init__.py (100%) create mode 100644 python/wedpr_ml_toolkit/common/utils/constant.py create mode 100644 python/wedpr_ml_toolkit/common/utils/utils.py rename python/wedpr_ml_toolkit/{result => context}/__init__.py (100%) rename python/wedpr_ml_toolkit/{wedpr_data => context}/data_context.py (82%) create mode 100644 python/wedpr_ml_toolkit/context/job_context.py rename python/wedpr_ml_toolkit/{utils => context/result}/__init__.py (100%) create mode 100644 python/wedpr_ml_toolkit/context/result/fe_result_context.py create mode 100644 python/wedpr_ml_toolkit/context/result/model_result_context.py create mode 100644 python/wedpr_ml_toolkit/context/result/psi_result_context.py create mode 100644 python/wedpr_ml_toolkit/context/result/result_context.py delete mode 100644 python/wedpr_ml_toolkit/job_exceuter/pws_client.py delete mode 100644 python/wedpr_ml_toolkit/result/fe_result.py delete mode 100644 python/wedpr_ml_toolkit/result/model_result.py delete mode 100644 python/wedpr_ml_toolkit/result/psi_result.py rename python/wedpr_ml_toolkit/{wedpr_data => transport}/__init__.py (100%) create mode 100644 python/wedpr_ml_toolkit/transport/credential_generator.py rename python/wedpr_ml_toolkit/{job_exceuter/hdfs_client.py => transport/storage_entrypoint.py} (76%) create mode 100644 python/wedpr_ml_toolkit/transport/wedpr_entrypoint.py create mode 100644 python/wedpr_ml_toolkit/transport/wedpr_remote_job_client.py delete mode 100644 python/wedpr_ml_toolkit/utils/agency.py delete mode 100644 python/wedpr_ml_toolkit/utils/utils.py delete mode 100644 python/wedpr_ml_toolkit/wedpr_data/wedpr_data.py create mode 100644 python/wedpr_ml_toolkit/wedpr_ml_toolkit.py delete mode 100644 python/wedpr_ml_toolkit/wedpr_session/__init__.py delete mode 100644 python/wedpr_ml_toolkit/wedpr_session/wedpr_session.py diff --git a/python/wedpr_ml_toolkit/common/base_context.py b/python/wedpr_ml_toolkit/common/base_context.py deleted file mode 100644 index cd2cdb66..00000000 --- a/python/wedpr_ml_toolkit/common/base_context.py +++ /dev/null @@ -1,13 +0,0 @@ -import os - - -class BaseContext: - - def __init__(self, project_id, user_name, pws_endpoint=None, hdfs_endpoint=None, token=None): - - self.project_id = project_id - self.user_name = user_name - self.pws_endpoint = pws_endpoint - self.hdfs_endpoint = hdfs_endpoint - self.token = token - self.workspace = './milestone2' diff --git a/python/wedpr_ml_toolkit/common/base_result.py b/python/wedpr_ml_toolkit/common/base_result.py deleted file mode 100644 index 88bb3f8f..00000000 --- a/python/wedpr_ml_toolkit/common/base_result.py +++ /dev/null @@ -1,8 +0,0 @@ -from wedpr_ml_toolkit.common.base_context import BaseContext - - -class BaseResult: - - def __init__(self, ctx: BaseContext): - - self.ctx = ctx diff --git a/python/wedpr_ml_toolkit/job_exceuter/__init__.py b/python/wedpr_ml_toolkit/common/utils/__init__.py similarity index 100% rename from python/wedpr_ml_toolkit/job_exceuter/__init__.py rename to python/wedpr_ml_toolkit/common/utils/__init__.py diff --git a/python/wedpr_ml_toolkit/common/utils/constant.py b/python/wedpr_ml_toolkit/common/utils/constant.py new file mode 100644 index 00000000..460f36d8 --- /dev/null +++ b/python/wedpr_ml_toolkit/common/utils/constant.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- + +class Constant: + NUMERIC_ARRAY = [i for i in range(10)] + HTTP_STATUS_OK = 200 + DEFAULT_SUBMIT_JOB_URI = '/api/wedpr/v3/project/submitJob' + DEFAULT_QUERY_JOB_STATUS_URL = '/api/wedpr/v3/project/queryJobByCondition' + PSI_RESULT_FILE = "psi_result.csv" + + FEATURE_BIN_FILE = "feature_bin.json" + TEST_MODEL_OUTPUT_FILE = "test_output.csv" + TRAIN_MODEL_OUTPUT_FILE = "train_output.csv" + + FE_RESULT_FILE = "fe_result.csv" \ No newline at end of file diff --git a/python/wedpr_ml_toolkit/common/utils/utils.py b/python/wedpr_ml_toolkit/common/utils/utils.py new file mode 100644 index 00000000..f517eaba --- /dev/null +++ b/python/wedpr_ml_toolkit/common/utils/utils.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +import uuid +from enum import Enum +import random +from common.utils.constant import Constant +from urllib.parse import urlencode, urlparse, parse_qs, quote + +class IdPrefixEnum(Enum): + DATASET = "d-" + ALGORITHM = "a-" + JOB = "j-" + + +def make_id(prefix): + return prefix + str(uuid.uuid4()).replace("-", "") + +def generate_nonce(nonce_len): + return ''.join(random.choice(Constant.NUMERIC_ARRAY) for _ in range(nonce_len)) + +def add_params_to_url(url, params): + parsed_url = urlparse(url) + query_params = parse_qs(parsed_url.query) + for key, value in params.items(): + query_params[key] = value + new_query = urlencode(query_params, doseq=True) + return parsed_url._replace(query=new_query).geturl() diff --git a/python/wedpr_ml_toolkit/result/__init__.py b/python/wedpr_ml_toolkit/context/__init__.py similarity index 100% rename from python/wedpr_ml_toolkit/result/__init__.py rename to python/wedpr_ml_toolkit/context/__init__.py diff --git a/python/wedpr_ml_toolkit/wedpr_data/data_context.py b/python/wedpr_ml_toolkit/context/data_context.py similarity index 82% rename from python/wedpr_ml_toolkit/wedpr_data/data_context.py rename to python/wedpr_ml_toolkit/context/data_context.py index cf9e7645..2dc611f8 100644 --- a/python/wedpr_ml_toolkit/wedpr_data/data_context.py +++ b/python/wedpr_ml_toolkit/context/data_context.py @@ -8,15 +8,18 @@ class DataContext: def __init__(self, *datasets): self.datasets = list(datasets) self.ctx = self.datasets[0].ctx - + self._check_datasets() def _save_dataset(self, dataset): if dataset.dataset_path is None: - dataset.dataset_id = utils.make_id(utils.IdPrefixEnum.DATASET.value) - dataset.dataset_path = os.path.join(dataset.storage_workspace, dataset.dataset_id) + dataset.dataset_id = utils.make_id( + utils.IdPrefixEnum.DATASET.value) + dataset.dataset_path = os.path.join( + dataset.storage_workspace, dataset.dataset_id) if dataset.storage_client is not None: - dataset.storage_client.upload(dataset.values, dataset.dataset_path) + dataset.storage_client.upload( + dataset.values, dataset.dataset_path) def _check_datasets(self): for dataset in self.datasets: diff --git a/python/wedpr_ml_toolkit/context/job_context.py b/python/wedpr_ml_toolkit/context/job_context.py new file mode 100644 index 00000000..620bc1fb --- /dev/null +++ b/python/wedpr_ml_toolkit/context/job_context.py @@ -0,0 +1,139 @@ +# -*- coding: utf-8 -*- +import json + +from wedpr_ml_toolkit.context.data_context import DataContext +from wedpr_ml_toolkit.transport.wedpr_remote_job_client import JobParam +from wedpr_ml_toolkit.transport.wedpr_remote_job_client import JobInfo +from abc import abstractmethod +from wedpr_ml_toolkit.transport.wedpr_remote_job_client import WeDPRRemoteJobClient + + +class JobContext: + + def __init__(self, remote_job_client: WeDPRRemoteJobClient, project_name: str, dataset: DataContext = None, my_agency=None): + if dataset is None: + raise Exception("Must define the job related datasets!") + self.remote_job_client = remote_job_client + self.project_name = project_name + self.dataset = dataset + self.create_agency = my_agency + self.participant_id_list = [] + self.task_parties = [] + self.dataset_id_list = [] + self.dataset_list = [] + self.label_holder_agency = None + self.label_columns = None + self.__init_participant__() + self.__init_label_information__() + self.result_receiver_id_list = [my_agency] # 仅限jupyter所在机构 + self.__check__() + + def __check__(self): + """ + 校验机构数和任务是否匹配 + """ + if len(self.participant_id_list) < 2: + raise Exception("至少需要传入两个机构") + if not self.label_holder_agency or self.label_holder_agency not in self.participant_id_list: + raise Exception("数据集中标签提供方配置错误") + + def __init_participant__(self): + participant_id_list = [] + dataset_id_list = [] + for dataset in self.dataset.datasets: + participant_id_list.append(dataset.agency.agency_id) + dataset_id_list.append(dataset.dataset_id) + self.task_parties.append({'userName': dataset.ctx.user_name, + 'agency': dataset.agency.agency_id}) + self.participant_id_list = participant_id_list + self.dataset_id_list = dataset_id_list + + def __init_label_information__(self): + label_holder_agency = None + label_columns = None + for dataset in self.dataset.datasets: + if dataset.is_label_holder: + label_holder_agency = dataset.agency.agency_id + label_columns = 'y' + self.label_holder_agency = label_holder_agency + self.label_columns = label_columns + + @abstractmethod + def build(self) -> JobParam: + pass + + @abstractmethod + def get_job_type(self) -> str: + pass + + def submit(self, project_name): + return self.submit(self.build(project_name)) + + +class PSIJobContext(JobContext): + def __init__(self, remote_job_client: WeDPRRemoteJobClient, project_name: str, dataset: DataContext = None, my_agency=None, merge_field: str = 'id'): + super().__init__(remote_job_client, project_name, dataset, my_agency) + self.merge_field = merge_field + + def get_job_type(self) -> str: + return "PSI" + + def build(self) -> JobParam: + self.dataset_list = self.dataset.to_psi_format( + self.merge_field, self.result_receiver_id_list) + job_info = JobInfo(self.get_job_type(), self.project_name, json.dumps( + {'dataSetList': self.dataset_list}).replace('"', '\\"')) + job_param = JobParam(job_info, self.task_parties, self.dataset_id_list) + return job_param + + +class PreprocessingJobContext(JobContext): + def __init__(self, remote_job_client: WeDPRRemoteJobClient, project_name: str, model_setting, dataset: DataContext = None, my_agency=None): + super().__init__(remote_job_client, project_name, dataset, my_agency) + self.model_setting = model_setting + + def get_job_type(self) -> str: + return "PREPROCESSING" + + # TODO: build the request + def build(self) -> JobParam: + return None + + +class FeatureEngineeringJobContext(JobContext): + def __init__(self, remote_job_client: WeDPRRemoteJobClient, project_name: str, model_setting, dataset: DataContext = None, my_agency=None): + super().__init__(remote_job_client, project_name, dataset, my_agency) + self.model_setting = model_setting + + def get_job_type(self) -> str: + return "FEATURE_ENGINEERING" + + # TODO: build the jobParam + def build(self) -> JobParam: + return None + + +class SecureLGBMTrainingJobContext(JobContext): + def __init__(self, remote_job_client: WeDPRRemoteJobClient, project_name: str, model_setting, dataset: DataContext = None, my_agency=None): + super().__init__(remote_job_client, project_name, dataset, my_agency) + self.model_setting = model_setting + + def get_job_type(self) -> str: + return "XGB_TRAINING" + + # TODO: build the jobParam + def build(self) -> JobParam: + return None + + +class SecureLGBMPredictJobContext(JobContext): + def __init__(self, remote_job_client: WeDPRRemoteJobClient, project_name: str, model_setting, dataset: DataContext = None, my_agency=None): + super().__init__(remote_job_client, project_name, dataset, my_agency) + self.model_setting = model_setting + + def get_job_type(self) -> str: + return "XGB_PREDICTING" + + # TODO: build the jobParam + def build(self) -> JobParam: + return None diff --git a/python/wedpr_ml_toolkit/utils/__init__.py b/python/wedpr_ml_toolkit/context/result/__init__.py similarity index 100% rename from python/wedpr_ml_toolkit/utils/__init__.py rename to python/wedpr_ml_toolkit/context/result/__init__.py diff --git a/python/wedpr_ml_toolkit/context/result/fe_result_context.py b/python/wedpr_ml_toolkit/context/result/fe_result_context.py new file mode 100644 index 00000000..81512241 --- /dev/null +++ b/python/wedpr_ml_toolkit/context/result/fe_result_context.py @@ -0,0 +1,22 @@ +import os + +from wedpr_ml_toolkit.context.data_context import DataContext +from wedpr_ml_toolkit.common.utils.constant import Constant +from wedpr_ml_toolkit.context.result.result_context import ResultContext +from wedpr_ml_toolkit.context.job_context import JobContext + + +class FeResultContext(ResultContext): + + def __init__(self, job_context: JobContext, job_id: str): + super().__init__(job_context, job_id) + + def parse_result(self): + result_list = [] + for dataset in self.job_context.dataset.datasets: + dataset.update_path(os.path.join( + self.job_id, Constant.FE_RESULT_FILE)) + result_list.append(dataset) + + fe_result = DataContext(*result_list) + return fe_result diff --git a/python/wedpr_ml_toolkit/context/result/model_result_context.py b/python/wedpr_ml_toolkit/context/result/model_result_context.py new file mode 100644 index 00000000..c7b3bd48 --- /dev/null +++ b/python/wedpr_ml_toolkit/context/result/model_result_context.py @@ -0,0 +1,51 @@ +import os +import numpy as np + +from ppc_common.ppc_utils import utils +from wedpr_ml_toolkit.context.result.result_context import ResultContext +from wedpr_ml_toolkit.transport.storage_entrypoint import StorageEntryPoint +from wedpr_ml_toolkit.common.utils.constant import Constant +from wedpr_ml_toolkit.context.job_context import JobContext + + +class ModelResultContext(ResultContext): + def __init__(self, job_context: JobContext, job_id: str, storage_entrypoint: StorageEntryPoint): + super().__init__(job_context, job_id) + self.storage_entrypoint = storage_entrypoint + + +class SecureLGBMResultContext(ModelResultContext): + MODEL_DATA_FILE = utils.XGB_TREE_PERFIX + '.json' + + def __init__(self, job_context: JobContext, job_id: str, storage_entrypoint: StorageEntryPoint): + super().__init__(job_context, job_id, storage_entrypoint) + + def parse_result(self): + + # train_praba, test_praba, train_y, test_y, feature_importance, split_xbin, trees, params + # 从hdfs读取结果文件信息,构造为属性 + train_praba_path = os.path.join( + self.job_id, Constant.TRAIN_MODEL_OUTPUT_FILE) + test_praba_path = os.path.join( + self.job_id, Constant.TEST_MODEL_OUTPUT_FILE) + train_output = self.storage_entrypoint.download(train_praba_path) + test_output = self.storage_entrypoint.download(test_praba_path) + self.train_praba = train_output['class_pred'].values + self.test_praba = test_output['class_pred'].values + if 'class_label' in train_output.columns: + self.train_y = train_output['class_label'].values + self.test_y = test_output['class_label'].values + else: + self.train_y = None + self.test_y = None + + feature_bin_path = os.path.join(self.job_id, Constant.FEATURE_BIN_FILE) + model_path = os.path.join(self.job_id, self.MODEL_DATA_FILE) + feature_bin_data = self.storage_entrypoint.download_data( + feature_bin_path) + model_data = self.storage_entrypoint.download_data(model_path) + + self.feature_importance = ... + self.split_xbin = feature_bin_data + self.trees = model_data + self.params = ... diff --git a/python/wedpr_ml_toolkit/context/result/psi_result_context.py b/python/wedpr_ml_toolkit/context/result/psi_result_context.py new file mode 100644 index 00000000..f5a94113 --- /dev/null +++ b/python/wedpr_ml_toolkit/context/result/psi_result_context.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- +import os + +from wedpr_ml_toolkit.context.job_context import JobContext +from wedpr_ml_toolkit.context.data_context import DataContext +from wedpr_ml_toolkit.common.utils.constant import Constant +from wedpr_ml_toolkit.context.result.result_context import ResultContext + + +class PSIResultContext(ResultContext): + + PSI_RESULT_FILE = "psi_result.csv" + + def __init__(self, job_context: JobContext, job_id: str): + super().__init__(job_context, job_id) + + def parse_result(self): + result_list = [] + for dataset in self.job_context.dataset.datasets: + dataset.update_path(os.path.join( + self.job_id, Constant.PSI_RESULT_FILE)) + result_list.append(dataset) + + self.psi_result = DataContext(*result_list) diff --git a/python/wedpr_ml_toolkit/context/result/result_context.py b/python/wedpr_ml_toolkit/context/result/result_context.py new file mode 100644 index 00000000..dcb0670a --- /dev/null +++ b/python/wedpr_ml_toolkit/context/result/result_context.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +from wedpr_ml_toolkit.context.job_context import JobContext +from abc import abstractmethod + + +class ResultContext: + def __init__(self, job_context: JobContext, job_id: str): + self.job_id = job_id + self.job_context = job_context + self.parse_result() + + @abstractmethod + def parse_result(self): + pass diff --git a/python/wedpr_ml_toolkit/job_exceuter/pws_client.py b/python/wedpr_ml_toolkit/job_exceuter/pws_client.py deleted file mode 100644 index 08cb2e5d..00000000 --- a/python/wedpr_ml_toolkit/job_exceuter/pws_client.py +++ /dev/null @@ -1,79 +0,0 @@ -import json -import random -import time -import requests - -from ppc_common.ppc_utils import http_utils -from ppc_common.ppc_utils.exception import PpcException, PpcErrorCode - - -PWS_URL = '/api/wedpr/v3/project/submitJob' - - -class PWSApi: - def __init__(self, endpoint, token, - polling_interval_s: int = 5, max_retries: int = 5, retry_delay_s: int = 5): - self.pws_url = endpoint + PWS_URL - self.token = token - self.polling_interval_s = polling_interval_s - self.max_retries = max_retries - self.retry_delay_s = retry_delay_s - self._completed_status = 'COMPLETED' - self._failed_status = 'FAILED' - - def run(self, params): - - headers = { - "Authorization": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE3MzEzMTUwMTksInVzZXIiOiJ7XCJ1c2VybmFtZVwiOlwiZmx5aHVhbmcxXCIsXCJncm91cEluZm9zXCI6W3tcImdyb3VwSWRcIjpcIjEwMDAwMDAwMDAwMDAwMDBcIixcImdyb3VwTmFtZVwiOlwi5Yid5aeL55So5oi357uEXCIsXCJncm91cEFkbWluTmFtZVwiOlwiYWRtaW5cIn1dLFwicm9sZU5hbWVcIjpcIm9yaWdpbmFsX3VzZXJcIixcInBlcm1pc3Npb25zXCI6bnVsbCxcImFjY2Vzc0tleUlEXCI6bnVsbCxcImFkbWluXCI6ZmFsc2V9In0.1jZFOVbiISzCvvE9SOsTx0IWb0-OQc3o3rJgCu9GM9A", - "content-type": "application/json" - } - - payload = { - "job": { - "jobType": params['jobType'], - "projectName": params['projectName'], - "param": params['param'] - }, - "taskParties": params['taskParties'], - "datasetList": params['datasetList'] - } - - response = requests.request("POST", self.pws_url, json=payload, headers=headers) - if response.status_code != 200: - raise Exception(f"创建任务失败: {response.json()}") - print(response.text) - # self._poll_task_status(response.data, self.token) - return json.loads(response.text) - - def _poll_task_status(self, job_id, token): - while True: - params = { - 'jsonrpc': '1', - 'method': self._get_task_status_method, - 'token': token, - 'id': random.randint(1, 65535), - 'params': { - 'taskID': job_id, - } - } - response = self._send_request_with_retry(http_utils.send_post_request, self.endpoint, None, params) - if response.status_code != 200: - raise Exception(f"轮询任务失败: {response.json()}") - if response['result']['status'] == self._completed_status: - return response['result'] - elif response['result']['status'] == self._failed_status: - raise PpcException(PpcErrorCode.CALL_SCS_ERROR.get_code(), response['data']) - time.sleep(self.polling_interval_s) - - def _send_request_with_retry(self, request_func, *args, **kwargs): - attempt = 0 - while attempt < self.max_retries: - try: - response = request_func(*args, **kwargs) - return response - except Exception as e: - attempt += 1 - if attempt < self.max_retries: - time.sleep(self.retry_delay_s) - else: - raise e diff --git a/python/wedpr_ml_toolkit/result/fe_result.py b/python/wedpr_ml_toolkit/result/fe_result.py deleted file mode 100644 index 3caa0fea..00000000 --- a/python/wedpr_ml_toolkit/result/fe_result.py +++ /dev/null @@ -1,27 +0,0 @@ -import os - -from wedpr_ml_toolkit.wedpr_data.data_context import DataContext -from wedpr_ml_toolkit.common.base_result import BaseResult - - -class FeResult(BaseResult): - - FE_RESULT_FILE = "fe_result.csv" - - def __init__(self, dataset: DataContext, job_id: str): - - super().__init__(dataset.ctx) - self.job_id = job_id - - participant_id_list = [] - for dataset in self.dataset.datasets: - participant_id_list.append(dataset.agency.agency_id) - self.participant_id_list = participant_id_list - - result_list = [] - for dataset in self.dataset.datasets: - dataset.update_path(os.path.join(self.job_id, self.FE_RESULT_FILE)) - result_list.append(dataset) - - fe_result = DataContext(*result_list) - return fe_result diff --git a/python/wedpr_ml_toolkit/result/model_result.py b/python/wedpr_ml_toolkit/result/model_result.py deleted file mode 100644 index 7609003f..00000000 --- a/python/wedpr_ml_toolkit/result/model_result.py +++ /dev/null @@ -1,58 +0,0 @@ -import os -import numpy as np - -from ppc_common.ppc_utils import utils - -from wedpr_ml_toolkit.wedpr_data.data_context import DataContext -from wedpr_ml_toolkit.common.base_result import BaseResult -from wedpr_ml_toolkit.job_exceuter.hdfs_client import HDFSApi - - -class ModelResult(BaseResult): - - FEATURE_BIN_FILE = "feature_bin.json" - MODEL_DATA_FILE = utils.XGB_TREE_PERFIX + '.json' - TEST_MODEL_OUTPUT_FILE = "test_output.csv" - TRAIN_MODEL_OUTPUT_FILE = "train_output.csv" - - def __init__(self, dataset: DataContext, job_id: str, job_type: str): - - super().__init__(dataset.ctx) - self.job_id = job_id - - participant_id_list = [] - for dataset in self.dataset.datasets: - participant_id_list.append(dataset.agency.agency_id) - self.participant_id_list = participant_id_list - - if job_type == 'xgb_training': - self._xgb_train_result() - - def _xgb_train_result(self): - - # train_praba, test_praba, train_y, test_y, feature_importance, split_xbin, trees, params - # 从hdfs读取结果文件信息,构造为属性 - train_praba_path = os.path.join( - self.job_id, self.TRAIN_MODEL_OUTPUT_FILE) - test_praba_path = os.path.join( - self.job_id, self.TEST_MODEL_OUTPUT_FILE) - train_output = HDFSApi.download(train_praba_path) - test_output = HDFSApi.download(test_praba_path) - self.train_praba = train_output['class_pred'].values - self.test_praba = test_output['class_pred'].values - if 'class_label' in train_output.columns: - self.train_y = train_output['class_label'].values - self.test_y = test_output['class_label'].values - else: - self.train_y = None - self.test_y = None - - feature_bin_path = os.path.join(self.job_id, self.FEATURE_BIN_FILE) - model_path = os.path.join(self.job_id, self.MODEL_DATA_FILE) - feature_bin_data = HDFSApi.download_data(feature_bin_path) - model_data = HDFSApi.download_data(model_path) - - self.feature_importance = ... - self.split_xbin = feature_bin_data - self.trees = model_data - self.params = ... diff --git a/python/wedpr_ml_toolkit/result/psi_result.py b/python/wedpr_ml_toolkit/result/psi_result.py deleted file mode 100644 index 3b7da74c..00000000 --- a/python/wedpr_ml_toolkit/result/psi_result.py +++ /dev/null @@ -1,27 +0,0 @@ -import os - -from wedpr_ml_toolkit.wedpr_data.data_context import DataContext -from wedpr_ml_toolkit.common.base_result import BaseResult - - -class PSIResult(BaseResult): - - PSI_RESULT_FILE = "psi_result.csv" - - def __init__(self, dataset: DataContext, job_id: str): - - super().__init__(dataset.ctx) - self.job_id = job_id - - participant_id_list = [] - for dataset in self.dataset.datasets: - participant_id_list.append(dataset.agency.agency_id) - self.participant_id_list = participant_id_list - - result_list = [] - for dataset in self.dataset.datasets: - dataset.update_path(os.path.join(self.job_id, self.PSI_RESULT_FILE)) - result_list.append(dataset) - - psi_result = DataContext(*result_list) - return psi_result diff --git a/python/wedpr_ml_toolkit/test/test_dev.py b/python/wedpr_ml_toolkit/test/test_dev.py index bfaf7a8d..db5efc9a 100644 --- a/python/wedpr_ml_toolkit/test/test_dev.py +++ b/python/wedpr_ml_toolkit/test/test_dev.py @@ -5,9 +5,9 @@ from wedpr_ml_toolkit.common.base_context import BaseContext from wedpr_ml_toolkit.utils.agency import Agency -from wedpr_ml_toolkit.wedpr_data.wedpr_data import WedprData -from wedpr_ml_toolkit.wedpr_data.data_context import DataContext -from wedpr_ml_toolkit.wedpr_session.wedpr_session import WedprSession +from wedpr_ml_toolkit.toolkit import DatasetToolkit +from wedpr_ml_toolkit.context.data_context import DataContext +from wedpr_ml_toolkit.context.job_context import JobContext # 从jupyter环境中获取project_id等信息 @@ -15,15 +15,15 @@ # 相同项目/刷新专家模式project_id固定 project_id = '测试-xinyi' user = 'flyhuang1' -my_agency='SGD' +my_agency = 'sgd' pws_endpoint = 'http://139.159.202.235:8005' # http hdfs_endpoint = 'http://192.168.0.18:50070' # client token = 'abc...' # 自定义合作方机构 -partner_agency1='WeBank' -partner_agency2='TX' +partner_agency1 = 'webank' +partner_agency2 = 'TX' # 初始化project ctx 信息 ctx = BaseContext(project_id, user, pws_endpoint, hdfs_endpoint, token) @@ -40,13 +40,14 @@ **{f'x{i}': np.random.rand(100) for i in range(1, 11)} # x1到x10列,随机数 }) -dataset1 = WedprData(ctx, values=df, agency=agency1, is_label_holder=True) +dataset1 = DatasetToolkit(ctx, values=df, agency=agency1, is_label_holder=True) dataset1.storage_client = None -dataset1.save_values(path='d-101') # './milestone2\\sgd\\flyhuang1\\share\\d-101' +# './milestone2\\sgd\\flyhuang1\\share\\d-101' +dataset1.save_values(path='d-101') # hdfs_path -ctx2 = BaseContext(project_id, 'flyhuang', pws_endpoint, hdfs_endpoint, token) -dataset2 = WedprData(ctx2, dataset_path='/user/ppc/milestone2/webank/flyhuang/d-9606695119693829', agency=agency2) +dataset2 = DatasetToolkit( + ctx, dataset_path='/user/ppc/milestone2/webank/flyhuang/d-9606695119693829', agency=agency2) dataset2.storage_client = None # dataset2.load_values() if dataset2.storage_client is None: @@ -57,20 +58,21 @@ }) dataset2.update_values(values=df2) if dataset1.storage_client is not None: - dataset1.update_values(path='/user/ppc/milestone2/sgd/flyhuang1/d-9606704699156485') + dataset1.update_values( + path='/user/ppc/milestone2/sgd/flyhuang1/d-9606704699156485') dataset1.load_values() # 构建 dataset context dataset = DataContext(dataset1, dataset2) # 初始化 wedpr task session(含数据) -task = WedprSession(dataset, my_agency=my_agency) +task = JobContext(dataset, my_agency=my_agency) print(task.participant_id_list, task.result_receiver_id_list) # 执行psi任务 psi_result = task.psi() # 初始化 wedpr task session(不含数据) (推荐:使用更灵活) -task = WedprSession(my_agency=my_agency) +task = JobContext(my_agency=my_agency) # 执行psi任务 fe_result = task.proprecessing(dataset) print(task.participant_id_list, task.result_receiver_id_list) diff --git a/python/wedpr_ml_toolkit/wedpr_data/__init__.py b/python/wedpr_ml_toolkit/transport/__init__.py similarity index 100% rename from python/wedpr_ml_toolkit/wedpr_data/__init__.py rename to python/wedpr_ml_toolkit/transport/__init__.py diff --git a/python/wedpr_ml_toolkit/transport/credential_generator.py b/python/wedpr_ml_toolkit/transport/credential_generator.py new file mode 100644 index 00000000..c8a23a7d --- /dev/null +++ b/python/wedpr_ml_toolkit/transport/credential_generator.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +import hashlib +from common import utils +import time + + +class CredentialInfo: + ACCESS_ID_KEY = "accessKeyID" + NONCE_KEY = "nonce" + TIMESTAMP_KEY = "timestamp" + SIGNATURE_KEY = "signature" + + def __init__(self, access_key_id: str, nonce: str, timestamp: str, signature: str): + self.access_key_id = access_key_id + self.nonce = nonce + self.timestamp = timestamp + self.signature = signature + + def to_dict(self): + result = {} + result.update(CredentialInfo.ACCESS_ID_KEY, self.access_key_id) + result.update(CredentialInfo.NONCE_KEY, self.nonce) + result.update(CredentialInfo.TIMESTAMP_KEY, self.timestamp) + result.update(CredentialInfo.SIGNATURE_KEY, self.signature) + + def update_url_with_auth_info(self, url): + auth_params = self.to_dict() + return utils.add_params_to_url(auth_params) + + +class CredentialGenerator: + def __init__(self, access_key_id: str, access_key_secret: str, nonce_len=5): + self.access_key_id = access_key_id + self.access_key_secret = access_key_secret + self.nonce_len = nonce_len + + def generate_credential(self) -> CredentialInfo: + nonce = utils.generate_nonce(self.nonce_len) + timestamp = int(time.time()) + # generate the signature + signature = CredentialGenerator.generate_signature( + self.access_key_id, self.access_key_secret, nonce, timestamp) + return CredentialInfo(self.access_key_id, nonce, timestamp, signature) + + @staticmethod + def generate_signature(access_key_id, access_key_secret, nonce, timestamp) -> str: + anti_replay_info_hash = hashlib.sha3_256() + # hash(access_key_id + nonce + timestamp) + anti_replay_info = f"{access_key_id}{nonce}{timestamp}" + anti_replay_info_hash.update(anti_replay_info) + # hash(anti_replay_info + access_key_secret) + signature_hash = hashlib.sha3_256() + signature_hash.update(anti_replay_info_hash.hexdigest()) + signature_hash.update(access_key_secret) + return signature_hash.hexdigest() diff --git a/python/wedpr_ml_toolkit/job_exceuter/hdfs_client.py b/python/wedpr_ml_toolkit/transport/storage_entrypoint.py similarity index 76% rename from python/wedpr_ml_toolkit/job_exceuter/hdfs_client.py rename to python/wedpr_ml_toolkit/transport/storage_entrypoint.py index d8c7be68..06770d05 100644 --- a/python/wedpr_ml_toolkit/job_exceuter/hdfs_client.py +++ b/python/wedpr_ml_toolkit/transport/storage_entrypoint.py @@ -2,16 +2,17 @@ import io from ppc_common.deps_services import storage_loader +from wedpr_ml_toolkit.config.wedpr_ml_config import StorageConfig -class HDFSApi: - def __init__(self, hdfs_endpoint): - self.hdfs_endpoint = hdfs_endpoint +class StorageEntryPoint: + def __init__(self, storage_config: StorageConfig): + self.storage_config = storage_config config_data = {} config_data['STORAGE_TYPE'] = 'HDFS' - config_data['HDFS_URL'] = self.hdfs_endpoint - config_data['HDFS_ENDPOINT'] = self.hdfs_endpoint + config_data['HDFS_URL'] = self.storage_config.endpoint + config_data['HDFS_ENDPOINT'] = self.storage_config.endpoint self.storage_client = storage_loader.load(config_data, logger=None) def upload(self, dataframe, hdfs_path): @@ -34,7 +35,7 @@ def download(self, hdfs_path): :return: Pandas DataFrame """ content = self.storage_client.get_data(hdfs_path) - dataframe = pd.read_csv(io.BytesIO(content)) + dataframe = pd.read_csv(io.BytesIO(content)) return dataframe def download_byte(self, hdfs_path): diff --git a/python/wedpr_ml_toolkit/transport/wedpr_entrypoint.py b/python/wedpr_ml_toolkit/transport/wedpr_entrypoint.py new file mode 100644 index 00000000..0803bc9a --- /dev/null +++ b/python/wedpr_ml_toolkit/transport/wedpr_entrypoint.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +import requests +from transport.credential_generator import CredentialGenerator +from common.utils.constant import Constant +import json + + +class LoadBanlancer: + def __init__(self, remote_entrypoints): + if remote_entrypoints == None or len(remote_entrypoints) == 0: + raise Exception(f"Must define the wedpr entrypoints") + self.remote_entrypoints = remote_entrypoints + self.round_robin_idx = 0 + + # choose with round-robin policy + def select(self, uri_path: str): + selected_idx = self.round_robin_idx + self.round_robin_idx += 1 + selected_entrypoint = self.remote_entrypoints[selected_idx % len( + self.remote_entrypoints)] + return f"{selected_entrypoint}/${uri_path}" + + +class WeDPREntryPoint: + def __init__(self, access_key_id: str, access_key_secret: str, remote_entrypoints, nonce_len: int = 5): + self.credential_generator = CredentialGenerator( + access_key_id, access_key_secret, nonce_len) + self.loadbalancer = LoadBanlancer(remote_entrypoints) + + def send_post_request(self, uri, params, headers, data): + credential_info = self.credential_generator.generate_credential() + url = credential_info.update_url_with_auth_info( + self.loadbalancer.select(uri)) + if not headers: + headers = {'content-type': 'application/json'} + response = requests.post(url, data=data, json=params, headers=headers) + if response.status_code != Constant.HTTP_STATUS_OK: + raise Exception( + f"send post request to {url} failed, response: {response.text}") + # parse the result + return json.loads(response.text) diff --git a/python/wedpr_ml_toolkit/transport/wedpr_remote_job_client.py b/python/wedpr_ml_toolkit/transport/wedpr_remote_job_client.py new file mode 100644 index 00000000..45246c63 --- /dev/null +++ b/python/wedpr_ml_toolkit/transport/wedpr_remote_job_client.py @@ -0,0 +1,91 @@ +from transport.wedpr_entrypoint import WeDPREntryPoint +from common.utils.constant import Constant +from config.wedpr_ml_config import AuthConfig +from config.wedpr_ml_config import JobConfig +import random +import json +import time + + +class JobInfo: + def __init__(self, job_type: str, project_name: str, param: str): + self.jobType = job_type + self.projectName = project_name + self.param = param + + +class JobParam: + def __init__(self, job_info: JobInfo, task_parities, dataset_list): + self.job = job_info + self.taskParties = task_parities + self.datasetList = dataset_list + + +class WeDPRResponse: + def __init__(self, code, msg, data): + self.code = code + self.msg = msg + self.data = data + + def success(self): + return self.code == 0 + + +class QueryJobRequest: + def __init__(self, job_id): + # the job condition + self.job = {} + self.job.update("id", job_id) + + +class WeDPRRemoteJobClient(WeDPREntryPoint): + def __init__(self, auth_config: AuthConfig, job_config: JobConfig): + if auth_config is None: + raise Exception("Must define the auth config!") + if job_config is None: + raise Exception("Must define the job config") + super().__init__(auth_config.access_key_id, auth_config.access_key_secret, + auth_config.remote_entrypoints, auth_config.nonce_len) + self.auth_config = auth_config + self.job_config = job_config + + def get_auth_config(self): + return self.auth_config + + def get_job_config(self): + return self.job_config + + def submit_job(self, job_params: JobParam) -> WeDPRResponse: + wedpr_response = self.send_post_request( + self.job_config._submit_job_uri, None, None, json.dumps(job_params)) + submit_result = WeDPRResponse(**wedpr_response) + # return the job_id + if submit_result.success(): + return submit_result.data + raise Exception( + f"submit_job failed, code: {submit_result.code}, msg: {submit_result.msg}") + + def _poll_task_status(self, job_id, token): + while True: + wedpr_response = WeDPRResponse(self._send_request_with_retry( + self.send_post_request, self.job_config.query_job_status_uri, None, None, json.dumps(QueryJobRequest(job_id)))) + # TODO: check with status + if wedpr_response.success(): + return wedpr_response + else: + raise Exception( + f"_poll_task_status for job {job_id} failed, code: {wedpr_response.code}, msg: {wedpr_response.msg}") + time.sleep(self.job_config.polling_interval_s) + + def _send_request_with_retry(self, request_func, *args, **kwargs): + attempt = 0 + while attempt < self.max_retries: + try: + response = request_func(*args, **kwargs) + return response + except Exception as e: + attempt += 1 + if attempt < self.max_retries: + time.sleep(self.retry_delay_s) + else: + raise e diff --git a/python/wedpr_ml_toolkit/utils/agency.py b/python/wedpr_ml_toolkit/utils/agency.py deleted file mode 100644 index 461bce03..00000000 --- a/python/wedpr_ml_toolkit/utils/agency.py +++ /dev/null @@ -1,5 +0,0 @@ -class Agency: - - def __init__(self, agency_id): - - self.agency_id = agency_id diff --git a/python/wedpr_ml_toolkit/utils/utils.py b/python/wedpr_ml_toolkit/utils/utils.py deleted file mode 100644 index 25e0ec72..00000000 --- a/python/wedpr_ml_toolkit/utils/utils.py +++ /dev/null @@ -1,12 +0,0 @@ -import uuid -from enum import Enum - - -class IdPrefixEnum(Enum): - DATASET = "d-" - ALGORITHM = "a-" - JOB = "j-" - - -def make_id(prefix): - return prefix + str(uuid.uuid4()).replace("-", "") diff --git a/python/wedpr_ml_toolkit/wedpr_data/wedpr_data.py b/python/wedpr_ml_toolkit/wedpr_data/wedpr_data.py deleted file mode 100644 index eb1f487d..00000000 --- a/python/wedpr_ml_toolkit/wedpr_data/wedpr_data.py +++ /dev/null @@ -1,68 +0,0 @@ -import os -import pandas as pd - -from wedpr_ml_toolkit.common.base_context import BaseContext -from wedpr_ml_toolkit.job_exceuter.hdfs_client import HDFSApi - - -class WedprData: - - def __init__(self, - ctx: BaseContext, - dataset_id=None, - dataset_path=None, - agency=None, - values=None, - is_label_holder=False): - - # super().__init__(project_id) - self.ctx = ctx - - self.dataset_id = dataset_id - self.dataset_path = dataset_path - self.agency = agency - self.values = values - self.is_label_holder = is_label_holder - self.columns = None - self.shape = None - - self.storage_client = HDFSApi(self.ctx.hdfs_endpoint) - self.storage_workspace = os.path.join(self.ctx.workspace, self.agency.agency_id, self.ctx.user_name, 'share') - - if self.values is not None: - self.columns = self.values.columns - self.shape = self.values.shape - - def load_values(self): - # 加载hdfs的数据集 - if self.storage_client is not None: - self.values = self.storage_client.download(self.dataset_path) - self.columns = self.values.columns - self.shape = self.values.shape - - def save_values(self, path=None): - # 保存数据到hdfs目录 - if path is not None: - self.dataset_path = path - if not self.dataset_path.startswith(self.ctx.workspace): - self.dataset_path = os.path.join(self.storage_workspace, self.dataset_path) - if self.storage_client is not None: - self.storage_client.upload(self.values, self.dataset_path) - - def update_values(self, values: pd.DataFrame = None, path: str = None): - # 将数据集存入hdfs相同路径,替换旧数据集 - if values is not None: - self.values = values - self.columns = self.values.columns - self.shape = self.values.shape - if path is not None: - self.dataset_path = path - if values is not None and self.storage_client is not None: - self.storage_client.upload(self.values, self.dataset_path) - - def update_path(self, path: str = None): - # 将数据集存入hdfs相同路径,替换旧数据集 - if path is not None: - self.dataset_path = path - if self.values is not None: - self.values = None diff --git a/python/wedpr_ml_toolkit/wedpr_ml_toolkit.py b/python/wedpr_ml_toolkit/wedpr_ml_toolkit.py new file mode 100644 index 00000000..02fa8d47 --- /dev/null +++ b/python/wedpr_ml_toolkit/wedpr_ml_toolkit.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- + +from config.wedpr_ml_config import WeDPRMlConfig +from wedpr_ml_toolkit.transport.wedpr_remote_job_client import WeDPRRemoteJobClient +from wedpr_ml_toolkit.transport.storage_entrypoint import StorageEntryPoint + + +class WeDPRMlToolkit: + def __init__(self, config: WeDPRMlConfig): + self.config = config + self.remote_job_client = WeDPRRemoteJobClient( + self.config.auth_config, self.config.job_config) + self.storage_entry_point = StorageEntryPoint( + self.config.storage_config) + + def get_config(self) -> WeDPRMlConfig: + return self.config + + def get_remote_job_client(self) -> WeDPRRemoteJobClient: + return self.remote_job_client + + def get_storage_entry_point(self) -> StorageEntryPoint: + return self.storage_entry_point + + def submit(self, job_param): + return self.remote_job_client.submit_job(job_param) diff --git a/python/wedpr_ml_toolkit/wedpr_session/__init__.py b/python/wedpr_ml_toolkit/wedpr_session/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/python/wedpr_ml_toolkit/wedpr_session/wedpr_session.py b/python/wedpr_ml_toolkit/wedpr_session/wedpr_session.py deleted file mode 100644 index 2cdae065..00000000 --- a/python/wedpr_ml_toolkit/wedpr_session/wedpr_session.py +++ /dev/null @@ -1,158 +0,0 @@ -import json - -from wedpr_ml_toolkit.wedpr_data.data_context import DataContext -from wedpr_ml_toolkit.job_exceuter.pws_client import PWSApi -from wedpr_ml_toolkit.result.psi_result import PSIResult -from wedpr_ml_toolkit.result.fe_result import FeResult -from wedpr_ml_toolkit.result.model_result import ModelResult - - -class WedprSession: - - def __init__(self, dataset: DataContext = None, my_agency = None): - - self.dataset = dataset - self.create_agency = my_agency - self.participant_id_list = [] - self.task_parties = [] - self.dataset_id_list = [] - self.dataset_list = [] - self.label_holder_agency = None - self.label_columns = None - - if self.dataset is not None: - self.get_agencies() - self.get_label_holder_agency() - self.result_receiver_id_list = [my_agency] # 仅限jupyter所在机构 - - self.excute = PWSApi(self.dataset.ctx.pws_endpoint, self.dataset.ctx.token) - - def task(self, params: dict = {}): - - self.check_agencies() - job_response = self.excute.run(params) - - return job_response['data'] - - def psi(self, dataset: DataContext = None, merge_filed: str = 'id'): - - if dataset is not None: - self.update_dataset(dataset) - - self.dataset_list = self.dataset.to_psi_format(merge_filed, self.result_receiver_id_list) - - # 构造参数 - # params = {merge_filed: merge_filed} - params = {'jobType': 'PSI', - 'projectName': self.dataset.ctx.project_id, - 'param': json.dumps({'dataSetList': self.dataset_list}).replace('"', '\\"'), - 'taskParties': self.task_parties, - 'datasetList': self.dataset_id_list} - - # 执行任务 - job_id = self.task(params) - - # 结果处理 - psi_result = PSIResult(dataset, 'psi-' + job_id) - - return psi_result - - def proprecessing(self, dataset: DataContext = None, psi_result = None, params: dict = None): - - if dataset is not None: - self.update_dataset(dataset) - - job_id = self.task(self.dataset.to_model_formort()) - - # 结果处理 - datasets_pre = FeResult(dataset, job_id) - - return datasets_pre - - def feature_engineering(self, dataset: DataContext = None, psi_result = None, params: dict = None): - - if dataset is not None: - self.update_dataset(dataset) - - job_id = self.task(self.dataset.to_model_formort()) - - # 结果处理 - datasets_fe = FeResult(dataset, job_id) - - return datasets_fe - - def xgb_training(self, dataset: DataContext = None, psi_result = None, params: dict = None): - - if dataset is not None: - self.update_dataset(dataset) - self.check_datasets() - - job_id = self.task(self.dataset.to_model_formort()) - - # 结果处理 - model_result = ModelResult(dataset, job_id, job_type='xgb_training') - - return model_result - - def xgb_predict(self, dataset: DataContext = None, psi_result = None, model = None): - - if dataset is not None: - self.update_dataset(dataset) - self.check_datasets() - - # 上传模型到hdfs - job_id = self.task(self.dataset.to_model_formort()) - - # 结果处理 - model_result = ModelResult(dataset, job_id, job_type='xgb_predict') - - # 结果处理 - return model_result - - def update_dataset(self, dataset: DataContext): - self.dataset = dataset - self.participant_id_list = self.get_agencies() - self.label_holder_agency = self.get_label_holder_agency() - - def get_agencies(self): - participant_id_list = [] - dataset_id_list = [] - for dataset in self.dataset.datasets: - participant_id_list.append(dataset.agency.agency_id) - dataset_id_list.append(dataset.dataset_id) - self.task_parties.append({'userName': dataset.ctx.user_name, - 'agency': dataset.agency.agency_id}) - self.participant_id_list = participant_id_list - self.dataset_id_list = dataset_id_list - - def get_label_holder_agency(self): - label_holder_agency = None - label_columns = None - for dataset in self.dataset.datasets: - if dataset.is_label_holder: - label_holder_agency = dataset.agency.agency_id - label_columns = 'y' - self.label_holder_agency = label_holder_agency - self.label_columns = label_columns - - def check_agencies(self): - """ - 校验机构数和任务是否匹配 - """ - if len(self.participant_id_list) < 2: - raise ValueError("至少需要传入两个机构") - - def check_datasets(self): - """ - 校验是否包含标签提供方 - """ - if not self.label_holder_agency or self.label_holder_agency not in self.participant_id_list: - raise ValueError("数据集中标签提供方配置错误") - - # def get_agencies(self): - # """ - # 返回所有机构名称的列表。 - - # :return: 机构名称的列表 - # """ - # return self.agencies