diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 75586dc4..ca01df86 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -55,6 +55,7 @@ project(WeDPR-Component VERSION ${VERSION}) include(CompilerSettings) include(BuildInfoGenerator) +set(PROTO_OUTPUT_PATH ${CMAKE_CURRENT_BINARY_DIR}/generated/pb) include(IncludeDirectories) # the target settings @@ -67,17 +68,24 @@ set(JNI_SOURCE_PATH wedpr-component-sdk/bindings/java/src/main/c) set(SDK_SOURCE_LIST ppc-homo ppc-crypto-core wedpr-component-sdk ${JNI_SOURCE_PATH}) # Note: the udf depends on mysql, not enabled in the full node mode set(UDF_SOURCE_LIST ${SDK_SOURCE_LIST} ppc-udf) -set(ALL_SOURCE_LIST - ${SDK_SOURCE_LIST} ppc-crypto - libhelper libinitializer ppc-io ppc-protocol - ppc-gateway ppc-front ppc-tars-protocol - ppc-tools ppc-storage ppc-psi ppc-rpc - ppc-http ppc-mpc ppc-pir - ${CEM_SOURCE} ppc-main) set(CEM_SOURCE "") if(BUILD_CEM) set(CEM_SOURCE "ppc-cem") endif() +#set(ALL_SOURCE_LIST +# ${SDK_SOURCE_LIST} ppc-crypto +# libhelper libinitializer ppc-io ppc-protocol +# ppc-gateway ppc-front ppc-tars-protocol +# ppc-tools ppc-storage ppc-psi ppc-rpc +# ppc-http ppc-mpc ppc-pir +# ${CEM_SOURCE} ppc-main) + +set(ALL_SOURCE_LIST + ${SDK_SOURCE_LIST} ppc-crypto + libhelper ppc-io wedpr-protocol + ppc-gateway ppc-front + ppc-tools ppc-storage ppc-psi ppc-rpc + ppc-http ppc-mpc ppc-pir ${CEM_SOURCE}) if(BUILD_WEDPR_TOOLKIT) # fetch the python dependencies diff --git a/cpp/cmake/Dependencies.cmake b/cpp/cmake/Dependencies.cmake index ff1000b4..4a2183ca 100644 --- a/cpp/cmake/Dependencies.cmake +++ b/cpp/cmake/Dependencies.cmake @@ -31,6 +31,7 @@ if(BUILD_ALL) find_package(SEAL REQUIRED) find_package(Kuku REQUIRED) + find_package(gRPC REQUIRED) # APSI: Note: APSI depends on seal 4.0 and Kuku 2.1 include(ProjectAPSI) diff --git a/cpp/cmake/IncludeDirectories.cmake b/cpp/cmake/IncludeDirectories.cmake index d6be1b25..b2f9e8c8 100644 --- a/cpp/cmake/IncludeDirectories.cmake +++ b/cpp/cmake/IncludeDirectories.cmake @@ -1,10 +1,11 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}) include_directories(${CMAKE_CURRENT_BINARY_DIR}) include_directories(${CMAKE_BINARY_DIR}/generated/) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/wedpr-protocol) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ppc-front) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ppc-gateway) include_directories(${CMAKE_CURRENT_SOURCE_DIR}/wedpr-component-sdk) -include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ppc-tars-protocol) +include_directories(${PROTO_OUTPUT_PATH}) set(VCPKG_INCLUDE_PATH "${VCPKG_INSTALLED_DIR}/${VCPKG_TARGET_TRIPLET}/include") include_directories(${VCPKG_INCLUDE_PATH}) \ No newline at end of file diff --git a/cpp/cmake/TargetSettings.cmake b/cpp/cmake/TargetSettings.cmake index 7f29bb81..73882793 100644 --- a/cpp/cmake/TargetSettings.cmake +++ b/cpp/cmake/TargetSettings.cmake @@ -17,8 +17,15 @@ set(TOOLS_TARGET "ppc-tools") # ppc-protocol set(PROTOCOL_TARGET "ppc-protocol") -# ppc-tars-protocol -set(TARS_PROTOCOL_TARGET "ppc-protocol-tars") +# wedpr-protocol/tars +set(TARS_PROTOCOL_TARGET "wedpr-tars-protocol") + +# wedpr-protocol/protobuf +set(PB_PROTOCOL_TARGET "wedpr-pb-protocol") + +# wedpr-protocol/grpc-client +set(SERVICE_CLIENT_TARGET "service-client") +set(SERVICE_CLIENT_PB_TARGET "service-client-pb") # ppc-front SET(FRONT_TARGET "ppc-front") diff --git a/cpp/libinitializer/Initializer.cpp b/cpp/libinitializer/Initializer.cpp index cff5c1a3..99465a3a 100644 --- a/cpp/libinitializer/Initializer.cpp +++ b/cpp/libinitializer/Initializer.cpp @@ -25,9 +25,9 @@ #include "ppc-framework/protocol/Protocol.h" #include "ppc-pir/src/OtPIRFactory.h" #include "ppc-pir/src/OtPIRImpl.h" -#include "ppc-protocol/src/PPCMessage.h" #include "ppc-psi/src/bs-ecdh-psi/BsEcdhPSIFactory.h" #include "ppc-psi/src/cm2020-psi/CM2020PSIFactory.h" +#include "protocol/src/PPCMessage.h" #if 0 //TODO: optimize here #include "ppc-psi/src/ecdh-conn-psi/EcdhConnPSIFactory.h" diff --git a/cpp/libinitializer/ProtocolInitializer.h b/cpp/libinitializer/ProtocolInitializer.h index 4d9546d4..56200e5a 100644 --- a/cpp/libinitializer/ProtocolInitializer.h +++ b/cpp/libinitializer/ProtocolInitializer.h @@ -28,10 +28,10 @@ #include "ppc-framework/protocol/GlobalConfig.h" #include "ppc-framework/protocol/PPCMessageFace.h" #include "ppc-io/src/DataResourceLoaderImpl.h" -#include "ppc-protocol/src/PPCMessage.h" #include "ppc-storage/src/FileStorageFactoryImpl.h" #include "ppc-storage/src/SQLStorageFactoryImpl.h" #include "ppc-tools/src/config/PPCConfig.h" +#include "protocol/src/PPCMessage.h" #include #include diff --git a/cpp/ppc-framework/front/Channel.h b/cpp/ppc-framework/front/Channel.h index 784a06fa..c344578e 100644 --- a/cpp/ppc-framework/front/Channel.h +++ b/cpp/ppc-framework/front/Channel.h @@ -46,7 +46,7 @@ class Channel * @brief notice task info to gateway by front * @param _taskInfo the latest task information */ - virtual bcos::Error::Ptr notifyTaskInfo(protocol::GatewayTaskInfo::Ptr _taskInfo) = 0; + virtual bcos::Error::Ptr notifyTaskInfo(std::string const& taskID) = 0; /** * @brief: send message diff --git a/cpp/ppc-framework/front/FrontInterface.h b/cpp/ppc-framework/front/FrontInterface.h index 41e491d7..e7648725 100644 --- a/cpp/ppc-framework/front/FrontInterface.h +++ b/cpp/ppc-framework/front/FrontInterface.h @@ -44,20 +44,6 @@ class FrontInterface using Ptr = std::shared_ptr; FrontInterface() = default; virtual ~FrontInterface() {} - /** - * @brief: start/stop service - */ - virtual void start() = 0; - virtual void stop() = 0; - - /** - * @brief: receive message from gateway, call by gateway - * @param _message: received ppc message - * @return void - */ - virtual void onReceiveMessage( - front::PPCMessageFace::Ptr _message, ErrorCallbackFunc _callback) = 0; - /** * @brief: send message to other party by gateway @@ -78,11 +64,11 @@ class FrontInterface * @brief notice task info to gateway * @param _taskInfo the latest task information */ - virtual bcos::Error::Ptr notifyTaskInfo(protocol::GatewayTaskInfo::Ptr _taskInfo) = 0; + virtual bcos::Error::Ptr notifyTaskInfo(std::string const& taskID) = 0; // erase the task-info when task finished virtual bcos::Error::Ptr eraseTaskInfo(std::string const& _taskID) = 0; - + // get the agencyList from the gateway virtual void asyncGetAgencyList(GetAgencyListCallback _callback) = 0; diff --git a/cpp/ppc-framework/front/IFront.h b/cpp/ppc-framework/front/IFront.h index be850b65..5de205f3 100644 --- a/cpp/ppc-framework/front/IFront.h +++ b/cpp/ppc-framework/front/IFront.h @@ -19,32 +19,47 @@ */ #pragma once #include "FrontConfig.h" +#include "ppc-framework/protocol/INodeInfo.h" #include "ppc-framework/protocol/Message.h" #include "ppc-framework/protocol/RouteType.h" #include namespace ppc::front { -class IFront +class IFrontClient +{ +public: + using Ptr = std::shared_ptr; + IFrontClient() = default; + virtual ~IFrontClient() = default; + /** + * @brief: receive message from gateway, call by gateway + * @param _message: received ppc message + * @return void + */ + virtual void onReceiveMessage( + ppc::protocol::Message::Ptr const& _msg, ppc::protocol::ReceiveMsgFunc _callback) = 0; +}; +class IFront : public virtual IFrontClient { public: using Ptr = std::shared_ptr; IFront() = default; - virtual ~IFront() = default; + ~IFront() override = default; /** * @brief start the IFront * * @param front the IFront to start */ - virtual void start() const = 0; + virtual void start() = 0; /** * @brief stop the IFront * * @param front the IFront to stop */ - virtual void stop() const = 0; + virtual void stop() = 0; /** * @@ -53,38 +68,61 @@ class IFront * @param callback the callback called when receive specified topic */ virtual void registerTopicHandler( - std::string const& topic, ppc::protocol::MessageCallback callback) = 0; + std::string const& topic, ppc::protocol::MessageDispatcherCallback callback) = 0; /** * @brief async send message * * @param routeType the route type - * @param topic the topic - * @param dstInst the dst agency(must set when 'route by agency' and 'route by + * @param routeInfo the route info, include + * - topic the topic + * - dstInst the dst agency(must set when 'route by agency' and 'route by * component') - * @param dstNodeID the dst nodeID(must set when 'route by nodeID') - * @param componentType the componentType(must set when 'route by component') + * - dstNodeID the dst nodeID(must set when 'route by nodeID') + * - componentType the componentType(must set when 'route by component') * @param payload the payload to send * @param seq the message seq * @param timeout timeout * @param callback callback */ - virtual void asyncSendMessage(ppc::protocol::RouteType routeType, std::string const& topic, - std::string const& dstInst, bcos::bytes const& dstNodeID, std::string const& componentType, - bcos::bytes&& payload, int seq, long timeout, ppc::protocol::MessageCallback callback) = 0; + virtual void asyncSendMessage(ppc::protocol::RouteType routeType, + ppc::protocol::MessageOptionalHeader::Ptr const& routeInfo, bcos::bytes&& payload, int seq, + long timeout, ppc::protocol::ReceiveMsgFunc errorCallback, + ppc::protocol::MessageCallback callback) = 0; // the sync interface for async_send_message - virtual ppc::protocol::Message::Ptr push(ppc::protocol::RouteType routeType, std::string topic, - std::string dstInst, std::string dstNodeID, std::string const& componentType, - bcos::bytes&& payload, int seq, long timeout) = 0; + virtual bcos::Error::Ptr push(ppc::protocol::RouteType routeType, + ppc::protocol::MessageOptionalHeader::Ptr const& routeInfo, bcos::bytes&& payload, int seq, + long timeout) = 0; + + virtual ppc::protocol::Message::Ptr pop(std::string const& topic, long timeoutMs) = 0; + virtual ppc::protocol::Message::Ptr peek(std::string const& topic); + /** - * @brief: receive message from gateway, call by gateway - * @param _message: received ppc message - * @return void + * @brief register the nodeInfo to the gateway + * @param nodeInfo the nodeInfo */ - virtual void onReceiveMessage( - ppc::protocol::Message::Ptr const& _msg, ppc::protocol::ReceiveMsgFunc _callback) = 0; + virtual void registerNodeInfo(ppc::protocol::INodeInfo::Ptr const& nodeInfo) = 0; + + /** + * @brief unRegister the nodeInfo to the gateway + */ + virtual void unRegisterNodeInfo() = 0; + + /** + * @brief register the topic + * + * @param topic the topic to register + */ + virtual void registerTopic(std::string const& topic) = 0; + + /** + * @brief unRegister the topic + * + * @param topic the topic to unregister + */ + virtual void unRegisterTopic(std::string const& topic) = 0; }; class IFrontBuilder @@ -101,6 +139,6 @@ class IFrontBuilder * @return IFront::Ptr he created Front */ virtual IFront::Ptr build(ppc::front::FrontConfig::Ptr config) const = 0; - virtual IFront::Ptr buildClient(std::string endPoint) const = 0; + virtual IFrontClient::Ptr buildClient(std::string endPoint) const = 0; }; } // namespace ppc::front \ No newline at end of file diff --git a/cpp/ppc-framework/gateway/GatewayInterface.h b/cpp/ppc-framework/gateway/GatewayInterface.h deleted file mode 100644 index 52ceda88..00000000 --- a/cpp/ppc-framework/gateway/GatewayInterface.h +++ /dev/null @@ -1,81 +0,0 @@ -/** - * Copyright (C) 2022 WeDPR. - * SPDX-License-Identifier: Apache-2.0 - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * @brief interface for Gateway module - * @file GatewayInterface.h - * @author: shawnhe - * @date 2022-10-19 - */ -#pragma once -#include "ppc-framework/front/FrontInterface.h" -#include "ppc-framework/protocol/PPCMessageFace.h" -#include "ppc-framework/protocol/Task.h" -#include - -namespace ppc -{ -namespace gateway -{ -using ErrorCallbackFunc = std::function; - -/** - * @brief: A list of interfaces provided by the gateway which are called by the front service. - */ -class GatewayInterface -{ -public: - using Ptr = std::shared_ptr; - GatewayInterface() = default; - virtual ~GatewayInterface() {} - - /** - * @brief: start/stop service - */ - virtual void start() = 0; - virtual void stop() = 0; - -public: - /** - * @brief: send message to gateway - * @param _agencyID: agency ID of receiver - * @param _message: ppc message data - * @param _callback: callback - * @return void - */ - virtual void asyncSendMessage(const std::string& _agencyID, front::PPCMessageFace::Ptr _message, - ErrorCallbackFunc _callback) = 0; - - - /** - * @brief notice task info to gateway - * @param _taskInfo the latest task information - */ - virtual bcos::Error::Ptr notifyTaskInfo(protocol::GatewayTaskInfo::Ptr _taskInfo) = 0; - - // erase the task-info when task finished - virtual bcos::Error::Ptr eraseTaskInfo(std::string const& _taskID) = 0; - - // register the gateway info of other parties - virtual bcos::Error::Ptr registerGateway( - const std::vector& _gatewayList) = 0; - - virtual void registerFront(std::string const& _endPoint, front::FrontInterface::Ptr _front) = 0; - virtual void unregisterFront(std::string const&) {} - - virtual void asyncGetAgencyList(ppc::front::GetAgencyListCallback _callback) = 0; -}; - -} // namespace gateway -} // namespace ppc diff --git a/cpp/ppc-framework/gateway/IGateway.h b/cpp/ppc-framework/gateway/IGateway.h index 9af59dba..d65171c0 100644 --- a/cpp/ppc-framework/gateway/IGateway.h +++ b/cpp/ppc-framework/gateway/IGateway.h @@ -57,14 +57,16 @@ class IGateway * @param timeout timeout * @param callback callback */ - virtual void asyncSendMessage(ppc::protocol::RouteType routeType, std::string const& topic, - std::string const& dstInst, bcos::bytes const& dstNodeID, std::string const& componentType, - bcos::bytes&& payload, long timeout, ppc::protocol::ReceiveMsgFunc callback) = 0; + virtual void asyncSendMessage(ppc::protocol::RouteType routeType, + ppc::protocol::MessageOptionalHeader::Ptr const& routeInfo, bcos::bytes&& payload, + long timeout, ppc::protocol::ReceiveMsgFunc callback) = 0; - virtual void registerNodeInfo(ppc::protocol::INodeInfo::Ptr const& nodeInfo); - virtual void unRegisterNodeInfo(bcos::bytesConstRef nodeID); - virtual void registerTopic(bcos::bytesConstRef nodeID, std::string const& topic); - virtual void unRegisterTopic(bcos::bytesConstRef nodeID, std::string const& topic); + virtual void asyncSendbroadcastMessage(ppc::protocol::RouteType routeType, + ppc::protocol::MessageOptionalHeader::Ptr const& routeInfo, bcos::bytes&& payload) = 0; + virtual void registerNodeInfo(ppc::protocol::INodeInfo::Ptr const& nodeInfo) = 0; + virtual void unRegisterNodeInfo(bcos::bytesConstRef nodeID) = 0; + virtual void registerTopic(bcos::bytesConstRef nodeID, std::string const& topic) = 0; + virtual void unRegisterTopic(bcos::bytesConstRef nodeID, std::string const& topic) = 0; }; } // namespace ppc::gateway diff --git a/cpp/ppc-framework/protocol/INodeInfo.h b/cpp/ppc-framework/protocol/INodeInfo.h index 859d6ac6..2d4f8e19 100644 --- a/cpp/ppc-framework/protocol/INodeInfo.h +++ b/cpp/ppc-framework/protocol/INodeInfo.h @@ -18,9 +18,16 @@ * @date 2024-08-26 */ #pragma once -#include "ppc-framework/front/IFront.h" +#include #include +#include +#include +#include +namespace ppc::front +{ +class IFrontClient; +} namespace ppc::protocol { // the node information @@ -35,14 +42,14 @@ class INodeInfo virtual bcos::bytesConstRef nodeID() const = 0; // components - virtual void setComponents(std::vector const& components) = 0; + virtual void setComponents(std::set const& components) = 0; virtual std::set const& components() const = 0; virtual void encode(bcos::bytes& data) const = 0; virtual void decode(bcos::bytesConstRef data) = 0; - virtual void setFront(ppc::front::IFront::Ptr&& front) = 0; - virtual ppc::front::IFront::Ptr const& getFront() const = 0; + virtual void setFront(std::shared_ptr&& front) = 0; + virtual std::shared_ptr const& getFront() const = 0; virtual bool equal(INodeInfo::Ptr const& info) { @@ -62,4 +69,21 @@ class INodeInfoFactory protected: bcos::bytes m_nodeID; }; + +inline std::string printNodeInfo(INodeInfo::Ptr const& nodeInfo) +{ + if (!nodeInfo) + { + return "nullptr"; + } + std::ostringstream stringstream; + stringstream << LOG_KV("endPoint", nodeInfo->endPoint()); + std::string components = ""; + for (auto const& it : nodeInfo->components()) + { + components = components + it + ","; + } + stringstream << LOG_KV("components", components); + return stringstream.str(); +} } // namespace ppc::protocol \ No newline at end of file diff --git a/cpp/ppc-framework/protocol/Message.h b/cpp/ppc-framework/protocol/Message.h index b318c1c0..95f7225e 100644 --- a/cpp/ppc-framework/protocol/Message.h +++ b/cpp/ppc-framework/protocol/Message.h @@ -19,9 +19,11 @@ */ #pragma once #include "../Common.h" +#include "MessagePayload.h" #include "RouteType.h" #include #include +#include #include #include #include @@ -59,12 +61,17 @@ class MessageOptionalHeader virtual void setTopic(std::string&& topic) { m_topic = std::move(topic); } virtual void setTopic(std::string const& topic) { m_topic = topic; } + virtual std::string srcInst() const { return m_srcInst; } + virtual void setSrcInst(std::string const& srcInst) { m_srcInst = srcInst; } + protected: std::string m_topic; // the componentType std::string m_componentType; // the source nodeID that send the message bcos::bytes m_srcNode; + // the source agency + std::string m_srcInst; // the target nodeID that should receive the message bcos::bytes m_dstNode; // the target agency that need receive the message @@ -126,6 +133,7 @@ class MessageHeader virtual uint16_t routeType() const = 0; virtual void setRouteType(ppc::protocol::RouteType type) = 0; + virtual bool hasOptionalField() const = 0; protected: // the msg version, used to support compatibility @@ -181,9 +189,18 @@ class Message : virtual public bcos::boostssl::MessageFace m_payload = std::move(_payload); } + void setFrontMessage(MessagePayload::Ptr frontMessage) + { + m_frontMessage = std::move(frontMessage); + } + + MessagePayload::Ptr const& frontMessage() const { return m_frontMessage; } + + protected: MessageHeader::Ptr m_header; std::shared_ptr m_payload; + MessagePayload::Ptr m_frontMessage; }; class MessageHeaderBuilder @@ -195,6 +212,7 @@ class MessageHeaderBuilder virtual MessageHeader::Ptr build(bcos::bytesConstRef _data) = 0; virtual MessageHeader::Ptr build() = 0; + virtual MessageOptionalHeader::Ptr build(MessageOptionalHeader::Ptr const& optionalHeader) = 0; }; class MessageBuilder : public bcos::boostssl::MessageFaceFactory @@ -206,13 +224,43 @@ class MessageBuilder : public bcos::boostssl::MessageFaceFactory virtual Message::Ptr build() = 0; virtual Message::Ptr build(bcos::bytesConstRef buffer) = 0; - virtual Message::Ptr build(ppc::protocol::RouteType routeType, std::string const& topic, - std::string const& dstInst, bcos::bytes const& dstNodeID, std::string const& componentType, - bcos::bytes&& payload) = 0; + virtual Message::Ptr build(ppc::protocol::RouteType routeType, + ppc::protocol::MessageOptionalHeader::Ptr const& routeInfo, bcos::bytes&& payload) = 0; }; +class MessageOptionalHeaderBuilder +{ +public: + using Ptr = std::shared_ptr; + MessageOptionalHeaderBuilder() = default; + virtual ~MessageOptionalHeaderBuilder() = default; + + virtual MessageOptionalHeader::Ptr build(MessageOptionalHeader::Ptr const& optionalHeader) = 0; + virtual MessageOptionalHeader::Ptr build() = 0; +}; + +inline std::string printOptionalField(MessageOptionalHeader::Ptr optionalHeader) +{ + if (!optionalHeader) + { + return "nullptr"; + } + std::ostringstream stringstream; + stringstream << LOG_KV("topic", optionalHeader->topic()) + << LOG_KV("componentType", optionalHeader->componentType()) + << LOG_KV("srcNode", *(bcos::toHexString(optionalHeader->srcNode()))) + << LOG_KV("dstNode", *(bcos::toHexString(optionalHeader->dstNode()))) + << LOG_KV("dstInst", optionalHeader->dstInst()); + + return stringstream.str(); +} + inline std::string printMessage(Message::Ptr const& _msg) { + if (!_msg) + { + return "nullptr"; + } std::ostringstream stringstream; stringstream << LOG_KV("from", _msg->header()->srcP2PNodeIDView()) << LOG_KV("to", _msg->header()->dstP2PNodeIDView()) @@ -221,11 +269,19 @@ inline std::string printMessage(Message::Ptr const& _msg) << LOG_KV("traceID", _msg->header()->traceID()) << LOG_KV("packetType", _msg->header()->packetType()) << LOG_KV("length", _msg->length()); + if (_msg->header()->hasOptionalField()) + { + stringstream << printOptionalField(_msg->header()->optionalField()); + } return stringstream.str(); } inline std::string printWsMessage(bcos::boostssl::MessageFace::Ptr const& _msg) { + if (!_msg) + { + return "nullptr"; + } std::ostringstream stringstream; stringstream << LOG_KV("rsp", _msg->isRespPacket()) << LOG_KV("traceID", _msg->seq()) << LOG_KV("packetType", _msg->packetType()) << LOG_KV("length", _msg->length()); @@ -233,9 +289,9 @@ inline std::string printWsMessage(bcos::boostssl::MessageFace::Ptr const& _msg) } // function to send response -using SendResponseFunction = std::function; +using SendResponseFunction = std::function&& payload)>; using ReceiveMsgFunc = std::function; using MessageCallback = std::function; - +using MessageDispatcherCallback = std::function; } // namespace ppc::protocol \ No newline at end of file diff --git a/cpp/ppc-framework/protocol/MessagePayload.h b/cpp/ppc-framework/protocol/MessagePayload.h index 14304a16..f5fbf9e3 100644 --- a/cpp/ppc-framework/protocol/MessagePayload.h +++ b/cpp/ppc-framework/protocol/MessagePayload.h @@ -23,6 +23,10 @@ namespace ppc::protocol { +enum class FrontMsgExtFlag : uint16_t +{ + Response = 0x1 +}; class MessagePayload { public: @@ -46,12 +50,26 @@ class MessagePayload // the length virtual int64_t length() const { return m_length; } + // the traceID + virtual std::string const& traceID() const { return m_traceID; } + virtual void setTraceID(std::string const& traceID) { m_traceID = traceID; } + + virtual uint16_t ext() const { return m_ext; } + virtual void setExt(uint16_t ext) { m_ext = ext; } + + virtual void setRespPacket() { m_ext |= (uint16_t)FrontMsgExtFlag::Response; } + + virtual bool isRespPacket() { return m_ext &= (uint16_t)FrontMsgExtFlag::Response; } + protected: // the front payload version, used to support compatibility uint8_t m_version; // the seq uint16_t m_seq; + // the traceID + std::string m_traceID; bcos::bytes m_data; + uint16_t m_ext; int64_t mutable m_length; }; diff --git a/cpp/ppc-framework/protocol/PPCMessageFace.h b/cpp/ppc-framework/protocol/PPCMessageFace.h index 275a473d..f6de1791 100644 --- a/cpp/ppc-framework/protocol/PPCMessageFace.h +++ b/cpp/ppc-framework/protocol/PPCMessageFace.h @@ -21,6 +21,7 @@ #pragma once #include "Protocol.h" +#include "ppc-framework/protocol/Message.h" #include #include #include @@ -58,8 +59,6 @@ class PPCMessageFace virtual void setTaskID(std::string const&) = 0; virtual std::string const& sender() const = 0; virtual void setSender(std::string const&) = 0; - virtual uint16_t ext() const = 0; - virtual void setExt(uint16_t) = 0; virtual std::shared_ptr data() const = 0; virtual void setData(std::shared_ptr) = 0; virtual std::map header() = 0; @@ -87,6 +86,7 @@ class PPCMessageFaceFactory public: virtual ~PPCMessageFaceFactory() {} virtual PPCMessageFace::Ptr buildPPCMessage() = 0; + virtual PPCMessageFace::Ptr buildPPCMessage(ppc::protocol::Message::Ptr msg) = 0; virtual PPCMessageFace::Ptr buildPPCMessage(bcos::bytesConstRef _data) = 0; virtual PPCMessageFace::Ptr buildPPCMessage(bcos::bytesPointer _buffer) = 0; virtual PPCMessageFace::Ptr buildPPCMessage(uint8_t _taskType, uint8_t _algorithmType, diff --git a/cpp/ppc-front/ppc-front/CallbackManager.cpp b/cpp/ppc-front/ppc-front/CallbackManager.cpp new file mode 100644 index 00000000..51031f5b --- /dev/null +++ b/cpp/ppc-front/ppc-front/CallbackManager.cpp @@ -0,0 +1,158 @@ +/** + * Copyright (C) 2023 WeDPR. + * SPDX-License-Identifier: Apache-2.0 + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * @file CallbackManager.cpp + * @author: yujiechen + * @date 2024-08-30 + */ +#include "CallbackManager.h" +#include "Common.h" +#include "ppc-framework/protocol/Protocol.h" +#include + +using namespace bcos; +using namespace ppc; +using namespace ppc::front; +using namespace ppc::protocol; + +void CallbackManager::addCallback( + std::string const& traceID, long timeout, ppc::protocol::MessageCallback msgCallback) +{ + if (!msgCallback) + { + return; + } + auto callback = std::make_shared(msgCallback); + // set the timeout handler + if (timeout > 0) + { + callback->timeoutHandler = std::make_shared( + *m_ioService, boost::posix_time::milliseconds(timeout)); + auto self = weak_from_this(); + callback->timeoutHandler->async_wait([self, traceID](const boost::system::error_code& e) { + auto front = self.lock(); + if (front) + { + front->onMessageTimeout(e, traceID); + } + }); + } + // insert the callback into m_traceID2Callback + WriteGuard l(x_traceID2Callback); + m_traceID2Callback.insert(std::make_pair(traceID, callback)); +} + +Callback::Ptr CallbackManager::pop(std::string const& traceID) +{ + bcos::UpgradableGuard l(x_traceID2Callback); + auto it = m_traceID2Callback.find(traceID); + if (it == m_traceID2Callback.end()) + { + return nullptr; + } + auto callback = it->second; + m_traceID2Callback.erase(it); + return callback; +} + +void CallbackManager::onMessageTimeout( + const boost::system::error_code& e, std::string const& traceID) +{ + // the timer has been canceled + if (e) + { + return; + } + try + { + auto callback = pop(traceID); + if (!callback) + { + return; + } + if (callback->timeoutHandler) + { + callback->timeoutHandler->cancel(); + } + auto errorMsg = "send message with traceID=" + traceID + " timeout"; + auto error = std::make_shared(PPCRetCode::TIMEOUT, errorMsg); + m_threadPool->enqueue( + [callback, error]() { callback->msgCallback(error, nullptr, nullptr); }); + FRONT_LOG(WARNING) << LOG_BADGE("onMessageTimeout") << LOG_KV("traceID", traceID); + } + catch (std::exception& e) + { + FRONT_LOG(ERROR) << "onMessageTimeout" << LOG_KV("traceID", traceID) + << LOG_KV("error", boost::diagnostic_information(e)); + } +} + + +void CallbackManager::handleCallback(bcos::Error::Ptr const& error, std::string const& traceID, + Message::Ptr message, SendResponseFunction resFunc) +{ + auto callback = pop(traceID); + if (!callback) + { + return; + } + if (callback->timeoutHandler) + { + callback->timeoutHandler->cancel(); + } + if (!message) + { + return; + } + m_threadPool->enqueue( + [error, callback, message, resFunc] { callback->msgCallback(error, message, resFunc); }); +} + + +void CallbackManager::registerTopicHandler( + std::string const& topic, ppc::protocol::MessageDispatcherCallback callback) +{ + bcos::WriteGuard l(x_topicHandlers); + m_topicHandlers.insert(std::make_pair(topic, callback)); +} + +void CallbackManager::onReceiveMessage(std::string const& topic, Message::Ptr msg) +{ + MessageDispatcherCallback callback = nullptr; + { + bcos::ReadGuard l(x_topicHandlers); + auto it = m_topicHandlers.find(topic); + if (it == m_topicHandlers.end()) + { + FRONT_LOG(DEBUG) << LOG_DESC( + "onReceiveMessage: not find the handler, put into the buffer") + << LOG_KV("topic", topic); + addMsgCache(topic, msg); + return; + } + callback = it->second; + } + m_threadPool->enqueue([callback, msg]() { + try + { + callback(std::move(msg)); + } + catch (Exception e) + { + FRONT_LOG(WARNING) << LOG_DESC("onReceiveMessage: dispatcher exception") + << LOG_KV("error", boost::diagnostic_information(e)); + } + }); +} \ No newline at end of file diff --git a/cpp/ppc-front/ppc-front/CallbackManager.h b/cpp/ppc-front/ppc-front/CallbackManager.h new file mode 100644 index 00000000..95894d52 --- /dev/null +++ b/cpp/ppc-front/ppc-front/CallbackManager.h @@ -0,0 +1,113 @@ +/** + * Copyright (C) 2023 WeDPR. + * SPDX-License-Identifier: Apache-2.0 + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * @file CallbackManager.h + * @author: yujiechen + * @date 2024-08-30 + */ +#pragma once +#include "ppc-framework/protocol/Message.h" +#include +#include +#include +#include +#include +#include +#define TBB_PREVIEW_CONCURRENT_ORDERED_CONTAINERS 1 +#include +#include + +namespace ppc::front +{ +class Callback +{ +public: + using Ptr = std::shared_ptr; + Callback(ppc::protocol::MessageCallback callback) : msgCallback(std::move(callback)) {} + + ppc::protocol::MessageCallback msgCallback; + std::shared_ptr timeoutHandler; +}; +class CallbackManager : public std::enable_shared_from_this +{ +public: + using MsgQueueType = bcos::ConcurrentQueue; + using Ptr = std::shared_ptr; + CallbackManager( + bcos::ThreadPool::Ptr threadPool, std::shared_ptr ioService) + : m_threadPool(std::move(threadPool)), m_ioService(std::move(ioService)) + {} + virtual ~CallbackManager() = default; + + virtual void addCallback( + std::string const& traceID, long timeout, ppc::protocol::MessageCallback msgCallback); + + virtual Callback::Ptr pop(std::string const& traceID); + + virtual void handleCallback(bcos::Error::Ptr const& error, std::string const& traceID, + ppc::protocol::Message::Ptr message, ppc::protocol::SendResponseFunction resFunc); + + virtual void onReceiveMessage(std::string const& topic, ppc::protocol::Message::Ptr msg); + + virtual void registerTopicHandler( + std::string const& topic, ppc::protocol::MessageDispatcherCallback callback); + + virtual ppc::protocol::Message::Ptr pop(std::string const& topic, int timeoutMs) + { + auto it = m_msgCache.find(topic); + if (it == m_msgCache.end()) + { + return nullptr; + } + auto msgQueue = it->second; + if (msgQueue->empty()) + { + return nullptr; + } + auto result = msgQueue->tryPop(timeoutMs); + return result.second; + } + +private: + void onMessageTimeout(const boost::system::error_code& e, std::string const& traceID); + void addMsgCache(std::string const& topic, ppc::protocol::Message::Ptr msg) + { + auto it = m_msgCache.find(topic); + if (it == m_msgCache.end()) + { + m_msgCache.insert(std::make_pair(topic, std::make_shared())); + } + auto msgQueue = m_msgCache[topic]; + // push + msgQueue->push(std::move(msg)); + } + +private: + bcos::ThreadPool::Ptr m_threadPool; + std::shared_ptr m_ioService; + // traceID => callback + std::unordered_map m_traceID2Callback; + mutable bcos::SharedMutex x_traceID2Callback; + + // topic => messageDispatcherCallback + std::map m_topicHandlers; + mutable bcos::SharedMutex x_topicHandlers; + + // the messageCache for the message with no topic handler defined + uint64_t m_maxMsgCacheSize = 10000; + // TODO: check the queueSize + tbb::concurrent_unordered_map> m_msgCache; +}; +} // namespace ppc::front \ No newline at end of file diff --git a/cpp/ppc-front/ppc-front/Front.cpp b/cpp/ppc-front/ppc-front/Front.cpp index 4fa711c0..23b379e0 100644 --- a/cpp/ppc-front/ppc-front/Front.cpp +++ b/cpp/ppc-front/ppc-front/Front.cpp @@ -17,349 +17,76 @@ * @author: shawnhe * @date 2022-10-20 */ - #include "Front.h" -#include -#include -#include -#include +using namespace ppc; using namespace bcos; -using namespace ppc::front; using namespace ppc::protocol; - -void Front::start() -{ - if (m_running) - { - FRONT_LOG(INFO) << LOG_DESC("Front has already been started"); - return; - } - m_running = true; - FRONT_LOG(INFO) << LOG_DESC("start the Front"); - m_thread = std::make_shared([&] { - bcos::pthread_setThreadName("front_io_service"); - while (m_running) - { - try - { - m_ioService->run(); - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - if (m_running && m_ioService->stopped()) - { - m_ioService->restart(); - } - } - catch (std::exception& e) - { - FRONT_LOG(WARNING) - << LOG_DESC("Exception in Front Thread:") << boost::diagnostic_information(e); - } - } - FRONT_LOG(INFO) << "Front exit"; - }); -} - - -/** - * @brief: receive message from gateway, call by gateway - * @param _message: received ppc message - * @return void - */ -void Front::onReceiveMessage(PPCMessageFace::Ptr _message, ErrorCallbackFunc _callback) -{ - if (_callback) - { - if (m_threadPool) - { - m_threadPool->enqueue([_callback] { _callback(nullptr); }); - } - else - { - _callback(nullptr); - } - } - - FRONT_LOG(TRACE) << LOG_BADGE("onReceiveMessage") << printPPCMsg(_message); - // response package - if (_message->response()) - { - handleCallback(nullptr, _message->uuid(), _message, _message->sender()); - return; - } - - uint16_t type = ((uint16_t)_message->taskType() << 8) | _message->algorithmType(); - auto it = m_handlers.find(type); - if (it != m_handlers.end()) - { - if (m_threadPool) - { - auto handler = it->second; - m_threadPool->enqueue([handler, _message] { handler(_message); }); - } - else - { - it->second(std::move(_message)); - } - } - else - { - FRONT_LOG(WARNING) << LOG_BADGE("onReceiveMessage") << LOG_DESC("message handler not found") - << LOG_KV("taskType", unsigned(_message->taskType())) - << LOG_KV("algorithmType", unsigned(_message->algorithmType())) - << LOG_KV("messageType", unsigned(_message->messageType())) - << LOG_KV("seq", _message->seq()) << LOG_KV("taskID", _message->taskID()) - << LOG_KV("sender", _message->sender()); - } -} - - +using namespace ppc::front; /** * @brief: send message to other party by gateway * @param _agencyID: agency ID of receiver * @param _message: ppc message data + * @param _callback: callback called when the message sent successfully + * @param _respCallback: callback called when receive the response from peer * @return void */ -void Front::asyncSendMessage(const std::string& _agencyID, PPCMessageFace::Ptr _message, +void Front::asyncSendMessage(const std::string& _agencyID, front::PPCMessageFace::Ptr _message, uint32_t _timeout, ErrorCallbackFunc _callback, CallbackFunc _respCallback) { - // generate uuid for the message - static thread_local auto uuid_gen = boost::uuids::basic_random_generator(); - std::string uuid = boost::uuids::to_string(uuid_gen()); - _message->setUuid(uuid); - // call gateway interface to send the message - _message->setSender(m_selfAgencyId); - auto taskID = _message->taskID(); - FRONT_LOG(TRACE) << LOG_BADGE("asyncSendMessage") - << LOG_KV("taskType", unsigned(_message->taskType())) - << LOG_KV("algorithmType", unsigned(_message->algorithmType())) - << LOG_KV("messageType", unsigned(_message->messageType())) - << LOG_KV("seq", _message->seq()) << LOG_KV("taskID", _message->taskID()) - << LOG_KV("receiver", _agencyID) << LOG_KV("uud", _message->uuid()); - // timeout logic - if (_respCallback) - { - auto callback = std::make_shared(_respCallback); - addCallback(uuid, callback); - if (_timeout > 0) - { - auto timeoutHandler = std::make_shared( - *m_ioService, boost::posix_time::milliseconds(_timeout)); - callback->timeoutHandler = timeoutHandler; - auto self = weak_from_this(); - timeoutHandler->async_wait( - [self, _agencyID, taskID, uuid](const boost::system::error_code& e) { - auto front = self.lock(); - if (front) - { - front->onMessageTimeout(e, _agencyID, taskID, uuid); - } - }); - } - } + auto routeInfo = m_front->routerInfoBuilder()->build(); + routeInfo->setDstInst(_agencyID); + routeInfo->setTopic(_message->taskID()); + bcos::bytes data; + _message->encode(data); auto self = weak_from_this(); - sendMessageToGateway(_agencyID, std::move(_message), uuid, false, - [self, uuid, _agencyID, taskID, _callback](const bcos::Error::Ptr& _error) { + m_front->asyncSendMessage(RouteType::ROUTE_THROUGH_AGENCY, routeInfo, std::move(data), + _message->seq(), _timeout, _callback, + [self, _agencyID, _respCallback]( + Error::Ptr error, Message::Ptr msg, SendResponseFunction resFunc) { auto front = self.lock(); if (!front) { return; } - // send message to gateway error, try to handleCallback - if (_error && (_error->errorCode() != 0)) - { - front->handleCallback(_error, uuid, nullptr, _agencyID); - } - if (_callback) - { - _callback(_error); - } - }); -} - -void Front::handleCallback(bcos::Error::Ptr const& _error, std::string const& _uuid, - PPCMessageFace::Ptr _message, std::string const& _agencyID) -{ - auto callback = getAndEraseCallback(_uuid); - if (!callback) - { - return; - } - // cancel the timer - if (callback->timeoutHandler) - { - callback->timeoutHandler->cancel(); - } - if (!_message) - { - return; - } - auto self = weak_from_this(); - auto respFunc = [self, _agencyID, _uuid](PPCMessageFace::Ptr _resp) { - auto front = self.lock(); - if (!front) - { - return; - } - front->sendMessageToGateway(front->m_selfAgencyId, std::move(_resp), _uuid, true, - [_agencyID](const bcos::Error::Ptr& _error) { - if (!_error) + auto responseCallback = [resFunc](PPCMessageFace::Ptr msg) { + if (!msg) { return; } - FRONT_LOG(WARNING) - << LOG_DESC("asyncSendResponse message error") << LOG_KV("agency", _agencyID) - << LOG_KV("code", _error->errorCode()) << LOG_KV("msg", _error->errorMessage()); - }); - }; - if (m_threadPool) - { - m_threadPool->enqueue([_error, callback, _message, _agencyID, respFunc] { - callback->callback(_error, _agencyID, _message, respFunc); + std::shared_ptr payload = std::make_shared(); + msg->encode(*payload); + resFunc(std::move(payload)); + }; + if (msg == nullptr) + { + _respCallback(error, _agencyID, nullptr, responseCallback); + } + // get the agencyID + _respCallback(error, msg->header()->optionalField()->srcInst(), + front->m_messageFactory->buildPPCMessage(msg), responseCallback); }); - } - else - { - callback->callback(_error, _agencyID, std::move(_message), std::move(respFunc)); - } -} - -void Front::sendMessageToGateway(std::string const& _agencyID, PPCMessageFace::Ptr _message, - std::string const& _uuid, bool _response, ErrorCallbackFunc _callback) -{ - _message->setSender(m_selfAgencyId); - _message->setUuid(_uuid); - if (_response) - { - _message->setResponse(); - } - m_gatewayInterface->asyncSendMessage(_agencyID, std::move(_message), std::move(_callback)); } -// send response +// send response when receiving message from given agencyID void Front::asyncSendResponse(const std::string& _agencyID, std::string const& _uuid, front::PPCMessageFace::Ptr _message, ErrorCallbackFunc _callback) -{ - FRONT_LOG(TRACE) << LOG_DESC("asyncSendResponse") << printPPCMsg(_message); - sendMessageToGateway(_agencyID, std::move(_message), _uuid, true, std::move(_callback)); -} - -void Front::onMessageTimeout(const boost::system::error_code& e, std::string const& _agencyID, - std::string const& _taskID, std::string const& _uuid) -{ - // the timer has been canceled - if (e) - { - return; - } - - try - { - auto callback = getAndEraseCallback(_uuid); - if (!callback) - { - return; - } - if (callback->timeoutHandler) - { - callback->timeoutHandler->cancel(); - } - auto errorMsg = "send message with uuid=" + _uuid + ", agency = " + _agencyID + - ", task = " + _taskID + " timeout"; - auto error = std::make_shared(PPCRetCode::TIMEOUT, errorMsg); - if (m_threadPool) - { - m_threadPool->enqueue([callback, _agencyID, error]() { - callback->callback(error, _agencyID, nullptr, nullptr); - }); - } - else - { - callback->callback(std::move(error), _agencyID, nullptr, nullptr); - } - FRONT_LOG(WARNING) << LOG_BADGE("onMessageTimeout") << LOG_KV("uuid", _uuid) - << LOG_KV("agency", _agencyID) << LOG_KV("task", _taskID); - } - catch (std::exception& e) - { - FRONT_LOG(ERROR) << "onMessageTimeout" << LOG_KV("uuid", _uuid) - << LOG_KV("error", boost::diagnostic_information(e)); - } -} - +{} /** * @brief notice task info to gateway * @param _taskInfo the latest task information */ -bcos::Error::Ptr Front::notifyTaskInfo(GatewayTaskInfo::Ptr _taskInfo) +bcos::Error::Ptr Front::notifyTaskInfo(std::string const& taskID) { - auto startT = bcos::utcSteadyTime(); - if (_taskInfo->serviceEndpoint.empty()) - { - _taskInfo->serviceEndpoint = m_selfEndPoint; - } - auto ret = m_gatewayInterface->notifyTaskInfo(_taskInfo); - FRONT_LOG(INFO) << LOG_BADGE("notifyTaskInfo") << LOG_KV("taskID", _taskInfo->taskID) - << LOG_KV("serviceEndpoint", _taskInfo->serviceEndpoint) - << LOG_KV("timecost", bcos::utcSteadyTime() - startT); - return ret; + m_front->registerTopic(taskID); } // erase the task-info when task finished bcos::Error::Ptr Front::eraseTaskInfo(std::string const& _taskID) { - auto startT = bcos::utcSteadyTime(); - auto ret = m_gatewayInterface->eraseTaskInfo(_taskID); - FRONT_LOG(INFO) << LOG_BADGE("eraseTaskInfo") << LOG_KV("taskID", _taskID) - << LOG_KV("timecost", bcos::utcSteadyTime() - startT); - return ret; + m_front->unRegisterTopic(_taskID); } // get the agencyList from the gateway -void Front::asyncGetAgencyList(ppc::front::GetAgencyListCallback _callback) -{ - FRONT_LOG(TRACE) << LOG_BADGE("asyncGetAgencyList"); - if (!m_gatewayInterface) - { - std::vector emptyAgencies; - _callback(std::make_shared( - -1, "asyncGetAgencyList failed for the gateway not been inited into front!"), - std::move(emptyAgencies)); - return; - } - m_gatewayInterface->asyncGetAgencyList(std::move(_callback)); -} - -void Front::stop() -{ - if (!m_running) - { - FRONT_LOG(INFO) << LOG_DESC("Front has already been stopped"); - return; - } - m_running = false; - if (m_ioService) - { - m_ioService->stop(); - } - if (m_thread) - { - // stop the thread - if (m_thread->get_id() != std::this_thread::get_id()) - { - m_thread->join(); - } - else - { - m_thread->detach(); - } - } -} - - -Front::Ptr FrontFactory::buildFront(std::shared_ptr _ioService) -{ - FRONT_LOG(INFO) << LOG_BADGE("buildFront") << LOG_KV("agencyID", m_selfAgencyId); - return std::make_shared(std::move(_ioService), m_selfAgencyId, m_threadPool); -} \ No newline at end of file +void Front::asyncGetAgencyList(GetAgencyListCallback _callback) {} diff --git a/cpp/ppc-front/ppc-front/Front.h b/cpp/ppc-front/ppc-front/Front.h index 0796b195..6394df73 100644 --- a/cpp/ppc-front/ppc-front/Front.h +++ b/cpp/ppc-front/ppc-front/Front.h @@ -19,73 +19,31 @@ */ #pragma once - -#include "Common.h" -#include "FrontService.h" +#include "FrontImpl.h" #include "ppc-framework/front/FrontInterface.h" -#include "ppc-framework/gateway/GatewayInterface.h" -#include "ppc-tars-protocol/ppc-tars-protocol/Common.h" -#include "ppc-tars-protocol/ppc-tars-protocol/client/FrontServiceClient.h" -#include -#include - - -#include +#include "ppc-framework/protocol/PPCMessageFace.h" -namespace ppc +namespace ppc::front { -namespace front -{ -struct Callback -{ - using Ptr = std::shared_ptr; - Callback(CallbackFunc _callback) : callback(std::move(_callback)) {} - std::shared_ptr timeoutHandler; - CallbackFunc callback; -}; -class Front : public front::FrontInterface, public std::enable_shared_from_this +class Front : public FrontInterface, public std::enable_shared_from_this { public: using Ptr = std::shared_ptr; - - Front(std::shared_ptr _ioService, std::string const& _selfAgencyId, - bcos::ThreadPool::Ptr _threadPool, std::string const& _selfEndPoint = "localhost") - : m_ioService(std::move(_ioService)), - m_selfAgencyId(_selfAgencyId), - m_threadPool(std::move(_threadPool)) - { - m_selfEndPoint = _selfEndPoint; - } - - Front(const Front&) = delete; - Front(Front&&) = delete; - - Front& operator=(const Front&) = delete; - Front& operator=(Front&&) = delete; - - virtual ~Front() override = default; - - void start() override; - void stop() override; - - /** - * @brief: receive message from gateway, call by gateway - * @param _message: received ppc message - * @return void - */ - void onReceiveMessage( - front::PPCMessageFace::Ptr _message, ErrorCallbackFunc _callback) override; - + Front(FrontImpl::Ptr front) : m_front(std::move(front)) {} + ~Front() override {} /** * @brief: send message to other party by gateway * @param _agencyID: agency ID of receiver * @param _message: ppc message data + * @param _callback: callback called when the message sent successfully + * @param _respCallback: callback called when receive the response from peer * @return void */ - void asyncSendMessage(const std::string& _agencyID, PPCMessageFace::Ptr _message, + void asyncSendMessage(const std::string& _agencyID, front::PPCMessageFace::Ptr _message, uint32_t _timeout, ErrorCallbackFunc _callback, CallbackFunc _respCallback) override; + // send response when receiving message from given agencyID void asyncSendResponse(const std::string& _agencyID, std::string const& _uuid, front::PPCMessageFace::Ptr _message, ErrorCallbackFunc _callback) override; @@ -93,98 +51,38 @@ class Front : public front::FrontInterface, public std::enable_shared_from_this< * @brief notice task info to gateway * @param _taskInfo the latest task information */ - bcos::Error::Ptr notifyTaskInfo(protocol::GatewayTaskInfo::Ptr _taskInfo) override; + bcos::Error::Ptr notifyTaskInfo(std::string const& taskID) override; + // erase the task-info when task finished bcos::Error::Ptr eraseTaskInfo(std::string const& _taskID) override; // get the agencyList from the gateway - void asyncGetAgencyList(ppc::front::GetAgencyListCallback _callback) override; + void asyncGetAgencyList(GetAgencyListCallback _callback) override; // register message handler for algorithm void registerMessageHandler(uint8_t _taskType, uint8_t _algorithmType, std::function _handler) { uint16_t type = ((uint16_t)_taskType << 8) | _algorithmType; - m_handlers[type] = std::move(_handler); - } - - ppc::gateway::GatewayInterface::Ptr gatewayInterface() { return m_gatewayInterface; } - void setGatewayInterface(ppc::gateway::GatewayInterface::Ptr _gatewayInterface) - { - m_gatewayInterface = std::move(_gatewayInterface); + auto self = weak_from_this(); + m_front->registerTopicHandler( + std::to_string(type), [self, _handler](ppc::protocol::Message::Ptr msg) { + auto front = self.lock(); + if (!front) + { + return; + } + if (msg == nullptr) + { + _handler(nullptr); + return; + } + _handler(front->m_messageFactory->buildPPCMessage(msg)); + }); } - std::shared_ptr threadPool() const { return m_threadPool; } - const std::string& selfAgencyId() const { return m_selfAgencyId; } - // Note: the selfEndPoint must be setted before start the front - virtual void setSelfEndPoint(std::string const& _selfEndPoint) - { - FRONT_LOG(INFO) << LOG_DESC("setSelfEndPoint: ") << _selfEndPoint; - m_selfEndPoint = _selfEndPoint; - } - -private: - void addCallback(std::string const& _uuid, Callback::Ptr _callback) - { - bcos::WriteGuard l(x_uuidToCallback); - m_uuidToCallback[_uuid] = std::move(_callback); - } - - Callback::Ptr getAndEraseCallback(std::string const& _uuid) - { - bcos::UpgradableGuard l(x_uuidToCallback); - auto it = m_uuidToCallback.find(_uuid); - if (it != m_uuidToCallback.end()) - { - auto callback = it->second; - bcos::UpgradeGuard ul(l); - m_uuidToCallback.erase(it); - return callback; - } - return nullptr; - } - - void onMessageTimeout(const boost::system::error_code& e, std::string const& _agencyID, - std::string const& _taskID, std::string const& _uuid); - void handleCallback(bcos::Error::Ptr const& _error, std::string const& _uuid, - PPCMessageFace::Ptr _message, std::string const& _agencyID); - void sendMessageToGateway(std::string const& _agencyID, PPCMessageFace::Ptr _msg, - std::string const& _uuid, bool _response, ErrorCallbackFunc _callback); - private: - std::shared_ptr m_ioService; - std::string m_selfAgencyId; - bcos::ThreadPool::Ptr m_threadPool; - - // gatewayInterface - ppc::gateway::GatewayInterface::Ptr m_gatewayInterface; - std::unordered_map> m_handlers; - - // uuid->callback - std::unordered_map m_uuidToCallback; - bcos::SharedMutex x_uuidToCallback; - - bool m_running = false; - // the thread to run ioservice - std::shared_ptr m_thread; + FrontImpl::Ptr m_front; + ppc::front::PPCMessageFaceFactory::Ptr m_messageFactory; }; - -class FrontFactory -{ -public: - using Ptr = std::shared_ptr; - FrontFactory(std::string _selfAgencyId, std::shared_ptr _threadPool) - : m_selfAgencyId(std::move(_selfAgencyId)), m_threadPool(std::move(_threadPool)) - {} - -public: - Front::Ptr buildFront(std::shared_ptr _ioService); - -private: - std::string m_selfAgencyId; - // thread pool - std::shared_ptr m_threadPool; -}; - -} // namespace front -} // namespace ppc +} // namespace ppc::front \ No newline at end of file diff --git a/cpp/ppc-front/ppc-front/FrontImpl.cpp b/cpp/ppc-front/ppc-front/FrontImpl.cpp new file mode 100644 index 00000000..467e4bf6 --- /dev/null +++ b/cpp/ppc-front/ppc-front/FrontImpl.cpp @@ -0,0 +1,256 @@ +/** + * Copyright (C) 2023 WeDPR. + * SPDX-License-Identifier: Apache-2.0 + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * @file FrontImpl.cpp + * @author: yujiechen + * @date 2024-08-30 + */ +#include "FrontImpl.h" +#include "ppc-utilities/Utilities.h" + +using namespace bcos; +using namespace ppc; +using namespace ppc::front; +using namespace ppc::protocol; + +FrontImpl::FrontImpl(std::shared_ptr threadPool, + ppc::protocol::INodeInfo::Ptr nodeInfo, MessagePayloadBuilder::Ptr messageFactory, + ppc::protocol::MessageOptionalHeaderBuilder::Ptr routerInfoBuilder, + ppc::gateway::IGateway::Ptr const& gateway, std::shared_ptr ioService) + : m_threadPool(std::move(threadPool)), + m_nodeInfo(std::move(nodeInfo)), + m_messageFactory(std::move(messageFactory)), + m_routerInfoBuilder(std::move(routerInfoBuilder)), + m_ioService(std::move(ioService)), + m_gatewayClient(gateway) +{ + m_nodeID = m_nodeInfo->nodeID().toBytes(); + m_callbackManager = std::make_shared(m_threadPool, ioService); +} + +/** + * @brief start the IFront + * + * @param front the IFront to start + */ +void FrontImpl::start() +{ + if (m_running) + { + FRONT_LOG(INFO) << LOG_DESC("The front has already been started"); + return; + } + m_running = true; + m_thread = std::make_shared([&] { + bcos::pthread_setThreadName("front_io_service"); + while (m_running) + { + try + { + m_ioService->run(); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + if (m_running && m_ioService->stopped()) + { + m_ioService->restart(); + } + } + catch (std::exception& e) + { + FRONT_LOG(WARNING) + << LOG_DESC("Exception in Front Thread:") << boost::diagnostic_information(e); + } + } + FRONT_LOG(INFO) << "Front exit"; + }); +} +/** + * @brief stop the IFront + * + * @param front the IFront to stop + */ +void FrontImpl::stop() +{ + if (!m_running) + { + FRONT_LOG(INFO) << LOG_DESC("The front has already been stopped"); + return; + } + m_running = false; + if (m_ioService) + { + m_ioService->stop(); + } + if (m_thread) + { + // stop the thread + if (m_thread->get_id() != std::this_thread::get_id()) + { + m_thread->join(); + } + else + { + m_thread->detach(); + } + } +} + +/** + * @brief async send message + * + * @param routeType the route type + * @param routeInfo the route info, include + * - topic the topic + * - dstInst the dst agency(must set when 'route by agency' and 'route by + * component') + * - dstNodeID the dst nodeID(must set when 'route by nodeID') + * - componentType the componentType(must set when 'route by component') + * @param payload the payload to send + * @param seq the message seq + * @param timeout timeout + * @param callback callback + */ +void FrontImpl::asyncSendMessage(RouteType routeType, MessageOptionalHeader::Ptr const& routeInfo, + bcos::bytes&& payload, int seq, long timeout, ReceiveMsgFunc errorCallback, + MessageCallback callback) +{ + // generate the frontMessage + MessagePayload::Ptr frontMessage = m_messageFactory->build(); + auto traceID = ppc::generateUUID(); + frontMessage->setTraceID(traceID); + frontMessage->setSeq(seq); + frontMessage->setData(std::move(payload)); + m_callbackManager->addCallback(traceID, timeout, callback); + auto self = weak_from_this(); + // send the message to the gateway + asyncSendMessageToGateway(false, std::move(frontMessage), routeType, routeInfo, timeout, + [self, traceID, routeInfo, errorCallback](bcos::Error::Ptr error) { + auto front = self.lock(); + if (!front) + { + return; + } + // send success + if (error && error->errorCode() != 0) + { + // send failed + FRONT_LOG(ERROR) << LOG_DESC("asyncSendMessage failed") + << LOG_KV("routeInfo", printOptionalField(routeInfo)) + << LOG_KV("traceID", traceID) << LOG_KV("code", error->errorCode()) + << LOG_KV("msg", error->errorMessage()); + // try to trigger the callback + front->handleCallback(error, traceID, nullptr); + } + // Note: be careful block here when use push + if (errorCallback) + { + errorCallback(error); + } + }); +} + +void FrontImpl::handleCallback( + bcos::Error::Ptr const& error, std::string const& traceID, Message::Ptr message) +{ + auto self = weak_from_this(); + m_callbackManager->handleCallback(error, traceID, std::move(message), + [self, message](std::shared_ptr&& payload) { + auto front = self.lock(); + if (!front) + { + return; + } + auto frontMessage = front->m_messageFactory->build(); + // set the traceID + frontMessage->setTraceID(message->header()->traceID()); + ///// populate the route info + auto routerInfo = front->m_routerInfoBuilder->build(message->header()->optionalField()); + // set the dstNodeID + routerInfo->setDstNode(message->header()->optionalField()->srcNode()); + // set the srcNodeID + routerInfo->setSrcNode(message->header()->optionalField()->dstNode()); + front->asyncSendMessageToGateway(true, std::move(frontMessage), + RouteType::ROUTE_THROUGH_NODEID, routerInfo, 0, + [routerInfo](bcos::Error::Ptr error) { + if (!error || error->errorCode() == 0) + { + return; + } + FRONT_LOG(WARNING) << LOG_DESC("send response message error") + << LOG_KV("routeInfo", printOptionalField(routerInfo)) + << LOG_KV("code", error->errorCode()) + << LOG_KV("msg", error->errorMessage()); + }); + }); +} + +void FrontImpl::asyncSendMessageToGateway(bool responsePacket, MessagePayload::Ptr&& frontMessage, + RouteType routeType, MessageOptionalHeader::Ptr const& routeInfo, long timeout, + ReceiveMsgFunc callback) +{ + if (responsePacket) + { + frontMessage->setRespPacket(); + } + routeInfo->setSrcNode(m_nodeID); + auto payload = std::make_shared(); + frontMessage->encode(*payload); + m_gatewayClient->asyncSendMessage(routeType, routeInfo, std::move(*payload), timeout, callback); +} + + +/** + * @brief: receive message from gateway, call by gateway + * @param _message: received ppc message + * @return void + */ +void FrontImpl::onReceiveMessage(Message::Ptr const& msg, ReceiveMsgFunc callback) +{ + try + { + // response to the gateway + if (callback) + { + m_threadPool->enqueue([callback] { callback(nullptr); }); + } + FRONT_LOG(TRACE) << LOG_BADGE("onReceiveMessage") << LOG_KV("msg", printMessage(msg)); + auto frontMessage = m_messageFactory->build(bcos::ref(*(msg->payload()))); + msg->setFrontMessage(frontMessage); + // the response packet, dispatcher by callback + if (frontMessage->isRespPacket()) + { + handleCallback(nullptr, msg->header()->traceID(), msg); + return; + } + // dispatcher by topic + m_callbackManager->onReceiveMessage(msg->header()->optionalField()->topic(), msg); + } + catch (Exception const& e) + { + FRONT_LOG(WARNING) << LOG_DESC("onReceiveMessage exception") + << LOG_KV("msg", printMessage(msg)) + << LOG_KV("error", boost::diagnostic_information(e)); + } +} + +// the sync interface for asyncSendMessage +bcos::Error::Ptr FrontImpl::push(RouteType routeType, MessageOptionalHeader::Ptr const& routeInfo, + bcos::bytes&& payload, int seq, long timeout) +{ + auto promise = std::make_shared>(); + asyncSendMessage( + routeType, routeInfo, std::move(payload), seq, timeout, + [promise](bcos::Error::Ptr error) { promise->set_value(error); }, nullptr); + return promise->get_future().get(); +} \ No newline at end of file diff --git a/cpp/ppc-front/ppc-front/FrontImpl.h b/cpp/ppc-front/ppc-front/FrontImpl.h new file mode 100644 index 00000000..3156b6c7 --- /dev/null +++ b/cpp/ppc-front/ppc-front/FrontImpl.h @@ -0,0 +1,183 @@ +/** + * Copyright (C) 2023 WeDPR. + * SPDX-License-Identifier: Apache-2.0 + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * @file FrontImpl.h + * @author: yujiechen + * @date 2024-08-30 + */ +#pragma once +#include "CallbackManager.h" +#include "Common.h" +#include "ppc-framework/front/IFront.h" +#include "ppc-framework/gateway/IGateway.h" +#include +#include +#include + +namespace ppc::front +{ +class FrontImpl : public IFront, public std::enable_shared_from_this +{ +public: + using Ptr = std::shared_ptr; + FrontImpl(std::shared_ptr threadPool, ppc::protocol::INodeInfo::Ptr nodeInfo, + ppc::protocol::MessagePayloadBuilder::Ptr messageFactory, + ppc::protocol::MessageOptionalHeaderBuilder::Ptr routerInfoBuilder, + ppc::gateway::IGateway::Ptr const& gateway, + std::shared_ptr ioService); + ~FrontImpl() override = default; + + /** + * @brief start the IFront + * + * @param front the IFront to start + */ + void start() override; + /** + * @brief stop the IFront + * + * @param front the IFront to stop + */ + void stop() override; + + bcos::Error::Ptr push(ppc::protocol::RouteType routeType, + ppc::protocol::MessageOptionalHeader::Ptr const& routeInfo, bcos::bytes&& payload, int seq, + long timeout) override; + /** + * @brief async send message + * + * @param routeType the route type + * @param routeInfo the route info, include + * - topic the topic + * - dstInst the dst agency(must set when 'route by agency' and 'route by + * component') + * - dstNodeID the dst nodeID(must set when 'route by nodeID') + * - componentType the componentType(must set when 'route by component') + * @param payload the payload to send + * @param seq the message seq + * @param timeout timeout + * @param callback callback + */ + void asyncSendMessage(ppc::protocol::RouteType routeType, + ppc::protocol::MessageOptionalHeader::Ptr const& routeInfo, bcos::bytes&& payload, int seq, + long timeout, ppc::protocol::ReceiveMsgFunc errorCallback, + ppc::protocol::MessageCallback callback) override; + + /** + * @brief: receive message from gateway, call by gateway + * @param _message: received ppc message + * @return void + */ + void onReceiveMessage( + ppc::protocol::Message::Ptr const& _msg, ppc::protocol::ReceiveMsgFunc _callback) override; + + ppc::protocol::Message::Ptr pop(std::string const& topic, long timeoutMs) override + { + return m_callbackManager->pop(topic, timeoutMs); + } + + ppc::protocol::Message::Ptr peek(std::string const& topic) override + { + return m_callbackManager->pop(topic, 0); + } + /** + * + * @param front the front object + * @param topic the topic + * @param callback the callback called when receive specified topic + */ + void registerTopicHandler( + std::string const& topic, ppc::protocol::MessageDispatcherCallback callback) override + { + m_callbackManager->registerTopicHandler(topic, callback); + } + + /** + * @brief register the nodeInfo to the gateway + * @param nodeInfo the nodeInfo + */ + void registerNodeInfo(ppc::protocol::INodeInfo::Ptr const& nodeInfo) override + { + FRONT_LOG(INFO) << LOG_DESC("registerNodeInfo") + << LOG_KV("nodeInfo", printNodeInfo(m_nodeInfo)); + m_gatewayClient->registerNodeInfo(m_nodeInfo); + } + + /** + * @brief unRegister the nodeInfo to the gateway + */ + void unRegisterNodeInfo() override + { + FRONT_LOG(INFO) << LOG_DESC("unRegisterNodeInfo"); + m_gatewayClient->unRegisterNodeInfo(bcos::ref(m_nodeID)); + } + + /** + * @brief register the topic + * + * @param topic the topic to register + */ + void registerTopic(std::string const& topic) override + { + FRONT_LOG(INFO) << LOG_DESC("register topic: ") << topic; + m_gatewayClient->registerTopic(bcos::ref(m_nodeID), topic); + } + + /** + * @brief unRegister the topic + * + * @param topic the topic to unregister + */ + void unRegisterTopic(std::string const& topic) override + { + FRONT_LOG(INFO) << LOG_DESC("unregister topic: ") << topic; + m_gatewayClient->unRegisterTopic(bcos::ref(m_nodeID), topic); + } + + ppc::protocol::MessageOptionalHeaderBuilder::Ptr const routerInfoBuilder() const + { + return m_routerInfoBuilder; + } + ppc::protocol::MessagePayloadBuilder::Ptr const payloadFactory() const + { + return m_messageFactory; + } + +private: + void asyncSendMessageToGateway(bool responsePacket, + ppc::protocol::MessagePayload::Ptr&& frontMessage, ppc::protocol::RouteType routeType, + ppc::protocol::MessageOptionalHeader::Ptr const& routeInfo, long timeout, + ppc::protocol::ReceiveMsgFunc callback); + + void handleCallback(bcos::Error::Ptr const& error, std::string const& traceID, + ppc::protocol::Message::Ptr message); + +private: + bcos::bytes m_nodeID; + std::shared_ptr m_threadPool; + ppc::protocol::INodeInfo::Ptr m_nodeInfo; + ppc::protocol::MessagePayloadBuilder::Ptr m_messageFactory; + ppc::protocol::MessageOptionalHeaderBuilder::Ptr m_routerInfoBuilder; + + ppc::gateway::IGateway::Ptr m_gatewayClient; + std::shared_ptr m_ioService; + + CallbackManager::Ptr m_callbackManager; + + bool m_running = false; + // the thread to run ioservice + std::shared_ptr m_thread; +}; +} // namespace ppc::front \ No newline at end of file diff --git a/cpp/ppc-front/ppc-front/PPCChannel.h b/cpp/ppc-front/ppc-front/PPCChannel.h index 8b8ec506..f881a915 100644 --- a/cpp/ppc-front/ppc-front/PPCChannel.h +++ b/cpp/ppc-front/ppc-front/PPCChannel.h @@ -52,9 +52,9 @@ class PPCChannel : public Channel, public std::enable_shared_from_thisnotifyTaskInfo(std::move(_taskInfo)); + return m_front->notifyTaskInfo(std::move(taskID)); }; /** diff --git a/cpp/ppc-front/ppc-front/PPCChannelManager.cpp b/cpp/ppc-front/ppc-front/PPCChannelManager.cpp index 897e902d..c42b88d3 100644 --- a/cpp/ppc-front/ppc-front/PPCChannelManager.cpp +++ b/cpp/ppc-front/ppc-front/PPCChannelManager.cpp @@ -41,7 +41,6 @@ void PPCChannelManager::registerMsgHandlerForChannel(uint8_t _taskType, uint8_t }); } - Channel::Ptr PPCChannelManager::buildChannelForTask(const std::string& _taskID) { FRONT_LOG(INFO) << LOG_BADGE("buildChannelForTask") << LOG_KV("taskID", _taskID); @@ -130,7 +129,6 @@ void PPCChannelManager::onMessageArrived(PPCMessageFace::Ptr _message) } } - void PPCChannelManager::removeHoldingMessages(const std::string& _taskID) { WriteGuard lock(x_message_channel); diff --git a/cpp/ppc-front/test/unittests/PPCChannelTest.cpp b/cpp/ppc-front/test/unittests/PPCChannelTest.cpp index b301bcfb..d0d2d6cf 100644 --- a/cpp/ppc-front/test/unittests/PPCChannelTest.cpp +++ b/cpp/ppc-front/test/unittests/PPCChannelTest.cpp @@ -20,7 +20,7 @@ #include "ppc-front/ppc-front/PPCChannel.h" #include "ppc-front/ppc-front/PPCChannelManager.h" -#include "ppc-protocol/src/PPCMessage.h" +#include "protocol/src/PPCMessage.h" #include #include #include diff --git a/cpp/ppc-gateway/CMakeLists.txt b/cpp/ppc-gateway/CMakeLists.txt index 730bbda6..05c438e2 100644 --- a/cpp/ppc-gateway/CMakeLists.txt +++ b/cpp/ppc-gateway/CMakeLists.txt @@ -10,7 +10,9 @@ file(GLOB_RECURSE SRCS ppc-gateway/*.cpp) find_package(tarscpp REQUIRED) add_library(${GATEWAY_TARGET} ${SRCS}) -target_link_libraries(${GATEWAY_TARGET} PUBLIC ${TOOLS_TARGET} jsoncpp_static Boost::filesystem ${BCOS_BOOSTSSL_TARGET} ${BCOS_UTILITIES_TARGET} ${HTTP_TARGET} ${PROTOCOL_TARGET} ${TARS_PROTOCOL_TARGET} tarscpp::tarsservant tarscpp::tarsutil TBB::tbb) +target_link_libraries(${GATEWAY_TARGET} PUBLIC ${TOOLS_TARGET} jsoncpp_static Boost::filesystem ${BCOS_BOOSTSSL_TARGET} ${BCOS_UTILITIES_TARGET} + ${HTTP_TARGET} ${PROTOCOL_TARGET} + ${TARS_PROTOCOL_TARGET} ${PB_PROTOCOL_TARGET} tarscpp::tarsservant tarscpp::tarsutil TBB::tbb) if (APPLE) # target_compile_options(${GATEWAY_TARGET} PRIVATE -faligned-allocation) diff --git a/cpp/ppc-gateway/ppc-gateway/Common.h b/cpp/ppc-gateway/ppc-gateway/Common.h index 02f1f2de..c6dc211c 100644 --- a/cpp/ppc-gateway/ppc-gateway/Common.h +++ b/cpp/ppc-gateway/ppc-gateway/Common.h @@ -32,18 +32,7 @@ namespace ppc::gateway { #define GATEWAY_LOG(LEVEL) BCOS_LOG(LEVEL) << "[GATEWAY]" -#define GATEWAY_WS_CLIENT_MODULE "m_gateway_websocket_client" -#define GATEWAY_WS_SERVER_MODULE "m_gateway_websocket_server" -#define GATEWAY_THREAD_POOL_MODULE "t_gateway" - -#define SEND_MESSAGE_TO_FRONT_SUCCESS "success" -#define SEND_MESSAGE_TO_FRONT_ERROR "error" -#define SEND_MESSAGE_TO_FRONT_TIMEOUT "timeout" - -#define SEND_MESSAGE_TO_FRONT_SUCCESS_CODE "E0000000000" -#define SEND_MESSAGE_TO_FRONT_ERROR_CODE "-1" - -//HTTP HEADER DEFINE +// HTTP HEADER DEFINE #define HEAD_TASK_ID "x-ptp-session-id" #define HEAD_ALGO_TYPE "x-ptp-algorithm-type" #define HEAD_TASK_TYPE "x-ptp-task-type" diff --git a/cpp/ppc-gateway/ppc-gateway/FrontNodeManager.cpp b/cpp/ppc-gateway/ppc-gateway/FrontNodeManager.cpp deleted file mode 100644 index b4c2e5a9..00000000 --- a/cpp/ppc-gateway/ppc-gateway/FrontNodeManager.cpp +++ /dev/null @@ -1,62 +0,0 @@ -/** - * Copyright (C) 2022 WeDPR. - * SPDX-License-Identifier: Apache-2.0 - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * @file FrontNodeManager.cpp - * @author: shawnhe - * @date 2022-10-23 - */ - -#include "FrontNodeManager.h" - -using namespace bcos; -using namespace ppc::gateway; -using namespace ppc::front; - -FrontInterface::Ptr FrontNodeManager::getFront(const std::string& _serviceEndpoint) -{ - bcos::ReadGuard lock(x_frontNodes); - auto it = m_frontNodes.find(_serviceEndpoint); - if (it != m_frontNodes.end()) - { - return it->second; - } - return nullptr; -} - -void FrontNodeManager::registerFront( - std::string const& _endPoint, front::FrontInterface::Ptr _front) -{ - bcos::UpgradableGuard l; - if (m_frontNodes.count(_endPoint)) - { - return; - } - bcos::UpgradeGuard ul(l); - m_frontNodes[_endPoint] = _front; - GATEWAY_LOG(INFO) << LOG_DESC("registerFront success") << LOG_KV("endPoint", _endPoint); -} - -void FrontNodeManager::unregisterFront(std::string const& _endPoint) -{ - bcos::UpgradableGuard l; - auto it = m_frontNodes.find(_endPoint); - if (it == m_frontNodes.end()) - { - return; - } - bcos::UpgradeGuard ul(l); - m_frontNodes.erase(it); - GATEWAY_LOG(INFO) << LOG_DESC("unregisterFront success") << LOG_KV("endPoint", _endPoint); -} diff --git a/cpp/ppc-gateway/ppc-gateway/FrontNodeManager.h b/cpp/ppc-gateway/ppc-gateway/FrontNodeManager.h deleted file mode 100644 index d5bac0f1..00000000 --- a/cpp/ppc-gateway/ppc-gateway/FrontNodeManager.h +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Copyright (C) 2022 WeDPR. - * SPDX-License-Identifier: Apache-2.0 - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * @file FrontNodeManager.h - * @author: shawnhe - * @date 2022-10-23 - */ - -#pragma once -#include "Common.h" -#include "ppc-framework/front/FrontInterface.h" -#include "ppc-front/ppc-front/Front.h" -#include -#include - -namespace ppc::gateway -{ -class FrontNodeManager : public std::enable_shared_from_this -{ -public: - using Ptr = std::shared_ptr; - FrontNodeManager() = default; - virtual ~FrontNodeManager() = default; - - front::FrontInterface::Ptr getFront(const std::string& _serviceEndpoint); - virtual void registerFront(std::string const& _endPoint, front::FrontInterface::Ptr _front); - - virtual void unregisterFront(std::string const& _endPoint); - - virtual std::unordered_map getAllFront() const - { - bcos::ReadGuard l(x_frontNodes); - return m_frontNodes; - } - -private: - // key: serviceEndpoint, value: FrontInterface - std::unordered_map m_frontNodes; - mutable bcos::SharedMutex x_frontNodes; -}; - -} // namespace ppc::gateway diff --git a/cpp/ppc-gateway/ppc-gateway/Gateway.cpp b/cpp/ppc-gateway/ppc-gateway/Gateway.cpp deleted file mode 100644 index f85f80f1..00000000 --- a/cpp/ppc-gateway/ppc-gateway/Gateway.cpp +++ /dev/null @@ -1,774 +0,0 @@ -/** - * Copyright (C) 2022 WeDPR. - * SPDX-License-Identifier: Apache-2.0 - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * @file Gateway.cpp - * @author: shawnhe - * @date 2022-10-20 - */ - -#include "Gateway.h" -#include "ProTaskManager.h" -#include -#include -#include - -using namespace bcos; -using namespace bcos::boostssl; -using namespace bcos::boostssl::ws; -using namespace bcos::boostssl::context; -using namespace ppc; -using namespace ppc::gateway; -using namespace ppc::protocol; -using namespace ppc::front; - - -void Gateway::start() -{ - if (m_running) - { - GATEWAY_LOG(INFO) << LOG_DESC("Gateway already started"); - return; - } - m_running = true; - GATEWAY_LOG(INFO) << LOG_DESC("start the Gateway"); - // register handler when receiving message from other agencies - if (m_protocol == m_webSocketService->gatewayConfig() - ->config() - ->gatewayConfig() - .networkConfig.PROTOCOL_WEBSOCKET) - { - registerWebSocketMsgHandler(); - } - else if (m_protocol == m_webSocketService->gatewayConfig() - ->config() - ->gatewayConfig() - .networkConfig.PROTOCOL_HTTP) - { -#if 0 - // TODO: optimize here - registerUrlMsgHandler(); -#endif - } - m_webSocketService->start(); - m_gatewayThread = std::make_shared([&] { - bcos::pthread_setThreadName("gw_io_service"); - while (m_running.load()) - { - try - { - m_ioContext->run(); - m_ioService->run(); - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - if (m_running) - { - if (m_ioService->stopped()) - { - m_ioService->restart(); - } - if (m_ioContext->stopped()) - { - m_ioContext->restart(); - } - } - } - catch (std::exception& e) - { - GATEWAY_LOG(WARNING) - << LOG_DESC("Exception in Gateway Thread:") << boost::diagnostic_information(e); - } - } - GATEWAY_LOG(INFO) << "Gateway exit"; - }); - GATEWAY_LOG(INFO) << LOG_BADGE("start the gateway end"); -} - - -void Gateway::stop() -{ - if (!m_running) - { - GATEWAY_LOG(INFO) << LOG_DESC("Gateway already stopped"); - return; - } - m_running = false; - GATEWAY_LOG(INFO) << LOG_DESC("stop the Gateway"); - if (m_webSocketService) - { - m_webSocketService->stop(); - } - if (m_ioService) - { - m_ioService->stop(); - } - if (m_ioContext) - { - m_ioContext->stop(); - } - // stop the gateway-thread - if (m_gatewayThread) - { - if (m_gatewayThread->get_id() != std::this_thread::get_id()) - { - m_gatewayThread->join(); - } - else - { - m_gatewayThread->detach(); - } - } - GATEWAY_LOG(INFO) << LOG_BADGE("stop the gateway end"); -} - -#if 0 -// TODO: optimize here -void Gateway::registerUrlMsgHandler() -{ - auto _url = m_webSocketService->gatewayConfig()->config()->gatewayConfig().networkConfig.url; - m_webSocketService->httpServer()->registerUrlHandler( - _url, [self = weak_from_this()](bcos::boostssl::http::HttpRequest const& _httpReq, - ppc::http::RespUrlFunc _handler) { - auto gateway = self.lock(); - if (!gateway) - { - return; - } - std::string code = SEND_MESSAGE_TO_FRONT_SUCCESS_CODE; - std::string _resMsg = SEND_MESSAGE_TO_FRONT_SUCCESS; - try - { - // handle message - auto senderStr = _httpReq[HEAD_SENDER_ID].to_string(); - auto taskIdStr = _httpReq[HEAD_TASK_ID].to_string(); - auto algoTypeStr = _httpReq[HEAD_ALGO_TYPE].to_string(); - auto taskTypeStr = _httpReq[HEAD_TASK_TYPE].to_string(); - auto messageTypeStr = _httpReq[HEAD_MESSAGE_TYPE].to_string(); - auto seqStr = _httpReq[HEAD_SEQ].to_string(); - auto uuidStr = _httpReq[HEAD_UUID].to_string(); - bool isResponse = _httpReq[HEAD_IS_RESPONSE].to_string() == "1"; - auto reqBodyBytes = bcos::bytes(_httpReq.body().begin(), _httpReq.body().end()); - - GATEWAY_LOG(TRACE) << LOG_BADGE("registerUrlMsgHandler") - << LOG_KV("req size: ", reqBodyBytes.size()) - << LOG_KV("taskId: ", taskIdStr) - << LOG_KV("algorithmType: ", algoTypeStr) - << LOG_KV("taskType: ", taskTypeStr) - << LOG_KV("messageType: ", messageTypeStr) - << LOG_KV("UUID: ", uuidStr) << LOG_KV("seq: ", seqStr) - << LOG_KV("isResponse: ", isResponse); - - auto message = gateway->messageFactory()->buildPPCMessage(); - message->setData(std::make_shared(reqBodyBytes)); - message->setTaskID(taskIdStr); -#ifdef ENABLE_CONN - message->setAlgorithmType(5); - message->setTaskType(0); -#else - message->setAlgorithmType(0x00 + atoi(algoTypeStr.c_str())); - message->setTaskType(0x00 + atoi(taskTypeStr.c_str())); -#endif - message->setMessageType(0x00 + atoi(messageTypeStr.c_str())); - message->setSeq(0x00 + atoi(seqStr.c_str())); - message->setSender(senderStr); - message->setUuid(uuidStr); - if (isResponse) - { - message->setResponse(); - } - gateway->onMessageArrived(std::move(message)); - } - catch (std::exception& e) - { - GATEWAY_LOG(ERROR) << LOG_BADGE("onReceiveMsgFromOtherGateway") - << LOG_KV("exception", boost::diagnostic_information(e)); - code = SEND_MESSAGE_TO_FRONT_ERROR_CODE; - _resMsg = SEND_MESSAGE_TO_FRONT_ERROR; - } - - org::interconnection::link::TransportOutbound transportOutbound; - transportOutbound.set_allocated_message(new std::string(_resMsg)); - transportOutbound.set_allocated_code(new std::string(code)); - std::string responseStr; - transportOutbound.SerializeToString(&responseStr); - _handler(nullptr, std::move(bcos::bytes(responseStr.begin(), responseStr.end()))); - }); -} -#endif - -void Gateway::registerWebSocketMsgHandler() -{ - GATEWAY_LOG(INFO) << LOG_BADGE("registerWebSocketMsgHandler"); - m_webSocketService->webSocketServer()->registerMsgHandler( - 0, [self = weak_from_this()](const std::shared_ptr& _wsMessage, - const std::shared_ptr& _session) { - auto gateway = self.lock(); - if (!gateway) - { - return; - } - try - { - // prepare ack - auto ack = gateway->wsMessageFactory()->buildMessage(); - ack->setRespPacket(); - ack->setPacketType(_wsMessage->packetType()); - ack->setSeq(_wsMessage->seq()); - gateway->addAckCallback(_wsMessage->seq(), ack, _session); - - // handle message - auto payload = _wsMessage->payload(); - auto message = gateway->messageFactory()->buildPPCMessage(payload); - if (!message) - { - GATEWAY_LOG(ERROR) << LOG_BADGE("onReceiveMsgFromOtherGateway") - << LOG_DESC("decode ppc message error"); - gateway->sendAck(_wsMessage->seq(), SEND_MESSAGE_TO_FRONT_ERROR); - return; - } - - GATEWAY_LOG(TRACE) - << LOG_BADGE("onReceiveMsgFromOtherGateway") - << LOG_KV("taskType", unsigned(message->taskType())) - << LOG_KV("algorithmType", unsigned(message->algorithmType())) - << LOG_KV("messageType", unsigned(message->messageType())) - << LOG_KV("seq", message->seq()) << LOG_KV("taskID", message->taskID()) - << LOG_KV("sender", message->sender()); - - gateway->onMessageArrived(std::move(message)); - } - catch (std::exception& e) - { - GATEWAY_LOG(ERROR) << LOG_BADGE("onReceiveMsgFromOtherGateway") - << LOG_KV("exception", boost::diagnostic_information(e)); - gateway->sendAck(_wsMessage->seq(), SEND_MESSAGE_TO_FRONT_ERROR); - } - }); -} - - -/** - * @brief: send message to other agency - * @param _agencyID: agency ID of receiver - * @param _message: ppc message data - * @return void - */ -void Gateway::asyncSendMessage( - const std::string& _agencyID, front::PPCMessageFace::Ptr _message, ErrorCallbackFunc _callback) -{ - auto taskID = _message->taskID(); - GATEWAY_LOG(TRACE) << LOG_BADGE("asyncSendMessage") << printPPCMsg(_message); - try - { - if (m_protocol == m_webSocketService->gatewayConfig() - ->config() - ->gatewayConfig() - .networkConfig.PROTOCOL_WEBSOCKET) - { - auto wsClient = m_webSocketService->webSocketClient(_agencyID); - if (!wsClient) - { - onError("asyncSendMessage", taskID, - BCOS_ERROR_PTR((int)PPCRetCode::NETWORK_ERROR, - "WebSocket client not found for " + _agencyID), - _callback); - return; - } - - auto payload = std::make_shared(); - _message->encode(*payload); - auto wsMessage = m_wsMessageFactory->buildMessage(0, payload); - wsMessage->setSeq(_message->uuid()); - - // forward to other agency - auto self = weak_from_this(); - wsClient->asyncSendMessage(wsMessage, Options(m_holdingMessageMinutes * 60 * 1000), - [self, wsMessage, wsClient, taskID, _callback](const Error::Ptr& _error, - const std::shared_ptr& _msg, - const std::shared_ptr& _session) { - Error::Ptr error; - // send success - if (!_error || _error->errorCode() == 0) - { - // check ack - auto payload = _msg->payload(); - if (payload) - { - std::string status = std::string(payload->begin(), payload->end()); - if (SEND_MESSAGE_TO_FRONT_ERROR == status || - SEND_MESSAGE_TO_FRONT_TIMEOUT == status) - { - error = std::make_shared(PPCRetCode::NETWORK_ERROR, - "send message to target front error, status = " + status); - } - } - if (!error) - { - GATEWAY_LOG(TRACE) - << LOG_DESC("asyncSendMessage success") << LOG_KV("task", taskID); - // response to the client in-case of tars-error - _callback(nullptr); - return; - } - } - - if (!error) - { - error = _error; - } - - auto gateway = self.lock(); - if (!gateway) - { - return; - } - gateway->onError("asyncSendMessage", taskID, error, _callback); - }); - } -#if 0 - else if (m_protocol == m_webSocketService->gatewayConfig() - ->config() - ->gatewayConfig() - .networkConfig.PROTOCOL_HTTP) - { - Error::Ptr error; - auto httpClient = m_webSocketService->httpClient(_agencyID); - auto _url = - m_webSocketService->gatewayConfig()->config()->gatewayConfig().networkConfig.url; - auto header = _message->header(); - appendHeader(header, _message); - auto body = _message->data(); - auto response = httpClient->post(_url, header, *body); - // parse response to TransportOutbound - std::string reponseStr(response.begin(), response.end()); - org::interconnection::link::TransportOutbound transportOutbound; - transportOutbound.ParseFromString(reponseStr); - if (transportOutbound.code() != SEND_MESSAGE_TO_FRONT_SUCCESS_CODE) - { - error = std::make_shared(PPCRetCode::NETWORK_ERROR, - "send message to target front error, code = " + transportOutbound.code()); - } - if (!error) - { - GATEWAY_LOG(TRACE) - << LOG_DESC("asyncSendMessage success") << LOG_KV("task", taskID); - _callback(nullptr); - return; - } - onError("asyncSendMessage", taskID, error, _callback); - } -#endif - } - catch (std::exception& e) - { - onError("asyncSendMessage", taskID, - BCOS_ERROR_PTR( - (int)PPCRetCode::EXCEPTION, std::string(boost::diagnostic_information(e))), - std::move(_callback)); - return; - } -} - -void Gateway::onError(std::string const& _desc, std::string const& _taskID, bcos::Error::Ptr _error, - ErrorCallbackFunc _callback) -{ - if (!_error || _error->errorCode() == 0) - { - return; - } - - if (_error->errorCode() == WsError::TimeOut) - { - // Lower the log level because the caller will type out the error message - GATEWAY_LOG(INFO) << LOG_BADGE(_desc) << LOG_KV("taskID", _taskID) - << LOG_KV("code", _error->errorCode()) - << LOG_KV("msg", _error->errorMessage()); - } - else - { - GATEWAY_LOG(ERROR) << LOG_BADGE(_desc) << LOG_KV("taskID", _taskID) - << LOG_KV("code", _error->errorCode()) - << LOG_KV("msg", _error->errorMessage()); - } - - if (!_callback) - { - return; - } - - if (m_threadPool) - { - m_threadPool->enqueue([_callback, _error]() { _callback(_error); }); - } - else - { - _callback(std::move(_error)); - } -} - - -/** - * @brief notice task info to gateway - * @param _taskInfo the latest task information - */ -bcos::Error::Ptr Gateway::notifyTaskInfo(GatewayTaskInfo::Ptr _taskInfo) -{ - auto startT = bcos::utcSteadyTime(); - auto error = std::make_shared(); - auto taskID = _taskInfo->taskID; - auto serviceEndpoint = _taskInfo->serviceEndpoint; - try - { - m_taskManager->registerTaskInfo(taskID, serviceEndpoint); - // check to see if any message has arrived - handleHoldingMessageQueue(std::move(_taskInfo)); - } - catch (std::exception& e) - { - error->setErrorCode(PPCRetCode::EXCEPTION); - error->setErrorMessage(boost::diagnostic_information(e)); - } - - if (error->errorCode()) - { - GATEWAY_LOG(ERROR) << LOG_BADGE("notifyTaskInfo") << LOG_KV("taskID", taskID) - << LOG_DESC(error->errorMessage()); - } - GATEWAY_LOG(INFO) << LOG_BADGE("notifyTaskInfo") << LOG_KV("taskID", taskID) - << LOG_KV("serviceEndpoint", serviceEndpoint) - << LOG_KV("timecost", (bcos::utcSteadyTime() - startT)); - return error; -} - -bcos::Error::Ptr Gateway::eraseTaskInfo(std::string const& _taskID) -{ - auto startT = bcos::utcSteadyTime(); - try - { - // release held message - getAndRemoveHoldingMessages(_taskID); - m_taskManager->removeTaskInfo(_taskID); - GATEWAY_LOG(INFO) << LOG_BADGE("eraseTaskInfo") << LOG_KV("taskID", _taskID) - << LOG_KV("timecost", bcos::utcSteadyTime() - startT); - return nullptr; - } - catch (std::exception const& e) - { - GATEWAY_LOG(ERROR) << LOG_DESC("eraseTaskInfo error") - << LOG_KV("exception", boost::diagnostic_information(e)) - << LOG_KV("timecost", bcos::utcSteadyTime() - startT); - return BCOS_ERROR_PTR( - PPCRetCode::EXCEPTION, "eraseTaskInfo error: " + boost::diagnostic_information(e)); - } -} - -// register gateway url for other parties -bcos::Error::Ptr Gateway::registerGateway( - const std::vector& _gatewayList) -{ - try - { - for (const auto& gateway : _gatewayList) - { - m_webSocketService->registerGatewayUrl(gateway.agencyID, gateway.endpoint); - } - - return nullptr; - } - catch (std::exception const& e) - { - GATEWAY_LOG(ERROR) << LOG_DESC("registerGateway error") - << LOG_KV("exception", boost::diagnostic_information(e)); - return BCOS_ERROR_PTR( - PPCRetCode::EXCEPTION, "registerGateway error: " + boost::diagnostic_information(e)); - } -} - -void Gateway::asyncGetAgencyList(ppc::front::GetAgencyListCallback _callback) -{ - GATEWAY_LOG(TRACE) << LOG_BADGE("asyncGetAgencyList"); - auto const& agencies = m_webSocketService->gatewayConfig()->config()->gatewayConfig().agencies; - std::vector agencyList; - for (auto const& it : agencies) - { - agencyList.emplace_back(it.first); - } - if (!_callback) - { - return; - } - _callback(nullptr, std::move(agencyList)); -} - -void Gateway::handleHoldingMessageQueue(protocol::GatewayTaskInfo::Ptr _taskInfo) -{ - HoldingMessageQueue::Ptr queue; - { - WriteGuard l(x_holdingMessageQueue); - auto it = m_holdingMessageQueue.find(_taskInfo->taskID); - // not find the holding-queue related to the task-info - if (it == m_holdingMessageQueue.end()) - { - return; - } - queue = it->second; - // erase the queue - m_holdingMessageQueue.erase(it); - } - // cancel the timer - if (queue->timer) - { - queue->timer->cancel(); - } - - auto frontInterface = m_frontNodeManager->getFront(_taskInfo->serviceEndpoint); - if (!frontInterface) - { - GATEWAY_LOG(WARNING) << LOG_DESC( - "handleHoldingMessageQueue error for not find the corresponding front"); - return; - } - // dispatch the message - for (auto& msg : queue->messages) - { - if (!frontInterface) - { - GATEWAY_LOG(WARNING) << LOG_DESC("send message error for the target front not found"); - sendAck(msg->uuid(), SEND_MESSAGE_TO_FRONT_ERROR); - continue; - } - // forward to self node - frontInterface->onReceiveMessage(msg, [self = weak_from_this(), msg]( - const bcos::Error::Ptr& _error) { - auto gateway = self.lock(); - if (!gateway) - { - return; - } - - if (_error && _error->errorCode() != 0) - { - GATEWAY_LOG(WARNING) - << LOG_DESC("handleHoldingMessageQueue: dispatch the message error") - << LOG_KV("code", _error->errorCode()) << LOG_KV("msg", _error->errorMessage()); - gateway->sendAck(msg->uuid(), SEND_MESSAGE_TO_FRONT_ERROR); - return; - } - - gateway->sendAck(msg->uuid(), SEND_MESSAGE_TO_FRONT_SUCCESS); - }); - } -} - -// broadcast the message to all front when the task-id is not specified -void Gateway::broadcastMsgToAllFront(ppc::front::PPCMessageFace::Ptr const& _message) -{ - auto frontList = m_frontNodeManager->getAllFront(); - GATEWAY_LOG(TRACE) << LOG_DESC("broadcastMsgToAllFront") - << LOG_KV("frontSize", frontList.size()); - for (auto const& it : frontList) - { - auto const& front = it.second; - auto const& serviceEndpoint = it.first; - dispatchMessageToFront(front, _message, serviceEndpoint); - } -} - -void Gateway::dispatchMessageToFront(ppc::front::FrontInterface::Ptr const& _front, - ppc::front::PPCMessageFace::Ptr const& _message, std::string const& _serviceEndpoint) -{ - auto taskID = _message->taskID(); - // dispatch message to the given front - auto startT = utcSteadyTime(); - _front->onReceiveMessage(_message, [self = weak_from_this(), taskID, _serviceEndpoint, _message, - startT](const bcos::Error::Ptr& _error) { - auto gateway = self.lock(); - if (!gateway) - { - return; - } - - if (_error && _error->errorCode() != 0) - { - GATEWAY_LOG(WARNING) << LOG_DESC("onReceiveMessage: dispatch message to front error") - << printPPCMsg(_message) << LOG_KV("task", taskID) - << LOG_KV("front", _serviceEndpoint) - << LOG_KV("code", _error->errorCode()) - << LOG_KV("msg", _error->errorMessage()) - << LOG_KV("timecost", (utcSteadyTime() - startT)); - gateway->sendAck(_message->uuid(), SEND_MESSAGE_TO_FRONT_ERROR); - return; - } - GATEWAY_LOG(TRACE) << LOG_DESC("onReceiveMessage success") << printPPCMsg(_message) - << LOG_KV("timecost", (utcSteadyTime() - startT)); - gateway->sendAck(_message->uuid(), SEND_MESSAGE_TO_FRONT_SUCCESS); - }); -} - -void Gateway::onMessageArrived(PPCMessageFace::Ptr _message) -{ - GATEWAY_LOG(TRACE) << LOG_DESC("onMessageArrived") << printPPCMsg(_message); - auto taskID = _message->taskID(); - // broadcast the message to all front when the task-id is not specified - if (taskID.empty()) - { - broadcastMsgToAllFront(_message); - return; - } - - bcos::UpgradableGuard l(x_holdingMessageQueue); - auto serviceEndpoint = m_taskManager->getServiceEndpoint(taskID); - if (!serviceEndpoint.empty()) - { - auto frontInterface = m_frontNodeManager->getFront(serviceEndpoint); - if (!frontInterface) - { - GATEWAY_LOG(WARNING) - << LOG_DESC( - "onMessageArrived: can't find the front to dispatch the receive message") - << printPPCMsg(_message) << LOG_KV("task", taskID) - << LOG_KV("frontEndPoint", serviceEndpoint); - sendAck(_message->uuid(), SEND_MESSAGE_TO_FRONT_ERROR); - return; - } - dispatchMessageToFront(frontInterface, _message, serviceEndpoint); - return; - } - // hold the message - GATEWAY_LOG(INFO) << LOG_BADGE("holdMessage") << LOG_KV("taskID", taskID); - - bcos::UpgradeGuard ul(l); - auto it = m_holdingMessageQueue.find(taskID); - if (it != m_holdingMessageQueue.end()) - { - it->second->messages.emplace_back(_message); - return; - } - // insert new holding-queue - auto queue = std::make_shared(); - queue->messages.emplace_back(_message); - // create timer to handle timeout - queue->timer = std::make_shared( - *m_ioService, boost::posix_time::minutes(m_holdingMessageMinutes)); - queue->timer->async_wait([self = weak_from_this(), taskID](boost::system::error_code _error) { - if (!_error) - { - auto gateway = self.lock(); - if (gateway) - { - // remove timeout message - auto msgQueue = gateway->getAndRemoveHoldingMessages(taskID); - gateway->handleTimeoutHoldingMessage(msgQueue); - } - } - }); - m_holdingMessageQueue[taskID] = queue; -} - -HoldingMessageQueue::Ptr Gateway::getAndRemoveHoldingMessages(const std::string& _taskID) -{ - WriteGuard lock(x_holdingMessageQueue); - auto it = m_holdingMessageQueue.find(_taskID); - if (it == m_holdingMessageQueue.end()) - { - return nullptr; - } - - HoldingMessageQueue::Ptr ret = it->second; - m_holdingMessageQueue.erase(_taskID); - return ret; -} - -void Gateway::handleTimeoutHoldingMessage(HoldingMessageQueue::Ptr _queue) -{ - if (!_queue) - { - return; - } - // dispatch the ack - for (auto& msg : _queue->messages) - { - sendAck(msg->uuid(), SEND_MESSAGE_TO_FRONT_TIMEOUT); - } -} - -void Gateway::addAckCallback( - std::string const& _uuid, MessageFace::Ptr _msg, boostssl::ws::WsSession::Ptr _session) -{ - WriteGuard lock(x_ackCallbacks); - m_ackCallbacks[_uuid] = {std::move(_msg), std::move(_session)}; -} - - -void Gateway::sendAck(std::string const& _uuid, std::string const& _status) -{ - WriteGuard lock(x_ackCallbacks); - auto it = m_ackCallbacks.find(_uuid); - if (it == m_ackCallbacks.end()) - { - return; - } - - auto payload = std::make_shared(_status.begin(), _status.end()); - auto& msg = it->second.first; - auto& session = it->second.second; - msg->setPayload(payload); - session->asyncSendMessage(msg, Options(), nullptr); - - m_ackCallbacks.erase(it); -} - -void Gateway::appendHeader( - std::map& origin_header, front::PPCMessageFace::Ptr _message) -{ - origin_header["Content-Type"] = "application/octet-stream;charset=utf-8"; - origin_header["has_uri"] = std::to_string(true); - origin_header[HEAD_TASK_ID] = _message->taskID(); - origin_header[HEAD_ALGO_TYPE] = std::to_string(_message->algorithmType()); - origin_header[HEAD_TASK_TYPE] = std::to_string(_message->taskType()); - origin_header[HEAD_SENDER_ID] = _message->sender(); - origin_header[HEAD_MESSAGE_TYPE] = std::to_string(_message->messageType()); - origin_header[HEAD_SEQ] = std::to_string(_message->seq()); - origin_header[HEAD_IS_RESPONSE] = std::to_string(_message->response()); - origin_header[HEAD_UUID] = _message->uuid(); -} - -Gateway::Ptr GatewayFactory::buildGateway(NodeArch _arch, ppc::tools::PPCConfig::Ptr _config, - storage::CacheStorage::Ptr _cache, front::PPCMessageFaceFactory::Ptr _messageFactory, - std::shared_ptr _threadPool) -{ - auto wsMessageFactory = std::make_shared(); - auto webSocketServiceFactory = std::make_shared(); - auto ioService = std::make_shared(); - auto ioContext = std::make_shared(); - - auto webSocketService = webSocketServiceFactory->buildWebSocketService(_config, ioContext); - - TaskManager::Ptr taskManager = nullptr; - if (_arch == NodeArch::AIR || _cache == nullptr) - { - GATEWAY_LOG(INFO) << LOG_BADGE("buildGateway without cache"); - taskManager = std::make_shared(ioService); - } - else - { - GATEWAY_LOG(INFO) << LOG_BADGE("buildGateway with cache"); - taskManager = std::make_shared(_cache, ioService); - } - auto frontNodeManager = std::make_shared(); - - return std::make_shared(std::move(webSocketService), std::move(wsMessageFactory), - std::move(ioService), std::move(_messageFactory), std::move(frontNodeManager), - std::move(taskManager), std::move(_threadPool), _config->holdingMessageMinutes(), - std::move(ioContext)); -} diff --git a/cpp/ppc-gateway/ppc-gateway/Gateway.h b/cpp/ppc-gateway/ppc-gateway/Gateway.h deleted file mode 100644 index 4a0a8a77..00000000 --- a/cpp/ppc-gateway/ppc-gateway/Gateway.h +++ /dev/null @@ -1,210 +0,0 @@ -/** - * Copyright (C) 2022 WeDPR. - * SPDX-License-Identifier: Apache-2.0 - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * @file Gateway.h - * @author: shawnhe - * @date 2022-10-20 - */ - -#pragma once - -#include "Common.h" -#include "FrontNodeManager.h" -#include "GatewayService.h" -#include "TaskManager.h" -#include "WebSocketService.h" -#include "ppc-framework/gateway/GatewayInterface.h" -#include "ppc-protocol/src/PPCMessage.h" -#if 0 -//TODO: optimize here -#include "ppc-protocol/src/protobuf/transport.pb.h" -#endif - -#include "tbb/concurrent_vector.h" -#include -#include - -#include - -namespace ppc::gateway -{ -struct HoldingMessageQueue -{ - using Ptr = std::shared_ptr; - HoldingMessageQueue() = default; - - tbb::concurrent_vector messages; - std::shared_ptr timer; -}; - -class Gateway : public GatewayInterface, public std::enable_shared_from_this -{ -public: - using Ptr = std::shared_ptr; - - Gateway(WebSocketService::Ptr _webSocketService, - bcos::boostssl::ws::WsMessageFactory::Ptr _wsMessageFactory, - std::shared_ptr _ioService, - front::PPCMessageFaceFactory::Ptr _messageFactory, FrontNodeManager::Ptr _frontNodeManager, - TaskManager::Ptr _taskManager, std::shared_ptr _threadPool, - int _holdingMessageMinutes = 10, - std::shared_ptr _ioContext = nullptr) - : m_holdingMessageMinutes(_holdingMessageMinutes), - m_webSocketService(std::move(_webSocketService)), - m_wsMessageFactory(std::move(_wsMessageFactory)), - m_ioService(std::move(_ioService)), - m_messageFactory(std::move(_messageFactory)), - m_frontNodeManager(std::move(_frontNodeManager)), - m_taskManager(std::move(_taskManager)), - m_threadPool(std::move(_threadPool)), - m_ioContext(_ioContext), - m_protocol( - _webSocketService->gatewayConfig()->config()->gatewayConfig().networkConfig.protocol) - { - GATEWAY_LOG(INFO) << LOG_KV("holdingMessageMinutes", m_holdingMessageMinutes); - } - - Gateway(const Gateway&) = delete; - Gateway(Gateway&&) = delete; - - Gateway& operator=(const Gateway&) = delete; - Gateway& operator=(Gateway&&) = delete; - - virtual ~Gateway() override { stop(); } - - void start() override; - void stop() override; - - void registerWebSocketMsgHandler(); - -#if 0 - // TODO: optimize here - void registerUrlMsgHandler(); -#endif - - /** - * @brief: send message to other agency - * @param _agencyID: agency ID of receiver - * @param _message: ppc message data - * @return void - */ - void asyncSendMessage(const std::string& _agencyID, front::PPCMessageFace::Ptr _message, - ErrorCallbackFunc _callback) override; - - /** - * @brief notice task info to gateway - * @param _taskInfo the latest task information - */ - bcos::Error::Ptr notifyTaskInfo(protocol::GatewayTaskInfo::Ptr _taskInfo) override; - - // erase the task info - bcos::Error::Ptr eraseTaskInfo(std::string const& _taskID) override; - - // register gateway url for other parties - bcos::Error::Ptr registerGateway( - const std::vector& _gatewayList) override; - - // get the agency-list - void asyncGetAgencyList(ppc::front::GetAgencyListCallback _callback) override; - - FrontNodeManager::Ptr frontNodeManager() { return m_frontNodeManager; } - TaskManager::Ptr taskManager() { return m_taskManager; } - void setTaskManager(TaskManager::Ptr _taskManager) { m_taskManager = std::move(_taskManager); } - - WebSocketService::Ptr webSocketService() { return m_webSocketService; } - - front::PPCMessageFaceFactory::Ptr messageFactory() { return m_messageFactory; } - bcos::boostssl::ws::WsMessageFactory::Ptr wsMessageFactory() { return m_wsMessageFactory; } - std::shared_ptr threadPool() { return m_threadPool; } - - void addAckCallback(std::string const& _uuid, bcos::boostssl::MessageFace::Ptr _msg, - bcos::boostssl::ws::WsSession::Ptr _session); - - void sendAck(std::string const& _uuid, std::string const& _status); - - // Note: since the front will periodically register the status, no need to response message to - // the front - void registerFront(std::string const& _endPoint, front::FrontInterface::Ptr _front) override - { - m_frontNodeManager->registerFront(_endPoint, _front); - } - - void unregisterFront(std::string const& _endPoint) override - { - m_frontNodeManager->unregisterFront(_endPoint); - } - -protected: - virtual void handleHoldingMessageQueue(protocol::GatewayTaskInfo::Ptr _taskInfo); - virtual void onMessageArrived(front::PPCMessageFace::Ptr _message); - virtual HoldingMessageQueue::Ptr getAndRemoveHoldingMessages(const std::string& _taskID); - virtual void handleTimeoutHoldingMessage(HoldingMessageQueue::Ptr _queue); - virtual void onError(std::string const& _desc, std::string const& _taskID, - bcos::Error::Ptr _error, ErrorCallbackFunc _callback); - - void broadcastMsgToAllFront(ppc::front::PPCMessageFace::Ptr const& _message); - void dispatchMessageToFront(ppc::front::FrontInterface::Ptr const& _front, - ppc::front::PPCMessageFace::Ptr const& _message, std::string const& _serviceEndpoint); - void appendHeader( - std::map& origin_header, front::PPCMessageFace::Ptr _message); - -private: - int m_holdingMessageMinutes = 30; - int m_protocol; - WebSocketService::Ptr m_webSocketService; - bcos::boostssl::ws::WsMessageFactory::Ptr m_wsMessageFactory; - std::shared_ptr m_ioService; - std::shared_ptr m_ioContext; - front::PPCMessageFaceFactory::Ptr m_messageFactory; - - FrontNodeManager::Ptr m_frontNodeManager; - TaskManager::Ptr m_taskManager; - - std::shared_ptr m_threadPool; - - // the thread to make ioservice run - std::shared_ptr m_gatewayThread; - - /** - * hold the message for the situation that - * gateway receives message from the other side while the task has not been registered. - */ - mutable boost::shared_mutex x_holdingMessageQueue; - std::unordered_map m_holdingMessageQueue; - - std::atomic_bool m_running = {false}; - - mutable boost::shared_mutex x_ackCallbacks; - std::unordered_map > - m_ackCallbacks; -}; - - -class GatewayFactory -{ -public: - using Ptr = std::shared_ptr; - -public: - GatewayFactory() = default; - ~GatewayFactory() = default; - - Gateway::Ptr buildGateway(ppc::protocol::NodeArch _arch, ppc::tools::PPCConfig::Ptr _config, - storage::CacheStorage::Ptr _cache, front::PPCMessageFaceFactory::Ptr _messageFactory, - std::shared_ptr _threadPool); -}; - -} // namespace ppc::gateway \ No newline at end of file diff --git a/cpp/ppc-gateway/ppc-gateway/GatewayConfigContext.cpp b/cpp/ppc-gateway/ppc-gateway/GatewayConfigContext.cpp index 9a1a37ba..8243d133 100644 --- a/cpp/ppc-gateway/ppc-gateway/GatewayConfigContext.cpp +++ b/cpp/ppc-gateway/ppc-gateway/GatewayConfigContext.cpp @@ -22,23 +22,27 @@ using namespace bcos; using namespace ppc::gateway; +using namespace bcos::boostssl::context; + void GatewayConfigContext::initContextConfig() { - m_contextConfig = std::make_shared(); + m_contextConfig = std::make_shared(); auto const& gatewayConfig = m_config->gatewayConfig().networkConfig; // non-sm-ssl if (!gatewayConfig.enableSM) { - boostssl::context::ContextConfig::CertConfig certConfig; + ContextConfig::CertConfig certConfig; certConfig.caCert = gatewayConfig.caCertPath; certConfig.nodeCert = gatewayConfig.sslCertPath; certConfig.nodeKey = gatewayConfig.sslKeyPath; m_contextConfig->setCertConfig(certConfig); m_contextConfig->setSslType("ssl"); + // parse the nodeID + NodeInfoTools::initCert2PubHexHandler()(certConfig.nodeCert, m_nodeID); GATEWAY_LOG(INFO) << LOG_DESC("initConfig: rpc work in non-sm-ssl model") << LOG_KV("caCert", certConfig.caCert) << LOG_KV("nodeCert", certConfig.nodeCert) - << LOG_KV("nodeKey", certConfig.nodeKey); + << LOG_KV("nodeKey", certConfig.nodeKey) << LOG_KV("nodeID", m_nodeID); GATEWAY_LOG(INFO) << LOG_DESC("initContextConfig: non-sm-ssl"); return; } @@ -51,5 +55,6 @@ void GatewayConfigContext::initContextConfig() certConfig.enNodeKey = gatewayConfig.smEnSslKeyPath; m_contextConfig->setSmCertConfig(certConfig); m_contextConfig->setSslType("sm_ssl"); - GATEWAY_LOG(INFO) << LOG_DESC("initContextConfig: sm-ssl"); + NodeInfoTools::initCert2PubHexHandler()(certConfig.enNodeCert, m_nodeID); + GATEWAY_LOG(INFO) << LOG_DESC("initContextConfig: sm-ssl") << LOG_KV("nodeID", m_nodeID); } diff --git a/cpp/ppc-gateway/ppc-gateway/GatewayConfigContext.h b/cpp/ppc-gateway/ppc-gateway/GatewayConfigContext.h index 41320ebc..259ceb71 100644 --- a/cpp/ppc-gateway/ppc-gateway/GatewayConfigContext.h +++ b/cpp/ppc-gateway/ppc-gateway/GatewayConfigContext.h @@ -21,8 +21,7 @@ #pragma once #include "Common.h" -#include "ppc-framework/storage/CacheStorage.h" -#include "ppc-storage/src/redis/RedisStorage.h" +#include "bcos-boostssl/context/NodeInfoTools.h" #include #include #include @@ -51,12 +50,15 @@ class GatewayConfigContext } ppc::tools::PPCConfig::Ptr const& config() const { return m_config; } + std::string const& nodeID() const { return m_nodeID; } + private: void initContextConfig(); private: ppc::tools::PPCConfig::Ptr m_config; std::shared_ptr m_contextConfig; + std::string m_nodeID; }; } // namespace ppc::gateway \ No newline at end of file diff --git a/cpp/ppc-gateway/ppc-gateway/GatewayConfigLoader.cpp b/cpp/ppc-gateway/ppc-gateway/GatewayConfigLoader.cpp new file mode 100644 index 00000000..8d87f282 --- /dev/null +++ b/cpp/ppc-gateway/ppc-gateway/GatewayConfigLoader.cpp @@ -0,0 +1,150 @@ +/** + * Copyright (C) 2023 WeDPR. + * SPDX-License-Identifier: Apache-2.0 + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * @file GatewayConfigLoader.cpp + * @author: yujiechen + * @date 2024-08-26 + */ +#include "GatewayConfigLoader.h" +#include "Common.h" +#include "bcos-utilities/FileUtility.h" +#include + +using namespace ppc; +using namespace bcos; +using namespace ppc::tools; +using namespace ppc::gateway; +using namespace bcos::boostssl; + +// load p2p connected peers +void GatewayConfigLoader::loadP2pConnectedNodes() +{ + std::string nodeFilePath = + m_config->gatewayConfig().nodePath + "/" + m_config->gatewayConfig().nodeFileName; + // load p2p connected nodes + auto jsonContent = readContentsToString(boost::filesystem::path(nodeFilePath)); + if (!jsonContent || jsonContent->empty()) + { + BOOST_THROW_EXCEPTION( + WeDPRException() << errinfo_comment( + "loadP2pConnectedNodes: unable to read nodes json file, path=" + nodeFilePath)); + } + + parseConnectedJson(*jsonContent.get(), *m_nodeIPEndpointSet); + GATEWAY_LOG(INFO) << LOG_DESC("loadP2pConnectedNodes success!") + << LOG_KV("nodePath", m_config->gatewayConfig().nodePath) + << LOG_KV("nodeFileName", m_config->gatewayConfig().nodeFileName) + << LOG_KV("nodes", m_nodeIPEndpointSet->size()); +} + +void GatewayConfigLoader::parseConnectedJson( + const std::string& _json, EndPointSet& _nodeIPEndpointSet) +{ + /* + {"nodes":["127.0.0.1:30355","127.0.0.1:30356"}]} + */ + Json::Value root; + Json::Reader jsonReader; + try + { + if (!jsonReader.parse(_json, root)) + { + GATEWAY_LOG(ERROR) << "unable to parse connected nodes json" << LOG_KV("json:", _json); + BOOST_THROW_EXCEPTION( + WeDPRException() << errinfo_comment("GatewayConfig: unable to parse p2p " + "connected nodes json")); + } + Json::Value jNodes = root["nodes"]; + if (jNodes.isArray()) + { + unsigned int jNodesSize = jNodes.size(); + for (unsigned int i = 0; i < jNodesSize; i++) + { + std::string host = jNodes[i].asString(); + + NodeIPEndpoint endpoint; + hostAndPort2Endpoint(host, endpoint); + _nodeIPEndpointSet.insert(endpoint); + + GATEWAY_LOG(INFO) << LOG_DESC("add one connected node") << LOG_KV("host", host); + } + } + } + catch (const std::exception& e) + { + GATEWAY_LOG(ERROR) << LOG_KV( + "parseConnectedJson error: ", boost::diagnostic_information(e)); + BOOST_THROW_EXCEPTION(e); + } +} + +bool GatewayConfigLoader::isValidPort(int port) +{ + if (port <= 1024 || port > 65535) + return false; + return true; +} + +void GatewayConfigLoader::hostAndPort2Endpoint(const std::string& _host, NodeIPEndpoint& _endpoint) +{ + std::string ip; + uint16_t port; + + std::vector s; + boost::split(s, _host, boost::is_any_of("]"), boost::token_compress_on); + if (s.size() == 2) + { // ipv6 + ip = s[0].data() + 1; + port = boost::lexical_cast(s[1].data() + 1); + } + else if (s.size() == 1) + { // ipv4 + std::vector v; + boost::split(v, _host, boost::is_any_of(":"), boost::token_compress_on); + if (v.size() < 2) + { + BOOST_THROW_EXCEPTION( + WeDPRException() << errinfo_comment("GatewayConfig: invalid host , host=" + _host)); + } + ip = v[0]; + port = boost::lexical_cast(v[1]); + } + else + { + GATEWAY_LOG(ERROR) << LOG_DESC("not valid host value") << LOG_KV("host", _host); + BOOST_THROW_EXCEPTION(WeDPRException() << errinfo_comment( + "GatewayConfig: the host is invalid, host=" + _host)); + } + + if (!isValidPort(port)) + { + GATEWAY_LOG(ERROR) << LOG_DESC("the port is not valid") << LOG_KV("port", port); + BOOST_THROW_EXCEPTION( + WeDPRException() << errinfo_comment( + "GatewayConfig: the port is invalid, port=" + std::to_string(port))); + } + + boost::system::error_code ec; + boost::asio::ip::address ip_address = boost::asio::ip::make_address(ip, ec); + if (ec.value() != 0) + { + GATEWAY_LOG(ERROR) << LOG_DESC("the host is invalid, make_address error") + << LOG_KV("host", _host); + BOOST_THROW_EXCEPTION( + WeDPRException() << errinfo_comment( + "GatewayConfig: the host is invalid make_address error, host=" + _host)); + } + _endpoint = NodeIPEndpoint{ip_address, port}; +} diff --git a/cpp/ppc-gateway/ppc-gateway/GatewayConfigLoader.h b/cpp/ppc-gateway/ppc-gateway/GatewayConfigLoader.h new file mode 100644 index 00000000..1c6b418e --- /dev/null +++ b/cpp/ppc-gateway/ppc-gateway/GatewayConfigLoader.h @@ -0,0 +1,52 @@ +/** + * Copyright (C) 2023 WeDPR. + * SPDX-License-Identifier: Apache-2.0 + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * @file GatewayConfigLoader.h + * @author: yujiechen + * @date 2024-08-26 + */ +#pragma once +#include "bcos-boostssl/interfaces/NodeInfoDef.h" +#include "ppc-tools/src/config/PPCConfig.h" + +namespace ppc::gateway +{ +class GatewayConfigLoader +{ +public: + using EndPointSet = std::set; + using Ptr = std::shared_ptr; + GatewayConfigLoader(ppc::tools::PPCConfig::Ptr config) + : m_config(std::move(config)), m_nodeIPEndpointSet(std::make_shared()) + { + loadP2pConnectedNodes(); + } + virtual ~GatewayConfigLoader() = default; + + EndPointSet const& nodeIPEndpointSet() const { return *m_nodeIPEndpointSet; } + + std::shared_ptr const& nodeIPEndpointSetPtr() const { return m_nodeIPEndpointSet; } + +protected: + void parseConnectedJson(const std::string& _json, EndPointSet& nodeIPEndpointSet); + void loadP2pConnectedNodes(); + void hostAndPort2Endpoint(const std::string& _host, bcos::boostssl::NodeIPEndpoint& _endpoint); + bool isValidPort(int port); + +private: + ppc::tools::PPCConfig::Ptr m_config; + std::shared_ptr m_nodeIPEndpointSet; +}; +} // namespace ppc::gateway \ No newline at end of file diff --git a/cpp/ppc-gateway/ppc-gateway/GatewayFactory.cpp b/cpp/ppc-gateway/ppc-gateway/GatewayFactory.cpp new file mode 100644 index 00000000..51f047f3 --- /dev/null +++ b/cpp/ppc-gateway/ppc-gateway/GatewayFactory.cpp @@ -0,0 +1,77 @@ +/** + * Copyright (C) 2023 WeDPR. + * SPDX-License-Identifier: Apache-2.0 + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * @file GatewayFactory.cpp + * @author: yujiechen + * @date 2024-08-26 + */ +#include "GatewayFactory.h" +#include "Common.h" +#include "bcos-boostssl/websocket/WsInitializer.h" +#include "ppc-gateway/p2p/Service.h" +#include "ppc-gateway/p2p/router/RouterTableImpl.h" +#include "protocol/src/v1/MessageHeaderImpl.h" +#include "protocol/src/v1/MessageImpl.h" + +using namespace ppc; +using namespace bcos; +using namespace ppc::tools; +using namespace ppc::protocol; +using namespace ppc::gateway; +using namespace bcos::boostssl::ws; +using namespace bcos::boostssl; + +WsConfig::Ptr GatewayFactory::createServiceConfig(GatewayConfig const& config) const +{ + auto wsConfig = std::make_shared(); + wsConfig->setModel(WsModel::Mixed); + wsConfig->setListenIP(config.networkConfig.listenIp); + wsConfig->setListenPort(config.networkConfig.listenPort); + wsConfig->setSmSSL(config.networkConfig.enableSM); + wsConfig->setMaxMsgSize(config.maxAllowedMsgSize); + wsConfig->setReconnectPeriod(config.reconnectTime); + // TODO: setHeartbeatPeriod, setSendMsgTimeout + wsConfig->setThreadPoolSize(config.networkConfig.threadPoolSize); + // connected peers + wsConfig->setConnectPeers(m_gatewayConfig->nodeIPEndpointSetPtr()); + wsConfig->setDisableSsl(config.networkConfig.disableSsl); + wsConfig->setContextConfig(m_contextConfig->contextConfig()); + return wsConfig; +} + +Service::Ptr GatewayFactory::buildService() const +{ + auto wsConfig = createServiceConfig(m_config->gatewayConfig()); + auto wsInitializer = std::make_shared(); + // set the messageFactory + wsInitializer->setMessageFactory( + std::make_shared(std::make_shared())); + // set the config + wsInitializer->setConfig(wsConfig); + auto p2pService = std::make_shared(m_contextConfig->nodeID(), + std::make_shared(), m_config->gatewayConfig().unreachableDistance, + "Service"); + p2pService->setNodeEndpoints(m_gatewayConfig->nodeIPEndpointSet()); + + wsInitializer->initWsService(p2pService); + return p2pService; +} + +IGateway::Ptr GatewayFactory::build(ppc::front::IFrontBuilder::Ptr const& frontBuilder) const +{ + auto service = buildService(); + return std::make_shared( + service, frontBuilder, std::make_shared(), m_config->agencyID()); +} \ No newline at end of file diff --git a/cpp/ppc-gateway/ppc-gateway/GatewayFactory.h b/cpp/ppc-gateway/ppc-gateway/GatewayFactory.h new file mode 100644 index 00000000..2393f240 --- /dev/null +++ b/cpp/ppc-gateway/ppc-gateway/GatewayFactory.h @@ -0,0 +1,54 @@ +/** + * Copyright (C) 2023 WeDPR. + * SPDX-License-Identifier: Apache-2.0 + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * @file GatewayFactory.h + * @author: yujiechen + * @date 2024-08-26 + */ +#pragma once +#include "GatewayConfigContext.h" +#include "GatewayConfigLoader.h" +#include "bcos-boostssl/websocket/WsConfig.h" +#include "gateway/GatewayImpl.h" +#include "ppc-gateway/p2p/Service.h" +#include "ppc-tools/src/config/PPCConfig.h" + +namespace ppc::gateway +{ +class GatewayFactory +{ +public: + using Ptr = std::shared_ptr; + GatewayFactory(ppc::tools::PPCConfig::Ptr config) : m_config(std::move(config)) + { + m_contextConfig = std::make_shared(m_config); + m_gatewayConfig = std::make_shared(m_config); + } + virtual ~GatewayFactory() = default; + + IGateway::Ptr build(ppc::front::IFrontBuilder::Ptr const& frontBuilder) const; + +protected: + Service::Ptr buildService() const; + + bcos::boostssl::ws::WsConfig::Ptr createServiceConfig( + ppc::tools::GatewayConfig const& config) const; + +private: + ppc::tools::PPCConfig::Ptr m_config; + GatewayConfigContext::Ptr m_contextConfig; + GatewayConfigLoader::Ptr m_gatewayConfig; +}; +} // namespace ppc::gateway \ No newline at end of file diff --git a/cpp/ppc-gateway/ppc-gateway/ProTaskManager.h b/cpp/ppc-gateway/ppc-gateway/ProTaskManager.h deleted file mode 100644 index fc8b8066..00000000 --- a/cpp/ppc-gateway/ppc-gateway/ProTaskManager.h +++ /dev/null @@ -1,96 +0,0 @@ -/** - * Copyright (C) 2022 WeDPR. - * SPDX-License-Identifier: Apache-2.0 - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * @file ProTaskManager.h - * @author: shawnhe - * @date 2022-10-23 - */ -#pragma once -#include "TaskManager.h" - -#include "ppc-framework/storage/CacheStorage.h" -#include -namespace ppc::gateway -{ -class ProTaskManager : public TaskManager -{ -public: - using Ptr = std::shared_ptr; - ProTaskManager( - storage::CacheStorage::Ptr _cache, std::shared_ptr _ioService) - : TaskManager(std::move(_ioService)), m_cache(std::move(_cache)) - {} - ~ProTaskManager() override = default; - - void registerTaskInfo(const std::string& _taskID, const std::string& _serviceEndpoint) override - { - // throw exception if taskID existed - TaskManager::registerTaskInfo(_taskID, _serviceEndpoint); - try - { - // add task info to cache server - m_cache->setValue(_taskID, _serviceEndpoint, TASK_TIMEOUT_M * 60); - } - catch (std::exception const& e) - { - GATEWAY_LOG(WARNING) << LOG_DESC( - "set value failed: " + std::string(boost::diagnostic_information(e))); - } - } - - void removeTaskInfo(const std::string& _taskID) override - { - // Note: remove the memory-task-info in-case-of the redis exception - TaskManager::removeTaskInfo(_taskID); - m_cache->deleteKey(_taskID); - } - - std::string getServiceEndpoint(const std::string& _taskID) override - { - // find task info in memory first - try - { - auto endPoint = TaskManager::getServiceEndpoint(_taskID); - if (!endPoint.empty()) - { - return endPoint; - } - // Note: different node should not share the cache with same database - // find task info in cache service - auto serviceEndpoint = m_cache->getValue(_taskID); - if (serviceEndpoint == std::nullopt) - { - GATEWAY_LOG(ERROR) << LOG_BADGE("keyNotFoundInCache") << LOG_KV("key", _taskID); - return ""; - } - GATEWAY_LOG(TRACE) << LOG_DESC("getServiceEndpoint: find the task from redis cache") - << LOG_KV("task", _taskID); - // add task info to memory - auto taskInfo = prepareTaskInfo(_taskID, *serviceEndpoint); - addTaskInfo(_taskID, taskInfo); - return *serviceEndpoint; - } - catch (std::exception const& e) - { - GATEWAY_LOG(ERROR) << LOG_DESC("getServiceEndpoint error") - << LOG_KV("exception", boost::diagnostic_information(e)); - return ""; - } - } - -private: - storage::CacheStorage::Ptr m_cache; -}; -} // namespace ppc::gateway \ No newline at end of file diff --git a/cpp/ppc-gateway/ppc-gateway/TaskManager.cpp b/cpp/ppc-gateway/ppc-gateway/TaskManager.cpp deleted file mode 100644 index 51ad5f15..00000000 --- a/cpp/ppc-gateway/ppc-gateway/TaskManager.cpp +++ /dev/null @@ -1,113 +0,0 @@ -/** - * Copyright (C) 2022 WeDPR. - * SPDX-License-Identifier: Apache-2.0 - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * @file TaskManager.cpp - * @author: shawnhe - * @date 2022-10-23 - */ - -#include "TaskManager.h" -#include "ppc-framework/protocol/Protocol.h" - -using namespace bcos; -using namespace ppc::gateway; -using namespace ppc::storage; - -void TaskManager::registerTaskInfo(const std::string& _taskID, const std::string& _serviceEndpoint) -{ - GATEWAY_LOG(INFO) << LOG_BADGE("registerTaskInfo") << LOG_KV("taskID", _taskID) - << LOG_KV("serviceEndpoint", _serviceEndpoint); - if (getTaskInfo(_taskID)) - { - BOOST_THROW_EXCEPTION( - BCOS_ERROR(protocol::PPCRetCode::EXCEPTION, "task id already exists")); - } - - // add task info to memory - auto taskInfo = prepareTaskInfo(_taskID, _serviceEndpoint); - addTaskInfo(_taskID, taskInfo); -} - - -std::string TaskManager::getServiceEndpoint(const std::string& _taskID) -{ - // find task info in memory first - auto taskInfo = getTaskInfo(_taskID); - if (taskInfo) - { - return taskInfo->serviceEndpoint; - } - return ""; -} - - -TaskManager::TaskInfo::Ptr TaskManager::prepareTaskInfo( - const std::string& _taskID, const std::string& _serviceEndpoint) -{ - auto taskInfo = std::make_shared(); - taskInfo->serviceEndpoint = _serviceEndpoint; - - // create timer to handle timeout - taskInfo->timer = std::make_shared( - *m_ioService, boost::posix_time::minutes(TASK_TIMEOUT_M)); - - taskInfo->timer->async_wait( - [self = weak_from_this(), _taskID](boost::system::error_code _error) { - if (!_error) - { - auto taskManager = self.lock(); - if (taskManager) - { - // remove timeout event - taskManager->removeTaskInfo(_taskID); - } - } - }); - - return taskInfo; -} - - -TaskManager::TaskInfo::Ptr TaskManager::getTaskInfo(const std::string& _taskID) -{ - ReadGuard lock(x_tasks); - auto it = m_tasks.find(_taskID); - if (it != m_tasks.end()) - { - return it->second; - } - else - { - return nullptr; - } -} - - -void TaskManager::addTaskInfo( - const std::string& _taskID, const TaskManager::TaskInfo::Ptr& _taskInfo) -{ - WriteGuard lock(x_tasks); - GATEWAY_LOG(INFO) << LOG_BADGE("addTaskInfo") << LOG_KV("taskID", _taskID) - << LOG_KV("serviceEndpoint", _taskInfo->serviceEndpoint); - m_tasks.emplace(_taskID, _taskInfo); -} - - -void TaskManager::removeTaskInfo(const std::string& _taskID) -{ - WriteGuard lock(x_tasks); - GATEWAY_LOG(INFO) << LOG_BADGE("removeTaskInfo") << LOG_KV("taskID", _taskID); - m_tasks.erase(_taskID); -} diff --git a/cpp/ppc-gateway/ppc-gateway/TaskManager.h b/cpp/ppc-gateway/ppc-gateway/TaskManager.h deleted file mode 100644 index 214888a0..00000000 --- a/cpp/ppc-gateway/ppc-gateway/TaskManager.h +++ /dev/null @@ -1,69 +0,0 @@ -/** - * Copyright (C) 2022 WeDPR. - * SPDX-License-Identifier: Apache-2.0 - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * @file TaskManager.h - * @author: shawnhe - * @date 2022-10-23 - */ - -#pragma once - -#include "Common.h" -#include "GatewayConfigContext.h" -#include -#include -#include -#include -#include - -namespace ppc::gateway -{ -class TaskManager : public std::enable_shared_from_this -{ -public: - using Ptr = std::shared_ptr; - TaskManager(std::shared_ptr _ioService) - : m_ioService(std::move(_ioService)) - {} - virtual ~TaskManager() = default; - - virtual void registerTaskInfo(const std::string& _taskID, const std::string& _serviceEndpoint); - - virtual std::string getServiceEndpoint(const std::string& _taskID); - - virtual void removeTaskInfo(const std::string& _taskID); - -protected: - struct TaskInfo - { - using Ptr = std::shared_ptr; - std::string serviceEndpoint; - // timeout of the task - std::shared_ptr timer; - }; - - TaskInfo::Ptr prepareTaskInfo(const std::string& _taskID, const std::string& _serviceEndpoint); - TaskInfo::Ptr getTaskInfo(const std::string& _taskID); - void addTaskInfo(const std::string& _taskID, const TaskInfo::Ptr& _taskInfo); - -protected: - std::shared_ptr m_ioService; - // key: taskID, value: TaskInfo - std::unordered_map m_tasks; - mutable bcos::SharedMutex x_tasks; - - constexpr static uint32_t TASK_TIMEOUT_M = 24 * 60; // minutes -}; -} // namespace ppc::gateway diff --git a/cpp/ppc-gateway/ppc-gateway/WebSocketService.cpp b/cpp/ppc-gateway/ppc-gateway/WebSocketService.cpp deleted file mode 100644 index 1c644fa8..00000000 --- a/cpp/ppc-gateway/ppc-gateway/WebSocketService.cpp +++ /dev/null @@ -1,387 +0,0 @@ -/** - * Copyright (C) 2022 WeDPR. - * SPDX-License-Identifier: Apache-2.0 - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * @file WebSocketService.cpp - * @author: shawnhe - * @date 2022-10-23 - */ - -#include "WebSocketService.h" -#include "ppc-tools/src/config/ParamChecker.h" -#include - -using namespace bcos; -using namespace bcos::boostssl; -using namespace bcos::boostssl::ws; -using namespace ppc::gateway; -using namespace ppc::tools; -using namespace ppc::http; - -void WebSocketService::start() -{ - if (m_timer) - { - m_timer->registerTimeoutHandler( - boost::bind(&WebSocketService::reconnectUnconnectedClient, this)); - } - if (m_protocol == m_config->config()->gatewayConfig().networkConfig.PROTOCOL_WEBSOCKET) - { - m_wsServer->start(); - GATEWAY_LOG(INFO) << LOG_BADGE("start the WebSocketService end"); - } - else if (m_protocol == m_config->config()->gatewayConfig().networkConfig.PROTOCOL_HTTP) - { - m_httpServer->start(); - GATEWAY_LOG(INFO) << LOG_BADGE("start the HttpService end"); - } - startConnect(); - m_timer->start(); -} - -void WebSocketService::stop() -{ - GATEWAY_LOG(INFO) << LOG_BADGE("stop the WebSocketService"); - if (m_timer) - { - m_timer->stop(); - } - ReadGuard l(x_agencyClients); - if (m_protocol == m_config->config()->gatewayConfig().networkConfig.PROTOCOL_WEBSOCKET) - { - for (auto& client : m_agencyClients) - { - auto it = client.second; - if (it) - { - it->stop(); - } - } - if (m_wsServer) - { - m_wsServer->stop(); - } - GATEWAY_LOG(INFO) << LOG_BADGE("stop the WebSocketService success"); - } - else if (m_protocol == m_config->config()->gatewayConfig().networkConfig.PROTOCOL_HTTP) - { - if (m_httpServer) - { - m_httpServer->stop(); - } - GATEWAY_LOG(INFO) << LOG_BADGE("stop the HttpService success"); - } -} - - -void WebSocketService::registerGatewayUrl( - const std::string& _agencyID, const std::string& _agencyUrl) -{ - if (!insertAgency(_agencyID, _agencyUrl)) - { - return; - } - GATEWAY_LOG(INFO) << LOG_BADGE("registerGatewayUrl") << LOG_KV("agencyID", _agencyID) - << LOG_KV("agencyUrl", _agencyUrl); - try - { - if (m_protocol == m_config->config()->gatewayConfig().networkConfig.PROTOCOL_WEBSOCKET) - { - auto client = buildWebSocketClient(_agencyID); - insertIntoMap(_agencyID, client, x_agencyClients, m_agencyClients); - } - else if (m_protocol == m_config->config()->gatewayConfig().networkConfig.PROTOCOL_HTTP) - { - auto client = buildHttpClient(_agencyID, m_ioContext); - insertIntoMap(_agencyID, client, x_agencyClients, m_agencyHttpClients); - } - } - catch (std::exception const& e) - { - GATEWAY_LOG(WARNING) << LOG_DESC("connect to the agency failed") - << LOG_KV("agency", _agencyID) << LOG_KV("url", _agencyUrl); - { - bcos::WriteGuard ucl(x_unConnectedAgencies); - m_unConnectedAgencies.insert(_agencyID); - } - } -} - -bool WebSocketService::insertAgency(const std::string& _agencyID, const std::string& _agencyUrl) -{ - std::vector endpoints; - boost::split(endpoints, _agencyUrl, boost::is_any_of(",")); - - WriteGuard l(x_agencies); - - if (m_urls.find(_agencyID) != m_urls.end()) - { - if (m_urls[_agencyID] == _agencyUrl) - { - // no need update - return false; - } - } - - m_urls[_agencyID] = _agencyUrl; - m_agencies[_agencyID] = endpoints; - return true; -} - -void WebSocketService::reconnectUnconnectedClient() -{ - { - // print connected clients - bcos::ReadGuard rl(x_agencyClients); - if (m_protocol == m_config->config()->gatewayConfig().networkConfig.PROTOCOL_WEBSOCKET) - { - GATEWAY_LOG(INFO) << LOG_DESC("connectedWebsocketClient") - << LOG_KV("size", m_agencyClients.size()); - } - else if (m_protocol == m_config->config()->gatewayConfig().networkConfig.PROTOCOL_HTTP) - { - GATEWAY_LOG(INFO) << LOG_DESC("connectedHttpClient") - << LOG_KV("size", m_agencyHttpClients.size()); - } - } - - // start reconnecting - bcos::UpgradableGuard l(x_unConnectedAgencies); - GATEWAY_LOG(INFO) << LOG_DESC("reconnectUnconnectedClient") - << LOG_KV("size", m_unConnectedAgencies.size()); - for (auto it = m_unConnectedAgencies.begin(); it != m_unConnectedAgencies.end();) - { - try - { - if (m_protocol == m_config->config()->gatewayConfig().networkConfig.PROTOCOL_WEBSOCKET) - { - GATEWAY_LOG(INFO) << LOG_BADGE("WebSocketService") - << LOG_DESC("reconnectUnconnectedClient: connect to peer") - << LOG_KV("agency", *it); - auto client = buildWebSocketClient(*it); - GATEWAY_LOG(INFO) << LOG_BADGE("WebSocketService") - << LOG_DESC("reconnectUnconnectedClient: connect to peer success") - << LOG_KV("agency", *it); - // insert the successfully started client into the m_agencyClients - insertIntoMap(*it, client, x_agencyClients, m_agencyClients); - // erase the connected client from the m_unConnectedAgencies - bcos::UpgradeGuard ul(l); - it = m_unConnectedAgencies.erase(it); - } - else if (m_protocol == m_config->config()->gatewayConfig().networkConfig.PROTOCOL_HTTP) - { - GATEWAY_LOG(INFO) << LOG_BADGE("HttpService") - << LOG_DESC("reconnectUnconnectedClient: connect to peer") - << LOG_KV("agency", *it); - auto client = buildHttpClient(*it, m_ioContext); - GATEWAY_LOG(INFO) << LOG_BADGE("HttpService") - << LOG_DESC("reconnectUnconnectedClient: connect to peer success") - << LOG_KV("agency", *it); - // insert the successfully started client into the m_agencyHttpClients - insertIntoMap(*it, client, x_agencyClients, m_agencyHttpClients); - // erase the connected client from the m_unConnectedAgencies - bcos::UpgradeGuard ul(l); - it = m_unConnectedAgencies.erase(it); - } - } - catch (std::exception const& e) - { - it++; - GATEWAY_LOG(INFO) << LOG_BADGE("reconnectUnconnectedClient failed"); - } - } - m_timer->restart(); -} - -WsService::Ptr WebSocketService::webSocketClient(const std::string& _agencyID) -{ - return getValueFromMap( - _agencyID, x_agencyClients, m_agencyClients); -} - -HttpClient::Ptr WebSocketService::httpClient(const std::string& _agencyID) -{ - return getValueFromMap( - _agencyID, x_agencyClients, m_agencyHttpClients); -} - -void WebSocketService::startConnect() -{ - GATEWAY_LOG(INFO) << LOG_DESC("WebSocketService: startConnect"); - auto const& agencyConfig = m_config->config()->gatewayConfig().agencies; - auto protocol = m_config->config()->gatewayConfig().networkConfig.protocol; - for (auto const& it : agencyConfig) - { - try - { - if (protocol == m_config->config()->gatewayConfig().networkConfig.PROTOCOL_WEBSOCKET) - { - auto client = buildWebSocketClient(it.first); - insertIntoMap(it.first, client, x_agencyClients, m_agencyClients); - } - else if (protocol == m_config->config()->gatewayConfig().networkConfig.PROTOCOL_HTTP) - { - auto client = buildHttpClient(it.first, m_ioContext); - insertIntoMap(it.first, client, x_agencyClients, m_agencyHttpClients); - } - } - catch (std::exception const& e) - { - { - bcos::WriteGuard l(x_unConnectedAgencies); - m_unConnectedAgencies.insert(it.first); - } - GATEWAY_LOG(WARNING) << LOG_BADGE("startConnect") - << LOG_DESC("connect to the agency failed") - << LOG_KV("agency", it.first) - << LOG_KV("exception", boost::diagnostic_information(e)); - } - } - GATEWAY_LOG(INFO) << LOG_DESC("WebSocketService: startConnect success"); -} - -HttpClient::Ptr WebSocketService::buildHttpClient( - const std::string& _agencyID, std::shared_ptr _ioContext) -{ - GATEWAY_LOG(INFO) << LOG_BADGE("WebSocketService: buildHttpClient") - << LOG_KV("agency", _agencyID); - { - bcos::ReadGuard l(x_agencies); - // one agencyID => one httpClient - for (const auto& endpoint : m_agencies.at(_agencyID)) - { - if (!checkEndpoint(endpoint)) - { - BOOST_THROW_EXCEPTION( - InvalidParameter() << bcos::errinfo_comment("Invalid endpoint: " + endpoint)); - } - std::vector url; - boost::split(url, endpoint, boost::is_any_of(":"), boost::token_compress_on); - auto client = std::make_shared(*_ioContext, url[0], std::stoi(url[1])); - return client; - } - } -} - -WsService::Ptr WebSocketService::buildWebSocketClient(std::string const& _agencyID) -{ - GATEWAY_LOG(INFO) << LOG_BADGE("buildWebSocketClient") << LOG_KV("agency", _agencyID); - auto peers = std::make_shared(); - { - bcos::ReadGuard l(x_agencies); - for (const auto& endpoint : m_agencies.at(_agencyID)) - { - if (!checkEndpoint(endpoint)) - { - BOOST_THROW_EXCEPTION( - InvalidParameter() << bcos::errinfo_comment("Invalid endpoint: " + endpoint)); - } - - std::vector url; - boost::split(url, endpoint, boost::is_any_of(":"), boost::token_compress_on); - NodeIPEndpoint nodeIpEndpoint = NodeIPEndpoint(url[0], std::stoi(url[1])); - peers->insert(nodeIpEndpoint); - } - } - - auto const& gatewayConfig = m_config->config()->gatewayConfig(); - auto wsConfig = std::make_shared(); - wsConfig->setModel(WsModel::Client); - wsConfig->setConnectPeers(peers); - wsConfig->setThreadPoolSize(gatewayConfig.networkConfig.threadPoolSize); - wsConfig->setDisableSsl(gatewayConfig.networkConfig.disableSsl); - wsConfig->setMaxMsgSize(gatewayConfig.maxAllowedMsgSize); - if (!wsConfig->disableSsl()) - { - wsConfig->setContextConfig(m_config->contextConfig()); - } - - auto wsInitializer = std::make_shared(); - wsInitializer->setConfig(wsConfig); - auto wsClient = std::make_shared(GATEWAY_WS_CLIENT_MODULE); - wsClient->setTimerFactory(std::make_shared()); - wsInitializer->initWsService(wsClient); - - wsClient->start(); - GATEWAY_LOG(INFO) << LOG_BADGE("WebSocketService") << LOG_DESC("connect to peer success") - << LOG_KV("agency", _agencyID); - return wsClient; -} - -WsService::Ptr WebSocketServiceFactory::buildWebSocketServer( - const GatewayConfigContext::Ptr& _config) -{ - GATEWAY_LOG(INFO) << LOG_BADGE("buildWebSocketServer"); - auto wsConfig = std::make_shared(); - wsConfig->setModel(WsModel::Server); - - auto const& gatewayConfig = _config->config()->gatewayConfig(); - wsConfig->setListenIP(gatewayConfig.networkConfig.listenIp); - wsConfig->setListenPort(gatewayConfig.networkConfig.listenPort); - wsConfig->setThreadPoolSize(gatewayConfig.networkConfig.threadPoolSize); - wsConfig->setDisableSsl(gatewayConfig.networkConfig.disableSsl); - if (!wsConfig->disableSsl()) - { - wsConfig->setContextConfig(_config->contextConfig()); - } - wsConfig->setMaxMsgSize(gatewayConfig.maxAllowedMsgSize); - auto wsInitializer = std::make_shared(); - wsInitializer->setConfig(wsConfig); - auto wsService = std::make_shared(GATEWAY_WS_SERVER_MODULE); - wsService->setTimerFactory(std::make_shared()); - wsInitializer->initWsService(wsService); - - return wsService; -} - -Http::Ptr WebSocketServiceFactory::buildHttpServer(const GatewayConfigContext::Ptr& _config) -{ - GATEWAY_LOG(INFO) << LOG_BADGE("buildHttpServer"); - auto ppcConfig = _config->config(); - auto httpFactory = std::make_shared(ppcConfig->agencyID()); - return httpFactory->buildHttp(ppcConfig); -} - -WebSocketService::Ptr WebSocketServiceFactory::buildWebSocketService( - ppc::tools::PPCConfig::Ptr const& _config, std::shared_ptr _ioContext) -{ - try - { - auto gatewayConfig = std::make_shared(_config); - auto _protocol = _config->gatewayConfig().networkConfig.protocol; - // init websocket service - if (_protocol == _config->gatewayConfig().networkConfig.PROTOCOL_WEBSOCKET) - { - GATEWAY_LOG(INFO) << LOG_BADGE("buildWebSocketService"); - auto wsServer = buildWebSocketServer(gatewayConfig); - auto webSocketService = std::make_shared(gatewayConfig, wsServer); - return webSocketService; - } - else if (_protocol == _config->gatewayConfig().networkConfig.PROTOCOL_HTTP) - { - GATEWAY_LOG(INFO) << LOG_BADGE("buildHttpService"); - auto httpServer = buildHttpServer(gatewayConfig); - auto webSocketService = - std::make_shared(gatewayConfig, httpServer, _ioContext); - return webSocketService; - } - } - catch (std::exception const& e) - { - GATEWAY_LOG(ERROR) << LOG_BADGE("buildWebSocketService") - << LOG_DESC("init gateway websocket service failed, error: " + - boost::diagnostic_information(e)); - throw e; - } -} diff --git a/cpp/ppc-gateway/ppc-gateway/WebSocketService.h b/cpp/ppc-gateway/ppc-gateway/WebSocketService.h deleted file mode 100644 index 73c03ad5..00000000 --- a/cpp/ppc-gateway/ppc-gateway/WebSocketService.h +++ /dev/null @@ -1,147 +0,0 @@ -/** - * Copyright (C) 2022 WeDPR. - * SPDX-License-Identifier: Apache-2.0 - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * @file WebSocketService.h - * @author: shawnhe - * @date 2022-10-23 - */ - -#pragma once - -#include "Common.h" -#include "GatewayConfigContext.h" -#include "ppc-http/src/Http.h" -#include "ppc-http/src/HttpClient.h" -#include "ppc-http/src/HttpFactory.h" -#include "ppc-tools/src/config/PPCConfig.h" -#include -#include -#include -#include -#include -#include - -namespace ppc::gateway -{ -class WebSocketService -{ -public: - using Ptr = std::shared_ptr; - using WebSocketClientMap = std::unordered_map; - using HttpClientMap = std::unordered_map; - WebSocketService( - GatewayConfigContext::Ptr _config, bcos::boostssl::ws::WsService::Ptr _wsServer) - : m_config(_config), - m_agencies(_config->config()->gatewayConfig().agencies), - m_wsServer(std::move(_wsServer)), - m_protocol(_config->config()->gatewayConfig().networkConfig.protocol), - m_timer(std::make_shared( - _config->config()->gatewayConfig().reconnectTime, "connectTimer")) - {} - - WebSocketService(GatewayConfigContext::Ptr _config, ppc::http::Http::Ptr _httpServer, - std::shared_ptr _ioContext) - : m_config(_config), - m_agencies(_config->config()->gatewayConfig().agencies), - m_httpServer(std::move(_httpServer)), - m_protocol(_config->config()->gatewayConfig().networkConfig.protocol), - m_ioContext(_ioContext), - m_timer(std::make_shared( - _config->config()->gatewayConfig().reconnectTime, "connectTimer")) - {} - - virtual ~WebSocketService() = default; - - void start(); - void stop(); - - void registerGatewayUrl(const std::string& _agencyID, const std::string& _agencyUrl); - bool insertAgency(const std::string& _agencyID, const std::string& _agencyUrl); - bcos::boostssl::ws::WsService::Ptr const& webSocketServer() const { return m_wsServer; } - ppc::http::Http::Ptr httpServer() const { return m_httpServer; } - bcos::boostssl::ws::WsService::Ptr webSocketClient(const std::string& _agencyID); - ppc::http::HttpClient::Ptr httpClient(const std::string& _agencyID); - GatewayConfigContext::Ptr const& gatewayConfig() const { return m_config; } - - -protected: - virtual bcos::boostssl::ws::WsService::Ptr buildWebSocketClient(std::string const& _agencyID); - virtual ppc::http::HttpClient::Ptr buildHttpClient( - const std::string& _agencyID, std::shared_ptr _ioContext); - virtual void startConnect(); - virtual void reconnectUnconnectedClient(); - - template - void insertIntoMap(std::string const& _key, T const& _value, bcos::SharedMutex& lock, S& _map) - { - bcos::WriteGuard l(lock); - _map[_key] = _value; - } - - template - T getValueFromMap(std::string const& _key, bcos::SharedMutex& lock, S const& _map) - { - bcos::ReadGuard l(lock); - auto it = _map.find(_key); - if (it != _map.end()) - { - return it->second; - } - return nullptr; - } - -private: - int m_protocol; - GatewayConfigContext::Ptr m_config; - bcos::SharedMutex x_agencies; - std::map> m_agencies; - std::map m_urls; - - bcos::boostssl::ws::WsService::Ptr m_wsServer; - ppc::http::Http::Ptr m_httpServer; - // the timer used to try connecting to the un-connected-clients - std::shared_ptr m_timer; - - std::shared_ptr m_ioContext; - // key: agencyID, value: WebSocketClient - WebSocketClientMap m_agencyClients; - HttpClientMap m_agencyHttpClients; - bcos::SharedMutex x_agencyClients; - // connect failed for all the agency-nodes are offline - std::set m_unConnectedAgencies; - mutable bcos::SharedMutex x_unConnectedAgencies; -}; - - -class WebSocketServiceFactory -{ -public: - using Ptr = std::shared_ptr; - -public: - WebSocketServiceFactory() = default; - ~WebSocketServiceFactory() = default; - - WebSocketService::Ptr buildWebSocketService(const ppc::tools::PPCConfig::Ptr& _config, - std::shared_ptr _ioContext = nullptr); - -private: - bcos::boostssl::ws::WsService::Ptr buildWebSocketServer( - const GatewayConfigContext::Ptr& _config); - - ppc::http::Http::Ptr buildHttpServer(const GatewayConfigContext::Ptr& _config); -}; - -} // namespace ppc::gateway \ No newline at end of file diff --git a/cpp/ppc-gateway/ppc-gateway/gateway/GatewayImpl.cpp b/cpp/ppc-gateway/ppc-gateway/gateway/GatewayImpl.cpp index 6707fae9..de92b96b 100644 --- a/cpp/ppc-gateway/ppc-gateway/gateway/GatewayImpl.cpp +++ b/cpp/ppc-gateway/ppc-gateway/gateway/GatewayImpl.cpp @@ -84,12 +84,14 @@ void GatewayImpl::stop() } void GatewayImpl::asyncSendbroadcastMessage(ppc::protocol::RouteType routeType, - std::string const& topic, std::string const& dstInst, std::string const& componentType, - bcos::bytes&& payload) + MessageOptionalHeader::Ptr const& routeInfo, bcos::bytes&& payload) { // dispatcher to all the local front - auto p2pMessage = m_msgBuilder->build( - routeType, topic, dstInst, bcos::bytes(), componentType, std::move(payload)); + routeInfo->setDstNode(bcos::bytes()); + routeInfo->setSrcInst(m_agency); + + auto p2pMessage = m_msgBuilder->build(routeType, routeInfo, std::move(payload)); + p2pMessage->setPacketType((uint16_t)GatewayPacketType::BroadcastMessage); m_localRouter->dispatcherMessage(p2pMessage, nullptr); // broadcast message to all peers @@ -97,13 +99,14 @@ void GatewayImpl::asyncSendbroadcastMessage(ppc::protocol::RouteType routeType, } -void GatewayImpl::asyncSendMessage(ppc::protocol::RouteType routeType, std::string const& topic, - std::string const& dstInst, bcos::bytes const& dstNodeID, std::string const& componentType, - bcos::bytes&& payload, long timeout, ReceiveMsgFunc callback) +void GatewayImpl::asyncSendMessage(ppc::protocol::RouteType routeType, + ppc::protocol::MessageOptionalHeader::Ptr const& routeInfo, bcos::bytes&& payload, long timeout, + ReceiveMsgFunc callback) { + routeInfo->setSrcInst(m_agency); // check the localRouter - auto p2pMessage = m_msgBuilder->build( - routeType, topic, dstInst, dstNodeID, componentType, std::move(payload)); + auto p2pMessage = m_msgBuilder->build(routeType, routeInfo, std::move(payload)); + p2pMessage->setPacketType((uint16_t)GatewayPacketType::P2PMessage); auto nodeList = m_localRouter->chooseReceiver(p2pMessage); // case send to the same agency diff --git a/cpp/ppc-gateway/ppc-gateway/gateway/GatewayImpl.h b/cpp/ppc-gateway/ppc-gateway/gateway/GatewayImpl.h index 9d2ced23..40f97fa7 100644 --- a/cpp/ppc-gateway/ppc-gateway/gateway/GatewayImpl.h +++ b/cpp/ppc-gateway/ppc-gateway/gateway/GatewayImpl.h @@ -53,12 +53,12 @@ class GatewayImpl : public IGateway, public std::enable_shared_from_this(-1, SEND_MESSAGE_TO_FRONT_TIMEOUT)); + msgInfo.callback(std::make_shared(-1, "timeout")); } } } \ No newline at end of file diff --git a/cpp/ppc-gateway/ppc-gateway/gateway/router/GatewayNodeInfo.h b/cpp/ppc-gateway/ppc-gateway/gateway/router/GatewayNodeInfo.h index e73f2f1d..9338f07c 100644 --- a/cpp/ppc-gateway/ppc-gateway/gateway/router/GatewayNodeInfo.h +++ b/cpp/ppc-gateway/ppc-gateway/gateway/router/GatewayNodeInfo.h @@ -43,10 +43,11 @@ class GatewayNodeInfo virtual bool tryAddNodeInfo(ppc::protocol::INodeInfo::Ptr const& nodeInfo) = 0; virtual void removeNodeInfo(bcos::bytes const& nodeID) = 0; - virtual std::vector chooseRouteByComponent( + virtual std::vector> chooseRouteByComponent( bool selectAll, std::string const& component) const = 0; - virtual std::vector chooseRouterByAgency(bool selectAll) const = 0; - virtual std::vector chooseRouterByTopic( + virtual std::vector> chooseRouterByAgency( + bool selectAll) const = 0; + virtual std::vector> chooseRouterByTopic( bool selectAll, std::string const& topic) const = 0; virtual void encode(bcos::bytes& data) const = 0; diff --git a/cpp/ppc-gateway/ppc-gateway/gateway/router/GatewayNodeInfoImpl.cpp b/cpp/ppc-gateway/ppc-gateway/gateway/router/GatewayNodeInfoImpl.cpp index 3c718fe9..690ded72 100644 --- a/cpp/ppc-gateway/ppc-gateway/gateway/router/GatewayNodeInfoImpl.cpp +++ b/cpp/ppc-gateway/ppc-gateway/gateway/router/GatewayNodeInfoImpl.cpp @@ -18,8 +18,9 @@ * @date 2024-08-26 */ #include "GatewayNodeInfoImpl.h" -#include "ppc-tars-protocol/Common.h" -#include "ppc-tars-protocol/impl/NodeInfoImpl.h" +#include "wedpr-protocol/protobuf/Common.h" +#include "wedpr-protocol/protobuf/NodeInfoImpl.h" +#include "wedpr-protocol/tars/Common.h" using namespace ppctars; using namespace ppc::protocol; @@ -29,21 +30,21 @@ using namespace ppc::gateway; // the gateway nodeID std::string const& GatewayNodeInfoImpl::p2pNodeID() const { - return m_inner()->p2pNodeID; + return m_inner()->p2pnodeid(); } // the agency std::string const& GatewayNodeInfoImpl::agency() const { - return m_inner()->agency; + return m_inner()->agency(); } uint32_t GatewayNodeInfoImpl::statusSeq() const { - return m_inner()->statusSeq; + return m_inner()->statusseq(); } void GatewayNodeInfoImpl::setStatusSeq(uint32_t statusSeq) { - m_inner()->statusSeq = statusSeq; + m_inner()->set_statusseq(statusSeq); } // get the node information by nodeID @@ -96,10 +97,10 @@ void GatewayNodeInfoImpl::removeNodeInfo(bcos::bytes const& nodeID) } } -std::vector GatewayNodeInfoImpl::chooseRouteByComponent( +std::vector> GatewayNodeInfoImpl::chooseRouteByComponent( bool selectAll, std::string const& component) const { - std::vector result; + std::vector> result; bcos::ReadGuard l(x_nodeList); for (auto const& it : m_nodeList) { @@ -116,9 +117,10 @@ std::vector GatewayNodeInfoImpl::chooseRouteByComponent } -vector GatewayNodeInfoImpl::chooseRouterByAgency(bool selectAll) const +std::vector> GatewayNodeInfoImpl::chooseRouterByAgency( + bool selectAll) const { - std::vector result; + std::vector> result; bcos::ReadGuard l(x_nodeList); for (auto const& it : m_nodeList) { @@ -131,10 +133,10 @@ vector GatewayNodeInfoImpl::chooseRouterByAgency(bool s return result; } -std::vector GatewayNodeInfoImpl::chooseRouterByTopic( +std::vector> GatewayNodeInfoImpl::chooseRouterByTopic( bool selectAll, std::string const& topic) const { - std::vector result; + std::vector> result; bcos::ReadGuard l(x_topicInfo); for (auto const& it : m_topicInfo) { @@ -182,34 +184,30 @@ void GatewayNodeInfoImpl::unRegisterTopic(bcos::bytes const& nodeID, std::string void GatewayNodeInfoImpl::encode(bcos::bytes& data) const { - m_inner()->nodeList.clear(); + m_inner()->clear_nodelist(); { bcos::ReadGuard l(x_nodeList); // encode nodeList for (auto const& it : m_nodeList) { auto nodeInfo = std::dynamic_pointer_cast(it.second); - m_inner()->nodeList.emplace_back(nodeInfo->inner()); + m_inner()->mutable_nodelist()->UnsafeArenaAddAllocated(nodeInfo->innerFunc()()); } } - tars::TarsOutputStream output; - m_inner()->writeTo(output); - output.getByteBuffer().swap(data); + encodePBObject(data, m_inner()); } void GatewayNodeInfoImpl::decode(bcos::bytesConstRef data) { - tars::TarsInputStream input; - input.setBuffer((const char*)data.data(), data.size()); - m_inner()->readFrom(input); + decodePBObject(m_inner(), data); { bcos::WriteGuard l(x_nodeList); // decode into m_nodeList m_nodeList.clear(); - for (auto& it : m_inner()->nodeList) + for (int i = 0; i < m_inner()->nodelist_size(); i++) { - auto nodeInfoPtr = - std::make_shared([m_entry = it]() mutable { return &m_entry; }); + auto nodeInfoPtr = std::make_shared( + [m_entry = m_inner()->nodelist(i)]() mutable { return &m_entry; }); m_nodeList.insert(std::make_pair(nodeInfoPtr->nodeID().toBytes(), nodeInfoPtr)); } } diff --git a/cpp/ppc-gateway/ppc-gateway/gateway/router/GatewayNodeInfoImpl.h b/cpp/ppc-gateway/ppc-gateway/gateway/router/GatewayNodeInfoImpl.h index 72c6d440..4e3ba366 100644 --- a/cpp/ppc-gateway/ppc-gateway/gateway/router/GatewayNodeInfoImpl.h +++ b/cpp/ppc-gateway/ppc-gateway/gateway/router/GatewayNodeInfoImpl.h @@ -19,7 +19,7 @@ */ #pragma once #include "GatewayNodeInfo.h" -#include "ppc-tars-protocol/tars/NodeInfo.h" +#include "NodeInfo.pb.h" #include #include @@ -30,12 +30,19 @@ class GatewayNodeInfoImpl : public GatewayNodeInfo public: using Ptr = std::shared_ptr; GatewayNodeInfoImpl(std::string const& p2pNodeID, std::string const& agency) - : m_inner([inner = ppctars::GatewayNodeInfo()]() mutable { return &inner; }) + : m_inner([inner = ppc::proto::GatewayNodeInfo()]() mutable { return &inner; }) { - m_inner()->p2pNodeID = p2pNodeID; - m_inner()->agency = agency; + m_inner()->set_p2pnodeid(p2pNodeID); + m_inner()->set_agency(agency); + } + ~GatewayNodeInfoImpl() override + { + auto allocatedNodeListSize = m_inner()->nodelist_size(); + for (int i = 0; i < allocatedNodeListSize; i++) + { + m_inner()->mutable_nodelist()->UnsafeArenaReleaseLast(); + } } - ~GatewayNodeInfoImpl() override = default; // the gateway nodeID std::string const& p2pNodeID() const override; @@ -52,10 +59,11 @@ class GatewayNodeInfoImpl : public GatewayNodeInfo bool tryAddNodeInfo(ppc::protocol::INodeInfo::Ptr const& nodeInfo) override; void removeNodeInfo(bcos::bytes const& nodeID) override; - std::vector chooseRouteByComponent( + std::vector> chooseRouteByComponent( bool selectAll, std::string const& component) const override; - std::vector chooseRouterByAgency(bool selectAll) const override; - std::vector chooseRouterByTopic( + std::vector> chooseRouterByAgency( + bool selectAll) const override; + std::vector> chooseRouterByTopic( bool selectAll, std::string const& topic) const override; void registerTopic(bcos::bytes const& nodeID, std::string const& topic) override; @@ -72,7 +80,7 @@ class GatewayNodeInfoImpl : public GatewayNodeInfo virtual uint16_t nodeSize() const override { return m_nodeList.size(); } private: - std::function m_inner; + std::function m_inner; // NodeID => nodeInfo std::map m_nodeList; mutable bcos::SharedMutex x_nodeList; diff --git a/cpp/ppc-gateway/ppc-gateway/gateway/router/LocalRouter.cpp b/cpp/ppc-gateway/ppc-gateway/gateway/router/LocalRouter.cpp index 7d934daa..2db66e76 100644 --- a/cpp/ppc-gateway/ppc-gateway/gateway/router/LocalRouter.cpp +++ b/cpp/ppc-gateway/ppc-gateway/gateway/router/LocalRouter.cpp @@ -80,10 +80,10 @@ bool LocalRouter::dispatcherMessage(Message::Ptr const& msg, ReceiveMsgFunc call return false; } -std::vector LocalRouter::chooseReceiver( +std::vector LocalRouter::chooseReceiver( ppc::protocol::Message::Ptr const& msg) { - std::vector receivers; + std::vector receivers; if (msg->header()->optionalField()->dstInst() != m_routerInfo->agency()) { return receivers; diff --git a/cpp/ppc-gateway/ppc-gateway/gateway/router/LocalRouter.h b/cpp/ppc-gateway/ppc-gateway/gateway/router/LocalRouter.h index 8b931c9c..2bd18e7f 100644 --- a/cpp/ppc-gateway/ppc-gateway/gateway/router/LocalRouter.h +++ b/cpp/ppc-gateway/ppc-gateway/gateway/router/LocalRouter.h @@ -59,7 +59,7 @@ class LocalRouter virtual void registerTopic(bcos::bytesConstRef nodeID, std::string const& topic); virtual void unRegisterTopic(bcos::bytesConstRef nodeID, std::string const& topic); - virtual std::vector chooseReceiver( + virtual std::vector chooseReceiver( ppc::protocol::Message::Ptr const& msg); // TODO: register component @@ -81,7 +81,6 @@ class LocalRouter return statusSeq; } - private: ppc::front::IFrontBuilder::Ptr m_frontBuilder; GatewayNodeInfo::Ptr m_routerInfo; diff --git a/cpp/ppc-gateway/ppc-gateway/p2p/Service.h b/cpp/ppc-gateway/ppc-gateway/p2p/Service.h index e78aff23..3be13906 100644 --- a/cpp/ppc-gateway/ppc-gateway/p2p/Service.h +++ b/cpp/ppc-gateway/ppc-gateway/p2p/Service.h @@ -57,6 +57,19 @@ class Service : public bcos::boostssl::ws::WsService return m_messageFactory; } + void setNodeEndpoints(std::set const& endPointList) + { + bcos::WriteGuard l(x_configuredNode2ID); + for (auto const& it : endPointList) + { + if (m_configuredNode2ID.count(it)) + { + continue; + } + m_configuredNode2ID.insert(std::make_pair(it, "")); + } + } + protected: void onRecvMessage(bcos::boostssl::MessageFace::Ptr _msg, bcos::boostssl::ws::WsSession::Ptr _session) override; diff --git a/cpp/ppc-gateway/ppc-gateway/p2p/router/RouterTableImpl.h b/cpp/ppc-gateway/ppc-gateway/p2p/router/RouterTableImpl.h index 72eda232..55ae8d8b 100644 --- a/cpp/ppc-gateway/ppc-gateway/p2p/router/RouterTableImpl.h +++ b/cpp/ppc-gateway/ppc-gateway/p2p/router/RouterTableImpl.h @@ -22,8 +22,8 @@ #pragma GCC diagnostic ignored "-Wunused-parameter" #include "RouterTableInterface.h" -#include -#include +#include "tars/RouterTable.h" +#include #include namespace ppc::gateway diff --git a/cpp/ppc-gateway/test/demo/gateway_demo.cpp b/cpp/ppc-gateway/test/demo/gateway_demo.cpp index b5d6f6f1..aa28ad9b 100644 --- a/cpp/ppc-gateway/test/demo/gateway_demo.cpp +++ b/cpp/ppc-gateway/test/demo/gateway_demo.cpp @@ -17,7 +17,7 @@ * @author: shawnhe * @date 2022-10-28 */ - +#if 0 #include "ppc-gateway/ppc-gateway/Gateway.h" #include "ppc-gateway/ppc-gateway/GatewayConfigContext.h" #include "ppc-tools/src/config/PPCConfig.h" @@ -98,7 +98,7 @@ int main(int argc, char* argv[]) auto info = std::make_shared(); info->taskID = taskID; info->serviceEndpoint = "endpoint1001"; - front->notifyTaskInfo(info); + front->notifyTaskInfo(taskID); auto message = buildMessage(taskID, 0); std::cout << "send message\n" @@ -127,10 +127,11 @@ int main(int argc, char* argv[]) auto info1 = std::make_shared(); info->taskID = taskID1; info->serviceEndpoint = "endpoint1001"; - front->notifyTaskInfo(info); + front->notifyTaskInfo(taskID1); while (flag != 3) { std::this_thread::sleep_for(std::chrono::milliseconds(100)); } } +#endif diff --git a/cpp/ppc-gateway/test/unittests/GatewayTest.cpp b/cpp/ppc-gateway/test/unittests/GatewayTest.cpp index 50ea67c3..49d9da28 100644 --- a/cpp/ppc-gateway/test/unittests/GatewayTest.cpp +++ b/cpp/ppc-gateway/test/unittests/GatewayTest.cpp @@ -18,10 +18,11 @@ * @date 2022-10-28 */ +#if 0 #include "ppc-gateway/ppc-gateway/Gateway.h" #include "MockCache.h" #include "ppc-gateway/ppc-gateway/GatewayConfigContext.h" -#include "ppc-gateway/ppc-gateway/TaskManager.h" +#include "ppc-gateway/ppc-gateway/" #include "ppc-tools/src/config/PPCConfig.h" #include #include @@ -93,3 +94,4 @@ BOOST_AUTO_TEST_CASE(test_frontNodeManager) } BOOST_AUTO_TEST_SUITE_END() +#endif diff --git a/cpp/ppc-pir/src/OtPIRConfig.h b/cpp/ppc-pir/src/OtPIRConfig.h index dd4d6bda..9f6deb6f 100644 --- a/cpp/ppc-pir/src/OtPIRConfig.h +++ b/cpp/ppc-pir/src/OtPIRConfig.h @@ -28,8 +28,8 @@ #include "ppc-framework/crypto/CryptoBox.h" #include "ppc-framework/crypto/Oprf.h" #include "ppc-framework/protocol/Protocol.h" -#include "ppc-protocol/src/PPCMessage.h" #include "ppc-psi/src/PSIConfig.h" +#include "protocol/src/PPCMessage.h" using namespace ppc::psi; @@ -67,4 +67,4 @@ class OtPIRConfig : public PSIConfig uint16_t m_parallelism; }; -} // namespace ppc::psi \ No newline at end of file +} // namespace ppc::pir \ No newline at end of file diff --git a/cpp/ppc-pir/src/OtPIRImpl.cpp b/cpp/ppc-pir/src/OtPIRImpl.cpp index 3ee59301..cd39535b 100644 --- a/cpp/ppc-pir/src/OtPIRImpl.cpp +++ b/cpp/ppc-pir/src/OtPIRImpl.cpp @@ -21,10 +21,10 @@ #include "OtPIRImpl.h" #include "BaseOT.h" #include "Common.h" +#include "OtPIR.h" #include "ppc-framework/protocol/Protocol.h" #include "ppc-front/Common.h" -#include "OtPIR.h" -#include "ppc-tars-protocol/TarsSerialize.h" +#include "wedpr-protocol/tars/TarsSerialize.h" #include @@ -43,8 +43,8 @@ OtPIRImpl::OtPIRImpl(const OtPIRConfig::Ptr& _config, unsigned _idleTimeMs) m_ioService(std::make_shared()), m_parallelism(m_config->parallelism()) { - m_senderThreadPool = std::make_shared( - "OT-PIR-Sender", std::thread::hardware_concurrency()); + m_senderThreadPool = + std::make_shared("OT-PIR-Sender", std::thread::hardware_concurrency()); m_ot = std::make_shared(m_config->eccCrypto(), m_config->hash()); } @@ -76,8 +76,8 @@ void OtPIRImpl::start() } catch (std::exception& e) { - FRONT_LOG(WARNING) << LOG_DESC("Exception in OT-PIR Thread:") - << boost::diagnostic_information(e); + FRONT_LOG(WARNING) + << LOG_DESC("Exception in OT-PIR Thread:") << boost::diagnostic_information(e); } } PIR_LOG(INFO) << "OT-PIR exit"; @@ -143,7 +143,7 @@ void OtPIRImpl::onReceiveMessage(ppc::front::PPCMessageFace::Ptr _msg) catch (std::exception const& e) { PIR_LOG(WARNING) << LOG_DESC("onReceiveMessage exception") << printPPCMsg(_msg) - << LOG_KV("error", boost::diagnostic_information(e)); + << LOG_KV("error", boost::diagnostic_information(e)); } } @@ -159,12 +159,11 @@ void OtPIRImpl::onReceivedErrorNotification(const std::string& _taskID) } } -void OtPIRImpl::onSelfError( - const std::string& _taskID, bcos::Error::Ptr _error, bool _noticePeer) +void OtPIRImpl::onSelfError(const std::string& _taskID, bcos::Error::Ptr _error, bool _noticePeer) { PIR_LOG(WARNING) << LOG_DESC("onSelfError") << LOG_KV("task", _taskID) - << LOG_KV("error", _error->errorMessage()) - << LOG_KV("noticePeer", _noticePeer); + << LOG_KV("error", _error->errorMessage()) + << LOG_KV("noticePeer", _noticePeer); auto taskState = findPendingTask(_taskID); if (!taskState) @@ -261,23 +260,23 @@ void OtPIRImpl::handleReceivedMessage(const ppc::front::PPCMessageFace::Ptr& _me } default: { - PIR_LOG(WARNING) - << LOG_DESC("unsupported messageType ") << unsigned(_message->messageType()); + PIR_LOG(WARNING) << LOG_DESC("unsupported messageType ") + << unsigned(_message->messageType()); break; } } } catch (std::exception const& e) { - PIR_LOG(WARNING) - << LOG_DESC("handleReceivedMessage exception") - << LOG_KV("type", unsigned(_message->messageType())) << printPPCMsg(_message) - << LOG_KV("error", boost::diagnostic_information(e)); + PIR_LOG(WARNING) << LOG_DESC("handleReceivedMessage exception") + << LOG_KV("type", unsigned(_message->messageType())) + << printPPCMsg(_message) + << LOG_KV("error", boost::diagnostic_information(e)); } }); } -void OtPIRImpl::onHelloReceiver(const ppc::front::PPCMessageFace::Ptr& _message) +void OtPIRImpl::onHelloReceiver(const ppc::front::PPCMessageFace::Ptr& _message) { // 接收方不需要记录taskID // if (m_taskState->taskDone()) @@ -285,21 +284,27 @@ void OtPIRImpl::onHelloReceiver(const ppc::front::PPCMessageFace::Ptr& _message) // return; // } PIR_LOG(DEBUG) << LOG_BADGE("onHelloReceiver") << LOG_KV("taskID", _message->taskID()) - << LOG_KV("seq", _message->seq()) << LOG_KV("length", _message->length()); + << LOG_KV("seq", _message->seq()) << LOG_KV("length", _message->length()); ppctars::SenderMessageParams senderMessageParams; ppctars::serialize::decode(*_message->data(), senderMessageParams); // crypto::SenderMessage senderMessage; // TODO: how to find my dataset // m_taskState->setReader(io::LineReader::Ptr _reader, int64_t _readerParam) - try { + try + { // auto writer = m_taskState->reader(); auto receiver = findReceiver(_message->taskID()); auto path = receiver.path; - PIR_LOG(INFO) << LOG_BADGE("onHelloReceiver") << LOG_KV("taskID", _message->taskID()) << LOG_KV("requestAgencyDataset", path) << LOG_KV("sendObfuscatedHash", std::string(senderMessageParams.sendObfuscatedHash.begin(), senderMessageParams.sendObfuscatedHash.end())); + PIR_LOG(INFO) << LOG_BADGE("onHelloReceiver") << LOG_KV("taskID", _message->taskID()) + << LOG_KV("requestAgencyDataset", path) + << LOG_KV("sendObfuscatedHash", + std::string(senderMessageParams.sendObfuscatedHash.begin(), + senderMessageParams.sendObfuscatedHash.end())); auto messageKeypair = m_ot->prepareDataset(senderMessageParams.sendObfuscatedHash, path); - auto receiverMessage = m_ot->receiverGenerateMessage(senderMessageParams.pointX, senderMessageParams.pointY, messageKeypair, senderMessageParams.pointZ); + auto receiverMessage = m_ot->receiverGenerateMessage(senderMessageParams.pointX, + senderMessageParams.pointY, messageKeypair, senderMessageParams.pointZ); ppctars::ReceiverMessageParams receiverMessageParams; receiverMessageParams.encryptMessagePair = receiverMessage.encryptMessagePair; receiverMessageParams.encryptCipher = receiverMessage.encryptCipher; @@ -307,60 +312,65 @@ void OtPIRImpl::onHelloReceiver(const ppc::front::PPCMessageFace::Ptr& _message) // PIR_LOG(INFO) << LOG_BADGE("buildPPCMessage"); auto message = m_config->ppcMsgFactory()->buildPPCMessage(uint8_t(protocol::TaskType::PIR), - uint8_t(protocol::PSIAlgorithmType::OT_PIR_2PC), m_taskID, - std::make_shared()); + uint8_t(protocol::PSIAlgorithmType::OT_PIR_2PC), m_taskID, + std::make_shared()); message->setMessageType(uint8_t(OTPIRMessageType::RESULTS)); ppctars::serialize::encode(receiverMessageParams, *message->data()); // PIR_LOG(INFO) << LOG_BADGE("asyncSendMessage"); m_config->front()->asyncSendMessage( - m_taskState->peerID(), message, m_config->networkTimeout(), - [self = weak_from_this()](bcos::Error::Ptr _error) { - auto receiver = self.lock(); - if (!receiver) - { - return; - } - if (_error && _error->errorCode()) - { - receiver->onReceiverTaskDone(std::move(_error)); - } - }, - nullptr); + m_taskState->peerID(), message, m_config->networkTimeout(), + [self = weak_from_this()](bcos::Error::Ptr _error) { + auto receiver = self.lock(); + if (!receiver) + { + return; + } + if (_error && _error->errorCode()) + { + receiver->onReceiverTaskDone(std::move(_error)); + } + }, + nullptr); auto endTask = std::make_shared(m_taskState->task()->id()); m_taskState->onTaskFinished(endTask, true); - } - catch (bcos::Error const& e) { - PIR_LOG(WARNING) << LOG_DESC("onHelloReceiver exception") - << LOG_KV("code", e.errorCode()) << LOG_KV("msg", e.errorMessage()); - onSelfError( - m_taskID, std::make_shared(e.errorCode(), e.errorMessage()), true); + } + catch (bcos::Error const& e) + { + PIR_LOG(WARNING) << LOG_DESC("onHelloReceiver exception") << LOG_KV("code", e.errorCode()) + << LOG_KV("msg", e.errorMessage()); + onSelfError(m_taskID, std::make_shared(e.errorCode(), e.errorMessage()), true); } } void OtPIRImpl::onSnederResults(ppc::front::PPCMessageFace::Ptr _message) { PIR_LOG(DEBUG) << LOG_BADGE("onSnederResults") << LOG_KV("taskID", _message->taskID()) - << LOG_KV("seq", _message->seq()); + << LOG_KV("seq", _message->seq()); ppctars::ReceiverMessageParams receiverMessageParams; ppctars::serialize::decode(*_message->data(), receiverMessageParams); crypto::SenderMessage senderMessage = findSender(_message->taskID()); - PIR_LOG(DEBUG) << LOG_BADGE("onSnederResults") << LOG_KV("scalarBlidingB", toHex(senderMessage.scalarBlidingB)) - << LOG_KV("pointWList Size", receiverMessageParams.pointWList.size()) - << LOG_KV("encryptCipher Size", receiverMessageParams.encryptCipher.size()) - << LOG_KV("encryptMessagePair Size", receiverMessageParams.encryptMessagePair.size()); - bcos::bytes result = m_ot->finishSender(senderMessage.scalarBlidingB, receiverMessageParams.pointWList, receiverMessageParams.encryptMessagePair, receiverMessageParams.encryptCipher); + PIR_LOG(DEBUG) << LOG_BADGE("onSnederResults") + << LOG_KV("scalarBlidingB", toHex(senderMessage.scalarBlidingB)) + << LOG_KV("pointWList Size", receiverMessageParams.pointWList.size()) + << LOG_KV("encryptCipher Size", receiverMessageParams.encryptCipher.size()) + << LOG_KV("encryptMessagePair Size", + receiverMessageParams.encryptMessagePair.size()); + bcos::bytes result = + m_ot->finishSender(senderMessage.scalarBlidingB, receiverMessageParams.pointWList, + receiverMessageParams.encryptMessagePair, receiverMessageParams.encryptCipher); saveResults(result); auto endTask = std::make_shared(m_taskState->task()->id()); m_taskState->onTaskFinished(endTask, true); } -void OtPIRImpl::asyncRunTask(ppc::protocol::Task::ConstPtr _task, TaskResponseCallback&& _onTaskFinished) +void OtPIRImpl::asyncRunTask( + ppc::protocol::Task::ConstPtr _task, TaskResponseCallback&& _onTaskFinished) { - //TODO + // TODO PIR_LOG(INFO) << LOG_DESC("receive a task") << LOG_KV("taskID", _task->id()); - m_taskID = _task->id(); + m_taskID = _task->id(); addTask(_task, [self = weak_from_this(), taskID = _task->id(), _onTaskFinished]( ppc::protocol::TaskResult::Ptr&& _result) { auto result = std::move(_result); @@ -429,14 +439,17 @@ void OtPIRImpl::asyncRunTask() PIR_LOG(TRACE) << LOG_DESC("originData") << LOG_KV("originData", originData); PirTaskMessage taskMessage = parseJson(originData); - PIR_LOG(TRACE) << LOG_DESC("taskMessage") << LOG_KV("requestAgencyDataset", taskMessage.requestAgencyDataset) << LOG_KV("prefixLength", taskMessage.prefixLength) << LOG_KV("searchId", taskMessage.searchId); + PIR_LOG(TRACE) << LOG_DESC("taskMessage") + << LOG_KV("requestAgencyDataset", taskMessage.requestAgencyDataset) + << LOG_KV("prefixLength", taskMessage.prefixLength) + << LOG_KV("searchId", taskMessage.searchId); auto writer = loadWriter(task->id(), dataResource, m_enableOutputExists); m_taskState->setWriter(writer); runSenderGenerateCipher(taskMessage); } else if (role == uint16_t(PartyType::Server)) { - // server接受任务请求,初始化reader + // server接受任务请求,初始化reader PIR_LOG(TRACE) << LOG_DESC("Server init"); crypto::ReceiverMessage receiverMessage; receiverMessage.path = dataResource->desc()->path(); @@ -444,7 +457,6 @@ void OtPIRImpl::asyncRunTask() // m_resource = dataResource; // auto reader = loadReader(task->id(), dataResource, DataSchema::Bytes); // m_taskState->setReader(reader, -1); - } else { @@ -458,14 +470,14 @@ void OtPIRImpl::asyncRunTask() catch (bcos::Error const& e) { PIR_LOG(WARNING) << LOG_DESC("asyncRunTask exception") << printTaskInfo(task) - << LOG_KV("code", e.errorCode()) << LOG_KV("msg", e.errorMessage()); + << LOG_KV("code", e.errorCode()) << LOG_KV("msg", e.errorMessage()); onSelfError( task->id(), std::make_shared(e.errorCode(), e.errorMessage()), true); } catch (const std::exception& e) { PIR_LOG(WARNING) << LOG_DESC("asyncRunTask exception") << printTaskInfo(task) - << LOG_KV("error", boost::diagnostic_information(e)); + << LOG_KV("error", boost::diagnostic_information(e)); onSelfError(task->id(), std::make_shared((int)OTPIRRetCode::ON_EXCEPTION, "exception caught while running task: " + boost::diagnostic_information(e)), @@ -473,7 +485,7 @@ void OtPIRImpl::asyncRunTask() } // notify the taskInfo to the front - error = m_config->front()->notifyTaskInfo(std::make_shared(task->id())); + error = m_config->front()->notifyTaskInfo(task->id()); if (error && error->errorCode()) { onSelfError(task->id(), error, true); @@ -486,7 +498,9 @@ void OtPIRImpl::runSenderGenerateCipher(PirTaskMessage taskMessage) { return; } - crypto::SenderMessage senderMessage = m_ot->senderGenerateCipher(bcos::bytes(taskMessage.searchId.begin(), taskMessage.searchId.end()), taskMessage.prefixLength); + crypto::SenderMessage senderMessage = m_ot->senderGenerateCipher( + bcos::bytes(taskMessage.searchId.begin(), taskMessage.searchId.end()), + taskMessage.prefixLength); ppctars::SenderMessageParams senderMessageParams; senderMessageParams.pointX = senderMessage.pointX; senderMessageParams.pointY = senderMessage.pointY; @@ -494,28 +508,31 @@ void OtPIRImpl::runSenderGenerateCipher(PirTaskMessage taskMessage) // senderMessageParams.requestAgencyDataset = taskMessage.requestAgencyDataset; senderMessageParams.sendObfuscatedHash = senderMessage.sendObfuscatedHash; auto message = m_config->ppcMsgFactory()->buildPPCMessage(uint8_t(protocol::TaskType::PIR), - uint8_t(protocol::PSIAlgorithmType::OT_PIR_2PC), m_taskID, - std::make_shared()); + uint8_t(protocol::PSIAlgorithmType::OT_PIR_2PC), m_taskID, std::make_shared()); message->setMessageType(uint8_t(OTPIRMessageType::HELLO_RECEIVER)); ppctars::serialize::encode(senderMessageParams, *message->data()); addSender(senderMessage); - // PIR_LOG(INFO) << LOG_BADGE("runSenderGenerateCipher") << LOG_KV("taskID", m_taskID) << LOG_KV("requestAgencyDataset", senderMessageParams.requestAgencyDataset); - PIR_LOG(INFO) << LOG_BADGE("runSenderGenerateCipher") << LOG_KV("taskID", m_taskID) << LOG_KV("sendObfuscatedHash", std::string(senderMessageParams.sendObfuscatedHash.begin(), senderMessageParams.sendObfuscatedHash.end())); + // PIR_LOG(INFO) << LOG_BADGE("runSenderGenerateCipher") << LOG_KV("taskID", m_taskID) << + // LOG_KV("requestAgencyDataset", senderMessageParams.requestAgencyDataset); + PIR_LOG(INFO) << LOG_BADGE("runSenderGenerateCipher") << LOG_KV("taskID", m_taskID) + << LOG_KV("sendObfuscatedHash", + std::string(senderMessageParams.sendObfuscatedHash.begin(), + senderMessageParams.sendObfuscatedHash.end())); // senderMessageParams.taskId = m_taskID; m_config->front()->asyncSendMessage( - m_taskState->peerID(), message, m_config->networkTimeout(), - [self = weak_from_this()](bcos::Error::Ptr _error) { - auto receiver = self.lock(); - if (!receiver) - { - return; - } - if (_error && _error->errorCode()) - { - receiver->onReceiverTaskDone(std::move(_error)); - } - }, - nullptr); + m_taskState->peerID(), message, m_config->networkTimeout(), + [self = weak_from_this()](bcos::Error::Ptr _error) { + auto receiver = self.lock(); + if (!receiver) + { + return; + } + if (_error && _error->errorCode()) + { + receiver->onReceiverTaskDone(std::move(_error)); + } + }, + nullptr); } @@ -537,7 +554,7 @@ void OtPIRImpl::onReceiverTaskDone(bcos::Error::Ptr _error) m_taskState->onTaskFinished(m_taskResult, true); PIR_LOG(INFO) << LOG_BADGE("receiverTaskDone") << LOG_KV("taskID", m_taskID) - << LOG_KV("detail", message); + << LOG_KV("detail", message); } @@ -546,17 +563,16 @@ void OtPIRImpl::saveResults(bcos::bytes result) PIR_LOG(INFO) << LOG_BADGE("saveResults") LOG_KV("taskID", m_taskID); try { - DataBatch::Ptr finalResults = std::make_shared(); - finalResults->append(result); - m_taskState->writeLines(finalResults, DataSchema::Bytes); + DataBatch::Ptr finalResults = std::make_shared(); + finalResults->append(result); + m_taskState->writeLines(finalResults, DataSchema::Bytes); } catch (const std::exception& e) { PIR_LOG(WARNING) << LOG_KV("taskID", m_taskID) - << LOG_KV("error", boost::diagnostic_information(e)); + << LOG_KV("error", boost::diagnostic_information(e)); auto error = std::make_shared( - (int)OTPIRRetCode::ON_EXCEPTION, boost::diagnostic_information(e)); + (int)OTPIRRetCode::ON_EXCEPTION, boost::diagnostic_information(e)); onReceiverTaskDone(error); } } - diff --git a/cpp/ppc-psi/src/bs-ecdh-psi/BsEcdhPSIImpl.h b/cpp/ppc-psi/src/bs-ecdh-psi/BsEcdhPSIImpl.h index 55c7f671..2b8de4ad 100644 --- a/cpp/ppc-psi/src/bs-ecdh-psi/BsEcdhPSIImpl.h +++ b/cpp/ppc-psi/src/bs-ecdh-psi/BsEcdhPSIImpl.h @@ -35,7 +35,6 @@ namespace ppc::psi { - class BsEcdhPSIImpl : public BsEcdhPSIInterface, public std::enable_shared_from_this { public: diff --git a/cpp/ppc-psi/src/bs-ecdh-psi/BsEcdhPSIInterface.h b/cpp/ppc-psi/src/bs-ecdh-psi/BsEcdhPSIInterface.h index 7c21bd95..6a3534fb 100644 --- a/cpp/ppc-psi/src/bs-ecdh-psi/BsEcdhPSIInterface.h +++ b/cpp/ppc-psi/src/bs-ecdh-psi/BsEcdhPSIInterface.h @@ -27,7 +27,6 @@ namespace ppc::psi { - class BsEcdhPSIInterface { public: diff --git a/cpp/ppc-psi/src/bs-ecdh-psi/Common.h b/cpp/ppc-psi/src/bs-ecdh-psi/Common.h index 4e45004e..3e241b5b 100644 --- a/cpp/ppc-psi/src/bs-ecdh-psi/Common.h +++ b/cpp/ppc-psi/src/bs-ecdh-psi/Common.h @@ -1,22 +1,22 @@ /** -* Copyright (C) 2023 WeDPR. -* SPDX-License-Identifier: Apache-2.0 -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -* -* @file Common.h -* @author: shawnhe -* @date 2023-09-22 -*/ + * Copyright (C) 2023 WeDPR. + * SPDX-License-Identifier: Apache-2.0 + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * @file Common.h + * @author: shawnhe + * @date 2023-09-22 + */ #pragma once @@ -29,7 +29,6 @@ namespace ppc::psi { - DERIVE_PPC_EXCEPTION(BsEcdhException); #define BS_ECDH_PSI_LOG(LEVEL) BCOS_LOG(LEVEL) << LOG_BADGE("PSI: BS-ECDH-PSI") diff --git a/cpp/ppc-psi/src/bs-ecdh-psi/protocol/Message.h b/cpp/ppc-psi/src/bs-ecdh-psi/protocol/Message.h index e01179be..397c837e 100644 --- a/cpp/ppc-psi/src/bs-ecdh-psi/protocol/Message.h +++ b/cpp/ppc-psi/src/bs-ecdh-psi/protocol/Message.h @@ -22,7 +22,7 @@ #include "ppc-framework/protocol/Protocol.h" #include "ppc-framework/rpc/RpcTypeDef.h" -#include "ppc-protocol/src/JsonTaskImpl.h" +#include "protocol/src/JsonTaskImpl.h" #include #include #include @@ -31,7 +31,6 @@ namespace ppc::psi { - struct GetTaskStatusRequest { using Ptr = std::shared_ptr; diff --git a/cpp/ppc-psi/src/cm2020-psi/CM2020PSIConfig.h b/cpp/ppc-psi/src/cm2020-psi/CM2020PSIConfig.h index 59a74e84..2f07e2cc 100644 --- a/cpp/ppc-psi/src/cm2020-psi/CM2020PSIConfig.h +++ b/cpp/ppc-psi/src/cm2020-psi/CM2020PSIConfig.h @@ -28,8 +28,8 @@ #include "ppc-framework/crypto/CryptoBox.h" #include "ppc-framework/crypto/Oprf.h" #include "ppc-framework/protocol/Protocol.h" -#include "ppc-protocol/src/PPCMessage.h" #include "ppc-psi/src/PSIConfig.h" +#include "protocol/src/PPCMessage.h" namespace ppc::psi { diff --git a/cpp/ppc-psi/src/cm2020-psi/CM2020PSIImpl.cpp b/cpp/ppc-psi/src/cm2020-psi/CM2020PSIImpl.cpp index a1aa8a62..197670f9 100644 --- a/cpp/ppc-psi/src/cm2020-psi/CM2020PSIImpl.cpp +++ b/cpp/ppc-psi/src/cm2020-psi/CM2020PSIImpl.cpp @@ -137,7 +137,7 @@ void CM2020PSIImpl::asyncRunTask() taskPair.second(std::move(result)); // mark this taskID as occupied - m_config->front()->notifyTaskInfo(std::make_shared(task->id())); + m_config->front()->notifyTaskInfo(task->id()); return; } @@ -210,7 +210,7 @@ void CM2020PSIImpl::asyncRunTask() } // notify the taskInfo to the front - error = m_config->front()->notifyTaskInfo(std::make_shared(task->id())); + error = m_config->front()->notifyTaskInfo(task->id()); if (error && error->errorCode()) { onSelfError(task->id(), error, true); diff --git a/cpp/ppc-psi/src/cm2020-psi/CM2020PSIImpl.h b/cpp/ppc-psi/src/cm2020-psi/CM2020PSIImpl.h index 06012237..7cd416e0 100644 --- a/cpp/ppc-psi/src/cm2020-psi/CM2020PSIImpl.h +++ b/cpp/ppc-psi/src/cm2020-psi/CM2020PSIImpl.h @@ -35,8 +35,8 @@ #include "ppc-framework/protocol/Task.h" #include "ppc-framework/task/TaskFrameworkInterface.h" #include "ppc-front/ppc-front/PPCChannel.h" -#include "ppc-protocol/src/PPCMessage.h" #include "ppc-tools/src/common/TransTools.h" +#include "protocol/src/PPCMessage.h" namespace ppc::psi { diff --git a/cpp/ppc-psi/src/cm2020-psi/Common.h b/cpp/ppc-psi/src/cm2020-psi/Common.h index 52de3fb2..237f1a65 100644 --- a/cpp/ppc-psi/src/cm2020-psi/Common.h +++ b/cpp/ppc-psi/src/cm2020-psi/Common.h @@ -28,7 +28,6 @@ namespace ppc::psi { - DERIVE_PPC_EXCEPTION(CM2020Exception); #define CM2020_PSI_LOG(LEVEL) BCOS_LOG(LEVEL) << LOG_BADGE("PSI: CM2020-PSI") diff --git a/cpp/ppc-psi/src/cm2020-psi/core/CM2020PSIReceiver.cpp b/cpp/ppc-psi/src/cm2020-psi/core/CM2020PSIReceiver.cpp index f8b5d3dd..2e4f46b0 100644 --- a/cpp/ppc-psi/src/cm2020-psi/core/CM2020PSIReceiver.cpp +++ b/cpp/ppc-psi/src/cm2020-psi/core/CM2020PSIReceiver.cpp @@ -22,8 +22,8 @@ #include "CM2020PSI.h" #include "openssl/rand.h" #include "ppc-crypto/src/prng/AESPRNG.h" -#include "ppc-tars-protocol/ppc-tars-protocol/TarsSerialize.h" #include "ppc-tools/src/common/TransTools.h" +#include "wedpr-protocol/tars/TarsSerialize.h" #include using namespace ppc::psi; diff --git a/cpp/ppc-psi/src/cm2020-psi/core/CM2020PSIReceiver.h b/cpp/ppc-psi/src/cm2020-psi/core/CM2020PSIReceiver.h index 8bd33468..7c3184fd 100644 --- a/cpp/ppc-psi/src/cm2020-psi/core/CM2020PSIReceiver.h +++ b/cpp/ppc-psi/src/cm2020-psi/core/CM2020PSIReceiver.h @@ -31,7 +31,6 @@ namespace ppc::psi { - class CM2020PSIReceiver : public std::enable_shared_from_this { public: diff --git a/cpp/ppc-psi/src/cm2020-psi/core/CM2020PSISender.cpp b/cpp/ppc-psi/src/cm2020-psi/core/CM2020PSISender.cpp index 7fa95bb1..860f8bfd 100644 --- a/cpp/ppc-psi/src/cm2020-psi/core/CM2020PSISender.cpp +++ b/cpp/ppc-psi/src/cm2020-psi/core/CM2020PSISender.cpp @@ -22,8 +22,8 @@ #include "CM2020PSI.h" #include "openssl/rand.h" #include "ppc-crypto/src/prng/AESPRNG.h" -#include "ppc-tars-protocol/ppc-tars-protocol/TarsSerialize.h" #include "ppc-tools/src/common/TransTools.h" +#include "wedpr-protocol/tars/TarsSerialize.h" #include using namespace ppc::psi; @@ -499,7 +499,7 @@ void CM2020PSISender::onMatrixColumnReceived(PPCMessageFace::Ptr _message) return; } CM2020_PSI_LOG(INFO) << LOG_BADGE("onMatrixColumnReceived") << LOG_KV("taskID", m_taskID) - << LOG_KV("seq", _message->seq()); + << LOG_KV("seq", _message->seq()); try { m_channel->onMessageArrived(uint8_t(CM2020PSIMessageType::MATRIX), _message); diff --git a/cpp/ppc-psi/src/ecdh-conn-psi/EcdhConnPSIConfig.h b/cpp/ppc-psi/src/ecdh-conn-psi/EcdhConnPSIConfig.h index 63d09b3b..d42c8bf8 100644 --- a/cpp/ppc-psi/src/ecdh-conn-psi/EcdhConnPSIConfig.h +++ b/cpp/ppc-psi/src/ecdh-conn-psi/EcdhConnPSIConfig.h @@ -23,8 +23,8 @@ #include "ppc-framework/crypto/ECDHCrypto.h" #include "ppc-framework/io/DataResourceLoader.h" #include "ppc-framework/protocol/Protocol.h" -#include "ppc-protocol/src/PPCMessage.h" #include "ppc-tools/src/config/PPCConfig.h" +#include "protocol/src/PPCMessage.h" #include #include diff --git a/cpp/ppc-psi/src/ecdh-conn-psi/EcdhConnPSIImpl.cpp b/cpp/ppc-psi/src/ecdh-conn-psi/EcdhConnPSIImpl.cpp index 70bfa486..55e22be3 100644 --- a/cpp/ppc-psi/src/ecdh-conn-psi/EcdhConnPSIImpl.cpp +++ b/cpp/ppc-psi/src/ecdh-conn-psi/EcdhConnPSIImpl.cpp @@ -50,7 +50,7 @@ void EcdhConnPSIImpl::asyncRunTask( // init process triggleProcess((uint8_t)EcdhConnProcess::HandShakeProcess, -1); // notify the taskInfo to the front - m_config->front()->notifyTaskInfo(std::make_shared(_task->id())); + m_config->front()->notifyTaskInfo(_task->id()); if (role == uint16_t(PartyType::Client)) { ECDH_CONN_LOG(INFO) << LOG_DESC("Client do the Task") << LOG_KV("taskID", _task->id()); diff --git a/cpp/ppc-psi/src/ecdh-conn-psi/EcdhConnPSIImpl.h b/cpp/ppc-psi/src/ecdh-conn-psi/EcdhConnPSIImpl.h index 3806491e..53d7cd4a 100644 --- a/cpp/ppc-psi/src/ecdh-conn-psi/EcdhConnPSIImpl.h +++ b/cpp/ppc-psi/src/ecdh-conn-psi/EcdhConnPSIImpl.h @@ -27,9 +27,9 @@ #include "core/EcdhConnPSIServer.h" #include "ppc-framework/rpc/RpcInterface.h" #include "ppc-framework/task/TaskFrameworkInterface.h" -#include "ppc-protocol/src/PPCMessage.h" #include "ppc-rpc/src/RpcFactory.h" #include "ppc-tools/src/common/ConcurrentPool.h" +#include "protocol/src/PPCMessage.h" #include #include #include diff --git a/cpp/ppc-psi/src/ecdh-multi-psi/EcdhMultiPSIConfig.h b/cpp/ppc-psi/src/ecdh-multi-psi/EcdhMultiPSIConfig.h index 3f5b1bd7..9b256a72 100644 --- a/cpp/ppc-psi/src/ecdh-multi-psi/EcdhMultiPSIConfig.h +++ b/cpp/ppc-psi/src/ecdh-multi-psi/EcdhMultiPSIConfig.h @@ -2,8 +2,8 @@ #include "EcdhMultiPSIMessageFactory.h" #include "ppc-framework/crypto/CryptoBox.h" #include "ppc-framework/protocol/Protocol.h" -#include "ppc-protocol/src/PPCMessage.h" #include "ppc-psi/src/PSIConfig.h" +#include "protocol/src/PPCMessage.h" #include #include #include @@ -28,7 +28,8 @@ class EcdhMultiPSIConfig : public PSIConfig m_dataBatchSize(_dataBatchSize) {} - virtual ~EcdhMultiPSIConfig() { + virtual ~EcdhMultiPSIConfig() + { if (m_threadPool) { m_threadPool->stop(); diff --git a/cpp/ppc-psi/src/ecdh-multi-psi/EcdhMultiPSIImpl.cpp b/cpp/ppc-psi/src/ecdh-multi-psi/EcdhMultiPSIImpl.cpp index d881405d..24fa7194 100644 --- a/cpp/ppc-psi/src/ecdh-multi-psi/EcdhMultiPSIImpl.cpp +++ b/cpp/ppc-psi/src/ecdh-multi-psi/EcdhMultiPSIImpl.cpp @@ -168,7 +168,7 @@ void EcdhMultiPSIImpl::asyncRunTask( } // notify the taskInfo to the front - m_config->front()->notifyTaskInfo(std::make_shared(_task->id())); + m_config->front()->notifyTaskInfo(_task->id()); } catch (bcos::Error const& e) { diff --git a/cpp/ppc-psi/src/ecdh-multi-psi/EcdhMultiPSIImpl.h b/cpp/ppc-psi/src/ecdh-multi-psi/EcdhMultiPSIImpl.h index 91f898a5..b2d8e968 100644 --- a/cpp/ppc-psi/src/ecdh-multi-psi/EcdhMultiPSIImpl.h +++ b/cpp/ppc-psi/src/ecdh-multi-psi/EcdhMultiPSIImpl.h @@ -1,12 +1,12 @@ #pragma once -#include "ppc-psi/src/psi-framework/TaskGuarder.h" #include "Common.h" #include "EcdhMultiPSIConfig.h" #include "core/EcdhMultiPSICalculator.h" #include "core/EcdhMultiPSIMaster.h" #include "core/EcdhMultiPSIPartner.h" #include "ppc-framework/task/TaskFrameworkInterface.h" -#include "ppc-protocol/src/PPCMessage.h" +#include "ppc-psi/src/psi-framework/TaskGuarder.h" +#include "protocol/src/PPCMessage.h" #include #include #include diff --git a/cpp/ppc-psi/src/ecdh-multi-psi/EcdhMultiPSIMessageFactory.h b/cpp/ppc-psi/src/ecdh-multi-psi/EcdhMultiPSIMessageFactory.h index a51bd0b3..d4a8bde8 100644 --- a/cpp/ppc-psi/src/ecdh-multi-psi/EcdhMultiPSIMessageFactory.h +++ b/cpp/ppc-psi/src/ecdh-multi-psi/EcdhMultiPSIMessageFactory.h @@ -1,6 +1,6 @@ #pragma once -#include "ppc-psi/src/psi-framework/protocol/PSIMessage.h" #include "Common.h" +#include "ppc-psi/src/psi-framework/protocol/PSIMessage.h" namespace ppc::psi { class EcdhMultiPSIMessageFactory : public PSIMessageFactoryImpl diff --git a/cpp/ppc-psi/src/ecdh-psi/EcdhPSIImpl.cpp b/cpp/ppc-psi/src/ecdh-psi/EcdhPSIImpl.cpp index 4f79196b..051ee101 100644 --- a/cpp/ppc-psi/src/ecdh-psi/EcdhPSIImpl.cpp +++ b/cpp/ppc-psi/src/ecdh-psi/EcdhPSIImpl.cpp @@ -123,7 +123,7 @@ void EcdhPSIImpl::runPSI(TaskState::Ptr const& _taskState) { ECDH_LOG(INFO) << LOG_DESC("runPSI") << printTaskInfo(_taskState->task()); // notify the taskInfo to the front - m_config->front()->notifyTaskInfo(std::make_shared(_taskState->task()->id())); + m_config->front()->notifyTaskInfo(_taskState->task()->id()); // the psi-client send handshake request to the server if (_taskState->task()->selfParty()->partyIndex() == (int)PartyType::Client) { diff --git a/cpp/ppc-psi/src/labeled-psi/Common.h b/cpp/ppc-psi/src/labeled-psi/Common.h index 78ab15e3..95fe31f1 100644 --- a/cpp/ppc-psi/src/labeled-psi/Common.h +++ b/cpp/ppc-psi/src/labeled-psi/Common.h @@ -28,7 +28,6 @@ namespace ppc::psi { - DERIVE_PPC_EXCEPTION(ConfigPowersDagException); DERIVE_PPC_EXCEPTION(TooManyItemsException); DERIVE_PPC_EXCEPTION(ResultPackageException); diff --git a/cpp/ppc-psi/src/labeled-psi/LabeledPSIConfig.h b/cpp/ppc-psi/src/labeled-psi/LabeledPSIConfig.h index fc2524bc..01b7ebdd 100644 --- a/cpp/ppc-psi/src/labeled-psi/LabeledPSIConfig.h +++ b/cpp/ppc-psi/src/labeled-psi/LabeledPSIConfig.h @@ -24,8 +24,8 @@ #include "ppc-crypto/src/oprf/EcdhOprf.h" #include "ppc-framework/crypto/CryptoBox.h" #include "ppc-framework/crypto/Oprf.h" -#include "ppc-protocol/src/PPCMessage.h" #include "ppc-psi/src/PSIConfig.h" +#include "protocol/src/PPCMessage.h" #include #include diff --git a/cpp/ppc-psi/src/labeled-psi/LabeledPSIImpl.cpp b/cpp/ppc-psi/src/labeled-psi/LabeledPSIImpl.cpp index bac2fa70..c9726b47 100644 --- a/cpp/ppc-psi/src/labeled-psi/LabeledPSIImpl.cpp +++ b/cpp/ppc-psi/src/labeled-psi/LabeledPSIImpl.cpp @@ -25,7 +25,7 @@ #include "core/SenderDB.h" #include "core/TaskCommand.h" #include "ppc-psi/src/labeled-psi/core/LabeledPSIReceiver.h" -#include "ppc-tars-protocol/ppc-tars-protocol/TarsSerialize.h" +#include "wedpr-protocol/tars/TarsSerialize.h" using namespace ppc::psi; using namespace ppc::protocol; @@ -122,7 +122,7 @@ void LabeledPSIImpl::asyncRunTask( addReceiver(receiver); // notify the taskInfo to the front - error = m_config->front()->notifyTaskInfo(std::make_shared(_task->id())); + error = m_config->front()->notifyTaskInfo(_task->id()); if (error && error->errorCode()) { onSelfError(_task->id(), error, true); @@ -347,7 +347,7 @@ void LabeledPSIImpl::asyncRunSenderTask( addPendingTask(taskState); // notify the taskInfo to the front - error = m_config->front()->notifyTaskInfo(std::make_shared(_task->id())); + error = m_config->front()->notifyTaskInfo(_task->id()); if (error && error->errorCode()) { onSelfError(_task->id(), error, true); diff --git a/cpp/ppc-psi/src/labeled-psi/LabeledPSIImpl.h b/cpp/ppc-psi/src/labeled-psi/LabeledPSIImpl.h index bafa00c1..5765187b 100644 --- a/cpp/ppc-psi/src/labeled-psi/LabeledPSIImpl.h +++ b/cpp/ppc-psi/src/labeled-psi/LabeledPSIImpl.h @@ -33,8 +33,8 @@ #include "ppc-framework/protocol/Task.h" #include "ppc-framework/task/TaskFrameworkInterface.h" #include "ppc-front/ppc-front/PPCChannel.h" -#include "ppc-protocol/src/PPCMessage.h" #include "ppc-tools/src/common/TransTools.h" +#include "protocol/src/PPCMessage.h" namespace ppc::psi { @@ -119,7 +119,7 @@ class LabeledPSIImpl : public bcos::Worker, protected: // allow the output-path exists, for ut bool m_enableOutputExists = false; - + private: void waitSignal() { diff --git a/cpp/ppc-psi/src/labeled-psi/core/LabeledPSIParams.h b/cpp/ppc-psi/src/labeled-psi/core/LabeledPSIParams.h index 9cbf3b97..70b21e82 100644 --- a/cpp/ppc-psi/src/labeled-psi/core/LabeledPSIParams.h +++ b/cpp/ppc-psi/src/labeled-psi/core/LabeledPSIParams.h @@ -35,7 +35,6 @@ namespace ppc::psi { - struct BaseParams { size_t maxItemsPerBin; diff --git a/cpp/ppc-psi/src/labeled-psi/core/LabeledPSIReceiver.cpp b/cpp/ppc-psi/src/labeled-psi/core/LabeledPSIReceiver.cpp index f9b3ebe7..37221243 100644 --- a/cpp/ppc-psi/src/labeled-psi/core/LabeledPSIReceiver.cpp +++ b/cpp/ppc-psi/src/labeled-psi/core/LabeledPSIReceiver.cpp @@ -30,7 +30,7 @@ #include "bcos-utilities/DataConvertUtility.h" #include "ppc-psi/src/labeled-psi/Common.h" #include "ppc-psi/src/labeled-psi/protocol/LabeledPSIResult.h" -#include "ppc-tars-protocol/ppc-tars-protocol/TarsSerialize.h" +#include "wedpr-protocol/tars/TarsSerialize.h" using namespace ppc::psi; using namespace ppc::protocol; diff --git a/cpp/ppc-psi/src/labeled-psi/core/LabeledPSIReceiver.h b/cpp/ppc-psi/src/labeled-psi/core/LabeledPSIReceiver.h index 7e392d19..cfb54d5d 100644 --- a/cpp/ppc-psi/src/labeled-psi/core/LabeledPSIReceiver.h +++ b/cpp/ppc-psi/src/labeled-psi/core/LabeledPSIReceiver.h @@ -24,10 +24,10 @@ #include "ppc-framework/protocol/Task.h" #include "ppc-framework/task/TaskFrameworkInterface.h" #include "ppc-front/ppc-front/PPCChannel.h" -#include "ppc-protocol/src/PPCMessage.h" #include "ppc-psi/src/labeled-psi/LabeledPSIConfig.h" #include "ppc-psi/src/labeled-psi/protocol/Protocol.h" #include "ppc-tools/src/common/Progress.h" +#include "protocol/src/PPCMessage.h" #include #include diff --git a/cpp/ppc-psi/src/labeled-psi/core/LabeledPSISender.cpp b/cpp/ppc-psi/src/labeled-psi/core/LabeledPSISender.cpp index 6cf086f5..419140cf 100644 --- a/cpp/ppc-psi/src/labeled-psi/core/LabeledPSISender.cpp +++ b/cpp/ppc-psi/src/labeled-psi/core/LabeledPSISender.cpp @@ -25,8 +25,8 @@ #include "LabeledPSIParams.h" #include "LabeledPSISender.h" #include "QueryPackage.h" -#include "ppc-tars-protocol/ppc-tars-protocol/TarsSerialize.h" #include "ppc-tools/src/common/TransTools.h" +#include "wedpr-protocol/tars/TarsSerialize.h" using namespace ppc::psi; using namespace ppc::protocol; diff --git a/cpp/ppc-psi/src/labeled-psi/core/LabeledPSISender.h b/cpp/ppc-psi/src/labeled-psi/core/LabeledPSISender.h index 32497157..bddd01c6 100644 --- a/cpp/ppc-psi/src/labeled-psi/core/LabeledPSISender.h +++ b/cpp/ppc-psi/src/labeled-psi/core/LabeledPSISender.h @@ -33,9 +33,9 @@ #include "ppc-framework/crypto/Oprf.h" #include "ppc-framework/protocol/Task.h" #include "ppc-front/ppc-front/PPCChannel.h" -#include "ppc-protocol/src/PPCMessage.h" #include "ppc-psi/src/labeled-psi/LabeledPSIConfig.h" #include "ppc-psi/src/labeled-psi/protocol/Protocol.h" +#include "protocol/src/PPCMessage.h" namespace ppc::psi { diff --git a/cpp/ppc-psi/src/labeled-psi/core/ResultPackage.h b/cpp/ppc-psi/src/labeled-psi/core/ResultPackage.h index 7e0d32f0..b6a55ad2 100644 --- a/cpp/ppc-psi/src/labeled-psi/core/ResultPackage.h +++ b/cpp/ppc-psi/src/labeled-psi/core/ResultPackage.h @@ -29,7 +29,6 @@ namespace ppc::psi { - /** Stores a decrypted and decoded PSI response and optionally a labeled PSI response. */ diff --git a/cpp/ppc-psi/src/labeled-psi/core/SenderDB.h b/cpp/ppc-psi/src/labeled-psi/core/SenderDB.h index c9580b59..a84c6935 100644 --- a/cpp/ppc-psi/src/labeled-psi/core/SenderDB.h +++ b/cpp/ppc-psi/src/labeled-psi/core/SenderDB.h @@ -55,7 +55,7 @@ #include "../protocol/Protocol.h" #include "BinBundle.h" #include "SenderCache.h" -#include "ppc-tars-protocol/ppc-tars-protocol/TarsSerialize.h" +#include "wedpr-protocol/tars/TarsSerialize.h" namespace ppc::psi { diff --git a/cpp/ppc-psi/src/psi-framework/TaskGuarder.h b/cpp/ppc-psi/src/psi-framework/TaskGuarder.h index a51c48f1..aea7e3bf 100644 --- a/cpp/ppc-psi/src/psi-framework/TaskGuarder.h +++ b/cpp/ppc-psi/src/psi-framework/TaskGuarder.h @@ -24,7 +24,7 @@ #include "../PSIConfig.h" #include "TaskState.h" #include "ppc-framework/protocol/Protocol.h" -#include "ppc-protocol/src/PPCMessage.h" +#include "protocol/src/PPCMessage.h" #include namespace ppc::psi diff --git a/cpp/ppc-psi/src/psi-framework/protocol/PSIMessage.cpp b/cpp/ppc-psi/src/psi-framework/protocol/PSIMessage.cpp index 1a93d668..376a9f91 100644 --- a/cpp/ppc-psi/src/psi-framework/protocol/PSIMessage.cpp +++ b/cpp/ppc-psi/src/psi-framework/protocol/PSIMessage.cpp @@ -18,7 +18,7 @@ * @date 2022-11-9 */ #include "PSIMessage.h" -#include "ppc-tars-protocol/Common.h" +#include "wedpr-protocol/tars/Common.h" using namespace ppc::psi; using namespace bcos; diff --git a/cpp/ppc-psi/src/ra2018-psi/RA2018PSIImpl.cpp b/cpp/ppc-psi/src/ra2018-psi/RA2018PSIImpl.cpp index 362857c7..7decb401 100644 --- a/cpp/ppc-psi/src/ra2018-psi/RA2018PSIImpl.cpp +++ b/cpp/ppc-psi/src/ra2018-psi/RA2018PSIImpl.cpp @@ -72,7 +72,7 @@ void RA2018PSIImpl::asyncRunTask( { if (m_disabled) { - m_config->front()->notifyTaskInfo(std::make_shared(_task->id())); + m_config->front()->notifyTaskInfo(_task->id()); auto taskResult = std::make_shared(_task->id()); auto error = BCOS_ERROR_PTR( (int)RA2018PSIDisabled, "The ra2018-psi has been disabled by this node!"); @@ -189,7 +189,7 @@ void RA2018PSIImpl::runPSI(TaskState::Ptr const& _taskState) { auto task = _taskState->task(); // notify the taskInfo to the front - m_config->front()->notifyTaskInfo(std::make_shared(task->id())); + m_config->front()->notifyTaskInfo(task->id()); switch (task->selfParty()->partyIndex()) { case (int)ppc::protocol::PartyType::Client: diff --git a/cpp/ppc-psi/src/ra2018-psi/protocol/RA2018Message.h b/cpp/ppc-psi/src/ra2018-psi/protocol/RA2018Message.h index c14d7c81..e6a49166 100644 --- a/cpp/ppc-psi/src/ra2018-psi/protocol/RA2018Message.h +++ b/cpp/ppc-psi/src/ra2018-psi/protocol/RA2018Message.h @@ -21,7 +21,7 @@ #include "../../psi-framework/protocol/PSIMessage.h" #include "../Common.h" #include "../core/CuckooFilterInfo.h" -#include "ppc-tars-protocol/Common.h" +#include "wedpr-protocol/tars/Common.h" namespace ppc::psi { diff --git a/cpp/ppc-rpc/src/Rpc.h b/cpp/ppc-rpc/src/Rpc.h index a48a9b2b..59eadce5 100644 --- a/cpp/ppc-rpc/src/Rpc.h +++ b/cpp/ppc-rpc/src/Rpc.h @@ -21,7 +21,7 @@ #include "ppc-framework/front/FrontInterface.h" #include "ppc-framework/rpc/RpcInterface.h" #include "ppc-framework/rpc/RpcStatusInterface.h" -#include "ppc-protocol/src/JsonTaskImpl.h" +#include "protocol/src/JsonTaskImpl.h" #include #include #include diff --git a/cpp/ppc-rpc/src/RpcMemory.cpp b/cpp/ppc-rpc/src/RpcMemory.cpp index c44fd9c7..86074c47 100644 --- a/cpp/ppc-rpc/src/RpcMemory.cpp +++ b/cpp/ppc-rpc/src/RpcMemory.cpp @@ -109,22 +109,6 @@ TaskResult::Ptr RpcMemory::getTaskStatus(const std::string& _taskID) return m_tasks[_taskID].second; } -bcos::Error::Ptr RpcMemory::insertGateway( - const std::string& _agencyID, const std::string& _endpoint) -{ - try - { - std::vector gatewayList; - gatewayList.push_back({_agencyID, _endpoint}); - m_gateway->registerGateway(gatewayList); - return nullptr; - } - catch (std::exception const& e) - { - return std::make_shared( - PPCRetCode::EXCEPTION, "insertGateway error: " + boost::diagnostic_information(e)); - } -} bcos::Error::Ptr RpcMemory::deleteGateway(const std::string& _agencyID) { diff --git a/cpp/ppc-rpc/src/RpcMemory.h b/cpp/ppc-rpc/src/RpcMemory.h index e275b3ea..30430027 100644 --- a/cpp/ppc-rpc/src/RpcMemory.h +++ b/cpp/ppc-rpc/src/RpcMemory.h @@ -20,23 +20,18 @@ */ #pragma once -#include "ppc-framework/gateway/GatewayInterface.h" #include "ppc-framework/rpc/RpcStatusInterface.h" #include #include namespace ppc::rpc { - class RpcMemory : public RpcStatusInterface { public: using Ptr = std::shared_ptr; - RpcMemory(ppc::gateway::GatewayInterface::Ptr _gateway) - : m_gateway(std::move(_gateway)), - m_taskCleaner(std::make_shared(60 * 60 * 1000, "taskCleaner")) - {} + RpcMemory() : m_taskCleaner(std::make_shared(60 * 60 * 1000, "taskCleaner")) {} ~RpcMemory() override = default; void start() override; @@ -54,8 +49,6 @@ class RpcMemory : public RpcStatusInterface void cleanTask(); private: - ppc::gateway::GatewayInterface::Ptr m_gateway; - mutable bcos::SharedMutex x_tasks; std::unordered_map> m_tasks; std::shared_ptr m_taskCleaner; diff --git a/cpp/ppc-tars-protocol/ppc-tars-protocol/client/FrontServiceClient.h b/cpp/ppc-tars-protocol/ppc-tars-protocol/client/FrontServiceClient.h deleted file mode 100644 index 2209f62f..00000000 --- a/cpp/ppc-tars-protocol/ppc-tars-protocol/client/FrontServiceClient.h +++ /dev/null @@ -1,139 +0,0 @@ -/** - * Copyright (C) 2022 WeDPR. - * SPDX-License-Identifier: Apache-2.0 - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * @brief client for front service - * @file FrontServiceClient.h - * @author: shawnhe - * @date 2022-10-20 - */ - -#pragma once - -#pragma GCC diagnostic ignored "-Wunused-variable" -#pragma GCC diagnostic ignored "-Wunused-parameter" - -#include "FrontService.h" -#include "ppc-framework/front/FrontInterface.h" -#include "ppc-tars-protocol/ppc-tars-protocol/Common.h" -#include -#include - -#include - -#define FRONTCLIENT_LOG(LEVEL) BCOS_LOG(LEVEL) << LOG_BADGE("FrontServiceClient") - -namespace ppctars -{ -class FrontServiceClient : public ppc::front::FrontInterface -{ -public: - void start() override {} - void stop() override {} - - explicit FrontServiceClient(const ppctars::FrontServicePrx& proxy) : m_proxy(proxy) {} - - void onReceiveMessage( - ppc::front::PPCMessageFace::Ptr _message, ppc::front::ErrorCallbackFunc _callback) override - { - class Callback : public FrontServicePrxCallback - { - public: - explicit Callback(ppc::front::ErrorCallbackFunc callback) - : m_callback(std::move(callback)) - {} - - void callback_onReceiveMessage(const ppctars::Error& ret) override - { - if (!m_callback) - { - return; - } - m_callback(toBcosError(ret)); - } - - void callback_onReceiveMessage_exception(tars::Int32 ret) override - { - if (!m_callback) - { - return; - } - m_callback(toBcosError(ret)); - } - - private: - ppc::front::ErrorCallbackFunc m_callback; - }; - - auto startT = bcos::utcSteadyTime(); - // encode message to bytes - bcos::bytes buffer; - _message->encode(buffer); - - FRONTCLIENT_LOG(TRACE) << LOG_DESC("after decode") - << LOG_KV("taskType", unsigned(_message->taskType())) - << LOG_KV("algorithmType", unsigned(_message->algorithmType())) - << LOG_KV("messageType", unsigned(_message->messageType())) - << LOG_KV("seq", _message->seq()) - << LOG_KV("taskID", _message->taskID()); - - m_proxy->tars_set_timeout(c_networkTimeout) - ->async_onReceiveMessage( - new Callback(_callback), std::vector(buffer.begin(), buffer.end())); - BCOS_LOG(TRACE) << LOG_DESC("call front onReceiveMessage") - << LOG_KV("msgSize", buffer.size()) - << LOG_KV("timecost", bcos::utcSteadyTime() - startT); - } - - // Note: since ppc-front is integrated with the node, no-need to implement this method - bcos::Error::Ptr notifyTaskInfo(ppc::protocol::GatewayTaskInfo::Ptr) override - { - throw std::runtime_error("notifyTaskInfo: unimplemented interface!"); - } - - // Note: since ppc-front is integrated with the node, no-need to implement this method - // erase the task-info when task finished - bcos::Error::Ptr eraseTaskInfo(std::string const&) override - { - throw std::runtime_error("eraseTaskInfo: unimplemented interface!"); - } - - // Note: since ppc-front is integrated with the node, no-need to implement this method - void asyncSendMessage(const std::string&, ppc::front::PPCMessageFace::Ptr, uint32_t _timeout, - ppc::front::ErrorCallbackFunc _callback, ppc::front::CallbackFunc _respCallback) override - { - throw std::runtime_error("asyncSendMessage: unimplemented interface!"); - } - - // Note: since ppc-front is integrated with the node, no-need to implement this method - // send response when receiving message from given agencyID - void asyncSendResponse(const std::string&, std::string const&, ppc::front::PPCMessageFace::Ptr, - ppc::front::ErrorCallbackFunc) override - { - throw std::runtime_error("asyncSendResponse: unimplemented interface!"); - } - - // Note: since ppc-front is integrated with the node, no-need to implement this method - void asyncGetAgencyList(ppc::front::GetAgencyListCallback) override - { - throw std::runtime_error("asyncGetAgencyList: unimplemented interface!"); - } - -private: - // 1800s - const int c_networkTimeout = 1800000; - - ppctars::FrontServicePrx m_proxy; -}; -} // namespace ppctars diff --git a/cpp/ppc-tars-protocol/ppc-tars-protocol/client/GatewayServiceClient.h b/cpp/ppc-tars-protocol/ppc-tars-protocol/client/GatewayServiceClient.h deleted file mode 100644 index 52db1dcc..00000000 --- a/cpp/ppc-tars-protocol/ppc-tars-protocol/client/GatewayServiceClient.h +++ /dev/null @@ -1,298 +0,0 @@ -/** - * Copyright (C) 2022 WeDPR. - * SPDX-License-Identifier: Apache-2.0 - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * @brief client for gateway service - * @file GatewayServiceClient.h - * @author: shawnhe - * @date 2022-10-20 - */ - -#pragma once - -#pragma GCC diagnostic ignored "-Wunused-variable" -#pragma GCC diagnostic ignored "-Wunused-parameter" - -#include "GatewayService.h" -#include "ppc-framework/Common.h" -#include "ppc-framework/gateway/GatewayInterface.h" -#include "ppc-framework/protocol/Protocol.h" -#include "ppc-tars-protocol/Common.h" -#include "ppc-tars-protocol/TarsServantProxyCallback.h" -#include -#include -#include - -#include - -#define GATEWAYCLIENT_LOG(LEVEL) BCOS_LOG(LEVEL) << LOG_BADGE("GatewayServiceClient") - -namespace ppctars -{ -class GatewayServiceClient : public ppc::gateway::GatewayInterface -{ -public: - void start() override {} - void stop() override {} - - explicit GatewayServiceClient(std::string const& _gatewayServiceName, - const ppctars::GatewayServicePrx& _prx, int _holdingMessageMinutes) - : m_gatewayServiceName(_gatewayServiceName), - m_prx(_prx), - m_networkTimeout(_holdingMessageMinutes * 60 * 1000) // convert to ms - { - BCOS_LOG(INFO) << LOG_DESC("GatewayServiceClient") - << LOG_KV("networkTimeout", m_networkTimeout); - } - - - void asyncSendMessage(const std::string& _agencyID, ppc::front::PPCMessageFace::Ptr _message, - ppc::gateway::ErrorCallbackFunc _callback) override - { - class Callback : public GatewayServicePrxCallback - { - public: - explicit Callback(ppc::gateway::ErrorCallbackFunc callback) - : m_callback(std::move(callback)) - {} - - void callback_asyncSendMessage(const ppctars::Error& ret) override - { - s_tarsTimeoutCount.store(0); - if (!m_callback) - { - return; - } - m_callback(toBcosError(ret)); - } - - void callback_asyncSendMessage_exception(tars::Int32 ret) override - { - s_tarsTimeoutCount++; - if (!m_callback) - { - return; - } - m_callback(toBcosError(ret)); - } - - private: - ppc::gateway::ErrorCallbackFunc m_callback; - }; - - // encode message to bytes - bcos::bytes buffer; - _message->encode(buffer); - - GATEWAYCLIENT_LOG(TRACE) << LOG_DESC("send message to gateway by client") - << LOG_KV("taskType", unsigned(_message->taskType())) - << LOG_KV("algorithmType", unsigned(_message->algorithmType())) - << LOG_KV("messageType", unsigned(_message->messageType())) - << LOG_KV("seq", _message->seq()) - << LOG_KV("taskID", _message->taskID()) - << LOG_KV("receiver", _agencyID); - - m_prx->tars_set_timeout(m_networkTimeout) - ->async_asyncSendMessage(new Callback(_callback), _agencyID, - std::vector(buffer.begin(), buffer.end())); - } - - bcos::Error::Ptr notifyTaskInfo(ppc::protocol::GatewayTaskInfo::Ptr _taskInfo) override - { - auto tarsTaskInfo = toTarsTaskInfo(_taskInfo); - auto activeEndPoints = tarsProxyActiveEndPoints(m_prx); - // broadcast to all gateways - uint errorCount = 0; - std::string lastErrorMsg; - for (auto& endPoint : activeEndPoints) - { - auto prx = - ppctars::createServantProxy(m_gatewayServiceName, endPoint); - auto error = prx->tars_set_timeout(m_networkTimeout)->notifyTaskInfo(tarsTaskInfo); - if (error.errorCode) - { - ++errorCount; - lastErrorMsg = error.errorMessage; - } - } - if (errorCount) - { - return std::make_shared( - ppc::protocol::PPCRetCode::NOTIFY_TASK_ERROR, lastErrorMsg); - } - - return nullptr; - } - - bcos::Error::Ptr eraseTaskInfo(std::string const& _taskID) override - { - auto activeEndPoints = tarsProxyActiveEndPoints(m_prx); - // broadcast to all gateways - uint errorCount = 0; - for (auto& endPoint : activeEndPoints) - { - try - { - auto prx = - ppctars::createServantProxy(m_gatewayServiceName, endPoint); - auto error = prx->tars_set_timeout(m_networkTimeout)->eraseTaskInfo(_taskID); - if (error.errorCode) - { - ++errorCount; - } - } - catch (std::exception const& e) - { - ++errorCount; - BCOS_LOG(INFO) << LOG_DESC("eraseTaskInfo exception") - << LOG_KV("exception", boost::diagnostic_information(e)); - } - } - if (errorCount) - { - return std::make_shared( - -1, "eraseTaskInfo: error count: " + std::to_string(errorCount)); - } - - return nullptr; - } - - bcos::Error::Ptr registerGateway( - const std::vector& _gatewayList) override - { - std::vector gatewayList; - for (const auto& gateway : _gatewayList) - { - ppctars::GatewayInfo tarsGate; - tarsGate.agencyID = gateway.agencyID; - tarsGate.endpoint = gateway.endpoint; - gatewayList.push_back(tarsGate); - } - auto activeEndPoints = tarsProxyActiveEndPoints(m_prx); - // broadcast to all gateways - uint errorCount = 0; - std::string lastErrorMsg; - for (auto& endPoint : activeEndPoints) - { - auto prx = - ppctars::createServantProxy(m_gatewayServiceName, endPoint); - auto error = prx->tars_set_timeout(m_networkTimeout)->registerGateway(gatewayList); - if (error.errorCode) - { - ++errorCount; - lastErrorMsg = error.errorMessage; - } - } - if (errorCount) - { - return std::make_shared( - ppc::protocol::PPCRetCode::REGISTER_GATEWAY_URL_ERROR, lastErrorMsg); - } - - return nullptr; - } - - - void registerFront(std::string const& _endPoint, ppc::front::FrontInterface::Ptr) override - { - // Error registerFront(string endPoint); - class Callback : public GatewayServicePrxCallback - { - public: - explicit Callback(std::function _callback) - : GatewayServicePrxCallback(), m_callback(_callback) - {} - ~Callback() override {} - - void callback_asyncRegisterFront(const ppctars::Error& ret) override - { - m_callback(toBcosError(ret)); - } - void callback_asyncRegisterFront_exception(tars::Int32 ret) override - { - m_callback(toBcosError(ret)); - } - - private: - std::function m_callback; - }; - auto startT = bcos::utcSteadyTime(); - auto callback = [_endPoint, startT](bcos::Error::Ptr _error) { - if (!_error || _error->errorCode() == 0) - { - GATEWAYCLIENT_LOG(TRACE) - << LOG_DESC("registerFront success") << LOG_KV("endPoint", _endPoint) - << LOG_KV("timecost", (bcos::utcSteadyTime() - startT)); - return; - } - GATEWAYCLIENT_LOG(INFO) - << LOG_DESC("registerFront failed") << LOG_KV("code", _error->errorCode()) - << LOG_KV("endPoint", _endPoint) - << LOG_KV("timecost", (bcos::utcSteadyTime() - startT)); - }; - m_prx->tars_set_timeout(m_networkTimeout) - ->async_asyncRegisterFront(new Callback(callback), _endPoint); - } - - void asyncGetAgencyList(ppc::front::GetAgencyListCallback _callback) override - { - class Callback : public GatewayServicePrxCallback - { - public: - explicit Callback(ppc::front::GetAgencyListCallback _callback) - : GatewayServicePrxCallback(), m_callback(_callback) - {} - ~Callback() override {} - - void callback_asyncGetAgencyList( - const ppctars::Error& ret, std::vector const& _agencyList) override - { - auto tmpAgencyList = _agencyList; - m_callback(toBcosError(ret), std::move(tmpAgencyList)); - } - void callback_asyncGetAgencyList_exception(tars::Int32 ret) override - { - std::vector emptyAgencyList; - m_callback(toBcosError(ret), std::move(emptyAgencyList)); - } - - private: - ppc::front::GetAgencyListCallback m_callback; - }; - m_prx->async_asyncGetAgencyList(new Callback(_callback)); - } - - // Note: unregisterFront is the function of the front-inner, no need to implement the client - void unregisterFront(std::string const&) override - { - throw std::runtime_error("unregisterFront: unimplemented interface!"); - } - - -protected: - static bool shouldStopCall() { return (s_tarsTimeoutCount >= c_maxTarsTimeoutCount); } - - -private: - std::string m_gatewayServiceName; - ppctars::GatewayServicePrx m_prx; - - // 1800s - int m_networkTimeout = 1800000; - std::string const c_moduleName = "GatewayServiceClient"; - - static std::atomic s_tarsTimeoutCount; - static const int64_t c_maxTarsTimeoutCount; -}; -} // namespace ppctars \ No newline at end of file diff --git a/cpp/ppc-tars-protocol/ppc-tars-protocol/tars/Error.tars b/cpp/ppc-tars-protocol/ppc-tars-protocol/tars/Error.tars deleted file mode 100644 index 5fa26bd7..00000000 --- a/cpp/ppc-tars-protocol/ppc-tars-protocol/tars/Error.tars +++ /dev/null @@ -1,6 +0,0 @@ -module ppctars { - struct Error { - 1 optional int errorCode; - 2 optional string errorMessage; - }; -}; \ No newline at end of file diff --git a/cpp/ppc-tars-protocol/ppc-tars-protocol/tars/FrontService.tars b/cpp/ppc-tars-protocol/ppc-tars-protocol/tars/FrontService.tars deleted file mode 100644 index 763bf91f..00000000 --- a/cpp/ppc-tars-protocol/ppc-tars-protocol/tars/FrontService.tars +++ /dev/null @@ -1,7 +0,0 @@ -#include "Error.tars" -module ppctars { - interface FrontService { - // avoid using unsigned char - Error onReceiveMessage(vector message); - }; -}; diff --git a/cpp/ppc-tars-protocol/ppc-tars-protocol/tars/GatewayService.tars b/cpp/ppc-tars-protocol/ppc-tars-protocol/tars/GatewayService.tars deleted file mode 100644 index 017ca607..00000000 --- a/cpp/ppc-tars-protocol/ppc-tars-protocol/tars/GatewayService.tars +++ /dev/null @@ -1,19 +0,0 @@ -#include "TaskInfo.tars" -#include "Error.tars" -module ppctars -{ - struct GatewayInfo { - 1 optional string agencyID; - 2 optional string endpoint; - }; - interface GatewayService - { - Error asyncSendMessage(string agencyID, vector message); - - Error notifyTaskInfo(TaskInfo taskInfo); - Error eraseTaskInfo(string taskID); - Error registerGateway(vector gatewayList); - Error asyncRegisterFront(string endPoint); - Error asyncGetAgencyList(out vector _agencyList); - }; -}; diff --git a/cpp/ppc-tars-protocol/ppc-tars-protocol/tars/NodeInfo.tars b/cpp/ppc-tars-protocol/ppc-tars-protocol/tars/NodeInfo.tars deleted file mode 100644 index 06055197..00000000 --- a/cpp/ppc-tars-protocol/ppc-tars-protocol/tars/NodeInfo.tars +++ /dev/null @@ -1,16 +0,0 @@ -module ppctars -{ - struct NodeInfo - { - 1 require vector nodeID; - 2 require string endPoint; - 3 optional vector components; - }; - struct GatewayNodeInfo - { - 1 require string p2pNodeID; - 2 require string agency; - 3 optional vector nodeList; - 4 optional int statusSeq; - }; -}; diff --git a/cpp/ppc-tars-protocol/ppc-tars-protocol/tars/TaskInfo.tars b/cpp/ppc-tars-protocol/ppc-tars-protocol/tars/TaskInfo.tars deleted file mode 100644 index 1d6af6f4..00000000 --- a/cpp/ppc-tars-protocol/ppc-tars-protocol/tars/TaskInfo.tars +++ /dev/null @@ -1,7 +0,0 @@ -module ppctars { -struct TaskInfo -{ - 1 require string taskID; - 2 require string serviceEndpoint; -}; -}; diff --git a/cpp/ppc-tars-protocol/test/CMakeLists.txt b/cpp/ppc-tars-protocol/test/CMakeLists.txt deleted file mode 100644 index 21935b94..00000000 --- a/cpp/ppc-tars-protocol/test/CMakeLists.txt +++ /dev/null @@ -1,13 +0,0 @@ -file(GLOB_RECURSE SOURCES "*.cpp") - -# cmake settings -set(TARS_PROTOCOL_TEST_BINARY_NAME test-ppc-tars-protocol) - -add_executable(${TARS_PROTOCOL_TEST_BINARY_NAME} ${SOURCES}) -target_include_directories(${TARS_PROTOCOL_TEST_BINARY_NAME} PRIVATE .) - -target_compile_options(${TARS_PROTOCOL_TEST_BINARY_NAME} PRIVATE -Wno-error -Wno-unused-variable) - -target_link_libraries(${TARS_PROTOCOL_TEST_BINARY_NAME} ${FRONT_TARGET} ${PROTOCOL_TARGET} ${TARS_PROTOCOL_TARGET} ${BOOST_UNIT_TEST}) - -add_test(NAME test-tars-protocol WORKING_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY} COMMAND ${TARS_PROTOCOL_TEST_BINARY_NAME}) \ No newline at end of file diff --git a/cpp/ppc-tars-protocol/test/SerializeTest.cpp b/cpp/ppc-tars-protocol/test/SerializeTest.cpp deleted file mode 100644 index b6f33f9f..00000000 --- a/cpp/ppc-tars-protocol/test/SerializeTest.cpp +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Copyright (C) 2022 WeDPR. - * SPDX-License-Identifier: Apache-2.0 - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * @file GatewayTest.cpp - * @author: shawnhe - * @date 2022-10-28 - */ - -#include "TaskInfo.h" -#include "ppc-protocol/src/PPCMessage.h" -#include "ppc-tars-protocol/ppc-tars-protocol/TarsSerialize.h" -#include -#include -#include -#include - -using namespace ppc; -using namespace ppctars; -using namespace ppctars::serialize; -using namespace ppc::protocol; -using namespace ppc::front; - -using namespace bcos; -using namespace bcos::test; - -BOOST_FIXTURE_TEST_SUITE(TarsProtocolTest, TestPromptFixture) - - -BOOST_AUTO_TEST_CASE(test_labeledPSIDataStructure) -{ - TaskInfo taskInfo; - taskInfo.taskID = "123fdsa456"; - taskInfo.serviceEndpoint = "234f567dfaj"; - - auto encodedData = std::make_shared(); - serialize::encode(taskInfo, *encodedData); - - TaskInfo taskInfoD; - serialize::decode(*encodedData, taskInfoD); - BOOST_CHECK(taskInfoD == taskInfo); - - auto messageFactory = std::make_shared(); - auto queryMessage = messageFactory->buildPPCMessage( - uint8_t(1), uint8_t(2), "345", std::make_shared()); - queryMessage->setMessageType(uint8_t(6)); - ppctars::serialize::encode(taskInfo, *queryMessage->data()); - auto buffer = std::make_shared(); - queryMessage->encode(*buffer); - - auto newMessage = messageFactory->buildPPCMessage(); - newMessage->decode(buffer); - - TaskInfo taskInfoN; - serialize::decode(*newMessage->data(), taskInfoN); - BOOST_CHECK(taskInfoN == taskInfo); -} - -BOOST_AUTO_TEST_SUITE_END() diff --git a/cpp/ppc-tars-protocol/test/main.cpp b/cpp/ppc-tars-protocol/test/main.cpp deleted file mode 100644 index 8034bf67..00000000 --- a/cpp/ppc-tars-protocol/test/main.cpp +++ /dev/null @@ -1,2 +0,0 @@ -#define BOOST_TEST_MAIN -#include diff --git a/cpp/ppc-tools/src/config/PPCConfig.cpp b/cpp/ppc-tools/src/config/PPCConfig.cpp index 07a773c6..ace6cee6 100644 --- a/cpp/ppc-tools/src/config/PPCConfig.cpp +++ b/cpp/ppc-tools/src/config/PPCConfig.cpp @@ -49,14 +49,17 @@ void PPCConfig::loadGatewayConfig( PPCConfig_LOG(INFO) << LOG_DESC("loadGatewayConfig: load the redis config success"); } - // load the agencies config - PPCConfig_LOG(INFO) << LOG_DESC("loadGatewayConfig: load the agency config"); - auto agencyInfo = parseAgencyConfig(_pt, "agency", "agency"); - m_gatewayConfig.agencies = std::move(agencyInfo); - PPCConfig_LOG(INFO) << LOG_DESC("loadGatewayConfig: load the agency config sucess") - << LOG_KV("agencySize", m_gatewayConfig.agencies.size()); + m_gatewayConfig.nodePath = _pt.get("gateway.nodes_path", "./"); + m_gatewayConfig.nodeFileName = _pt.get("gateway.nodes_file", "nodes.json"); m_gatewayConfig.reconnectTime = _pt.get("gateway.reconnect_time", 10000); + m_gatewayConfig.unreachableDistance = _pt.get("gateway.unreachable_distance", 10); + if (m_gatewayConfig.unreachableDistance < GatewayConfig::MinUnreachableDistance) + { + BOOST_THROW_EXCEPTION(InvalidConfig() << bcos::errinfo_comment( + "Invalid unreachable_distance, must no smaller than " + + std::to_string(GatewayConfig::MinUnreachableDistance))); + } // load the maxAllowedMsgSize, in MBytes m_gatewayConfig.maxAllowedMsgSize = _pt.get("gateway.max_allow_msg_size", GatewayConfig::DefaultMaxAllowedMsgSize) * @@ -78,7 +81,6 @@ void PPCConfig::loadGatewayConfig( << LOG_KV("holdingMessageMinutes", m_holdingMessageMinutes); } - int PPCConfig::loadHoldingMessageMinutes( const boost::property_tree::ptree& _pt, std::string const& _section) { @@ -127,62 +129,6 @@ void PPCConfig::initRedisConfigForGateway( << LOG_KV("socketTimeout", _redisConfig.socketTimeout); } -std::map> PPCConfig::parseAgencyConfig( - const boost::property_tree::ptree& _pt, std::string const& _sectionName, - std::string const& _subSectionName) -{ - PPCConfig_LOG(INFO) << LOG_DESC("parseAgencyConfig") << LOG_KV("section", _sectionName) - << LOG_KV("subSection", _subSectionName); - std::map> agencyInfo; - // without the config - if (!_pt.get_child_optional(_sectionName)) - { - PPCConfig_LOG(INFO) << LOG_DESC("parseAgencyConfig return for empty config") - << LOG_KV("section", _sectionName) - << LOG_KV("subSection", _subSectionName); - return agencyInfo; - } - // tranverse the child_section to parse the agencyInfo - for (auto const& it : _pt.get_child(_sectionName)) - { - if (it.first.find(_subSectionName) != 0) - { - continue; - } - // find and parse the agencyInfo - auto key = it.first.data(); - std::vector agencyIDInfo; - boost::split(agencyIDInfo, key, boost::is_any_of(".")); - // invalid agencyID - if (agencyIDInfo.size() < 2) - { - BOOST_THROW_EXCEPTION( - InvalidConfig() << bcos::errinfo_comment("Invalid agency key " + it.first + - ", the key must be in format of " + - _sectionName + ".${agencyID}")); - } - std::vector endPointInfos; - auto value = it.second.data(); - boost::split(endPointInfos, value, boost::is_any_of(",")); - for (auto& endpoint : endPointInfos) - { - if (!checkEndpoint(endpoint)) - { - BOOST_THROW_EXCEPTION( - InvalidConfig() << bcos::errinfo_comment("Invalid agency endpoint" + endpoint)); - } - } - auto& currentEndPoints = agencyInfo[agencyIDInfo.at(1)]; - currentEndPoints.reserve(currentEndPoints.size() + endPointInfos.size()); - std::move(std::begin(endPointInfos), std::end(endPointInfos), - std::back_inserter(currentEndPoints)); - PPCConfig_LOG(INFO) << LOG_DESC("parseAgencyConfig") - << LOG_KV("agencyID", agencyIDInfo.at(1)) - << LOG_KV("endPointSize", currentEndPoints.size()); - } - PPCConfig_LOG(INFO) << LOG_DESC("parseAgencyConfig") << LOG_KV("agencySize", agencyInfo.size()); - return agencyInfo; -} void PPCConfig::loadNetworkConfig(NetworkConfig& _config, const char* _certPath, boost::property_tree::ptree const& _pt, std::string const& _sectionName, int _defaultListenPort, diff --git a/cpp/ppc-tools/src/config/PPCConfig.h b/cpp/ppc-tools/src/config/PPCConfig.h index 6fd10a14..3bd29cfa 100644 --- a/cpp/ppc-tools/src/config/PPCConfig.h +++ b/cpp/ppc-tools/src/config/PPCConfig.h @@ -94,14 +94,17 @@ struct GatewayConfig constexpr static uint64_t DefaultMaxAllowedMsgSize = 100; constexpr static uint64_t MinMsgSize = 10 * 1024 * 1024; constexpr static uint64_t MaxMsgSize = 1024 * 1024 * 1024; + constexpr static int MinUnreachableDistance = 2; bool disableCache; NetworkConfig networkConfig; ppc::storage::CacheStorageConfig cacheStorageConfig; - // agencyID => endPointList - std::map> agencies; + std::string nodeFileName; + std::string nodePath; uint64_t maxAllowedMsgSize = DefaultMaxAllowedMsgSize; int reconnectTime = 10000; + // the unreachable distance + int unreachableDistance = 10; }; // the ecdh-psi config @@ -294,9 +297,6 @@ class PPCConfig void initRedisConfigForGateway( ppc::storage::CacheStorageConfig& _redisConfig, const boost::property_tree::ptree& _pt); - std::map> parseAgencyConfig( - const boost::property_tree::ptree& _pt, std::string const& _sectionName, - std::string const& _subSectionName); // load the tars-config for the given service, e.g: /* diff --git a/cpp/ppc-utilities/Utilities.h b/cpp/ppc-utilities/Utilities.h index 4b04ea21..13739719 100644 --- a/cpp/ppc-utilities/Utilities.h +++ b/cpp/ppc-utilities/Utilities.h @@ -21,6 +21,9 @@ #include "ppc-framework/Common.h" #include +#include +#include +#include namespace ppc { @@ -38,4 +41,10 @@ inline uint64_t decodeNetworkBuffer( curOffset += dataLen; return curOffset; } + +inline std::string generateUUID() +{ + static thread_local auto uuid_gen = boost::uuids::basic_random_generator(); + return boost::uuids::to_string(uuid_gen()); +} } // namespace ppc \ No newline at end of file diff --git a/cpp/test-utils/FakeFront.h b/cpp/test-utils/FakeFront.h index c5f36374..b7cdc97d 100644 --- a/cpp/test-utils/FakeFront.h +++ b/cpp/test-utils/FakeFront.h @@ -182,7 +182,7 @@ class FakeFront : public FrontInterface } } - bcos::Error::Ptr notifyTaskInfo(protocol::GatewayTaskInfo::Ptr) override { return nullptr; } + bcos::Error::Ptr notifyTaskInfo(std::string const&) override { return nullptr; } // erase the task-info when task finished bcos::Error::Ptr eraseTaskInfo(std::string const&) override { return nullptr; } diff --git a/cpp/test-utils/FakePPCMessage.h b/cpp/test-utils/FakePPCMessage.h index 279cb1cf..a9070eae 100644 --- a/cpp/test-utils/FakePPCMessage.h +++ b/cpp/test-utils/FakePPCMessage.h @@ -49,8 +49,6 @@ class FakePPCMessage : public PPCMessageFace std::string const& sender() const override { return m_sender; } void setSender(std::string const& _sender) override { m_sender = _sender; } - uint16_t ext() const override { return m_ext; } - void setExt(uint16_t _ext) override { m_ext = _ext; } std::shared_ptr data() const override { return m_data; } void setData(std::shared_ptr _data) override { m_data = _data; } // Note: we don't fake the encode-decode here diff --git a/cpp/vcpkg.json b/cpp/vcpkg.json index 635e6990..a56cfe04 100644 --- a/cpp/vcpkg.json +++ b/cpp/vcpkg.json @@ -51,6 +51,10 @@ { "name": "libsodium", "version": "1.0.18#9" + }, + { + "name": "grpc", + "version": "1.51.1#1" } ], "features": { @@ -90,6 +94,10 @@ "name": "mysql-connector-cpp", "version>=": "8.0.32" }, + { + "name": "grpc", + "version>=": "1.51.1" + }, "libhdfs3", "tarscpp", "tbb", diff --git a/cpp/wedpr-protocol/CMakeLists.txt b/cpp/wedpr-protocol/CMakeLists.txt new file mode 100644 index 00000000..33136407 --- /dev/null +++ b/cpp/wedpr-protocol/CMakeLists.txt @@ -0,0 +1,4 @@ +add_subdirectory(protocol) +add_subdirectory(protobuf) +add_subdirectory(tars) +add_subdirectory(grpc) diff --git a/cpp/wedpr-protocol/grpc/CMakeLists.txt b/cpp/wedpr-protocol/grpc/CMakeLists.txt new file mode 100644 index 00000000..e9a9279b --- /dev/null +++ b/cpp/wedpr-protocol/grpc/CMakeLists.txt @@ -0,0 +1,32 @@ +# proto generation +set(PROTO_INPUT_PATH ${CMAKE_SOURCE_DIR}/wedpr-protocol/proto/pb) + +file(GLOB_RECURSE MESSAGES_PROTOS "${PROTO_INPUT_PATH}/Service*.proto") + +# create PROTO_OUTPUT_PATH +file(MAKE_DIRECTORY ${PROTO_OUTPUT_PATH}) + +find_program(GRPC_CPP_PLUGIN grpc_cpp_plugin REQUIRED) +find_program(PROTOC_BINARY protoc REQUIRED) + +foreach(proto_file ${MESSAGES_PROTOS}) + get_filename_component(basename ${proto_file} NAME_WE) + set(generated_file ${PROTO_OUTPUT_PATH}/${basename}.grpc.pb.cc) + + list(APPEND GRPC_MESSAGES_SRCS ${generated_file}) + message("Command: ${PROTOC_BINARY} --grpc_out ${PROTO_OUTPUT_PATH} -I ${PROTO_INPUT_PATH} --plugin=protoc-gen-grpc=${GRPC_CPP_PLUGIN} ${proto_file}") + add_custom_command( + OUTPUT ${generated_file} + COMMAND ${PROTOC_BINARY} + ARGS --grpc_out ${PROTO_OUTPUT_PATH} + -I ${PROTO_INPUT_PATH} + --plugin=protoc-gen-grpc="${GRPC_CPP_PLUGIN}" + ${proto_file} DEPENDS ${proto_file} + COMMENT "Generating ${generated_file} from ${proto_file}" + ) +endforeach() + +add_library(${SERVICE_CLIENT_PB_TARGET} ${GRPC_MESSAGES_SRCS}) +target_link_libraries(${SERVICE_CLIENT_PB_TARGET} PUBLIC ${PB_PROTOCOL_TARGET} gRPC::grpc++_unsecure) + +add_subdirectory(client) \ No newline at end of file diff --git a/cpp/wedpr-protocol/grpc/client/CMakeLists.txt b/cpp/wedpr-protocol/grpc/client/CMakeLists.txt new file mode 100644 index 00000000..c7f6c983 --- /dev/null +++ b/cpp/wedpr-protocol/grpc/client/CMakeLists.txt @@ -0,0 +1,3 @@ +file(GLOB_RECURSE SRCS *.cpp) +add_library(${SERVICE_CLIENT_TARGET} ${SRCS}) +target_link_libraries(${SERVICE_CLIENT_TARGET} PUBLIC ${SERVICE_CLIENT_PB_TARGET} ${PB_PROTOCOL_TARGET}) \ No newline at end of file diff --git a/cpp/wedpr-protocol/grpc/client/Common.h b/cpp/wedpr-protocol/grpc/client/Common.h new file mode 100644 index 00000000..68deebe7 --- /dev/null +++ b/cpp/wedpr-protocol/grpc/client/Common.h @@ -0,0 +1,23 @@ +/** + * Copyright (C) 2021 FISCO BCOS. + * SPDX-License-Identifier: Apache-2.0 + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * @file Common.h + * @author: yujiechen + * @date 2021-04-12 + */ +#pragma once +#include "ppc-framework/Common.h" + +#define GRPC_CLIENT_LOG(LEVEL) BCOS_LOG(LEVEL) << "[GRPC][CLIENT]" \ No newline at end of file diff --git a/cpp/wedpr-protocol/grpc/client/FrontClient.cpp b/cpp/wedpr-protocol/grpc/client/FrontClient.cpp new file mode 100644 index 00000000..4f351dc8 --- /dev/null +++ b/cpp/wedpr-protocol/grpc/client/FrontClient.cpp @@ -0,0 +1,51 @@ +/** + * Copyright (C) 2023 WeDPR. + * SPDX-License-Identifier: Apache-2.0 + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * @file FrontClient.cpp + * @author: yujiechen + * @date 2024-09-02 + */ +#include "FrontClient.h" +#include "wedpr-protocol/protobuf/Common.h" + + +using namespace ppc::protocol; +using namespace ppc::proto; +using grpc::Channel; +using grpc::ClientAsyncResponseReader; +using grpc::ClientContext; +using grpc::CompletionQueue; +using grpc::Status; + +void FrontClient::onReceiveMessage(ppc::protocol::Message::Ptr const& msg, ReceiveMsgFunc callback) +{ + // TODO: optimize here + ReceivedMessage receivedMsg; + bcos::bytes encodedData; + msg->encode(encodedData); + receivedMsg.set_data(encodedData.data(), encodedData.size()); + + auto grpcCallback = [callback](ClientContext const&, Status const& status, Error&& response) { + auto error = std::make_shared(response.errorcode(), response.errormessage()); + callback(error); + }; + + auto call = std::make_shared(grpcCallback); + call->responseReader = + m_stub->PrepareAsynconReceiveMessage(&call->context, receivedMsg, &m_client->queue()); + call->responseReader->StartCall(); + // send request, upon completion of the RPC, "reply" be updated with the server's response + call->responseReader->Finish(&call->reply, &call->status, (void*)call.get()); +} \ No newline at end of file diff --git a/cpp/wedpr-protocol/grpc/client/FrontClient.h b/cpp/wedpr-protocol/grpc/client/FrontClient.h new file mode 100644 index 00000000..61205f7c --- /dev/null +++ b/cpp/wedpr-protocol/grpc/client/FrontClient.h @@ -0,0 +1,42 @@ +/** + * Copyright (C) 2023 WeDPR. + * SPDX-License-Identifier: Apache-2.0 + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * @file FrontClient.h + * @author: yujiechen + * @date 2024-09-02 + */ +#pragma once +#include "GrpcClient.h" +#include "ppc-framework/front/IFront.h" + +namespace ppc::protocol +{ +class FrontClient : public ppc::front::IFrontClient +{ +public: + using Ptr = std::shared_ptr; + FrontClient(GrpcClient::Ptr client) + : m_client(std::move(client)), m_stub(ppc::proto::Front::NewStub(m_client->channel())) + {} + + ~FrontClient() override = default; + void onReceiveMessage( + ppc::protocol::Message::Ptr const& _msg, ppc::protocol::ReceiveMsgFunc _callback) override; + +private: + std::unique_ptr m_stub; + GrpcClient::Ptr m_client; +}; +} // namespace ppc::protocol \ No newline at end of file diff --git a/cpp/wedpr-protocol/grpc/client/GatewayClient.cpp b/cpp/wedpr-protocol/grpc/client/GatewayClient.cpp new file mode 100644 index 00000000..bdcb43b9 --- /dev/null +++ b/cpp/wedpr-protocol/grpc/client/GatewayClient.cpp @@ -0,0 +1,41 @@ +/** + * Copyright (C) 2023 WeDPR. + * SPDX-License-Identifier: Apache-2.0 + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * @file GatewayClient.h + * @author: yujiechen + * @date 2024-09-02 + */ +#include "GatewayClient.h" + +using namespace ppc; +using namespace ppc::gateway; +using namespace ppc::protocol; + + +void GatewayClient::start() {} +void GatewayClient::stop() {} + +void GatewayClient::asyncSendMessage(RouteType routeType, + MessageOptionalHeader::Ptr const& routeInfo, bcos::bytes&& payload, long timeout, + ReceiveMsgFunc callback) +{} + +void GatewayClient::asyncSendbroadcastMessage( + RouteType routeType, MessageOptionalHeader::Ptr const& routeInfo, bcos::bytes&& payload) +{} +void GatewayClient::registerNodeInfo(INodeInfo::Ptr const& nodeInfo) {} +void GatewayClient::unRegisterNodeInfo(bcos::bytesConstRef nodeID) {} +void GatewayClient::registerTopic(bcos::bytesConstRef nodeID, std::string const& topic) {} +void GatewayClient::unRegisterTopic(bcos::bytesConstRef nodeID, std::string const& topic) {} \ No newline at end of file diff --git a/cpp/wedpr-protocol/grpc/client/GatewayClient.h b/cpp/wedpr-protocol/grpc/client/GatewayClient.h new file mode 100644 index 00000000..f1e9a012 --- /dev/null +++ b/cpp/wedpr-protocol/grpc/client/GatewayClient.h @@ -0,0 +1,61 @@ +/** + * Copyright (C) 2023 WeDPR. + * SPDX-License-Identifier: Apache-2.0 + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * @file GatewayClient.h + * @author: yujiechen + * @date 2024-09-02 + */ +#pragma once +#include "ppc-framework/gateway/IGateway.h" + +namespace ppc::protocol +{ +class GatewayClient : public ppc::gateway::IGateway +{ +public: + using Ptr = std::shared_ptr; + GatewayClient() = default; + ~GatewayClient() override = default; + + + void start() override; + void stop() override; + + /** + * @brief send message to gateway + * + * @param routeType the route type + * @param topic the topic + * @param dstInst the dst agency(must set when 'route by agency' and 'route by + * component') + * @param dstNodeID the dst nodeID(must set when 'route by nodeID') + * @param componentType the componentType(must set when 'route by component') + * @param payload the payload to send + * @param seq the message seq + * @param timeout timeout + * @param callback callback + */ + void asyncSendMessage(ppc::protocol::RouteType routeType, + ppc::protocol::MessageOptionalHeader::Ptr const& routeInfo, bcos::bytes&& payload, + long timeout, ppc::protocol::ReceiveMsgFunc callback) override; + + void asyncSendbroadcastMessage(ppc::protocol::RouteType routeType, + ppc::protocol::MessageOptionalHeader::Ptr const& routeInfo, bcos::bytes&& payload) override; + void registerNodeInfo(ppc::protocol::INodeInfo::Ptr const& nodeInfo) override; + void unRegisterNodeInfo(bcos::bytesConstRef nodeID) override; + void registerTopic(bcos::bytesConstRef nodeID, std::string const& topic) override; + void unRegisterTopic(bcos::bytesConstRef nodeID, std::string const& topic) override; +}; +} // namespace ppc::protocol \ No newline at end of file diff --git a/cpp/wedpr-protocol/grpc/client/GrpcClient.cpp b/cpp/wedpr-protocol/grpc/client/GrpcClient.cpp new file mode 100644 index 00000000..019e5c10 --- /dev/null +++ b/cpp/wedpr-protocol/grpc/client/GrpcClient.cpp @@ -0,0 +1,56 @@ +/** + * Copyright (C) 2023 WeDPR. + * SPDX-License-Identifier: Apache-2.0 + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * @file GrpcClient.cpp + * @author: yujiechen + * @date 2024-09-02 + */ +#include "GrpcClient.h" +#include "Common.h" + + +using namespace ppc::protocol; +using namespace grpc; + +void GrpcClient::handleRpcResponse() +{ + void* callback; + bool ok = false; + // Block until the next result is available in the completion queue "m_queue". + while (m_queue.Next(&callback, &ok)) + { + try + { + // The tag in this example is the memory location of the call object + // Note: the should been managed by shared_ptr + AsyncClientCall* call = static_cast(callback); + + // Verify that the request was completed successfully. Note that "ok" + // corresponds solely to the request for updates introduced by Finish(). + if (!ok) + { + GRPC_CLIENT_LOG(WARNING) + << LOG_DESC("handleRpcResponse: receive response with unormal status"); + return; + } + call->callback(call->context, call->status, std::move(call->reply)); + } + catch (std::exception const& e) + { + GRPC_CLIENT_LOG(WARNING) << LOG_DESC("handleRpcResponse exception") + << LOG_KV("error", boost::diagnostic_information(e)); + } + } +} \ No newline at end of file diff --git a/cpp/wedpr-protocol/grpc/client/GrpcClient.h b/cpp/wedpr-protocol/grpc/client/GrpcClient.h new file mode 100644 index 00000000..b00437ac --- /dev/null +++ b/cpp/wedpr-protocol/grpc/client/GrpcClient.h @@ -0,0 +1,65 @@ +/** + * Copyright (C) 2023 WeDPR. + * SPDX-License-Identifier: Apache-2.0 + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * @file GrpcClient.h + * @author: yujiechen + * @date 2024-09-02 + */ +#pragma once +#include "Service.grpc.pb.h" +#include + +namespace ppc::protocol +{ +// struct for keeping state and data information +class AsyncClientCall +{ +public: + using CallbackDef = + std::function; + AsyncClientCall(CallbackDef _callback) : callback(std::move(_callback)) {} + + CallbackDef callback; + // Container for the data we expect from the server. + ppc::proto::Error reply; + // Context for the client. It could be used to convey extra information to + // the server and/or tweak certain RPC behaviors. + grpc::ClientContext context; + // Storage for the status of the RPC upon completion. + grpc::Status status; + std::unique_ptr> responseReader; +}; + +class GrpcClient +{ +public: + using Ptr = std::shared_ptr; + GrpcClient(std::shared_ptr channel) : m_channel(std::move(channel)) {} + + virtual ~GrpcClient() = default; + + std::shared_ptr const& channel() { return m_channel; } + grpc::CompletionQueue& queue() { return m_queue; } + + void handleRpcResponse(); + +private: + std::shared_ptr m_channel; + // The producer-consumer queue we use to communicate asynchronously with the + // gRPC runtime. + // TODO: check threadsafe + grpc::CompletionQueue m_queue; +}; +} // namespace ppc::protocol \ No newline at end of file diff --git a/cpp/wedpr-protocol/proto/pb/NodeInfo.proto b/cpp/wedpr-protocol/proto/pb/NodeInfo.proto new file mode 100644 index 00000000..f238ee1e --- /dev/null +++ b/cpp/wedpr-protocol/proto/pb/NodeInfo.proto @@ -0,0 +1,19 @@ +syntax = "proto3"; +package ppc.proto; + +message NodeInfo{ + // the nodeID + bytes nodeID = 1; + // the endPoint + string endPoint = 2; + // the components + repeated string components = 3; + string topic = 4; +}; + +message GatewayNodeInfo{ + string p2pNodeID = 1; + string agency = 2; + repeated NodeInfo nodeList = 3; + int32 statusSeq = 4; +}; \ No newline at end of file diff --git a/cpp/wedpr-protocol/proto/pb/Service.proto b/cpp/wedpr-protocol/proto/pb/Service.proto new file mode 100644 index 00000000..b41fd3c4 --- /dev/null +++ b/cpp/wedpr-protocol/proto/pb/Service.proto @@ -0,0 +1,43 @@ +syntax = "proto3"; +import "NodeInfo.proto"; +option java_package = "ppc.proto.grpc"; + +package ppc.proto; + +message Error{ + // the errorCode + int64 errorCode = 1; + // the errorMessage + string errorMessage = 2; +}; + +message ReceivedMessage{ + bytes data = 1; +}; + +service Front { + rpc onReceiveMessage (ReceivedMessage) returns (Error) {} +} + +message RouteInfo{ + string topic = 1; + string componentType = 2; + bytes srcNode = 3; + bytes dstNode = 4; + bytes dstInst = 5; +}; + +message SendedMessageRequest{ + int32 routeType = 1; + RouteInfo routeInfo = 2; + bytes payload = 3; + int64 timeout = 4; +}; + +service Gateway{ + rpc asyncSendMessage(SendedMessageRequest) returns(Error){} + rpc registerNodeInfo(NodeInfo) returns(Error){} + rpc unRegisterNodeInfo(NodeInfo)returns(Error){} + rpc registerTopic(NodeInfo) returns(Error){} + rpc unRegisterTopic(NodeInfo) returns(Error){} +}; \ No newline at end of file diff --git a/cpp/ppc-tars-protocol/ppc-tars-protocol/tars/RouterTable.tars b/cpp/wedpr-protocol/proto/tars/RouterTable.tars similarity index 100% rename from cpp/ppc-tars-protocol/ppc-tars-protocol/tars/RouterTable.tars rename to cpp/wedpr-protocol/proto/tars/RouterTable.tars diff --git a/cpp/wedpr-protocol/protobuf/CMakeLists.txt b/cpp/wedpr-protocol/protobuf/CMakeLists.txt new file mode 100644 index 00000000..6dd7bc81 --- /dev/null +++ b/cpp/wedpr-protocol/protobuf/CMakeLists.txt @@ -0,0 +1,27 @@ +# proto generation +set(PROTO_INPUT_PATH ${CMAKE_SOURCE_DIR}/wedpr-protocol/proto/pb) + +file(GLOB_RECURSE MESSAGES_PROTOS "${PROTO_INPUT_PATH}/*.proto") + +find_program(PROTOC_BINARY protoc REQUIRED) + +# create PROTO_OUTPUT_PATH +file(MAKE_DIRECTORY ${PROTO_OUTPUT_PATH}) +foreach(proto_file ${MESSAGES_PROTOS}) + get_filename_component(basename ${proto_file} NAME_WE) + set(generated_file ${PROTO_OUTPUT_PATH}/${basename}.pb.cc) + + list(APPEND MESSAGES_SRCS ${generated_file}) + + message("Command: protoc --cpp_out ${PROTO_OUTPUT_PATH} -I ${PROTO_INPUT_PATH} ${proto_file}") + add_custom_command( + OUTPUT ${generated_file} + COMMAND ${PROTOC_BINARY} --cpp_out ${PROTO_OUTPUT_PATH} -I ${PROTO_INPUT_PATH} ${proto_file} + COMMENT "Generating ${generated_file} from ${proto_file}" + VERBATIM + ) +endforeach() + +file(GLOB_RECURSE SRCS *.cpp) +add_library(${PB_PROTOCOL_TARGET} ${SRCS} ${MESSAGES_SRCS}) +target_link_libraries(${PB_PROTOCOL_TARGET} PUBLIC ${BCOS_UTILITIES_TARGET}) \ No newline at end of file diff --git a/cpp/wedpr-protocol/protobuf/Common.h b/cpp/wedpr-protocol/protobuf/Common.h new file mode 100644 index 00000000..09796003 --- /dev/null +++ b/cpp/wedpr-protocol/protobuf/Common.h @@ -0,0 +1,50 @@ +/** + * Copyright (C) 2021 FISCO BCOS. + * SPDX-License-Identifier: Apache-2.0 + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * @file Common.h + * @author: yujiechen + * @date 2021-04-12 + */ +#pragma once +#include "bcos-utilities/Common.h" +#include "ppc-framework/Common.h" + +namespace ppc::protocol +{ +DERIVE_PPC_EXCEPTION(PBObjectEncodeException); +DERIVE_PPC_EXCEPTION(PBObjectDecodeException); + +template +void encodePBObject(bcos::bytes& _encodedData, T _pbObject) +{ + auto encodedData = std::make_shared(); + _encodedData.resize(_pbObject->ByteSizeLong()); + if (!_pbObject->SerializeToArray(_encodedData.data(), _encodedData.size())) + { + BOOST_THROW_EXCEPTION(PBObjectEncodeException() + << bcos::errinfo_comment("encode PBObject into bytes data failed")); + } +} + +template +void decodePBObject(T _pbObject, bcos::bytesConstRef _data) +{ + if (!_pbObject->ParseFromArray(_data.data(), _data.size())) + { + BOOST_THROW_EXCEPTION(PBObjectDecodeException() + << bcos::errinfo_comment("decode bytes data into PBObject failed")); + } +} +} // namespace ppc::protocol diff --git a/cpp/ppc-tars-protocol/ppc-tars-protocol/impl/NodeInfoImpl.cpp b/cpp/wedpr-protocol/protobuf/NodeInfoImpl.cpp similarity index 67% rename from cpp/ppc-tars-protocol/ppc-tars-protocol/impl/NodeInfoImpl.cpp rename to cpp/wedpr-protocol/protobuf/NodeInfoImpl.cpp index e8f585d4..742d53a8 100644 --- a/cpp/ppc-tars-protocol/ppc-tars-protocol/impl/NodeInfoImpl.cpp +++ b/cpp/wedpr-protocol/protobuf/NodeInfoImpl.cpp @@ -19,22 +19,22 @@ */ #include "NodeInfoImpl.h" -#include "../Common.h" +#include "Common.h" -using namespace ppctars; using namespace ppc::protocol; void NodeInfoImpl::encode(bcos::bytes& data) const { - tars::TarsOutputStream output; - m_inner()->writeTo(output); - output.getByteBuffer().swap(data); + // set the components + for (auto const& component : m_components) + { + m_inner()->add_components(component); + } + encodePBObject(data, m_inner()); } void NodeInfoImpl::decode(bcos::bytesConstRef data) { - tars::TarsInputStream input; - input.setBuffer((const char*)data.data(), data.size()); - m_inner()->readFrom(input); + decodePBObject(m_inner(), data); m_components = - std::set(m_inner()->components.begin(), m_inner()->components.end()); + std::set(m_inner()->components().begin(), m_inner()->components().end()); } \ No newline at end of file diff --git a/cpp/ppc-tars-protocol/ppc-tars-protocol/impl/NodeInfoImpl.h b/cpp/wedpr-protocol/protobuf/NodeInfoImpl.h similarity index 69% rename from cpp/ppc-tars-protocol/ppc-tars-protocol/impl/NodeInfoImpl.h rename to cpp/wedpr-protocol/protobuf/NodeInfoImpl.h index ff4e1627..1a5a4000 100644 --- a/cpp/ppc-tars-protocol/ppc-tars-protocol/impl/NodeInfoImpl.h +++ b/cpp/wedpr-protocol/protobuf/NodeInfoImpl.h @@ -18,8 +18,8 @@ * @date 2024-08-26 */ #pragma once +#include "NodeInfo.pb.h" #include "ppc-framework/protocol/INodeInfo.h" -#include "ppc-tars-protocol/tars/NodeInfo.h" #include namespace ppc::protocol { @@ -28,46 +28,49 @@ class NodeInfoImpl : public INodeInfo { public: using Ptr = std::shared_ptr; - explicit NodeInfoImpl(std::function inner) : m_inner(std::move(inner)) {} + explicit NodeInfoImpl(std::function inner) : m_inner(std::move(inner)) + {} NodeInfoImpl(bcos::bytesConstRef const& nodeID) - : m_inner([inner = ppctars::NodeInfo()]() mutable { return &inner; }) + : m_inner([inner = ppc::proto::NodeInfo()]() mutable { return &inner; }) { - m_inner()->nodeID = std::vector(nodeID.begin(), nodeID.end()); + m_inner()->set_nodeid(nodeID.data(), nodeID.size()); } NodeInfoImpl(bcos::bytesConstRef const& nodeID, std::string const& endPoint) : NodeInfoImpl(nodeID) { - m_inner()->endPoint = endPoint; + m_inner()->set_endpoint(endPoint); } ~NodeInfoImpl() override = default; - void setComponents(std::vector const& components) override + void setComponents(std::set const& components) override { - m_components = std::set(components.begin(), components.end()); - m_inner()->components = components; + m_components = components; } std::set const& components() const override { return m_components; } - std::string const& endPoint() const override { return m_inner()->endPoint; } + std::string const& endPoint() const override { return m_inner()->endpoint(); } bcos::bytesConstRef nodeID() const override { - return {reinterpret_cast(m_inner()->nodeID.data()), - m_inner()->nodeID.size()}; + return {reinterpret_cast(m_inner()->nodeid().data()), + m_inner()->nodeid().size()}; } void encode(bcos::bytes& data) const override; void decode(bcos::bytesConstRef data) override; - ppctars::NodeInfo const& inner() { return *(m_inner()); } + std::function innerFunc() { return m_inner; } - void setFront(ppc::front::IFront::Ptr&& front) override { m_front = std::move(front); } - ppc::front::IFront::Ptr const& getFront() const override { return m_front; } + void setFront(std::shared_ptr&& front) override + { + m_front = std::move(front); + } + std::shared_ptr const& getFront() const override { return m_front; } private: - ppc::front::IFront::Ptr m_front; + std::shared_ptr m_front; std::set m_components; - std::function m_inner; + std::function m_inner; }; class NodeInfoFactory : public INodeInfoFactory diff --git a/cpp/ppc-protocol/CMakeLists.txt b/cpp/wedpr-protocol/protocol/CMakeLists.txt similarity index 100% rename from cpp/ppc-protocol/CMakeLists.txt rename to cpp/wedpr-protocol/protocol/CMakeLists.txt diff --git a/cpp/ppc-protocol/src/CMakeLists.txt b/cpp/wedpr-protocol/protocol/src/CMakeLists.txt similarity index 100% rename from cpp/ppc-protocol/src/CMakeLists.txt rename to cpp/wedpr-protocol/protocol/src/CMakeLists.txt diff --git a/cpp/ppc-protocol/src/JsonTaskImpl.cpp b/cpp/wedpr-protocol/protocol/src/JsonTaskImpl.cpp similarity index 100% rename from cpp/ppc-protocol/src/JsonTaskImpl.cpp rename to cpp/wedpr-protocol/protocol/src/JsonTaskImpl.cpp diff --git a/cpp/ppc-protocol/src/JsonTaskImpl.h b/cpp/wedpr-protocol/protocol/src/JsonTaskImpl.h similarity index 100% rename from cpp/ppc-protocol/src/JsonTaskImpl.h rename to cpp/wedpr-protocol/protocol/src/JsonTaskImpl.h diff --git a/cpp/ppc-protocol/src/PPCMessage.cpp b/cpp/wedpr-protocol/protocol/src/PPCMessage.cpp similarity index 62% rename from cpp/ppc-protocol/src/PPCMessage.cpp rename to cpp/wedpr-protocol/protocol/src/PPCMessage.cpp index 3f724e87..6c2a0fa0 100644 --- a/cpp/ppc-protocol/src/PPCMessage.cpp +++ b/cpp/wedpr-protocol/protocol/src/PPCMessage.cpp @@ -29,29 +29,12 @@ void PPCMessage::encode(bytes& _buffer) { _buffer.clear(); - uint32_t seq = boost::asio::detail::socket_ops::host_to_network_long(m_seq); - uint16_t taskIDLength = boost::asio::detail::socket_ops::host_to_network_short(m_taskID.size()); - uint16_t senderLength = boost::asio::detail::socket_ops::host_to_network_short(m_sender.size()); uint32_t dataLength = boost::asio::detail::socket_ops::host_to_network_long(m_data->size()); - uint16_t ext = boost::asio::detail::socket_ops::host_to_network_short(m_ext); _buffer.insert(_buffer.end(), (byte*)&m_version, (byte*)&m_version + 1); _buffer.insert(_buffer.end(), (byte*)&m_taskType, (byte*)&m_taskType + 1); _buffer.insert(_buffer.end(), (byte*)&m_algorithmType, (byte*)&m_algorithmType + 1); _buffer.insert(_buffer.end(), (byte*)&m_messageType, (byte*)&m_messageType + 1); - _buffer.insert(_buffer.end(), (byte*)&seq, (byte*)&seq + 4); - _buffer.insert(_buffer.end(), (byte*)&taskIDLength, (byte*)&taskIDLength + 2); - _buffer.insert(_buffer.end(), m_taskID.begin(), m_taskID.end()); - _buffer.insert(_buffer.end(), (byte*)&senderLength, (byte*)&senderLength + 2); - _buffer.insert(_buffer.end(), m_sender.begin(), m_sender.end()); - _buffer.insert(_buffer.end(), (byte*)&ext, (byte*)&ext + 2); - // encode the uuid: uuidLen, uuidData - auto uuidLen = m_uuid.size(); - _buffer.insert(_buffer.end(), (byte*)&uuidLen, (byte*)&uuidLen + 1); - if (uuidLen > 0) - { - _buffer.insert(_buffer.end(), m_uuid.begin(), m_uuid.end()); - } // encode the data: dataLen, dataData _buffer.insert(_buffer.end(), (byte*)&dataLength, (byte*)&dataLength + 4); if (dataLength > 0) @@ -99,53 +82,6 @@ int64_t PPCMessage::decode(uint32_t _length, bcos::byte* _data) m_messageType = *((uint8_t*)p); p += 1; - // seq field - m_seq = boost::asio::detail::socket_ops::network_to_host_long(*((uint32_t*)p)); - p += 4; - - // taskIDLength - uint16_t taskIDLength = boost::asio::detail::socket_ops::network_to_host_short(*((uint16_t*)p)); - p += 2; - minLen += taskIDLength; - if (_length < minLen) - { - return -1; - } - - // taskID field - m_taskID.insert(m_taskID.begin(), p, p + taskIDLength); - p += taskIDLength; - - // senderLength - uint16_t senderLength = boost::asio::detail::socket_ops::network_to_host_short(*((uint16_t*)p)); - p += 2; - minLen += senderLength; - if (_length < minLen) - { - return -1; - } - // sender field - m_sender.insert(m_sender.begin(), p, p + senderLength); - p += senderLength; - - // ext field - m_ext = boost::asio::detail::socket_ops::network_to_host_short(*((uint16_t*)p)); - p += 2; - - // decode the uuid - auto uuidLen = *((byte*)p); - p += 1; - minLen += uuidLen; - if (_length < minLen) - { - return -1; - } - if (uuidLen > 0) - { - m_uuid.assign(p, p + uuidLen); - p += uuidLen; - } - // dataLength uint32_t dataLength = boost::asio::detail::socket_ops::network_to_host_long(*((uint32_t*)p)); p += 4; diff --git a/cpp/ppc-protocol/src/PPCMessage.h b/cpp/wedpr-protocol/protocol/src/PPCMessage.h similarity index 82% rename from cpp/ppc-protocol/src/PPCMessage.h rename to cpp/wedpr-protocol/protocol/src/PPCMessage.h index ca34657b..fc6983f6 100644 --- a/cpp/ppc-protocol/src/PPCMessage.h +++ b/cpp/wedpr-protocol/protocol/src/PPCMessage.h @@ -36,10 +36,10 @@ namespace front class PPCMessage : public PPCMessageFace { public: - // version(1) + taskType(1) + algorithmType(1) + messageType(1) + seq(4) - // + taskIDLength(2) + senderLength(2) + ext(2) + uuidLen(1) + dataLen(4) + data(N) + // version(1) + taskType(1) + algorithmType(1) + messageType(1) + // + dataLen(4) + data(N) // + header(M) - const static size_t MESSAGE_MIN_LENGTH = 15; + const static size_t MESSAGE_MIN_LENGTH = 8; using Ptr = std::shared_ptr; PPCMessage() { m_data = std::make_shared(); } @@ -59,8 +59,6 @@ class PPCMessage : public PPCMessageFace void setTaskID(std::string const& _taskID) override { m_taskID = _taskID; } std::string const& sender() const override { return m_sender; } void setSender(std::string const& _sender) override { m_sender = _sender; } - virtual uint16_t ext() const override { return m_ext; } - virtual void setExt(uint16_t _ext) override { m_ext = _ext; } std::shared_ptr data() const override { return m_data; } // Note: here directly use passed-in _data, make-sure _data not changed before send the message void setData(std::shared_ptr _data) override { m_data = _data; } @@ -81,9 +79,9 @@ class PPCMessage : public PPCMessageFace uint32_t length() const override { return m_length; } // determine the message is response or not - bool response() const override { return m_ext & MessageExtFlag::ResponseFlag; } + bool response() const override { return m_isResponse; } // set the message to be response - void setResponse() override { m_ext |= MessageExtFlag::ResponseFlag; } + void setResponse() override { m_isResponse = true; } protected: std::string encodeMap(const std::map& _map); @@ -97,7 +95,7 @@ class PPCMessage : public PPCMessageFace uint32_t m_seq = 0; std::string m_taskID; std::string m_sender; - uint16_t m_ext = 0; + bool m_isResponse; // the uuid used to find the response-callback std::string m_uuid; std::shared_ptr m_data; @@ -149,6 +147,28 @@ class PPCMessageFactory : public PPCMessageFaceFactory } return msg; } + + PPCMessageFace::Ptr buildPPCMessage(ppc::protocol::Message::Ptr msg) override + { + auto ppcMsg = buildPPCMessage(); + auto frontMsg = msg->frontMessage(); + if (frontMsg) + { + ppcMsg->setSeq(frontMsg->seq()); + ppcMsg->setUuid(frontMsg->traceID()); + if (frontMsg->isRespPacket()) + { + ppcMsg->setResponse(); + } + } + if (msg->header() && msg->header()->optionalField()) + { + auto const& routeInfo = msg->header()->optionalField(); + ppcMsg->setTaskID(routeInfo->topic()); + ppcMsg->setSender(routeInfo->srcInst()); + } + return ppcMsg; + } }; } // namespace front diff --git a/cpp/ppc-protocol/src/v1/MessageHeaderImpl.cpp b/cpp/wedpr-protocol/protocol/src/v1/MessageHeaderImpl.cpp similarity index 94% rename from cpp/ppc-protocol/src/v1/MessageHeaderImpl.cpp rename to cpp/wedpr-protocol/protocol/src/v1/MessageHeaderImpl.cpp index f14e1cac..117f304c 100644 --- a/cpp/ppc-protocol/src/v1/MessageHeaderImpl.cpp +++ b/cpp/wedpr-protocol/protocol/src/v1/MessageHeaderImpl.cpp @@ -37,6 +37,10 @@ void MessageOptionalHeaderImpl::encode(bcos::bytes& buffer) const uint16_t srcNodeLen = boost::asio::detail::socket_ops::host_to_network_short(m_srcNode.size()); buffer.insert(buffer.end(), (byte*)&srcNodeLen, (byte*)&srcNodeLen + 2); buffer.insert(buffer.end(), m_srcNode.begin(), m_srcNode.end()); + // the source agency + uint16_t srcInstLen = boost::asio::detail::socket_ops::host_to_network_short(m_srcInst.size()); + buffer.insert(buffer.end(), (byte*)&srcInstLen, (byte*)&srcInstLen + 2); + buffer.insert(buffer.end(), m_srcInst.begin(), m_srcInst.end()); // the target nodeID that should receive the message uint16_t dstNodeLen = boost::asio::detail::socket_ops::host_to_network_short(m_dstNode.size()); buffer.insert(buffer.end(), (byte*)&dstNodeLen, (byte*)&dstNodeLen + 2); @@ -61,6 +65,10 @@ int64_t MessageOptionalHeaderImpl::decode(bcos::bytesConstRef data, uint64_t con m_componentType = std::string(componentType.begin(), componentType.end()); // srcNode offset = decodeNetworkBuffer(m_srcNode, data.data(), data.size(), offset); + // source inst + bcos::bytes sourceInst; + offset = decodeNetworkBuffer(sourceInst, data.data(), data.size(), offset); + m_srcInst = std::string(sourceInst.begin(), sourceInst.end()); // dstNode offset = decodeNetworkBuffer(m_dstNode, data.data(), data.size(), offset); // dstInst, TODO: optimize here diff --git a/cpp/ppc-protocol/src/v1/MessageHeaderImpl.h b/cpp/wedpr-protocol/protocol/src/v1/MessageHeaderImpl.h similarity index 72% rename from cpp/ppc-protocol/src/v1/MessageHeaderImpl.h rename to cpp/wedpr-protocol/protocol/src/v1/MessageHeaderImpl.h index 9d9c3850..17f4b99d 100644 --- a/cpp/ppc-protocol/src/v1/MessageHeaderImpl.h +++ b/cpp/wedpr-protocol/protocol/src/v1/MessageHeaderImpl.h @@ -28,6 +28,18 @@ class MessageOptionalHeaderImpl : public MessageOptionalHeader public: using Ptr = std::shared_ptr; MessageOptionalHeaderImpl() = default; + MessageOptionalHeaderImpl(MessageOptionalHeader::Ptr const& optionalHeader) + { + if (!optionalHeader) + { + return; + } + setTopic(optionalHeader->topic()); + setComponentType(optionalHeader->componentType()); + setSrcNode(optionalHeader->srcNode()); + setDstNode(optionalHeader->dstNode()); + setDstInst(optionalHeader->dstInst()); + } MessageOptionalHeaderImpl(bcos::bytesConstRef data, uint64_t const offset) { decode(data, offset); @@ -50,7 +62,7 @@ class MessageHeaderImpl : public MessageHeader void encode(bcos::bytes& buffer) const override; int64_t decode(bcos::bytesConstRef data) override; - virtual bool hasOptionalField() const + bool hasOptionalField() const override { return m_packetType == (uint16_t)ppc::gateway::GatewayPacketType::P2PMessage; } @@ -81,5 +93,21 @@ class MessageHeaderBuilderImpl : public MessageHeaderBuilder return std::make_shared(data); } MessageHeader::Ptr build() override { return std::make_shared(); } + MessageOptionalHeader::Ptr build(MessageOptionalHeader::Ptr const& optionalHeader) override + { + return std::make_shared(optionalHeader); + } +}; +class MessageOptionalHeaderBuilderImpl : public MessageOptionalHeaderBuilder +{ +public: + using Ptr = std::shared_ptr; + MessageOptionalHeaderBuilderImpl() = default; + ~MessageOptionalHeaderBuilderImpl() override = default; + + MessageOptionalHeader::Ptr build(MessageOptionalHeader::Ptr const& optionalHeader) override + { + return std::make_shared(optionalHeader); + } }; } // namespace ppc::protocol \ No newline at end of file diff --git a/cpp/ppc-protocol/src/v1/MessageImpl.cpp b/cpp/wedpr-protocol/protocol/src/v1/MessageImpl.cpp similarity index 100% rename from cpp/ppc-protocol/src/v1/MessageImpl.cpp rename to cpp/wedpr-protocol/protocol/src/v1/MessageImpl.cpp diff --git a/cpp/ppc-protocol/src/v1/MessageImpl.h b/cpp/wedpr-protocol/protocol/src/v1/MessageImpl.h similarity index 81% rename from cpp/ppc-protocol/src/v1/MessageImpl.h rename to cpp/wedpr-protocol/protocol/src/v1/MessageImpl.h index aec9cb3d..e261a7a2 100644 --- a/cpp/ppc-protocol/src/v1/MessageImpl.h +++ b/cpp/wedpr-protocol/protocol/src/v1/MessageImpl.h @@ -20,6 +20,7 @@ #pragma once #include "ppc-framework/Common.h" #include "ppc-framework/protocol/Message.h" +#include "ppc-utilities/Utilities.h" #include #include #include @@ -78,16 +79,13 @@ class MessageBuilderImpl : public MessageBuilder { return std::make_shared(m_msgHeaderBuilder, m_maxMessageLen, buffer); } - Message::Ptr build(ppc::protocol::RouteType routeType, std::string const& topic, - std::string const& dstInst, bcos::bytes const& dstNodeID, std::string const& componentType, - bcos::bytes&& payload) override + + Message::Ptr build(ppc::protocol::RouteType routeType, + ppc::protocol::MessageOptionalHeader::Ptr const& routeInfo, bcos::bytes&& payload) override { auto msg = build(); msg->header()->setRouteType(routeType); - msg->header()->optionalField()->setDstInst(dstInst); - msg->header()->optionalField()->setDstNode(dstNodeID); - msg->header()->optionalField()->setTopic(topic); - msg->header()->optionalField()->setComponentType(componentType); + msg->header()->setOptionalField(routeInfo); msg->setPayload(std::make_shared(std::move(payload))); return msg; } @@ -97,12 +95,11 @@ class MessageBuilderImpl : public MessageBuilder return std::make_shared(m_msgHeaderBuilder, m_maxMessageLen); } - std::string newSeq() override + virtual MessageOptionalHeader::Ptr build(MessageOptionalHeader::Ptr const& optionalHeader) { - std::string seq = boost::uuids::to_string(boost::uuids::random_generator()()); - seq.erase(std::remove(seq.begin(), seq.end(), '-'), seq.end()); - return seq; + return m_msgHeaderBuilder->build(optionalHeader); } + std::string newSeq() override { return generateUUID(); } private: MessageHeaderBuilder::Ptr m_msgHeaderBuilder; diff --git a/cpp/ppc-protocol/src/v1/MessagePayloadImpl.cpp b/cpp/wedpr-protocol/protocol/src/v1/MessagePayloadImpl.cpp similarity index 72% rename from cpp/ppc-protocol/src/v1/MessagePayloadImpl.cpp rename to cpp/wedpr-protocol/protocol/src/v1/MessagePayloadImpl.cpp index 360567b4..7f1f676a 100644 --- a/cpp/ppc-protocol/src/v1/MessagePayloadImpl.cpp +++ b/cpp/wedpr-protocol/protocol/src/v1/MessagePayloadImpl.cpp @@ -34,6 +34,13 @@ int64_t MessagePayloadImpl::encode(bcos::bytes& buffer) const // seq uint16_t seq = boost::asio::detail::socket_ops::host_to_network_short(m_seq); buffer.insert(buffer.end(), (byte*)&seq, (byte*)&seq + 2); + // ext field + uint16_t ext = boost::asio::detail::socket_ops::host_to_network_short(m_ext); + buffer.insert(buffer.end(), (byte*)&ext, (byte*)&ext + 2); + // traceID + uint16_t traceIDLen = boost::asio::detail::socket_ops::host_to_network_short(m_traceID.size()); + buffer.insert(buffer.end(), (byte*)&traceIDLen, (byte*)&traceIDLen + 2); + buffer.insert(buffer.end(), m_traceID.begin(), m_traceID.end()); // data uint16_t dataLen = boost::asio::detail::socket_ops::host_to_network_short(m_data.size()); buffer.insert(buffer.end(), (byte*)&dataLen, (byte*)&dataLen + 2); @@ -59,6 +66,15 @@ int64_t MessagePayloadImpl::decode(bcos::bytesConstRef buffer) CHECK_OFFSET_WITH_THROW_EXCEPTION((pointer - buffer.data()), buffer.size()); m_seq = boost::asio::detail::socket_ops::network_to_host_short(*((uint16_t*)pointer)); pointer += 2; + // the ext + CHECK_OFFSET_WITH_THROW_EXCEPTION((pointer - buffer.data()), buffer.size()); + m_ext = boost::asio::detail::socket_ops::network_to_host_short(*((uint16_t*)pointer)); + pointer += 2; + // the traceID + bcos::bytes traceID; + auto offset = + decodeNetworkBuffer(traceID, buffer.data(), buffer.size(), (pointer - buffer.data())); + m_traceID = std::string(traceID.begin(), traceID.end()); // data - return decodeNetworkBuffer(m_data, buffer.data(), buffer.size(), (pointer - buffer.data())); + return decodeNetworkBuffer(m_data, buffer.data(), buffer.size(), offset); } \ No newline at end of file diff --git a/cpp/ppc-protocol/src/v1/MessagePayloadImpl.h b/cpp/wedpr-protocol/protocol/src/v1/MessagePayloadImpl.h similarity index 100% rename from cpp/ppc-protocol/src/v1/MessagePayloadImpl.h rename to cpp/wedpr-protocol/protocol/src/v1/MessagePayloadImpl.h diff --git a/cpp/ppc-protocol/tests/CMakeLists.txt b/cpp/wedpr-protocol/protocol/tests/CMakeLists.txt similarity index 100% rename from cpp/ppc-protocol/tests/CMakeLists.txt rename to cpp/wedpr-protocol/protocol/tests/CMakeLists.txt diff --git a/cpp/ppc-protocol/tests/PPCMessageTest.cpp b/cpp/wedpr-protocol/protocol/tests/PPCMessageTest.cpp similarity index 100% rename from cpp/ppc-protocol/tests/PPCMessageTest.cpp rename to cpp/wedpr-protocol/protocol/tests/PPCMessageTest.cpp diff --git a/cpp/ppc-protocol/tests/TestTaskImpl.cpp b/cpp/wedpr-protocol/protocol/tests/TestTaskImpl.cpp similarity index 100% rename from cpp/ppc-protocol/tests/TestTaskImpl.cpp rename to cpp/wedpr-protocol/protocol/tests/TestTaskImpl.cpp diff --git a/cpp/ppc-protocol/tests/main.cpp b/cpp/wedpr-protocol/protocol/tests/main.cpp similarity index 100% rename from cpp/ppc-protocol/tests/main.cpp rename to cpp/wedpr-protocol/protocol/tests/main.cpp diff --git a/cpp/ppc-tars-protocol/CMakeLists.txt b/cpp/wedpr-protocol/tars/CMakeLists.txt similarity index 67% rename from cpp/ppc-tars-protocol/CMakeLists.txt rename to cpp/wedpr-protocol/tars/CMakeLists.txt index 7c684a1a..d016016e 100644 --- a/cpp/ppc-tars-protocol/CMakeLists.txt +++ b/cpp/wedpr-protocol/tars/CMakeLists.txt @@ -1,13 +1,14 @@ cmake_minimum_required(VERSION 3.14) include(Version) -project(ppc-tars-protocol VERSION ${VERSION}) +project(wedpr-tars-protocol VERSION ${VERSION}) # for tars generator -set(TARS_HEADER_DIR ${CMAKE_BINARY_DIR}/generated/ppc-tars-protocol/tars) +set(TARS_HEADER_DIR ${CMAKE_BINARY_DIR}/generated/tars) find_program(TARS_TARS2CPP tars2cpp REQUIRED) -file(GLOB_RECURSE TARS_INPUT "*.tars") +set(PROTO_INPUT_PATH ${CMAKE_SOURCE_DIR}/wedpr-protocol/proto/tars) +file(GLOB_RECURSE TARS_INPUT "${PROTO_INPUT_PATH}/*.tars") # generate tars if (TARS_INPUT) @@ -28,7 +29,7 @@ endif () set_directory_properties(PROPERTIES ADDITIONAL_MAKE_CLEAN_FILES "${OUT_TARS_H_LIST}") -file(GLOB_RECURSE SRC_LIST "ppc-tars-protocol/*.cpp" "ppc-tars-protocol/*.h") +file(GLOB_RECURSE SRC_LIST *.cpp) find_package(tarscpp REQUIRED) add_library(${TARS_PROTOCOL_TARGET} ${SRC_LIST} ${OUT_TARS_H_LIST}) @@ -36,7 +37,7 @@ target_include_directories(${TARS_PROTOCOL_TARGET} PUBLIC $ $ $ - $) + $) target_link_libraries(${TARS_PROTOCOL_TARGET} PUBLIC ${BCOS_UTILITIES_TARGET} tarscpp::tarsservant tarscpp::tarsutil) # ut @@ -45,9 +46,3 @@ if (TESTS) set(CTEST_OUTPUT_ON_FAILURE TRUE) add_subdirectory(test) endif () - - -include(GNUInstallDirs) -#install(TARGETS ${TARS_PROTOCOL_TARGET} EXPORT ppcTargets ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}") -install(DIRECTORY "ppc-tars-protocol" DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}" FILES_MATCHING PATTERN "*.h") -install(DIRECTORY "${CMAKE_BINARY_DIR}/generated/ppc-tars-protocol" DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}" FILES_MATCHING PATTERN "*.h") \ No newline at end of file diff --git a/cpp/ppc-tars-protocol/ppc-tars-protocol/Common.h b/cpp/wedpr-protocol/tars/Common.h similarity index 79% rename from cpp/ppc-tars-protocol/ppc-tars-protocol/Common.h rename to cpp/wedpr-protocol/tars/Common.h index 576e3ec6..90f62d33 100644 --- a/cpp/ppc-tars-protocol/ppc-tars-protocol/Common.h +++ b/cpp/wedpr-protocol/tars/Common.h @@ -20,14 +20,10 @@ #pragma once -#include "Error.h" #include "TarsServantProxyCallback.h" -#include "TaskInfo.h" -#include "ppc-framework/gateway/GatewayInterface.h" #include "ppc-framework/protocol/Task.h" #include "ppc-tools/src/config/ParamChecker.h" #include -#include #include #include #include @@ -118,97 +114,6 @@ using BufferWriterStdByteVector = BufferWriter>; using BufferWriterString = BufferWriter; } // namespace serialize - -template -bool checkConnection(std::string const& _module, std::string const& _func, const T& prx, - std::function _errorCallback, bool _callsErrorCallback = true) -{ - auto cb = prx->tars_get_push_callback(); - assert(cb); - auto* tarsServantProxyCallback = (TarsServantProxyCallback*)cb.get(); - - if (tarsServantProxyCallback->available()) - { - return true; - } - - if (_errorCallback && _callsErrorCallback) - { - std::string errorMessage = - _module + " calls interface " + _func + " failed for empty connection"; - _errorCallback(std::make_shared(-1, errorMessage)); - } - return false; -} - - -inline ppctars::TaskInfo toTarsTaskInfo(ppc::protocol::GatewayTaskInfo::Ptr _taskInfo) -{ - ppctars::TaskInfo tarsTaskInfo; - if (!_taskInfo) - { - return tarsTaskInfo; - } - - tarsTaskInfo.taskID = _taskInfo->taskID; - tarsTaskInfo.serviceEndpoint = _taskInfo->serviceEndpoint; - - return tarsTaskInfo; -} - -inline ppc::protocol::GatewayTaskInfo::Ptr toGatewayTaskInfo(ppctars::TaskInfo _taskInfo) -{ - auto gatewayTaskInfo = std::make_shared(); - - gatewayTaskInfo->taskID = _taskInfo.taskID; - gatewayTaskInfo->serviceEndpoint = _taskInfo.serviceEndpoint; - - return gatewayTaskInfo; -} - -inline ppctars::Error toTarsError(const bcos::Error& error) -{ - ppctars::Error tarsError; - tarsError.errorCode = error.errorCode(); - tarsError.errorMessage = error.errorMessage(); - - return tarsError; -} - -template -inline ppctars::Error toTarsError(const T& error) -{ - ppctars::Error tarsError; - - if (error) - { - tarsError.errorCode = error->errorCode(); - tarsError.errorMessage = error->errorMessage(); - } - - return tarsError; -} - -inline bcos::Error::Ptr toBcosError(const ppctars::Error& error) -{ - if (error.errorCode == 0) - { - return nullptr; - } - - return std::make_shared(error.errorCode, error.errorMessage); -} - -inline bcos::Error::Ptr toBcosError(tars::Int32 ret) -{ - if (ret == 0) - { - return nullptr; - } - - return std::make_shared(ret, "TARS error!"); -} - inline std::string getProxyDesc(std::string const& _servantName) { std::string desc = diff --git a/cpp/ppc-tars-protocol/ppc-tars-protocol/TarsSerialize.h b/cpp/wedpr-protocol/tars/TarsSerialize.h similarity index 97% rename from cpp/ppc-tars-protocol/ppc-tars-protocol/TarsSerialize.h rename to cpp/wedpr-protocol/tars/TarsSerialize.h index b15ea3f4..51bb4591 100644 --- a/cpp/ppc-tars-protocol/ppc-tars-protocol/TarsSerialize.h +++ b/cpp/wedpr-protocol/tars/TarsSerialize.h @@ -25,7 +25,6 @@ namespace ppctars::serialize { - void encode(TarsStruct auto const& object, bcos::bytes& out) { tars::TarsOutputStream output; @@ -40,4 +39,4 @@ void decode(const bcos::bytes& in, TarsStruct auto& out) out.readFrom(input); } -} // namespace ppctars +} // namespace ppctars::serialize diff --git a/cpp/ppc-tars-protocol/ppc-tars-protocol/TarsServantProxyCallback.h b/cpp/wedpr-protocol/tars/TarsServantProxyCallback.h similarity index 100% rename from cpp/ppc-tars-protocol/ppc-tars-protocol/TarsServantProxyCallback.h rename to cpp/wedpr-protocol/tars/TarsServantProxyCallback.h diff --git a/cpp/ppc-tars-protocol/ppc-tars-protocol/client/GatewayServiceClient.cpp b/cpp/wedpr-protocol/tars/TarsStruct.cpp similarity index 73% rename from cpp/ppc-tars-protocol/ppc-tars-protocol/client/GatewayServiceClient.cpp rename to cpp/wedpr-protocol/tars/TarsStruct.cpp index 534a993f..064378c9 100644 --- a/cpp/ppc-tars-protocol/ppc-tars-protocol/client/GatewayServiceClient.cpp +++ b/cpp/wedpr-protocol/tars/TarsStruct.cpp @@ -13,12 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. * - * @file GatewayServiceClient.cpp + * @file TarsStruct.cpp * @author: shawnhe - * @date 2022-10-20 + * @date 2022-11-4 */ -#include "GatewayServiceClient.h" - -std::atomic ppctars::GatewayServiceClient::s_tarsTimeoutCount = {0}; -const int64_t ppctars::GatewayServiceClient::c_maxTarsTimeoutCount = 500; +#include "TarsStruct.h" \ No newline at end of file diff --git a/cpp/ppc-tars-protocol/ppc-tars-protocol/TarsStruct.h b/cpp/wedpr-protocol/tars/TarsStruct.h similarity index 89% rename from cpp/ppc-tars-protocol/ppc-tars-protocol/TarsStruct.h rename to cpp/wedpr-protocol/tars/TarsStruct.h index 7d775a45..395725cd 100644 --- a/cpp/ppc-tars-protocol/ppc-tars-protocol/TarsStruct.h +++ b/cpp/wedpr-protocol/tars/TarsStruct.h @@ -24,16 +24,17 @@ namespace ppctars::serialize { - template concept TarsStruct = requires(TarsStructType tarsStruct) { { tarsStruct.className() - } -> std::same_as; + } + ->std::same_as; { tarsStruct.MD5() - } -> std::same_as; + } + ->std::same_as; tarsStruct.resetDefautlt(); }; -} +} // namespace ppctars::serialize