Skip to content

Commit

Permalink
add asyncSendResponse implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
cyjseagull committed Sep 10, 2024
1 parent d8b5ecf commit 449a68c
Show file tree
Hide file tree
Showing 21 changed files with 89 additions and 32 deletions.
4 changes: 2 additions & 2 deletions cpp/ppc-framework/front/FrontInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ class FrontInterface
uint32_t _timeout, ErrorCallbackFunc _callback, CallbackFunc _respCallback) = 0;

// send response when receiving message from given agencyID
virtual void asyncSendResponse(const std::string& _agencyID, std::string const& _uuid,
front::PPCMessageFace::Ptr _message, ErrorCallbackFunc _callback) = 0;
virtual void asyncSendResponse(bcos::bytes const& dstNode, std::string const& traceID,
front::PPCMessageFace::Ptr message, ErrorCallbackFunc _callback) = 0;

virtual void registerMessageHandler(uint8_t _taskType, uint8_t _algorithmType,
std::function<void(front::PPCMessageFace::Ptr)> _handler) = 0;
Expand Down
3 changes: 3 additions & 0 deletions cpp/ppc-framework/front/IFront.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ class IFront : virtual public IFrontClient
long timeout, ppc::protocol::ReceiveMsgFunc errorCallback,
ppc::protocol::MessageCallback callback) = 0;

virtual void asyncSendResponse(bcos::bytes const& dstNode, std::string const& traceID,
bcos::bytes&& payload, int seq, ppc::protocol::ReceiveMsgFunc errorCallback) = 0;

// the sync interface for async_send_message
virtual bcos::Error::Ptr push(ppc::protocol::RouteType routeType,
ppc::protocol::MessageOptionalHeader::Ptr const& routeInfo, bcos::bytes&& payload, int seq,
Expand Down
3 changes: 2 additions & 1 deletion cpp/ppc-framework/protocol/INodeInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ class INodeInfo

virtual bool equal(INodeInfo::Ptr const& info)
{
return (nodeID() == info->nodeID()) && (components() == info->components());
return (nodeID().toBytes() == info->nodeID().toBytes()) &&
(components() == info->components());
}

virtual void toJson(Json::Value& jsonObject) const = 0;
Expand Down
3 changes: 3 additions & 0 deletions cpp/ppc-framework/protocol/PPCMessageFace.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ class PPCMessageFace
virtual std::string const& taskID() const = 0;
virtual void setTaskID(std::string const&) = 0;
virtual std::string const& sender() const = 0;
virtual bcos::bytes const& senderNode() const = 0;
virtual void setSender(std::string const&) = 0;
virtual void setSenderNode(bcos::bytes const&) = 0;

virtual std::shared_ptr<bcos::bytes> data() const = 0;
virtual void setData(std::shared_ptr<bcos::bytes>) = 0;
virtual std::map<std::string, std::string> header() = 0;
Expand Down
4 changes: 2 additions & 2 deletions cpp/test-utils/FakeFront.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ class FakeFront : public FrontInterface
bcos::Error::Ptr eraseTaskInfo(std::string const&) override { return nullptr; }

// send response when receiving message from given agencyID
void asyncSendResponse(const std::string& _agencyID, std::string const& _uuid,
void asyncSendResponse(bcos::bytes const& peer, std::string const& _uuid,
front::PPCMessageFace::Ptr _message, ErrorCallbackFunc _callback) override
{
if (m_uuidToCallback.count(_uuid))
Expand All @@ -193,7 +193,7 @@ class FakeFront : public FrontInterface
removeCallback(_uuid);
if (callback)
{
callback(nullptr, _agencyID, _message, nullptr);
callback(nullptr, std::string(peer.begin(), peer.end()), _message, nullptr);
}
}
}
Expand Down
6 changes: 6 additions & 0 deletions cpp/test-utils/FakePPCMessage.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,13 @@ class FakePPCMessage : public PPCMessageFace
// set the message to be response
void setResponse() override { m_response = true; }


bcos::bytes const& senderNode() const override { return m_senderNode; }

void setSenderNode(bcos::bytes const& senderNode) override { m_senderNode = senderNode; }

private:
bcos::bytes m_senderNode;
uint8_t m_version;
uint8_t m_taskType;
uint8_t m_algorithmType;
Expand Down
10 changes: 6 additions & 4 deletions cpp/wedpr-computing/ppc-psi/src/PSIConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#pragma once
#include "Common.h"
#include "bcos-utilities/Common.h"
#include "ppc-framework/Helper.h"
#include "ppc-framework/front/FrontInterface.h"
#include "ppc-framework/io/DataResourceLoader.h"
#include "ppc-framework/protocol/Protocol.h"
Expand Down Expand Up @@ -96,14 +97,15 @@ class PSIConfig
_responseCallback);
}

