Skip to content

Commit

Permalink
update script
Browse files Browse the repository at this point in the history
  • Loading branch information
cyjseagull committed Oct 21, 2024
1 parent 676b27e commit 2f3339c
Show file tree
Hide file tree
Showing 19 changed files with 171 additions and 117 deletions.
4 changes: 2 additions & 2 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ endif()
set(TRANSPORT_SDK_SOURCE_LIST
wedpr-protocol
wedpr-transport/ppc-front
wedpr-transport/sdk
wedpr-transport/sdk-wrapper)
wedpr-transport/sdk)

set(TRANSPORT_SDK_TOOLKIT_SOURCE_LIST
${TRANSPORT_SDK_SOURCE_LIST}
Expand All @@ -90,6 +89,7 @@ set(ALL_SOURCE_LIST
wedpr-helper/libhelper wedpr-helper/ppc-tools
wedpr-storage/ppc-io wedpr-storage/ppc-storage
wedpr-transport/ppc-gateway
wedpr-transport/ppc-rpc
wedpr-transport/ppc-http
wedpr-computing/ppc-psi wedpr-computing/ppc-mpc wedpr-computing/ppc-pir ${CEM_SOURCE}
wedpr-initializer wedpr-main)
Expand Down
3 changes: 2 additions & 1 deletion cpp/ppc-framework/protocol/INodeInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class INodeInfo
virtual void setComponents(std::set<std::string> const& components) = 0;
virtual bool addComponent(std::string const& component) = 0;
virtual bool eraseComponent(std::string const& component) = 0;
virtual bool componentExist(std::string const& component) const = 0;
virtual std::set<std::string> const& components() const = 0;
virtual std::vector<std::string> copiedComponents() const = 0;

Expand Down Expand Up @@ -88,7 +89,7 @@ inline std::string printNodeInfo(INodeInfo::Ptr const& nodeInfo)
stringstream << LOG_KV("endPoint", nodeInfo->endPoint())
<< LOG_KV("nodeID", printNodeID(nodeInfo->nodeID()));
std::string components = "";
for (auto const& it : nodeInfo->components())
for (auto const& it : nodeInfo->copiedComponents())
{
components = components + it + ",";
}
Expand Down
2 changes: 1 addition & 1 deletion cpp/wedpr-protocol/grpc/client/RemoteFrontBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ IFrontClient::Ptr RemoteFrontBuilder::buildClient(std::string endPoint,
auto frontClient = std::make_shared<FrontClient>(m_grpcConfig, endPoint);
if (m_healthChecker)
{
auto healthCheckHandler = std::make_shared<HealthCheckHandler>("front" + endPoint);
auto healthCheckHandler = std::make_shared<HealthCheckHandler>("front_" + endPoint);
healthCheckHandler->checkHealthHandler = [frontClient]() {
return frontClient->checkHealth();
};
Expand Down
5 changes: 5 additions & 0 deletions cpp/wedpr-protocol/protobuf/src/NodeInfoImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ class NodeInfoImpl : public INodeInfo
bcos::ReadGuard l(x_components);
return m_components;
}
bool componentExist(std::string const& component) const override
{
bcos::ReadGuard l(x_components);
return m_components.count(component);
}

std::vector<std::string> copiedComponents() const override
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ bcos::Error::Ptr GatewayImpl::registerNodeInfo(ppc::protocol::INodeInfo::Ptr con
return;
}
gateway->m_localRouter->unRegisterNode(nodeInfo->nodeID().toBytes());
gateway->m_localRouter->increaseSeq();
},
true);
return nullptr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ inline std::string printNodeStatus(GatewayNodeInfo::Ptr const& status)
stringstream << LOG_KV("p2pNodeID", printP2PIDElegantly(status->p2pNodeID()))
<< LOG_KV("agency", status->agency()) << LOG_KV("statusSeq", status->statusSeq())
<< LOG_KV("nodeSize", status->nodeSize());
auto nodeInfoList = status->nodeList();
for (auto const& it : nodeInfoList)
{
stringstream << printNodeInfo(it.second);
}
return stringstream.str();
}
} // namespace ppc::gateway
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ bool GatewayNodeInfoImpl::existComponent(std::string const& component) const
bcos::ReadGuard l(x_nodeList);
for (auto const& it : m_nodeList)
{
if (it.second->components().count(component))
if (it.second->componentExist(component))
{
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ class LocalRouter

GatewayNodeInfo::Ptr const& routerInfo() const { return m_routerInfo; }

private:
uint32_t increaseSeq()
{
uint32_t statusSeq = ++m_statusSeq;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ void PeerRouterTable::updateGatewayInfo(GatewayNodeInfo::Ptr const& gatewayInfo)
auto nodeList = gatewayInfo->nodeList();

removeP2PNodeIDFromNodeIDInfos(gatewayInfo);
removeP2PNodeIDFromAgencyInfos(gatewayInfo->p2pNodeID());
insertGatewayInfo(gatewayInfo);
}

Expand Down Expand Up @@ -170,18 +171,30 @@ std::vector<std::string> PeerRouterTable::selectTargetNodes(
auto selectedP2PNodes = selectRouter(routeType, routeInfo);
if (selectedP2PNodes.empty())
{
PEER_ROUTER_LOG(INFO) << LOG_DESC("selectTargetNodes with empty result")
<< LOG_KV("routeType", routeType)
<< LOG_KV("routeInfo", printOptionalField(routeInfo));
return std::vector<std::string>();
}
for (auto const& it : selectedP2PNodes)
{
auto nodeList = it->nodeList();
for (auto const& it : nodeList)
{
if (routeType == RouteType::ROUTE_THROUGH_COMPONENT)
{
if (it.second->componentExist(routeInfo->componentType()))
{
targetNodeList.insert(std::string(it.first.begin(), it.first.end()));
}
continue;
}
targetNodeList.insert(std::string(it.first.begin(), it.first.end()));
}
}
PEER_ROUTER_LOG(INFO) << LOG_DESC("selectTargetNodes, result: ")
<< printCollection(targetNodeList);
<< printCollection(targetNodeList) << LOG_KV("routeType", routeType)
<< LOG_KV("routeInfo", printOptionalField(routeInfo));
return std::vector<std::string>(targetNodeList.begin(), targetNodeList.end());
}

Expand Down Expand Up @@ -261,7 +274,7 @@ void PeerRouterTable::selectRouterByComponent(GatewayNodeInfos& choosedGateway,
auto const& nodeListInfo = it->nodeList();
for (auto const& nodeInfo : nodeListInfo)
{
if (nodeInfo.second->components().count(routeInfo->componentType()))
if (nodeInfo.second->componentExist(routeInfo->componentType()))
{
choosedGateway.insert(it);
break;
Expand All @@ -270,6 +283,7 @@ void PeerRouterTable::selectRouterByComponent(GatewayNodeInfos& choosedGateway,
}
}


void PeerRouterTable::asyncBroadcastMessage(ppc::protocol::Message::Ptr const& msg) const
{
bcos::ReadGuard l(x_mutex);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def load(transport_config: TransportConfig) -> Transport:
return Transport(transport, transport_config)

@staticmethod
def build(self, transport_threadpool_size: int = 4,
def build(transport_threadpool_size: int = 4,
transport_node_id: str = None,
transport_gateway_targets: str = None,
transport_host_ip: str = None,
Expand Down
76 changes: 33 additions & 43 deletions python/ppc_model/common/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from ppc_common.deps_services import storage_loader
from ppc_common.ppc_utils import common_func
from ppc_common.ppc_async_executor.thread_event_manager import ThreadEventManager
from wedpr_python_gateway_sdk.transport.impl.transport_loader import TransportLoader
from ppc_model.network.wedpr_model_transport import ModelTransport
from ppc_model.task.task_manager import TaskManager
Expand All @@ -15,41 +16,35 @@
class Initializer:
def __init__(self, log_config_path, config_path, plot_lock=None):
self.log_config_path = log_config_path
logging.config.fileConfig(self.log_config_path)
self.config_path = config_path
self.config_data = None
self.grpc_options = None
self.task_manager = None
self.thread_event_manager = None
self.storage_client = None
# default send msg timeout
self.transport = None
self.send_msg_timeout_ms = 5000
self.pop_msg_timeout_ms = 60000
self.MODEL_COMPONENT = "WEDPR_MODEL"
# 只用于测试
self.mock_logger = None
self.public_key_length = 2048
self.homo_algorithm = 0
self.init_config()
self.job_cache_dir = common_func.get_config_value(
"JOB_TEMP_DIR", "/tmp", self.config_data, False)
self.thread_event_manager = ThreadEventManager()
self.task_manager = TaskManager(
logger=self.logger(),
thread_event_manager=self.thread_event_manager,
task_timeout_h=self.config_data['TASK_TIMEOUT_H']
)
self.storage_client = storage_loader.load(
self.config_data, self.logger())
# default send msg timeout
self.MODEL_COMPONENT = "WEDPR_MODEL"
self.send_msg_timeout_ms = 5000
self.pop_msg_timeout_ms = 60000
# for UT
self.transport = None
# matplotlib 线程不安全,并行任务绘图增加全局锁
self.plot_lock = plot_lock
if plot_lock is None:
self.plot_lock = threading.Lock()

def init_all(self):
self.init_log()
self.init_config()
self.init_task_manager()
self.init_transport()
self.init_storage_client()
self.init_cache()

def init_log(self):
logging.config.fileConfig(self.log_config_path)

def init_cache(self):
self.job_cache_dir = common_func.get_config_value(
"JOB_TEMP_DIR", "/tmp", self.config_data, False)

def init_config(self):
with open(self.config_path, 'rb') as f:
self.config_data = yaml.safe_load(f.read())
Expand All @@ -59,34 +54,29 @@ def init_config(self):
if 'HOMO_ALGORITHM' in self.config_data:
self.homo_algorithm = self.config_data['HOMO_ALGORITHM']

def init_transport(self):
def init_all(self):
self.init_transport(task_manager=self.task_manager,
component_type=self.MODEL_COMPONENT,
send_msg_timeout_ms=self.send_msg_timeout_ms,
pop_msg_timeout_ms=self.pop_msg_timeout_ms)

def init_transport(self, task_manager: TaskManager, component_type: str, send_msg_timeout_ms: int, pop_msg_timeout_ms: int):
# create the transport
transport = TransportLoader.build(**self.config_data)
self.logger(
f"Create transport success, config: {self.get_config().desc()}")
f"Create transport success, config: {transport.get_config().desc()}")
# start the transport
transport.start()
self.logger().info(
f"Start transport success, config: {transport.get_config().desc()}")
transport.register_component(self.MODEL_COMPONENT)
transport.register_component(component_type)
self.logger().info(
f"Register the component {self.MODEL_COMPONENT} success")
f"Register the component {component_type} success")
self.transport = ModelTransport(transport=transport,
task_manager=self.task_manager,
component_type=self.MODEL_COMPONENT,
send_msg_timeout_ms=self.send_msg_timeout_ms,
pop_msg_timeout_ms=self.pop_msg_timeout_ms)

def init_task_manager(self):
self.task_manager = TaskManager(
logger=self.logger(),
thread_event_manager=self.thread_event_manager,
task_timeout_h=self.config_data['TASK_TIMEOUT_H']
)

def init_storage_client(self):
self.storage_client = storage_loader.load(
self.config_data, self.logger())
task_manager=task_manager,
component_type=component_type,
send_msg_timeout_ms=send_msg_timeout_ms,
pop_msg_timeout_ms=pop_msg_timeout_ms)

def logger(self, name=None):
if self.mock_logger is None:
Expand Down
2 changes: 1 addition & 1 deletion python/ppc_model/conf/application-sample.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,4 @@ transport_threadpool_size: 4
transport_node_id: "MODEL_WeBank_NODE"
transport_gateway_targets: "ipv4:127.0.0.1:40600,127.0.0.1:40601"
transport_host_ip: "127.0.0.1"
transport_listen_port: 6200
transport_listen_port: 6500
2 changes: 1 addition & 1 deletion python/ppc_model/conf/logging.conf
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ formatter=simpleFormatter
[handler_consoleHandler]
class=StreamHandler
args=(sys.stdout,)
level=ERROR
level=INFO
formatter=simpleFormatter

[formatters]
Expand Down
40 changes: 24 additions & 16 deletions python/ppc_model/network/wedpr_model_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@

class ModelTransport(ModelTransportApi):
def __init__(self, transport: Transport, task_manager: TaskManager,
component_type, send_msg_timeout_ms: int = 5000, pop_msg_timeout_ms: int = 60000):
component_type,
send_msg_timeout_ms: int = 5000,
pop_msg_timeout_ms: int = 60000):
self.transport = transport
# default send msg timeout
self.send_msg_timeout = send_msg_timeout_ms
Expand All @@ -20,7 +22,7 @@ def __init__(self, transport: Transport, task_manager: TaskManager,

@staticmethod
def get_topic(task_id: str, task_type: str, dst_agency: str):
return f"{dst_agency}_{task_id}{task_type}"
return f"{dst_agency}_{task_id}_{task_type}"

def push_by_component(self, task_id: str, task_type: str, dst_inst: str, data):
self.transport.push_by_component(topic=self.get_topic(task_id, task_type, dst_inst),
Expand Down Expand Up @@ -55,35 +57,41 @@ def select_node(self, route_type: RouteType, dst_agency: str, dst_component: str
return self.transport.select_node_by_route_policy(route_type=route_type,
dst_inst=dst_agency, dst_component=dst_component)

def stop(self):
self.transport.stop()


class ModelRouter(ModelRouterApi):
def __init__(self, logger, transport: ModelTransport, participant_id_list):
self.logger = logger
self.transport = transport
self.participant_id_list = participant_id_list
self.router_info = {}

def __init_routers__(self):
for participant in self.participant_id_list:
result = self.transport.select_node(route_type=RouteType.ROUTE_THROUGH_COMPONENT,
dst_agency=participant, dst_component=self.transport.get_component_type())
if result is None:
raise Exception(
f"No router can reach participant {participant}")
self.logger.info(
f"ModelRouter, select node {result} for participant {participant}")
self.router_info.update({participant: result})
self.__init_router__(participant)

def __init_router__(self, participant):
result = self.transport.select_node(route_type=RouteType.ROUTE_THROUGH_COMPONENT,
dst_agency=participant,
dst_component=self.transport.get_component_type())
self.logger.info(
f"__init_router__ for {participant}, result: {result}, component: {self.transport.get_component_type()}")
if result is None:
raise Exception(
f"No router can reach participant {participant}")
self.logger.info(
f"ModelRouter, select node {result} for participant {participant}, "
f"component: {self.transport.get_component_type()}")
self.router_info.update({participant: result})
return result

def __get_dstnode_by_participant(self, participant):
if participant in self.router_info.keys():
return self.router_info.get(participant)
return None
return self.__init_router__(participant)

def push(self, task_id: str, task_type: str, dst_agency: str, payload: bytes, seq: int = 0):
dst_node = self.__get_dstnode_by_participant(dst_agency)
if dst_node is None:
raise Exception(
f"send message to {dst_agency} failed for no router!")
self.transport.push_by_nodeid(
task_id=task_id, task_type=task_type, dst_node=dst_node, dst_inst=dst_agency, payload=payload, seq=seq)

Expand Down
5 changes: 4 additions & 1 deletion python/ppc_model/ppc_model_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,10 @@ def register_task_handler():
TransLogger(app, setup_console_handler=False), numthreads=2)

protocol = 'http'
message = f"Starting ppc model server at {protocol}://{app.config['HOST']}:{app.config['HTTP_PORT']}"
message = f"Starting ppc model server at {protocol}://{app.config['HOST']}:{app.config['HTTP_PORT']} successfully"
print(message)
components.logger().info(message)
server.start()
# stop the nodes
components.transport.stop()
print("Stop ppc model server successfully")
2 changes: 1 addition & 1 deletion python/ppc_model/secure_lgbm/vertical/booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def _send_byte_data(self, ctx, key_type, byte_data, partner_index):
partner_id = ctx.participant_id_list[partner_index]

self.ctx.model_router.push(
task_id=ctx.task_id, task_type=key_type, dst_agency=partner_id, data=byte_data)
task_id=ctx.task_id, task_type=key_type, dst_agency=partner_id, payload=byte_data)
self.logger.info(
f"task {ctx.task_id}: Sending {key_type} to {partner_id} finished, "
f"data_size: {len(byte_data) / 1024}KB, time_costs: {time.time() - start_time}s")
Expand Down
2 changes: 1 addition & 1 deletion python/ppc_model/secure_lr/vertical/booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def _send_byte_data(self, ctx, key_type, byte_data, partner_index):
start_time = time.time()
partner_id = ctx.participant_id_list[partner_index]
self.ctx.model_router.push(
task_id=ctx.task_id, task_type=key_type, dst_agency=partner_id, data=byte_data)
task_id=ctx.task_id, task_type=key_type, dst_agency=partner_id, payload=byte_data)
self.logger.info(
f"task {ctx.task_id}: Sending {key_type} to {partner_id} finished, "
f"data_size: {len(byte_data) / 1024}KB, time_costs: {time.time() - start_time}s")
Expand Down
Loading

0 comments on commit 2f3339c

Please sign in to comment.