From 3fa242598ada2d602bd41a010bae5bb454034d7b Mon Sep 17 00:00:00 2001 From: cyjseagull Date: Mon, 21 Oct 2024 19:04:31 +0800 Subject: [PATCH] update script --- cpp/CMakeLists.txt | 4 +- cpp/ppc-framework/protocol/INodeInfo.h | 3 +- .../grpc/client/RemoteFrontBuilder.cpp | 2 +- .../protobuf/src/NodeInfoImpl.h | 5 ++ .../ppc-gateway/gateway/GatewayImpl.cpp | 1 + .../gateway/router/GatewayNodeInfo.h | 5 ++ .../gateway/router/GatewayNodeInfoImpl.cpp | 2 +- .../ppc-gateway/gateway/router/LocalRouter.h | 1 - .../gateway/router/PeerRouterTable.cpp | 18 +++- .../transport/impl/transport_loader.py | 2 +- python/ppc_model/common/initializer.py | 84 +++++++++---------- .../common/mock/mock_model_transport.py | 4 +- python/ppc_model/common/model_setting.py | 12 +-- python/ppc_model/conf/application-sample.yml | 2 +- python/ppc_model/conf/logging.conf | 2 +- python/ppc_model/datasets/dataset.py | 8 +- .../vertical/passive_party.py | 2 +- .../network/wedpr_model_transport.py | 53 +++++++----- .../network/wedpr_model_transport_api.py | 2 +- python/ppc_model/ppc_model_app.py | 5 +- .../ppc_model/secure_lgbm/vertical/booster.py | 2 +- .../ppc_model/secure_lr/vertical/booster.py | 3 +- python/ppc_model/task/task_manager.py | 2 +- python/ppc_model/tools/start.sh | 63 +++++++------- python/ppc_model/tools/stop.sh | 53 +++++++++--- 25 files changed, 201 insertions(+), 139 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index a92bbebc..74642c67 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -76,8 +76,7 @@ endif() set(TRANSPORT_SDK_SOURCE_LIST wedpr-protocol wedpr-transport/ppc-front - wedpr-transport/sdk - wedpr-transport/sdk-wrapper) + wedpr-transport/sdk) set(TRANSPORT_SDK_TOOLKIT_SOURCE_LIST ${TRANSPORT_SDK_SOURCE_LIST} @@ -90,6 +89,7 @@ set(ALL_SOURCE_LIST wedpr-helper/libhelper wedpr-helper/ppc-tools wedpr-storage/ppc-io wedpr-storage/ppc-storage wedpr-transport/ppc-gateway + wedpr-transport/ppc-rpc wedpr-transport/ppc-http wedpr-computing/ppc-psi wedpr-computing/ppc-mpc wedpr-computing/ppc-pir ${CEM_SOURCE} wedpr-initializer wedpr-main) diff --git a/cpp/ppc-framework/protocol/INodeInfo.h b/cpp/ppc-framework/protocol/INodeInfo.h index 37f8b772..6dbf5281 100644 --- a/cpp/ppc-framework/protocol/INodeInfo.h +++ b/cpp/ppc-framework/protocol/INodeInfo.h @@ -49,6 +49,7 @@ class INodeInfo virtual void setComponents(std::set const& components) = 0; virtual bool addComponent(std::string const& component) = 0; virtual bool eraseComponent(std::string const& component) = 0; + virtual bool componentExist(std::string const& component) const = 0; virtual std::set const& components() const = 0; virtual std::vector copiedComponents() const = 0; @@ -88,7 +89,7 @@ inline std::string printNodeInfo(INodeInfo::Ptr const& nodeInfo) stringstream << LOG_KV("endPoint", nodeInfo->endPoint()) << LOG_KV("nodeID", printNodeID(nodeInfo->nodeID())); std::string components = ""; - for (auto const& it : nodeInfo->components()) + for (auto const& it : nodeInfo->copiedComponents()) { components = components + it + ","; } diff --git a/cpp/wedpr-protocol/grpc/client/RemoteFrontBuilder.cpp b/cpp/wedpr-protocol/grpc/client/RemoteFrontBuilder.cpp index a2d06fa9..05129869 100644 --- a/cpp/wedpr-protocol/grpc/client/RemoteFrontBuilder.cpp +++ b/cpp/wedpr-protocol/grpc/client/RemoteFrontBuilder.cpp @@ -32,7 +32,7 @@ IFrontClient::Ptr RemoteFrontBuilder::buildClient(std::string endPoint, auto frontClient = std::make_shared(m_grpcConfig, endPoint); if (m_healthChecker) { - auto healthCheckHandler = std::make_shared("front" + endPoint); + auto healthCheckHandler = std::make_shared("front_" + endPoint); healthCheckHandler->checkHealthHandler = [frontClient]() { return frontClient->checkHealth(); }; diff --git a/cpp/wedpr-protocol/protobuf/src/NodeInfoImpl.h b/cpp/wedpr-protocol/protobuf/src/NodeInfoImpl.h index 2871074a..04f0496c 100644 --- a/cpp/wedpr-protocol/protobuf/src/NodeInfoImpl.h +++ b/cpp/wedpr-protocol/protobuf/src/NodeInfoImpl.h @@ -65,6 +65,11 @@ class NodeInfoImpl : public INodeInfo bcos::ReadGuard l(x_components); return m_components; } + bool componentExist(std::string const& component) const override + { + bcos::ReadGuard l(x_components); + return m_components.count(component); + } std::vector copiedComponents() const override { diff --git a/cpp/wedpr-transport/ppc-gateway/ppc-gateway/gateway/GatewayImpl.cpp b/cpp/wedpr-transport/ppc-gateway/ppc-gateway/gateway/GatewayImpl.cpp index 06dbab01..273bb7fe 100644 --- a/cpp/wedpr-transport/ppc-gateway/ppc-gateway/gateway/GatewayImpl.cpp +++ b/cpp/wedpr-transport/ppc-gateway/ppc-gateway/gateway/GatewayImpl.cpp @@ -253,6 +253,7 @@ bcos::Error::Ptr GatewayImpl::registerNodeInfo(ppc::protocol::INodeInfo::Ptr con return; } gateway->m_localRouter->unRegisterNode(nodeInfo->nodeID().toBytes()); + gateway->m_localRouter->increaseSeq(); }, true); return nullptr; diff --git a/cpp/wedpr-transport/ppc-gateway/ppc-gateway/gateway/router/GatewayNodeInfo.h b/cpp/wedpr-transport/ppc-gateway/ppc-gateway/gateway/router/GatewayNodeInfo.h index 6014a6e2..e12fe0d6 100644 --- a/cpp/wedpr-transport/ppc-gateway/ppc-gateway/gateway/router/GatewayNodeInfo.h +++ b/cpp/wedpr-transport/ppc-gateway/ppc-gateway/gateway/router/GatewayNodeInfo.h @@ -90,6 +90,11 @@ inline std::string printNodeStatus(GatewayNodeInfo::Ptr const& status) stringstream << LOG_KV("p2pNodeID", printP2PIDElegantly(status->p2pNodeID())) << LOG_KV("agency", status->agency()) << LOG_KV("statusSeq", status->statusSeq()) << LOG_KV("nodeSize", status->nodeSize()); + auto nodeInfoList = status->nodeList(); + for (auto const& it : nodeInfoList) + { + stringstream << printNodeInfo(it.second); + } return stringstream.str(); } } // namespace ppc::gateway \ No newline at end of file diff --git a/cpp/wedpr-transport/ppc-gateway/ppc-gateway/gateway/router/GatewayNodeInfoImpl.cpp b/cpp/wedpr-transport/ppc-gateway/ppc-gateway/gateway/router/GatewayNodeInfoImpl.cpp index 1c52fa8a..00a1e503 100644 --- a/cpp/wedpr-transport/ppc-gateway/ppc-gateway/gateway/router/GatewayNodeInfoImpl.cpp +++ b/cpp/wedpr-transport/ppc-gateway/ppc-gateway/gateway/router/GatewayNodeInfoImpl.cpp @@ -64,7 +64,7 @@ bool GatewayNodeInfoImpl::existComponent(std::string const& component) const bcos::ReadGuard l(x_nodeList); for (auto const& it : m_nodeList) { - if (it.second->components().count(component)) + if (it.second->componentExist(component)) { return true; } diff --git a/cpp/wedpr-transport/ppc-gateway/ppc-gateway/gateway/router/LocalRouter.h b/cpp/wedpr-transport/ppc-gateway/ppc-gateway/gateway/router/LocalRouter.h index d1fc9040..7c2e015e 100644 --- a/cpp/wedpr-transport/ppc-gateway/ppc-gateway/gateway/router/LocalRouter.h +++ b/cpp/wedpr-transport/ppc-gateway/ppc-gateway/gateway/router/LocalRouter.h @@ -67,7 +67,6 @@ class LocalRouter GatewayNodeInfo::Ptr const& routerInfo() const { return m_routerInfo; } -private: uint32_t increaseSeq() { uint32_t statusSeq = ++m_statusSeq; diff --git a/cpp/wedpr-transport/ppc-gateway/ppc-gateway/gateway/router/PeerRouterTable.cpp b/cpp/wedpr-transport/ppc-gateway/ppc-gateway/gateway/router/PeerRouterTable.cpp index b7988588..6c7c1f4a 100644 --- a/cpp/wedpr-transport/ppc-gateway/ppc-gateway/gateway/router/PeerRouterTable.cpp +++ b/cpp/wedpr-transport/ppc-gateway/ppc-gateway/gateway/router/PeerRouterTable.cpp @@ -34,6 +34,7 @@ void PeerRouterTable::updateGatewayInfo(GatewayNodeInfo::Ptr const& gatewayInfo) auto nodeList = gatewayInfo->nodeList(); removeP2PNodeIDFromNodeIDInfos(gatewayInfo); + removeP2PNodeIDFromAgencyInfos(gatewayInfo->p2pNodeID()); insertGatewayInfo(gatewayInfo); } @@ -170,6 +171,9 @@ std::vector PeerRouterTable::selectTargetNodes( auto selectedP2PNodes = selectRouter(routeType, routeInfo); if (selectedP2PNodes.empty()) { + PEER_ROUTER_LOG(INFO) << LOG_DESC("selectTargetNodes with empty result") + << LOG_KV("routeType", routeType) + << LOG_KV("routeInfo", printOptionalField(routeInfo)); return std::vector(); } for (auto const& it : selectedP2PNodes) @@ -177,11 +181,20 @@ std::vector PeerRouterTable::selectTargetNodes( auto nodeList = it->nodeList(); for (auto const& it : nodeList) { + if (routeType == RouteType::ROUTE_THROUGH_COMPONENT) + { + if (it.second->componentExist(routeInfo->componentType())) + { + targetNodeList.insert(std::string(it.first.begin(), it.first.end())); + } + continue; + } targetNodeList.insert(std::string(it.first.begin(), it.first.end())); } } PEER_ROUTER_LOG(INFO) << LOG_DESC("selectTargetNodes, result: ") - << printCollection(targetNodeList); + << printCollection(targetNodeList) << LOG_KV("routeType", routeType) + << LOG_KV("routeInfo", printOptionalField(routeInfo)); return std::vector(targetNodeList.begin(), targetNodeList.end()); } @@ -261,7 +274,7 @@ void PeerRouterTable::selectRouterByComponent(GatewayNodeInfos& choosedGateway, auto const& nodeListInfo = it->nodeList(); for (auto const& nodeInfo : nodeListInfo) { - if (nodeInfo.second->components().count(routeInfo->componentType())) + if (nodeInfo.second->componentExist(routeInfo->componentType())) { choosedGateway.insert(it); break; @@ -270,6 +283,7 @@ void PeerRouterTable::selectRouterByComponent(GatewayNodeInfos& choosedGateway, } } + void PeerRouterTable::asyncBroadcastMessage(ppc::protocol::Message::Ptr const& msg) const { bcos::ReadGuard l(x_mutex); diff --git a/cpp/wedpr-transport/sdk-wrapper/python/bindings/wedpr_python_gateway_sdk/transport/impl/transport_loader.py b/cpp/wedpr-transport/sdk-wrapper/python/bindings/wedpr_python_gateway_sdk/transport/impl/transport_loader.py index f5e18fd4..7ee50de4 100644 --- a/cpp/wedpr-transport/sdk-wrapper/python/bindings/wedpr_python_gateway_sdk/transport/impl/transport_loader.py +++ b/cpp/wedpr-transport/sdk-wrapper/python/bindings/wedpr_python_gateway_sdk/transport/impl/transport_loader.py @@ -55,7 +55,7 @@ def load(transport_config: TransportConfig) -> Transport: return Transport(transport, transport_config) @staticmethod - def build(self, transport_threadpool_size: int = 4, + def build(transport_threadpool_size: int = 4, transport_node_id: str = None, transport_gateway_targets: str = None, transport_host_ip: str = None, diff --git a/python/ppc_model/common/initializer.py b/python/ppc_model/common/initializer.py index 5ebcf5b9..8a89404e 100644 --- a/python/ppc_model/common/initializer.py +++ b/python/ppc_model/common/initializer.py @@ -7,6 +7,7 @@ from ppc_common.deps_services import storage_loader from ppc_common.ppc_utils import common_func +from ppc_common.ppc_async_executor.thread_event_manager import ThreadEventManager from wedpr_python_gateway_sdk.transport.impl.transport_loader import TransportLoader from ppc_model.network.wedpr_model_transport import ModelTransport from ppc_model.task.task_manager import TaskManager @@ -15,41 +16,35 @@ class Initializer: def __init__(self, log_config_path, config_path, plot_lock=None): self.log_config_path = log_config_path + logging.config.fileConfig(self.log_config_path) self.config_path = config_path self.config_data = None - self.grpc_options = None - self.task_manager = None - self.thread_event_manager = None - self.storage_client = None - # default send msg timeout - self.transport = None - self.send_msg_timeout_ms = 5000 - self.pop_msg_timeout_ms = 60000 - self.MODEL_COMPONENT = "WEDPR_MODEL" # 只用于测试 self.mock_logger = None self.public_key_length = 2048 self.homo_algorithm = 0 + self.init_config() + self.job_cache_dir = common_func.get_config_value( + "JOB_TEMP_DIR", "/tmp", self.config_data, False) + self.thread_event_manager = ThreadEventManager() + self.task_manager = TaskManager( + logger=self.logger(), + thread_event_manager=self.thread_event_manager, + task_timeout_h=self.config_data['TASK_TIMEOUT_H'] + ) + self.storage_client = storage_loader.load( + self.config_data, self.logger()) + # default send msg timeout + self.MODEL_COMPONENT = "WEDPR_MODEL" + self.send_msg_timeout_ms = 5000 + self.pop_msg_timeout_ms = 60000 + # for UT + self.transport = None # matplotlib 线程不安全,并行任务绘图增加全局锁 self.plot_lock = plot_lock if plot_lock is None: self.plot_lock = threading.Lock() - def init_all(self): - self.init_log() - self.init_config() - self.init_task_manager() - self.init_transport() - self.init_storage_client() - self.init_cache() - - def init_log(self): - logging.config.fileConfig(self.log_config_path) - - def init_cache(self): - self.job_cache_dir = common_func.get_config_value( - "JOB_TEMP_DIR", "/tmp", self.config_data, False) - def init_config(self): with open(self.config_path, 'rb') as f: self.config_data = yaml.safe_load(f.read()) @@ -59,34 +54,37 @@ def init_config(self): if 'HOMO_ALGORITHM' in self.config_data: self.homo_algorithm = self.config_data['HOMO_ALGORITHM'] - def init_transport(self): + def init_all(self): + agency_id = common_func.get_config_value( + "AGENCY_ID", None, self.config_data, True) + self.init_transport(task_manager=self.task_manager, + self_agency_id=agency_id, + component_type=self.MODEL_COMPONENT, + send_msg_timeout_ms=self.send_msg_timeout_ms, + pop_msg_timeout_ms=self.pop_msg_timeout_ms) + + def init_transport(self, task_manager: TaskManager, + self_agency_id: str, + component_type: str, + send_msg_timeout_ms: int, + pop_msg_timeout_ms: int): # create the transport transport = TransportLoader.build(**self.config_data) self.logger( - f"Create transport success, config: {self.get_config().desc()}") + f"Create transport success, config: {transport.get_config().desc()}") # start the transport transport.start() self.logger().info( f"Start transport success, config: {transport.get_config().desc()}") - transport.register_component(self.MODEL_COMPONENT) + transport.register_component(component_type) self.logger().info( - f"Register the component {self.MODEL_COMPONENT} success") + f"Register the component {component_type} success") self.transport = ModelTransport(transport=transport, - task_manager=self.task_manager, - component_type=self.MODEL_COMPONENT, - send_msg_timeout_ms=self.send_msg_timeout_ms, - pop_msg_timeout_ms=self.pop_msg_timeout_ms) - - def init_task_manager(self): - self.task_manager = TaskManager( - logger=self.logger(), - thread_event_manager=self.thread_event_manager, - task_timeout_h=self.config_data['TASK_TIMEOUT_H'] - ) - - def init_storage_client(self): - self.storage_client = storage_loader.load( - self.config_data, self.logger()) + self_agency_id=self_agency_id, + task_manager=task_manager, + component_type=component_type, + send_msg_timeout_ms=send_msg_timeout_ms, + pop_msg_timeout_ms=pop_msg_timeout_ms) def logger(self, name=None): if self.mock_logger is None: diff --git a/python/ppc_model/common/mock/mock_model_transport.py b/python/ppc_model/common/mock/mock_model_transport.py index e2048a56..e4f97521 100644 --- a/python/ppc_model/common/mock/mock_model_transport.py +++ b/python/ppc_model/common/mock/mock_model_transport.py @@ -12,9 +12,9 @@ def __init__(self, agency_name): def get_topic(task_id: str, task_type: str, dst_agency: str): return f"{dst_agency}_{task_id}{task_type}" - def push_by_nodeid(self, task_id: str, task_type: str, dst_node: str, dst_inst: str, payload: bytes, seq: int = 0): + def push_by_nodeid(self, task_id: str, task_type: str, dst_node: str, payload: bytes, seq: int = 0): self.msg_queue.update({MockModelTransportApi.get_topic( - task_id, task_type, dst_inst): payload}) + task_id, task_type, self.agency_name): payload}) def pop(self, task_id: str, task_type: str, dst_inst: str): topic = MockModelTransportApi.get_topic(task_id, task_type, dst_inst) diff --git a/python/ppc_model/common/model_setting.py b/python/ppc_model/common/model_setting.py index d2340878..d325b73a 100644 --- a/python/ppc_model/common/model_setting.py +++ b/python/ppc_model/common/model_setting.py @@ -42,7 +42,7 @@ def __init__(self, model_dict): "iv_thresh", 0.1, model_dict, False)) -class CommmonModelSetting: +class CommonModelSetting: def __init__(self, model_dict): self.learning_rate = float(common_func.get_config_value( "learning_rate", 0.1, model_dict, False)) @@ -67,7 +67,7 @@ def __init__(self, model_dict): "n_jobs", 0, model_dict, False)) -class SecureLGBMSetting(CommmonModelSetting): +class SecureLGBMSetting(CommonModelSetting): def __init__(self, model_dict): super().__init__(model_dict) self.test_size = float(common_func.get_config_value( @@ -107,7 +107,7 @@ def __init__(self, model_dict): "one_hot", 0, model_dict, False) -class SecureLRSetting(CommmonModelSetting): +class SecureLRSetting(CommonModelSetting): def __init__(self, model_dict): super().__init__(model_dict) self.feature_rate = float(common_func.get_config_value( @@ -123,8 +123,8 @@ def __init__(self, model_dict): # init PreprocessingSetting super().__init__(model_dict) # init FeatureEngineeringEngineSetting - super(FeatureEngineeringEngineSetting, self).__init__(model_dict) + FeatureEngineeringEngineSetting.__init__(self, model_dict) # init SecureLGBMSetting - super(SecureLGBMSetting, self).__init__(model_dict) + SecureLGBMSetting.__init__(self, model_dict) # init SecureLRSetting - super(SecureLRSetting, self).__init__(model_dict) + SecureLRSetting.__init__(self, model_dict) diff --git a/python/ppc_model/conf/application-sample.yml b/python/ppc_model/conf/application-sample.yml index 01b30aaf..cf3cbd11 100644 --- a/python/ppc_model/conf/application-sample.yml +++ b/python/ppc_model/conf/application-sample.yml @@ -40,4 +40,4 @@ transport_threadpool_size: 4 transport_node_id: "MODEL_WeBank_NODE" transport_gateway_targets: "ipv4:127.0.0.1:40600,127.0.0.1:40601" transport_host_ip: "127.0.0.1" -transport_listen_port: 6200 \ No newline at end of file +transport_listen_port: 6500 \ No newline at end of file diff --git a/python/ppc_model/conf/logging.conf b/python/ppc_model/conf/logging.conf index b9b3bdb8..f78ab644 100644 --- a/python/ppc_model/conf/logging.conf +++ b/python/ppc_model/conf/logging.conf @@ -29,7 +29,7 @@ formatter=simpleFormatter [handler_consoleHandler] class=StreamHandler args=(sys.stdout,) -level=ERROR +level=INFO formatter=simpleFormatter [formatters] diff --git a/python/ppc_model/datasets/dataset.py b/python/ppc_model/datasets/dataset.py index 451ee44a..994cef06 100644 --- a/python/ppc_model/datasets/dataset.py +++ b/python/ppc_model/datasets/dataset.py @@ -112,7 +112,7 @@ def _random_split_dataset(self): def _customized_split_dataset(self): if self.ctx.role == TaskRole.ACTIVE_PARTY: for partner_index in range(1, len(self.ctx.participant_id_list)): - byte_data = SendMessage._receive_byte_data(self.ctx.components.stub, self.ctx, + byte_data = SendMessage._receive_byte_data(self.ctx.model_router, self.ctx, f'{CommonMessage.EVAL_SET_FILE.value}', partner_index) if not os.path.exists(self.eval_column_file) and byte_data != bytes(): with open(self.eval_column_file, 'wb') as f: @@ -120,7 +120,7 @@ def _customized_split_dataset(self): with open(self.eval_column_file, 'rb') as f: byte_data = f.read() for partner_index in range(1, len(self.ctx.participant_id_list)): - SendMessage._send_byte_data(self.ctx.components.stub, self.ctx, f'{CommonMessage.EVAL_SET_FILE.value}', + SendMessage._send_byte_data(self.ctx.model_router, self.ctx, f'{CommonMessage.EVAL_SET_FILE.value}', byte_data, partner_index) else: if not os.path.exists(self.eval_column_file): @@ -128,9 +128,9 @@ def _customized_split_dataset(self): else: with open(self.eval_column_file, 'rb') as f: byte_data = f.read() - SendMessage._send_byte_data(self.ctx.components.stub, self.ctx, f'{CommonMessage.EVAL_SET_FILE.value}', + SendMessage._send_byte_data(self.ctx.model_router, self.ctx, f'{CommonMessage.EVAL_SET_FILE.value}', byte_data, 0) - byte_data = SendMessage._receive_byte_data(self.ctx.components.stub, self.ctx, + byte_data = SendMessage._receive_byte_data(self.ctx.model_router, self.ctx, f'{CommonMessage.EVAL_SET_FILE.value}', 0) if not os.path.exists(self.eval_column_file): with open(self.eval_column_file, 'wb') as f: diff --git a/python/ppc_model/feature_engineering/vertical/passive_party.py b/python/ppc_model/feature_engineering/vertical/passive_party.py index 9f6b36a4..f3d1aaaa 100644 --- a/python/ppc_model/feature_engineering/vertical/passive_party.py +++ b/python/ppc_model/feature_engineering/vertical/passive_party.py @@ -151,7 +151,7 @@ def _send_all_enc_aggr_labels(self, public_key, aggr_labels_bytes_list): def _get_and_save_result(self): active_party = self.ctx.participant_id_list[0] - if self.ctx.components.stub.agency_id in self.ctx.result_receiver_id_list: + if self.ctx.components.transport.self_agency_id in self.ctx.result_receiver_id_list: # 保存来自标签方的woe/iv结果 data = self.ctx.model_router.pop( task_id=self.ctx.task_id, task_type=FeMessage.WOE_FILE.value, from_inst=active_party) diff --git a/python/ppc_model/network/wedpr_model_transport.py b/python/ppc_model/network/wedpr_model_transport.py index b1cec812..b260d9bc 100644 --- a/python/ppc_model/network/wedpr_model_transport.py +++ b/python/ppc_model/network/wedpr_model_transport.py @@ -9,9 +9,12 @@ class ModelTransport(ModelTransportApi): - def __init__(self, transport: Transport, task_manager: TaskManager, - component_type, send_msg_timeout_ms: int = 5000, pop_msg_timeout_ms: int = 60000): + def __init__(self, transport: Transport, self_agency_id: str, task_manager: TaskManager, + component_type, + send_msg_timeout_ms: int = 5000, + pop_msg_timeout_ms: int = 60000): self.transport = transport + self.self_agency_id = self_agency_id # default send msg timeout self.send_msg_timeout = send_msg_timeout_ms self.pop_msg_timeout = pop_msg_timeout_ms @@ -19,8 +22,8 @@ def __init__(self, transport: Transport, task_manager: TaskManager, self.component_type = component_type @staticmethod - def get_topic(task_id: str, task_type: str, dst_agency: str): - return f"{dst_agency}_{task_id}{task_type}" + def get_topic(task_id: str, task_type: str, agency: str): + return f"{agency}_{task_id}_{task_type}" def push_by_component(self, task_id: str, task_type: str, dst_inst: str, data): self.transport.push_by_component(topic=self.get_topic(task_id, task_type, dst_inst), @@ -28,15 +31,15 @@ def push_by_component(self, task_id: str, task_type: str, dst_inst: str, data): component=self.component_type, payload=data, timeout=self.send_msg_timeout) - def push_by_nodeid(self, task_id: str, task_type: str, dst_node: str, dst_inst: str, payload: bytes, seq: int = 0): - self.transport.push_by_nodeid(topic=self.get_topic(task_id, task_type, dst_inst), + def push_by_nodeid(self, task_id: str, task_type: str, dst_node: str, payload: bytes, seq: int = 0): + self.transport.push_by_nodeid(topic=self.get_topic(task_id, task_type, self.self_agency_id), dstNode=bytes( dst_node, encoding="utf-8"), seq=seq, payload=payload, timeout=self.send_msg_timeout) def pop(self, task_id: str, task_type: str, dst_inst: str) -> MessageAPI: - while self.task_manager.task_finished(task_id): + while not self.task_manager.task_finished(task_id): msg = self.transport.pop(topic=self.get_topic( task_id, task_type, dst_inst), timeout_ms=self.pop_msg_timeout) # wait for the msg if the task is running @@ -55,6 +58,9 @@ def select_node(self, route_type: RouteType, dst_agency: str, dst_component: str return self.transport.select_node_by_route_policy(route_type=route_type, dst_inst=dst_agency, dst_component=dst_component) + def stop(self): + self.transport.stop() + class ModelRouter(ModelRouterApi): def __init__(self, logger, transport: ModelTransport, participant_id_list): @@ -62,30 +68,33 @@ def __init__(self, logger, transport: ModelTransport, participant_id_list): self.transport = transport self.participant_id_list = participant_id_list self.router_info = {} - - def __init_routers__(self): for participant in self.participant_id_list: - result = self.transport.select_node(route_type=RouteType.ROUTE_THROUGH_COMPONENT, - dst_agency=participant, dst_component=self.transport.get_component_type()) - if result is None: - raise Exception( - f"No router can reach participant {participant}") - self.logger.info( - f"ModelRouter, select node {result} for participant {participant}") - self.router_info.update({participant: result}) + self.__init_router__(participant) + + def __init_router__(self, participant): + result = self.transport.select_node(route_type=RouteType.ROUTE_THROUGH_COMPONENT, + dst_agency=participant, + dst_component=self.transport.get_component_type()) + self.logger.info( + f"__init_router__ for {participant}, result: {result}, component: {self.transport.get_component_type()}") + if result is None: + raise Exception( + f"No router can reach participant {participant}") + self.logger.info( + f"ModelRouter, select node {result} for participant {participant}, " + f"component: {self.transport.get_component_type()}") + self.router_info.update({participant: result}) + return result def __get_dstnode_by_participant(self, participant): if participant in self.router_info.keys(): return self.router_info.get(participant) - return None + return self.__init_router__(participant) def push(self, task_id: str, task_type: str, dst_agency: str, payload: bytes, seq: int = 0): dst_node = self.__get_dstnode_by_participant(dst_agency) - if dst_node is None: - raise Exception( - f"send message to {dst_agency} failed for no router!") self.transport.push_by_nodeid( - task_id=task_id, task_type=task_type, dst_node=dst_node, dst_inst=dst_agency, payload=payload, seq=seq) + task_id=task_id, task_type=task_type, dst_node=dst_node, payload=payload, seq=seq) def pop(self, task_id: str, task_type: str, from_inst: str) -> bytes: result = self.transport.pop( diff --git a/python/ppc_model/network/wedpr_model_transport_api.py b/python/ppc_model/network/wedpr_model_transport_api.py index adab0c82..c2f60172 100644 --- a/python/ppc_model/network/wedpr_model_transport_api.py +++ b/python/ppc_model/network/wedpr_model_transport_api.py @@ -5,7 +5,7 @@ class ModelTransportApi(ABC): @abstractmethod - def push_by_nodeid(self, task_id: str, task_type: str, dst_node: str, dst_inst: str, payload: bytes, seq: int = 0): + def push_by_nodeid(self, task_id: str, task_type: str, dst_node: str, payload: bytes, seq: int = 0): pass @abstractmethod diff --git a/python/ppc_model/ppc_model_app.py b/python/ppc_model/ppc_model_app.py index b1c0d7bc..0a95bc9a 100644 --- a/python/ppc_model/ppc_model_app.py +++ b/python/ppc_model/ppc_model_app.py @@ -66,7 +66,10 @@ def register_task_handler(): TransLogger(app, setup_console_handler=False), numthreads=2) protocol = 'http' - message = f"Starting ppc model server at {protocol}://{app.config['HOST']}:{app.config['HTTP_PORT']}" + message = f"Starting ppc model server at {protocol}://{app.config['HOST']}:{app.config['HTTP_PORT']} successfully" print(message) components.logger().info(message) server.start() + # stop the nodes + components.transport.stop() + print("Stop ppc model server successfully") diff --git a/python/ppc_model/secure_lgbm/vertical/booster.py b/python/ppc_model/secure_lgbm/vertical/booster.py index 803d6c3d..1fa7a5b6 100644 --- a/python/ppc_model/secure_lgbm/vertical/booster.py +++ b/python/ppc_model/secure_lgbm/vertical/booster.py @@ -167,7 +167,7 @@ def _send_byte_data(self, ctx, key_type, byte_data, partner_index): partner_id = ctx.participant_id_list[partner_index] self.ctx.model_router.push( - task_id=ctx.task_id, task_type=key_type, dst_agency=partner_id, data=byte_data) + task_id=ctx.task_id, task_type=key_type, dst_agency=partner_id, payload=byte_data) self.logger.info( f"task {ctx.task_id}: Sending {key_type} to {partner_id} finished, " f"data_size: {len(byte_data) / 1024}KB, time_costs: {time.time() - start_time}s") diff --git a/python/ppc_model/secure_lr/vertical/booster.py b/python/ppc_model/secure_lr/vertical/booster.py index e7cc5f1f..97400116 100644 --- a/python/ppc_model/secure_lr/vertical/booster.py +++ b/python/ppc_model/secure_lr/vertical/booster.py @@ -24,7 +24,6 @@ class VerticalBooster(SecureModelBooster): def __init__(self, ctx: SecureLRContext, dataset: SecureDataset) -> None: super().__init__(ctx) self.dataset = dataset - self._stub = ctx.components.stub self._iter_id = None @@ -209,7 +208,7 @@ def _send_byte_data(self, ctx, key_type, byte_data, partner_index): start_time = time.time() partner_id = ctx.participant_id_list[partner_index] self.ctx.model_router.push( - task_id=ctx.task_id, task_type=key_type, dst_agency=partner_id, data=byte_data) + task_id=ctx.task_id, task_type=key_type, dst_agency=partner_id, payload=byte_data) self.logger.info( f"task {ctx.task_id}: Sending {key_type} to {partner_id} finished, " f"data_size: {len(byte_data) / 1024}KB, time_costs: {time.time() - start_time}s") diff --git a/python/ppc_model/task/task_manager.py b/python/ppc_model/task/task_manager.py index 48159c8f..92764d4a 100644 --- a/python/ppc_model/task/task_manager.py +++ b/python/ppc_model/task/task_manager.py @@ -89,7 +89,7 @@ def kill_one_task(self, task_id: str): def task_finished(self, task_id: str) -> bool: (status, _, _) = self.status(task_id) - if status == TaskStatus.RUNNING: + if status == TaskStatus.RUNNING.value: return False return True diff --git a/python/ppc_model/tools/start.sh b/python/ppc_model/tools/start.sh index 477ac215..2eb4834f 100644 --- a/python/ppc_model/tools/start.sh +++ b/python/ppc_model/tools/start.sh @@ -1,36 +1,37 @@ #!/bin/bash +SHELL_FOLDER=$(cd $(dirname $0);pwd) +LOG_ERROR() { + content=${1} + echo -e "\033[31m[ERROR] ${content}\033[0m" +} -dirpath="$(cd "$(dirname "$0")" && pwd)" -cd $dirpath +LOG_INFO() { + content=${1} + echo -e "\033[32m[INFO] ${content}\033[0m" +} +binary_path=${SHELL_FOLDER}/ppc_model_app.py +cd ${SHELL_FOLDER} +node=$(basename ${SHELL_FOLDER}) +node_pid=$(ps aux|grep ${binary_path}|grep -v grep|awk '{print $2}') -# kill crypto process -crypto_pro_num=`ps -ef | grep /ppc/scripts | grep j- | grep -v 'grep' | awk '{print $2}' | wc -l` -for i in $( seq 1 $crypto_pro_num ) +if [ ! -z ${node_pid} ];then + echo " ${node} is running, pid is $node_pid." + exit 0 +else + nohup python ${binary_path} > start.out 2>&1 & + sleep 1.5 +fi +try_times=4 +i=0 +while [ $i -lt ${try_times} ] do - crypto_pid=`ps -ef | grep /ppc/scripts | grep j- | grep -v 'grep' | awk '{print $2}' | awk 'NR==1{print}'` - kill -9 $crypto_pid + node_pid=$(ps aux|grep ${binary_path}|grep -v grep|awk '{print $2}') + success_flag=$(tail -n20 start.out | grep successfully) + if [[ ! -z ${node_pid} && ! -z "${success_flag}" ]];then + echo -e "\033[32m ${node} start successfully pid=${node_pid}\033[0m" + exit 0 + fi + sleep 0.5 + ((i=i+1)) done - -sleep 1 - -nohup python ppc_model_app.py > start.out 2>&1 & - -check_service() { - try_times=5 - i=0 - while [ -z `ps -ef | grep ${1} | grep python | grep -v grep | awk '{print $2}'` ]; do - sleep 1 - ((i = i + 1)) - if [ $i -lt ${try_times} ]; then - echo -e "\033[32m.\033[0m\c" - else - echo -e "\033[31m\nServer ${1} isn't running. \033[0m" - return - fi - done - - echo -e "\033[32mServer ${1} started \033[0m" -} - -sleep 5 -check_service ppc_model_app.py \ No newline at end of file +echo -e "\033[31m Exceed waiting time. Please try again to start ${node} \033[0m" \ No newline at end of file diff --git a/python/ppc_model/tools/stop.sh b/python/ppc_model/tools/stop.sh index 3b290668..ae35743c 100644 --- a/python/ppc_model/tools/stop.sh +++ b/python/ppc_model/tools/stop.sh @@ -1,19 +1,46 @@ #!/bin/bash +SHELL_FOLDER=$(cd $(dirname $0);pwd) -dirpath="$(cd "$(dirname "$0")" && pwd)" -cd $dirpath +LOG_ERROR() { + content=${1} + echo -e "\033[31m[ERROR] ${content}\033[0m" +} -# kill crypto process -crypto_pro_num=`ps -ef | grep /ppc/scripts | grep j- | grep -v 'grep' | awk '{print $2}' | wc -l` -for i in $( seq 1 $crypto_pro_num ) -do - crypto_pid=`ps -ef | grep /ppc/scripts | grep j- | grep -v 'grep' | awk '{print $2}' | awk 'NR==1{print}'` - kill -9 $crypto_pid -done +LOG_INFO() { + content=${1} + echo -e "\033[32m[INFO] ${content}\033[0m" +} -sleep 1 +binary_path=${SHELL_FOLDER}/ppc_model_app.py +node=$(basename ${SHELL_FOLDER}) +node_pid=$(ps aux|grep ${binary_path}|grep -v grep|awk '{print $2}') +try_times=10 +i=0 +if [ -z ${node_pid} ];then + echo " ${node} isn't running." + exit 0 +fi -ppc_model_app_pid=`ps aux |grep ppc_model_app.py |grep -v grep |awk '{print $2}'` -kill -9 $ppc_model_app_pid +#Stop monitor here +dirs=($(ls -l ${SHELL_FOLDER} | awk '/^d/ {print $NF}')) +for dir in ${dirs[*]} +do + if [[ -f "${SHELL_FOLDER}/${dir}/node.mtail" && -f "${SHELL_FOLDER}/${dir}/stop_mtail_monitor.sh" ]];then + echo "try to start ${dir}" + bash ${SHELL_FOLDER}/${dir}/stop_mtail_monitor.sh & + fi +done -echo -e "\033[32mServer ppc_model_app.py killed. \033[0m" +[ ! -z ${node_pid} ] && kill ${node_pid} > /dev/null +while [ $i -lt ${try_times} ] +do + sleep 1 + node_pid=$(ps aux|grep ${binary_path}|grep -v grep|awk '{print $2}') + if [ -z ${node_pid} ];then + echo -e "\033[32m stop ${node} success.\033[0m" + exit 0 + fi + ((i=i+1)) +done +echo " Exceed maximum number of retries. Please try again to stop ${node}" +exit 1