void asyncSendResponse(std::string const& _peerID, std::string const& _taskID,
void asyncSendResponse(bcos::bytes const& fromNode, std::string const& _taskID,
std::string const& _uuid, PSIMessageInterface::Ptr const& _msg,
ppc::front::ErrorCallbackFunc _callback, uint32_t _seq = 0)
{
auto ppcMsg = generatePPCMsg(_taskID, _msg, _seq);
PSI_LOG(TRACE) << LOG_DESC("sendResponse") << LOG_KV("peer", _peerID) << printPPCMsg(ppcMsg)
<< LOG_KV("msgType", (int)_msg->packetType()) << LOG_KV("uuid", _uuid);
m_front->asyncSendResponse(_peerID, _uuid, ppcMsg, _callback);
PSI_LOG(TRACE) << LOG_DESC("sendResponse") << LOG_KV("peer", printNodeID(fromNode))
<< printPPCMsg(ppcMsg) << LOG_KV("msgType", (int)_msg->packetType())
<< LOG_KV("uuid", _uuid);
m_front->asyncSendResponse(fromNode, _uuid, ppcMsg, _callback);
}

ppc::io::DataResourceLoader::Ptr const& dataResourceLoader() const
Expand Down
2 changes: 1 addition & 1 deletion cpp/wedpr-computing/ppc-psi/src/ecdh-psi/EcdhPSIImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ void EcdhPSIImpl::onHandshakeResponse(PSIMessageInterface::Ptr const& _msg)
psiMsg->setErrorCode(0);
psiMsg->setErrorMessage("success");
auto startT = bcos::utcSteadyTime();
m_config->asyncSendResponse(taskState->peerID(), taskState->task()->id(), _msg->uuid(), psiMsg,
m_config->asyncSendResponse(_msg->fromNode(), taskState->task()->id(), _msg->uuid(), psiMsg,
[this, startT, _msg](bcos::Error::Ptr _error) {
if (!_error || _error->errorCode() == 0)
{
Expand Down
15 changes: 9 additions & 6 deletions cpp/wedpr-computing/ppc-psi/src/psi-framework/PSIFramework.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ void PSIFramework::onReceiveMessage(PPCMessageFace::Ptr _msg)
psiMsg->setTaskID(_msg->taskID());
psiMsg->setSeq(_msg->seq());
psiMsg->setUUID(_msg->uuid());
psiMsg->setFromNode(_msg->senderNode());
m_msgQueue->push(psiMsg);
PSI_FRAMEWORK_LOG(TRACE) << LOG_DESC("onReceiveMessage") << printPSIMessage(psiMsg)
<< LOG_KV("uuid", _msg->uuid());
Expand Down Expand Up @@ -699,7 +700,7 @@ void PSIFramework::sendHandshakeRequest(TaskState::Ptr const& _taskState)


void PSIFramework::responsePSIResultSyncStatus(int32_t _code, std::string const& _msg,
std::string const& _peer, std::string const& _taskID, std::string const& _uuid, uint32_t _seq)
bcos::bytes const& _peer, std::string const& _taskID, std::string const& _uuid, uint32_t _seq)
{
// response to the client
auto psiMsg =
Expand Down Expand Up @@ -732,15 +733,17 @@ void PSIFramework::handlePSIResultSyncMsg(PSIMessageInterface::Ptr _resultSyncMs
<< printPSIMessage(_resultSyncMsg);
std::string msg =
"sync psi result for task " + _resultSyncMsg->taskID() + " failed for task not found!";
responsePSIResultSyncStatus((int32_t)PSIRetCode::TaskNotFound, msg, _resultSyncMsg->from(),
_resultSyncMsg->taskID(), _resultSyncMsg->uuid(), _resultSyncMsg->seq());
responsePSIResultSyncStatus((int32_t)PSIRetCode::TaskNotFound, msg,
_resultSyncMsg->fromNode(), _resultSyncMsg->taskID(), _resultSyncMsg->uuid(),
_resultSyncMsg->seq());
return;
}
try
{
taskState->storePSIResult(m_dataResourceLoader, _resultSyncMsg->takeData());
responsePSIResultSyncStatus((int32_t)PSIRetCode::Success, "success", _resultSyncMsg->from(),
_resultSyncMsg->taskID(), _resultSyncMsg->uuid(), _resultSyncMsg->seq());
responsePSIResultSyncStatus((int32_t)PSIRetCode::Success, "success",
_resultSyncMsg->fromNode(), _resultSyncMsg->taskID(), _resultSyncMsg->uuid(),
_resultSyncMsg->seq());
}
catch (std::exception const& e)
{
Expand All @@ -749,7 +752,7 @@ void PSIFramework::handlePSIResultSyncMsg(PSIMessageInterface::Ptr _resultSyncMs
auto errorMessage = "sync psi result for " + _resultSyncMsg->taskID() +
" failed, error: " + std::string(boost::diagnostic_information(e));
responsePSIResultSyncStatus((int32_t)PSIRetCode::SyncPSIResultFailed, errorMessage,
_resultSyncMsg->from(), _resultSyncMsg->taskID(), _resultSyncMsg->uuid(),
_resultSyncMsg->fromNode(), _resultSyncMsg->taskID(), _resultSyncMsg->uuid(),
_resultSyncMsg->seq());
// cancel the task
auto error = BCOS_ERROR_PTR((int32_t)PSIRetCode::SyncPSIResultFailed, errorMessage);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ class PSIFramework : public bcos::Worker, public ppc::task::TaskFrameworkInterfa
m_signalled.wait_for(l, boost::chrono::milliseconds(5));
}
void responsePSIResultSyncStatus(int32_t _code, std::string const& _msg,
std::string const& _peer, std::string const& _taskID, std::string const& _uuid,
bcos::bytes const& _peer, std::string const& _taskID, std::string const& _uuid,
uint32_t _seq);

void broadcastSyncTaskInfo(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ class PSIMessageInterface
virtual void setTaskID(std::string const& _taskID) { m_taskID = _taskID; }
virtual void setSeq(uint32_t _seq) { m_seq = _seq; }
virtual void setFrom(std::string const& _from) { m_from = _from; }

virtual void setFromNode(bcos::bytes const& fromNode) { m_fromNode = fromNode; }
virtual bcos::bytes fromNode() const { return m_fromNode; }

virtual std::string const& taskID() const { return m_taskID; }
virtual uint32_t seq() const { return m_seq; }
Expand All @@ -88,7 +89,11 @@ class PSIMessageInterface
private:
std::string m_taskID;
int32_t m_seq;
// the agency
std::string m_from;
// the fromNode
bcos::bytes m_fromNode;

std::string m_uuid;
};

Expand Down
1 change: 1 addition & 0 deletions cpp/wedpr-protocol/protocol/src/PPCMessage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ PPCMessageFace::Ptr PPCMessageFactory::decodePPCMessage(Message::Ptr msg)
auto const& routeInfo = msg->header()->optionalField();
ppcMsg->setTaskID(routeInfo->topic());
ppcMsg->setSender(routeInfo->srcInst());
ppcMsg->setSenderNode(routeInfo->srcNode());
}
return ppcMsg;
}
Expand Down
6 changes: 6 additions & 0 deletions cpp/wedpr-protocol/protocol/src/PPCMessage.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ class PPCMessage : public PPCMessageFace
// set the message to be response
void setResponse() override { m_isResponse = true; }

void setSenderNode(bcos::bytes const& senderNode) override { m_senderNode = senderNode; }

bcos::bytes const& senderNode() const override { return m_senderNode; }

protected:
std::string encodeMap(const std::map<std::string, std::string>& _map);
std::map<std::string, std::string> decodeMap(const std::string& _encval);
Expand All @@ -95,6 +99,8 @@ class PPCMessage : public PPCMessageFace
uint32_t m_seq = 0;
std::string m_taskID;
std::string m_sender;
bcos::bytes m_senderNode;

bool m_isResponse;
// the uuid used to find the response-callback
std::string m_uuid;
Expand Down
10 changes: 7 additions & 3 deletions cpp/wedpr-transport/ppc-front/ppc-front/Front.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,13 @@ void Front::asyncSendMessage(const std::string& _agencyID, front::PPCMessageFace
}

// send response when receiving message from given agencyID
void Front::asyncSendResponse(const std::string& _agencyID, std::string const& _uuid,
front::PPCMessageFace::Ptr _message, ErrorCallbackFunc _callback)
{}
void Front::asyncSendResponse(bcos::bytes const& dstNode, std::string const& traceID,
PPCMessageFace::Ptr message, ErrorCallbackFunc _callback)
{
bcos::bytes data;
message->encode(data);
m_front->asyncSendResponse(dstNode, traceID, std::move(data), 0, _callback);
}

/**
* @brief notice task info to gateway
Expand Down
5 changes: 2 additions & 3 deletions cpp/wedpr-transport/ppc-front/ppc-front/Front.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@ class Front : public FrontInterface, public std::enable_shared_from_this<Front>
uint32_t _timeout, ErrorCallbackFunc _callback, CallbackFunc _respCallback) override;

// send response when receiving message from given agencyID
void asyncSendResponse(const std::string& _agencyID, std::string const& _uuid,
front::PPCMessageFace::Ptr _message, ErrorCallbackFunc _callback) override;

void asyncSendResponse(bcos::bytes const& dstNode, std::string const& traceID,
front::PPCMessageFace::Ptr message, ErrorCallbackFunc _callback) override;
/**
* @brief notice task info to gateway
* @param _taskInfo the latest task information
Expand Down
19 changes: 18 additions & 1 deletion cpp/wedpr-transport/ppc-front/ppc-front/FrontImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ FrontImpl::FrontImpl(std::shared_ptr<bcos::ThreadPool> threadPool,
m_gatewayClient(gateway)
{
m_nodeID = m_nodeInfo->nodeID().toBytes();
m_callbackManager = std::make_shared<CallbackManager>(m_threadPool, ioService);
m_callbackManager = std::make_shared<CallbackManager>(m_threadPool, m_ioService);
}

/**
Expand Down Expand Up @@ -106,6 +106,23 @@ void FrontImpl::stop()
}
}

void FrontImpl::asyncSendResponse(bcos::bytes const& dstNode, std::string const& traceID,
bcos::bytes&& payload, int seq, ppc::protocol::ReceiveMsgFunc errorCallback)
{
// generate the frontMessage
auto frontMessage = m_messageFactory->build();
frontMessage->setTraceID(traceID);
frontMessage->setSeq(seq);
frontMessage->setData(std::move(payload));

auto routeInfo = m_routerInfoBuilder->build();
routeInfo->setSrcNode(m_nodeID);
routeInfo->setDstNode(dstNode);

asyncSendMessageToGateway(true, std::move(frontMessage), RouteType::ROUTE_THROUGH_NODEID,
traceID, routeInfo, -1, errorCallback);
}

/**
* @brief async send message
*
Expand Down
3 changes: 3 additions & 0 deletions cpp/wedpr-transport/ppc-front/ppc-front/FrontImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ class FrontImpl : public IFront, public IFrontClient, public std::enable_shared_
return m_messageFactory;
}

void asyncSendResponse(bcos::bytes const& dstNode, std::string const& traceID,
bcos::bytes&& payload, int seq, ppc::protocol::ReceiveMsgFunc errorCallback) override;

private:
void asyncSendMessageToGateway(bool responsePacket,
ppc::protocol::MessagePayload::Ptr&& frontMessage, ppc::protocol::RouteType routeType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class GatewayNodeInfo
virtual std::vector<std::shared_ptr<ppc::front::IFrontClient>> chooseRouterByAgency(
bool selectAll) const = 0;
virtual std::vector<std::shared_ptr<ppc::front::IFrontClient>> chooseRouterByTopic(
bool selectAll, std::string const& topic) const = 0;
bool selectAll, bcos::bytes const& fromNode, std::string const& topic) const = 0;

virtual void encode(bcos::bytes& data) const = 0;
virtual void decode(bcos::bytesConstRef data) = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ std::vector<std::shared_ptr<ppc::front::IFrontClient>> GatewayNodeInfoImpl::choo
}

std::vector<std::shared_ptr<ppc::front::IFrontClient>> GatewayNodeInfoImpl::chooseRouterByTopic(
bool selectAll, std::string const& topic) const
bool selectAll, bcos::bytes const& fromNode, std::string const& topic) const
{
std::vector<std::shared_ptr<ppc::front::IFrontClient>> result;
// empty topic means broadcast message to all front
Expand All @@ -175,7 +175,7 @@ std::vector<std::shared_ptr<ppc::front::IFrontClient>> GatewayNodeInfoImpl::choo
{
selectedNode = nodeInfo(it.first);
}
if (selectedNode != nullptr)
if (selectedNode != nullptr && selectedNode->nodeID().toBytes() != fromNode)
{
result.emplace_back(selectedNode->getFront());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class GatewayNodeInfoImpl : public GatewayNodeInfo
std::vector<std::shared_ptr<ppc::front::IFrontClient>> chooseRouterByAgency(
bool selectAll) const override;
std::vector<std::shared_ptr<ppc::front::IFrontClient>> chooseRouterByTopic(
bool selectAll, std::string const& topic) const override;
bool selectAll, bcos::bytes const& fromNode, std::string const& topic) const override;

void registerTopic(bcos::bytes const& nodeID, std::string const& topic) override;
void unRegisterTopic(bcos::bytes const& nodeID, std::string const& topic) override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ std::vector<ppc::front::IFrontClient::Ptr> LocalRouter::chooseReceiver(
ppc::protocol::Message::Ptr const& msg)
{
std::vector<ppc::front::IFrontClient::Ptr> receivers;
if (msg->header()->optionalField()->dstInst() != m_routerInfo->agency())
auto const& dstInst = msg->header()->optionalField()->dstInst();
if (!dstInst.empty() && dstInst != m_routerInfo->agency())
{
return receivers;
}
Expand All @@ -123,17 +124,20 @@ std::vector<ppc::front::IFrontClient::Ptr> LocalRouter::chooseReceiver(
}
case (uint16_t)RouteType::ROUTE_THROUGH_COMPONENT:
{
// Note: should check the dstInst when route-by-component
return m_routerInfo->chooseRouteByComponent(
selectAll, msg->header()->optionalField()->componentType());
}
case (uint16_t)RouteType::ROUTE_THROUGH_AGENCY:
{
// Note: should check the dstInst when route-by-agency
return m_routerInfo->chooseRouterByAgency(selectAll);
}
case (uint16_t)RouteType::ROUTE_THROUGH_TOPIC:
{
return m_routerInfo->chooseRouterByTopic(
selectAll, msg->header()->optionalField()->topic());
// Note: should ignore the srcNode when route-by-topic
return m_routerInfo->chooseRouterByTopic(selectAll,
msg->header()->optionalField()->srcNode(), msg->header()->optionalField()->topic());
}
default:
BOOST_THROW_EXCEPTION(WeDPRException() << errinfo_comment(
Expand Down

0 comments on commit 449a68c

Please sign in to comment.