From 114e7900d870b1da073791271e11b8789e11db2d Mon Sep 17 00:00:00 2001 From: cyjseagull Date: Mon, 25 Nov 2024 20:09:33 +0800 Subject: [PATCH] fix cm2020 check hostResource (#101) * fix cm2020 check hostResource * complement log * psi support retry * unregister front taskInfo --- cpp/ppc-framework/protocol/DataResource.h | 4 +++ cpp/ppc-framework/protocol/PartyResource.h | 4 +++ cpp/ppc-framework/protocol/Task.h | 14 ++++++++ cpp/ppc-framework/rpc/RpcStatusInterface.h | 2 -- cpp/wedpr-computing/ppc-pir/src/OtPIRImpl.cpp | 10 +++--- cpp/wedpr-computing/ppc-pir/src/OtPIRImpl.h | 4 +-- .../ppc-pir/tests/FakeOtPIRFactory.h | 4 +-- .../ppc-pir/tests/TestBaseOT.cpp | 2 ++ .../ppc-psi/src/cm2020-psi/CM2020PSIImpl.cpp | 18 +++++----- .../ppc-psi/src/cm2020-psi/CM2020PSIImpl.h | 6 +--- .../src/ecdh-conn-psi/EcdhConnPSIImpl.cpp | 6 ++-- .../src/ecdh-conn-psi/EcdhConnPSIImpl.h | 5 +-- .../src/ecdh-multi-psi/EcdhMultiPSIImpl.cpp | 16 ++++----- .../src/ecdh-multi-psi/EcdhMultiPSIImpl.h | 3 +- .../core/EcdhMultiPSICalculator.cpp | 2 +- .../core/EcdhMultiPSIMaster.cpp | 2 +- .../core/EcdhMultiPSIPartner.cpp | 2 +- .../ppc-psi/src/ecdh-psi/EcdhPSIImpl.cpp | 2 +- .../ppc-psi/src/ecdh-psi/EcdhPSIImpl.h | 3 -- .../src/labeled-psi/LabeledPSIImpl.cpp | 11 +++--- .../ppc-psi/src/labeled-psi/LabeledPSIImpl.h | 5 +-- .../ppc-psi/src/psi-framework/TaskGuarder.cpp | 2 ++ .../ppc-psi/src/psi-framework/TaskGuarder.h | 2 +- .../ppc-psi/src/psi-framework/TaskState.cpp | 14 +++++--- .../ppc-psi/src/ra2018-psi/RA2018PSIImpl.cpp | 2 +- .../ppc-psi/src/ra2018-psi/RA2018PSIImpl.h | 3 -- .../tests/cm2020-psi/FakeCM2020PSIFactory.h | 4 +-- .../tests/cm2020-psi/TestCM2020Impl.cpp | 2 ++ .../tests/labeled-psi/FakeLabeledPSIFactory.h | 4 +-- .../tests/labeled-psi/TestLabeledPSIImpl.cpp | 4 +++ .../tests/ra2018-psi/TestEcdhPSIImpl.cpp | 4 +++ .../tests/ra2018-psi/TestRA2018Impl.cpp | 3 ++ .../tests/ra2018-psi/mock/EcdhPSIFixture.h | 1 - .../tests/ra2018-psi/mock/RA2018PSIFixture.h | 1 - .../air-node/AirNodeInitializer.cpp | 3 +- .../pro-node/ProNodeInitializer.cpp | 3 +- .../protocol/src/JsonTaskImpl.cpp | 5 +++ .../ppc-io/src/DataResourceLoaderImpl.cpp | 4 +-- cpp/wedpr-transport/ppc-rpc/src/RpcMemory.cpp | 34 +++++++++++-------- cpp/wedpr-transport/ppc-rpc/src/RpcMemory.h | 10 ++++-- 40 files changed, 132 insertions(+), 98 deletions(-) diff --git a/cpp/ppc-framework/protocol/DataResource.h b/cpp/ppc-framework/protocol/DataResource.h index 012314c9..b69a5f65 100644 --- a/cpp/ppc-framework/protocol/DataResource.h +++ b/cpp/ppc-framework/protocol/DataResource.h @@ -119,6 +119,10 @@ class DataResource inline std::string printDataResourceInfo(DataResource::ConstPtr _dataResource) { + if (!_dataResource) + { + return "empty"; + } std::ostringstream stringstream; stringstream << LOG_KV("dataResource", _dataResource->resourceID()); if (_dataResource->desc()) diff --git a/cpp/ppc-framework/protocol/PartyResource.h b/cpp/ppc-framework/protocol/PartyResource.h index 98bb4984..0c0e66f7 100644 --- a/cpp/ppc-framework/protocol/PartyResource.h +++ b/cpp/ppc-framework/protocol/PartyResource.h @@ -67,6 +67,10 @@ using ConstParties = std::vector; inline std::string printPartyInfo(PartyResource::ConstPtr _party) { + if (!_party) + { + return "empty"; + } std::ostringstream stringstream; stringstream << LOG_KV("partyId", _party->id()) << LOG_KV("partyIndex", _party->partyIndex()) << LOG_KV("desc", _party->desc()); diff --git a/cpp/ppc-framework/protocol/Task.h b/cpp/ppc-framework/protocol/Task.h index 1c9afccf..705fa670 100644 --- a/cpp/ppc-framework/protocol/Task.h +++ b/cpp/ppc-framework/protocol/Task.h @@ -143,6 +143,15 @@ class Task // decode the task virtual void decode(std::string_view _taskData) = 0; virtual std::string encode() const = 0; + + virtual bool enableOutputExists() const { return m_enableOutputExists; } + virtual void setEnableOutputExists(bool enableOutputExists) + { + m_enableOutputExists = enableOutputExists; + } + +protected: + bool m_enableOutputExists = false; }; class TaskFactory @@ -160,9 +169,14 @@ class TaskFactory inline std::string printTaskInfo(Task::ConstPtr _task) { std::ostringstream stringstream; + if (!_task) + { + return "empty"; + } stringstream << LOG_KV("id", _task->id()) << LOG_KV("type", (ppc::protocol::TaskType)_task->type()) << LOG_KV("algorithm", (ppc::protocol::TaskAlgorithmType)_task->algorithm()) + << LOG_KV("enableOutputExists", _task->enableOutputExists()) << LOG_KV("taskPtr", _task); if (_task->selfParty()) { diff --git a/cpp/ppc-framework/rpc/RpcStatusInterface.h b/cpp/ppc-framework/rpc/RpcStatusInterface.h index 5a5b18d2..db971436 100644 --- a/cpp/ppc-framework/rpc/RpcStatusInterface.h +++ b/cpp/ppc-framework/rpc/RpcStatusInterface.h @@ -44,7 +44,5 @@ class RpcStatusInterface virtual bcos::Error::Ptr insertTask(protocol::Task::Ptr _task) = 0; virtual bcos::Error::Ptr updateTaskStatus(protocol::TaskResult::Ptr _taskResult) = 0; virtual protocol::TaskResult::Ptr getTaskStatus(const std::string& _taskID) = 0; - virtual bcos::Error::Ptr deleteGateway(const std::string& _agencyID) = 0; - virtual std::vector listGateway() = 0; }; } // namespace ppc::rpc \ No newline at end of file diff --git a/cpp/wedpr-computing/ppc-pir/src/OtPIRImpl.cpp b/cpp/wedpr-computing/ppc-pir/src/OtPIRImpl.cpp index e344408e..a37fd892 100644 --- a/cpp/wedpr-computing/ppc-pir/src/OtPIRImpl.cpp +++ b/cpp/wedpr-computing/ppc-pir/src/OtPIRImpl.cpp @@ -149,10 +149,11 @@ void OtPIRImpl::onReceiveMessage(ppc::front::PPCMessageFace::Ptr _msg) } } -void OtPIRImpl::onReceivedErrorNotification(const std::string& _taskID) +void OtPIRImpl::onReceivedErrorNotification(ppc::front::PPCMessageFace::Ptr const& _message) { + PIR_LOG(WARNING) << LOG_DESC("onReceivedErrorNotification") << printPPCMsg(_message); // finish the task while the peer is failed - auto taskState = findPendingTask(_taskID); + auto taskState = findPendingTask(_message->taskID()); if (taskState) { taskState->onPeerNotifyFinish(); @@ -242,7 +243,7 @@ void OtPIRImpl::handleReceivedMessage(const ppc::front::PPCMessageFace::Ptr& _me { case int(CommonMessageType::ErrorNotification): { - pir->onReceivedErrorNotification(_message->taskID()); + pir->onReceivedErrorNotification(_message); break; } case int(CommonMessageType::PingPeer): @@ -444,7 +445,8 @@ void OtPIRImpl::asyncRunTask() << LOG_KV("requestAgencyDataset", taskMessage.requestAgencyDataset) << LOG_KV("prefixLength", taskMessage.prefixLength) << LOG_KV("searchId", taskMessage.searchId); - auto writer = loadWriter(task->id(), dataResource, m_enableOutputExists); + auto writer = + loadWriter(task->id(), dataResource, m_taskState->task()->enableOutputExists()); m_taskState->setWriter(writer); runSenderGenerateCipher(taskMessage); } diff --git a/cpp/wedpr-computing/ppc-pir/src/OtPIRImpl.h b/cpp/wedpr-computing/ppc-pir/src/OtPIRImpl.h index 56aeac7d..d48773a2 100644 --- a/cpp/wedpr-computing/ppc-pir/src/OtPIRImpl.h +++ b/cpp/wedpr-computing/ppc-pir/src/OtPIRImpl.h @@ -59,7 +59,7 @@ class OtPIRImpl : public std::enable_shared_from_this, // register to the front to get the message related to ot-pir void onReceiveMessage(ppc::front::PPCMessageFace::Ptr _message) override; - void onReceivedErrorNotification(const std::string& _taskID) override; + void onReceivedErrorNotification(ppc::front::PPCMessageFace::Ptr const& _message) override; void onSelfError( const std::string& _taskID, bcos::Error::Ptr _error, bool _noticePeer) override; @@ -150,8 +150,6 @@ class OtPIRImpl : public std::enable_shared_from_this, m_senders.erase(it); } } - // allow the output-path exists, for ut - bool m_enableOutputExists = false; // 为true时启动时会从配置中加载文件作为匹配源 bool m_enableMemoryFile = false; ppc::protocol::DataResource m_resource; diff --git a/cpp/wedpr-computing/ppc-pir/tests/FakeOtPIRFactory.h b/cpp/wedpr-computing/ppc-pir/tests/FakeOtPIRFactory.h index 23ee2179..30e6ce60 100644 --- a/cpp/wedpr-computing/ppc-pir/tests/FakeOtPIRFactory.h +++ b/cpp/wedpr-computing/ppc-pir/tests/FakeOtPIRFactory.h @@ -46,9 +46,7 @@ class FakeOtPIRImpl : public OtPIRImpl using Ptr = std::shared_ptr; FakeOtPIRImpl(OtPIRConfig::Ptr const& _config, unsigned _idleTimeMs = 0) : OtPIRImpl(_config, _idleTimeMs) - { - m_enableOutputExists = true; - } + {} ~FakeOtPIRImpl() override = default; }; diff --git a/cpp/wedpr-computing/ppc-pir/tests/TestBaseOT.cpp b/cpp/wedpr-computing/ppc-pir/tests/TestBaseOT.cpp index 621a4255..615e309b 100644 --- a/cpp/wedpr-computing/ppc-pir/tests/TestBaseOT.cpp +++ b/cpp/wedpr-computing/ppc-pir/tests/TestBaseOT.cpp @@ -194,6 +194,7 @@ void testOTPIRImplFunc(const std::string& _taskID, const std::string& _params, b auto senderPIRTask = std::make_shared(senderAgencyName); senderPIRTask->setId(_taskID); senderPIRTask->setParam(_params); + senderPIRTask->setEnableOutputExists(true); senderPIRTask->setSelf(_senderParty); senderPIRTask->addParty(_receiverParty); senderPIRTask->setSyncResultToPeer(_syncResults); @@ -203,6 +204,7 @@ void testOTPIRImplFunc(const std::string& _taskID, const std::string& _params, b auto receiverPIRTask = std::make_shared(receiverAgencyName); receiverPIRTask->setId(_taskID); receiverPIRTask->setParam(_params); + receiverPIRTask->setEnableOutputExists(true); receiverPIRTask->setSelf(_receiverParty); receiverPIRTask->addParty(_senderParty); receiverPIRTask->setSyncResultToPeer(_syncResults); diff --git a/cpp/wedpr-computing/ppc-psi/src/cm2020-psi/CM2020PSIImpl.cpp b/cpp/wedpr-computing/ppc-psi/src/cm2020-psi/CM2020PSIImpl.cpp index 53d8f08e..fa56d7a0 100644 --- a/cpp/wedpr-computing/ppc-psi/src/cm2020-psi/CM2020PSIImpl.cpp +++ b/cpp/wedpr-computing/ppc-psi/src/cm2020-psi/CM2020PSIImpl.cpp @@ -152,15 +152,14 @@ void CM2020PSIImpl::asyncRunTask() { return; } - + CM2020_PSI_LOG(INFO) << LOG_DESC("noticePeerToFinish") << printTaskInfo(task); psi->noticePeerToFinish(task); }); - // check the memory - checkHostResource(m_config->minNeededMemoryGB()); - addPendingTask(taskState); - try { + addPendingTask(taskState); + // check the memory + checkHostResource(m_config->minNeededMemoryGB()); // prepare reader and writer auto dataResource = task->selfParty()->dataResource(); auto reader = loadReader(task->id(), dataResource, DataSchema::Bytes); @@ -169,7 +168,7 @@ void CM2020PSIImpl::asyncRunTask() auto role = task->selfParty()->partyIndex(); if (role == uint16_t(PartyType::Client) || task->syncResultToPeer()) { - auto writer = loadWriter(task->id(), dataResource, m_enableOutputExists); + auto writer = loadWriter(task->id(), dataResource, task->enableOutputExists()); taskState->setWriter(writer); } @@ -319,10 +318,11 @@ void CM2020PSIImpl::stop() CM2020_PSI_LOG(INFO) << LOG_DESC("CM2020-PSI stopped"); } -void CM2020PSIImpl::onReceivedErrorNotification(const std::string& _taskID) +void CM2020PSIImpl::onReceivedErrorNotification(ppc::front::PPCMessageFace::Ptr const& _message) { + CM2020_PSI_LOG(INFO) << LOG_DESC("onReceivedErrorNotification") << printPPCMsg(_message); // finish the task while the peer is failed - auto taskState = findPendingTask(_taskID); + auto taskState = findPendingTask(_message->taskID()); if (taskState) { taskState->onPeerNotifyFinish(); @@ -410,7 +410,7 @@ void CM2020PSIImpl::handleReceivedMessage(const ppc::front::PPCMessageFace::Ptr& { case int(CommonMessageType::ErrorNotification): { - psi->onReceivedErrorNotification(_message->taskID()); + psi->onReceivedErrorNotification(_message); break; } case int(CommonMessageType::PingPeer): diff --git a/cpp/wedpr-computing/ppc-psi/src/cm2020-psi/CM2020PSIImpl.h b/cpp/wedpr-computing/ppc-psi/src/cm2020-psi/CM2020PSIImpl.h index 003fea3c..01c6d0ce 100644 --- a/cpp/wedpr-computing/ppc-psi/src/cm2020-psi/CM2020PSIImpl.h +++ b/cpp/wedpr-computing/ppc-psi/src/cm2020-psi/CM2020PSIImpl.h @@ -62,7 +62,7 @@ class CM2020PSIImpl : public bcos::Worker, void start() override; void stop() override; - void onReceivedErrorNotification(const std::string& _taskID) override; + void onReceivedErrorNotification(ppc::front::PPCMessageFace::Ptr const& _message) override; void onSelfError( const std::string& _taskID, bcos::Error::Ptr _error, bool _noticePeer) override; @@ -148,10 +148,6 @@ class CM2020PSIImpl : public bcos::Worker, } } -protected: - // allow the output-path exists, for ut - bool m_enableOutputExists = false; - private: void waitSignal() { diff --git a/cpp/wedpr-computing/ppc-psi/src/ecdh-conn-psi/EcdhConnPSIImpl.cpp b/cpp/wedpr-computing/ppc-psi/src/ecdh-conn-psi/EcdhConnPSIImpl.cpp index 6e989d85..6e7a3717 100644 --- a/cpp/wedpr-computing/ppc-psi/src/ecdh-conn-psi/EcdhConnPSIImpl.cpp +++ b/cpp/wedpr-computing/ppc-psi/src/ecdh-conn-psi/EcdhConnPSIImpl.cpp @@ -55,7 +55,7 @@ void EcdhConnPSIImpl::asyncRunTask( if (role == uint16_t(PartyType::Client)) { ECDH_CONN_LOG(INFO) << LOG_DESC("Client do the Task") << LOG_KV("taskID", _task->id()); - auto writer = loadWriter(_task->id(), dataResource, m_enableOutputExists); + auto writer = loadWriter(_task->id(), dataResource, _task->enableOutputExists()); ecdhTaskState->setWriter(writer); auto client = std::make_shared(m_config, ecdhTaskState); addClient(client); @@ -66,7 +66,7 @@ void EcdhConnPSIImpl::asyncRunTask( ECDH_CONN_LOG(INFO) << LOG_DESC("Server do the Task") << LOG_KV("taskID", _task->id()); if (_task->syncResultToPeer()) { - auto writer = loadWriter(_task->id(), dataResource, m_enableOutputExists); + auto writer = loadWriter(_task->id(), dataResource, _task->enableOutputExists()); ecdhTaskState->setWriter(writer); } auto server = std::make_shared(m_config, ecdhTaskState); @@ -131,7 +131,7 @@ void EcdhConnPSIImpl::stop() } } -void EcdhConnPSIImpl::onReceivedErrorNotification(const std::string& _taskID) {} +void EcdhConnPSIImpl::onReceivedErrorNotification(ppc::front::PPCMessageFace::Ptr const&) {} void EcdhConnPSIImpl::onSelfError( const std::string& _taskID, bcos::Error::Ptr _error, bool _noticePeer) diff --git a/cpp/wedpr-computing/ppc-psi/src/ecdh-conn-psi/EcdhConnPSIImpl.h b/cpp/wedpr-computing/ppc-psi/src/ecdh-conn-psi/EcdhConnPSIImpl.h index 53d7cd4a..bb34a6a8 100644 --- a/cpp/wedpr-computing/ppc-psi/src/ecdh-conn-psi/EcdhConnPSIImpl.h +++ b/cpp/wedpr-computing/ppc-psi/src/ecdh-conn-psi/EcdhConnPSIImpl.h @@ -61,14 +61,11 @@ class EcdhConnPSIImpl : public std::enable_shared_from_this, void start() override; void stop() override; - void onReceivedErrorNotification(const std::string& _taskID) override; + void onReceivedErrorNotification(ppc::front::PPCMessageFace::Ptr const& _message) override; void onSelfError( const std::string& _taskID, bcos::Error::Ptr _error, bool _noticePeer) override; void executeWorker() override; -protected: - bool m_enableOutputExists = false; - void handlerPSIReceiveMessage(PSIConnMessage::Ptr _msg); void onHandShakeRequestHandler(const std::string& _taskId, const bcos::bytes& _msg); void onHandShakeResponseHandler(const std::string& _taskId, const bcos::bytes& _msg); diff --git a/cpp/wedpr-computing/ppc-psi/src/ecdh-multi-psi/EcdhMultiPSIImpl.cpp b/cpp/wedpr-computing/ppc-psi/src/ecdh-multi-psi/EcdhMultiPSIImpl.cpp index 94d7ec08..bb6d0d6d 100644 --- a/cpp/wedpr-computing/ppc-psi/src/ecdh-multi-psi/EcdhMultiPSIImpl.cpp +++ b/cpp/wedpr-computing/ppc-psi/src/ecdh-multi-psi/EcdhMultiPSIImpl.cpp @@ -131,9 +131,9 @@ void EcdhMultiPSIImpl::asyncRunTask( psi->removePartner(taskID); psi->removePendingTask(taskID); }); + addPendingTask(taskState); // check the memory checkHostResource(m_config->minNeededMemoryGB()); - addPendingTask(taskState); // over the peer limit if (_task->getAllPeerParties().size() > c_max_peer_size) { @@ -155,7 +155,7 @@ void EcdhMultiPSIImpl::asyncRunTask( << LOG_KV("roleId", role); if (role == uint16_t(PartiesType::Calculator)) { - auto writer = loadWriter(_task->id(), dataResource, m_enableOutputExists); + auto writer = loadWriter(_task->id(), dataResource, _task->enableOutputExists()); taskState->setWriter(writer); ECDH_MULTI_LOG(INFO) << LOG_DESC("Calculator do the Task") << LOG_KV("taskID", _task->id()); @@ -170,7 +170,7 @@ void EcdhMultiPSIImpl::asyncRunTask( if (_task->syncResultToPeer() && std::find(receivers.begin(), receivers.end(), m_config->selfParty()) != receivers.end()) { - auto writer = loadWriter(_task->id(), dataResource, m_enableOutputExists); + auto writer = loadWriter(_task->id(), dataResource, _task->enableOutputExists()); taskState->setWriter(writer); } auto partner = std::make_shared(m_config, taskState); @@ -183,7 +183,7 @@ void EcdhMultiPSIImpl::asyncRunTask( if (_task->syncResultToPeer() && std::find(receivers.begin(), receivers.end(), m_config->selfParty()) != receivers.end()) { - auto writer = loadWriter(_task->id(), dataResource, m_enableOutputExists); + auto writer = loadWriter(_task->id(), dataResource, _task->enableOutputExists()); taskState->setWriter(writer); } auto master = std::make_shared(m_config, taskState); @@ -266,11 +266,11 @@ void EcdhMultiPSIImpl::checkFinishedTask() } } -void EcdhMultiPSIImpl::onReceivedErrorNotification(const std::string& _taskID) +void EcdhMultiPSIImpl::onReceivedErrorNotification(ppc::front::PPCMessageFace::Ptr const& _message) { - ECDH_MULTI_LOG(INFO) << LOG_DESC("onReceivedErrorNotification") << LOG_KV("taskID", _taskID); + ECDH_MULTI_LOG(INFO) << LOG_DESC("onReceivedErrorNotification") << printPPCMsg(_message); // finish the task while the peer is failed - auto taskState = findPendingTask(_taskID); + auto taskState = findPendingTask(_message->taskID()); if (taskState) { taskState->onPeerNotifyFinish(); @@ -308,7 +308,7 @@ void EcdhMultiPSIImpl::executeWorker() auto pop_msg = _msg.second; if (pop_msg->messageType() == uint8_t(CommonMessageType::ErrorNotification)) { - onReceivedErrorNotification(pop_msg->taskID()); + onReceivedErrorNotification(pop_msg); return; } else if (pop_msg->messageType() == uint8_t(CommonMessageType::PingPeer)) diff --git a/cpp/wedpr-computing/ppc-psi/src/ecdh-multi-psi/EcdhMultiPSIImpl.h b/cpp/wedpr-computing/ppc-psi/src/ecdh-multi-psi/EcdhMultiPSIImpl.h index df4e8034..3be24c7b 100644 --- a/cpp/wedpr-computing/ppc-psi/src/ecdh-multi-psi/EcdhMultiPSIImpl.h +++ b/cpp/wedpr-computing/ppc-psi/src/ecdh-multi-psi/EcdhMultiPSIImpl.h @@ -40,14 +40,13 @@ class EcdhMultiPSIImpl : public std::enable_shared_from_this, void stop() override; void checkFinishedTask(); - void onReceivedErrorNotification(const std::string& _taskID) override; + void onReceivedErrorNotification(ppc::front::PPCMessageFace::Ptr const& _message) override; void onSelfError( const std::string& _taskID, bcos::Error::Ptr _error, bool _noticePeer) override; void executeWorker() override; protected: - bool m_enableOutputExists = false; virtual void onReceiveRandomA(PSIMessageInterface::Ptr _msg); virtual void onReceiveCalCipher(PSIMessageInterface::Ptr _msg); virtual void handlerPSIReceiveMessage(PSIMessageInterface::Ptr _msg); diff --git a/cpp/wedpr-computing/ppc-psi/src/ecdh-multi-psi/core/EcdhMultiPSICalculator.cpp b/cpp/wedpr-computing/ppc-psi/src/ecdh-multi-psi/core/EcdhMultiPSICalculator.cpp index c9e8b483..47d65bce 100644 --- a/cpp/wedpr-computing/ppc-psi/src/ecdh-multi-psi/core/EcdhMultiPSICalculator.cpp +++ b/cpp/wedpr-computing/ppc-psi/src/ecdh-multi-psi/core/EcdhMultiPSICalculator.cpp @@ -95,7 +95,7 @@ void EcdhMultiPSICalculator::blindData(std::string _taskID, bcos::bytes _randA) uint64_t dataOffset = 0; do { - if (m_taskState->loadFinished()) + if (m_taskState->loadFinished() || m_taskState->taskDone()) { break; } diff --git a/cpp/wedpr-computing/ppc-psi/src/ecdh-multi-psi/core/EcdhMultiPSIMaster.cpp b/cpp/wedpr-computing/ppc-psi/src/ecdh-multi-psi/core/EcdhMultiPSIMaster.cpp index 97295541..52632e4a 100644 --- a/cpp/wedpr-computing/ppc-psi/src/ecdh-multi-psi/core/EcdhMultiPSIMaster.cpp +++ b/cpp/wedpr-computing/ppc-psi/src/ecdh-multi-psi/core/EcdhMultiPSIMaster.cpp @@ -121,7 +121,7 @@ void EcdhMultiPSIMaster::blindData() auto reader = m_taskState->reader(); do { - if (m_taskState->loadFinished()) + if (m_taskState->loadFinished() || m_taskState->taskDone()) { break; } diff --git a/cpp/wedpr-computing/ppc-psi/src/ecdh-multi-psi/core/EcdhMultiPSIPartner.cpp b/cpp/wedpr-computing/ppc-psi/src/ecdh-multi-psi/core/EcdhMultiPSIPartner.cpp index c51617fc..67775c74 100644 --- a/cpp/wedpr-computing/ppc-psi/src/ecdh-multi-psi/core/EcdhMultiPSIPartner.cpp +++ b/cpp/wedpr-computing/ppc-psi/src/ecdh-multi-psi/core/EcdhMultiPSIPartner.cpp @@ -53,7 +53,7 @@ void EcdhMultiPSIPartner::onReceiveRandomA(bcos::bytesPointer _randA) uint64_t dataOffset = 0; do { - if (m_taskState->loadFinished()) + if (m_taskState->loadFinished() || m_taskState->taskDone()) { break; } diff --git a/cpp/wedpr-computing/ppc-psi/src/ecdh-psi/EcdhPSIImpl.cpp b/cpp/wedpr-computing/ppc-psi/src/ecdh-psi/EcdhPSIImpl.cpp index 694dfe86..3aa192bc 100644 --- a/cpp/wedpr-computing/ppc-psi/src/ecdh-psi/EcdhPSIImpl.cpp +++ b/cpp/wedpr-computing/ppc-psi/src/ecdh-psi/EcdhPSIImpl.cpp @@ -207,7 +207,7 @@ bool EcdhPSIImpl::initTaskState(TaskState::Ptr const& _taskState) if (!server) { m_cache->insertServerCipherCache(task->id(), _taskState); - if (!m_enableOutputExists) + if (!_taskState->task()->enableOutputExists()) { // Note: if the output-resource already exists, will throw exception m_config->dataResourceLoader()->checkResourceExists(dataResource->outputDesc()); diff --git a/cpp/wedpr-computing/ppc-psi/src/ecdh-psi/EcdhPSIImpl.h b/cpp/wedpr-computing/ppc-psi/src/ecdh-psi/EcdhPSIImpl.h index 107280bf..52887156 100644 --- a/cpp/wedpr-computing/ppc-psi/src/ecdh-psi/EcdhPSIImpl.h +++ b/cpp/wedpr-computing/ppc-psi/src/ecdh-psi/EcdhPSIImpl.h @@ -85,8 +85,5 @@ class EcdhPSIImpl : public PSIFramework, public std::enable_shared_from_thisnoticePeerToFinish(_task); }); + addPendingTask(taskState); // check the memory checkHostResource(m_config->minNeededMemoryGB()); - addPendingTask(taskState); auto oprfClient = std::make_shared( sizeof(apsi::Item::value_type) + sizeof(apsi::LabelKey), m_config->hash(), @@ -220,10 +220,11 @@ void LabeledPSIImpl::stop() LABELED_PSI_LOG(INFO) << LOG_DESC("LabeledPSI stopped"); } -void LabeledPSIImpl::onReceivedErrorNotification(const std::string& _taskID) +void LabeledPSIImpl::onReceivedErrorNotification(ppc::front::PPCMessageFace::Ptr const& _message) { + LABELED_PSI_LOG(WARNING) << LOG_DESC("onReceivedErrorNotification") << printPPCMsg(_message); // finish the task while the peer is failed - auto taskState = findPendingTask(_taskID); + auto taskState = findPendingTask(_message->taskID()); if (taskState) { taskState->onPeerNotifyFinish(); @@ -553,7 +554,7 @@ void LabeledPSIImpl::saveSenderCache(const ppc::protocol::Task::ConstPtr& _task) auto dataResource = _task->selfParty()->dataResource(); LineWriter::Ptr writer; - if (!m_enableOutputExists) + if (!_task->enableOutputExists()) { // Note: if the output-resource already exists, will throw exception m_config->dataResourceLoader()->checkResourceExists(dataResource->outputDesc()); @@ -690,7 +691,7 @@ void LabeledPSIImpl::handleReceivedMessage(const ppc::front::PPCMessageFace::Ptr { case int(CommonMessageType::ErrorNotification): { - psi->onReceivedErrorNotification(_message->taskID()); + psi->onReceivedErrorNotification(_message); break; } case int(CommonMessageType::PingPeer): diff --git a/cpp/wedpr-computing/ppc-psi/src/labeled-psi/LabeledPSIImpl.h b/cpp/wedpr-computing/ppc-psi/src/labeled-psi/LabeledPSIImpl.h index 9cd06c4e..0d213fa2 100644 --- a/cpp/wedpr-computing/ppc-psi/src/labeled-psi/LabeledPSIImpl.h +++ b/cpp/wedpr-computing/ppc-psi/src/labeled-psi/LabeledPSIImpl.h @@ -62,7 +62,7 @@ class LabeledPSIImpl : public bcos::Worker, void start() override; void stop() override; - void onReceivedErrorNotification(const std::string& _taskID) override; + void onReceivedErrorNotification(ppc::front::PPCMessageFace::Ptr const& _message) override; void onSelfError( const std::string& _taskID, bcos::Error::Ptr _error, bool _noticePeer) override; @@ -116,9 +116,6 @@ class LabeledPSIImpl : public bcos::Worker, } } -protected: - // allow the output-path exists, for ut - bool m_enableOutputExists = false; private: void waitSignal() diff --git a/cpp/wedpr-computing/ppc-psi/src/psi-framework/TaskGuarder.cpp b/cpp/wedpr-computing/ppc-psi/src/psi-framework/TaskGuarder.cpp index 07cf6732..208602b9 100644 --- a/cpp/wedpr-computing/ppc-psi/src/psi-framework/TaskGuarder.cpp +++ b/cpp/wedpr-computing/ppc-psi/src/psi-framework/TaskGuarder.cpp @@ -199,6 +199,8 @@ void TaskGuarder::checkPeerActivity() std::make_shared( (int)PSIRetCode::PeerNodeDown, "peer node is down, id: " + peerID), false); + PSI_LOG(INFO) << LOG_DESC("checkPeerActivity: peer node-down") + << LOG_KV("peer", peerID) << LOG_KV("task", task->id()); }, nullptr); } diff --git a/cpp/wedpr-computing/ppc-psi/src/psi-framework/TaskGuarder.h b/cpp/wedpr-computing/ppc-psi/src/psi-framework/TaskGuarder.h index c64c0604..8587d23c 100644 --- a/cpp/wedpr-computing/ppc-psi/src/psi-framework/TaskGuarder.h +++ b/cpp/wedpr-computing/ppc-psi/src/psi-framework/TaskGuarder.h @@ -44,7 +44,7 @@ class TaskGuarder } virtual ~TaskGuarder() = default; - virtual void onReceivedErrorNotification(const std::string& _taskID){}; + virtual void onReceivedErrorNotification(ppc::front::PPCMessageFace::Ptr const&){}; virtual void onSelfError( const std::string& _taskID, bcos::Error::Ptr _error, bool _noticePeer){}; diff --git a/cpp/wedpr-computing/ppc-psi/src/psi-framework/TaskState.cpp b/cpp/wedpr-computing/ppc-psi/src/psi-framework/TaskState.cpp index 2dbf0b07..1ea77322 100644 --- a/cpp/wedpr-computing/ppc-psi/src/psi-framework/TaskState.cpp +++ b/cpp/wedpr-computing/ppc-psi/src/psi-framework/TaskState.cpp @@ -201,13 +201,16 @@ void TaskState::removeGeneratedOutputFile() { return; } + if (!m_writer) + { + return; + } auto outputDataResource = m_task->selfParty()->dataResource(); if (!outputDataResource->desc()) { return; } - PSI_LOG(INFO) << LOG_DESC("removeGeneratedFilesForFailed") - << LOG_KV("task", printTaskInfo(m_task)); + PSI_LOG(INFO) << LOG_DESC("removeGeneratedOutputFile") << LOG_KV("task", printTaskInfo(m_task)); m_config->dataResourceLoader()->deleteResource(outputDataResource->desc()); } @@ -271,10 +274,11 @@ void TaskState::onTaskFinished(TaskResult::Ptr _result, bool _noticePeer) void TaskState::onPeerNotifyFinish() { PSI_LOG(WARNING) << LOG_BADGE("onReceivePeerError") << LOG_KV("taskID", m_task->id()); - auto tesult = std::make_shared(task()->id()); - tesult->setError(std::make_shared( + auto result = std::make_shared(task()->id()); + result->setError(std::make_shared( (int)PSIRetCode::PeerNotifyFinish, "job participant sent an error")); - onTaskFinished(std::move(tesult), false); + onTaskFinished(std::move(result), false); + removeGeneratedOutputFile(); } // Note: must store the result serially diff --git a/cpp/wedpr-computing/ppc-psi/src/ra2018-psi/RA2018PSIImpl.cpp b/cpp/wedpr-computing/ppc-psi/src/ra2018-psi/RA2018PSIImpl.cpp index 7b98051b..38f19ddb 100644 --- a/cpp/wedpr-computing/ppc-psi/src/ra2018-psi/RA2018PSIImpl.cpp +++ b/cpp/wedpr-computing/ppc-psi/src/ra2018-psi/RA2018PSIImpl.cpp @@ -646,7 +646,7 @@ void RA2018PSIImpl::runClientPSI(TaskState::Ptr const& _taskState) { return; } - if (!m_enableOutputExists) + if (!_taskState->task()->enableOutputExists()) { // Note: if the output-resource already exists, will throw exception m_config->dataResourceLoader()->checkResourceExists(dataResource->outputDesc()); diff --git a/cpp/wedpr-computing/ppc-psi/src/ra2018-psi/RA2018PSIImpl.h b/cpp/wedpr-computing/ppc-psi/src/ra2018-psi/RA2018PSIImpl.h index d47c7253..c3cf9a3c 100644 --- a/cpp/wedpr-computing/ppc-psi/src/ra2018-psi/RA2018PSIImpl.h +++ b/cpp/wedpr-computing/ppc-psi/src/ra2018-psi/RA2018PSIImpl.h @@ -179,8 +179,5 @@ class RA2018PSIImpl : public PSIFramework, public std::enable_shared_from_this; FakeCM2020PSIImpl(CM2020PSIConfig::Ptr const& _config, unsigned _idleTimeMs = 0) : CM2020PSIImpl(_config, _idleTimeMs) - { - m_enableOutputExists = true; - } + {} ~FakeCM2020PSIImpl() override = default; }; diff --git a/cpp/wedpr-computing/ppc-psi/tests/cm2020-psi/TestCM2020Impl.cpp b/cpp/wedpr-computing/ppc-psi/tests/cm2020-psi/TestCM2020Impl.cpp index c4fdf8eb..6014a1df 100644 --- a/cpp/wedpr-computing/ppc-psi/tests/cm2020-psi/TestCM2020Impl.cpp +++ b/cpp/wedpr-computing/ppc-psi/tests/cm2020-psi/TestCM2020Impl.cpp @@ -153,6 +153,7 @@ void testCM2020PSIImplFunc(const std::string& _taskID, const std::string& _param senderPSITask->setId(_taskID); senderPSITask->setParam(_params); senderPSITask->setSelf(_senderParty); + senderPSITask->setEnableOutputExists(true); senderPSITask->addParty(_receiverParty); senderPSITask->setSyncResultToPeer(_syncResults); senderPSITask->setAlgorithm((uint8_t)TaskAlgorithmType::CM_PSI_2PC); @@ -161,6 +162,7 @@ void testCM2020PSIImplFunc(const std::string& _taskID, const std::string& _param receiverPSITask->setId(_taskID); receiverPSITask->setParam(_params); receiverPSITask->setSelf(_receiverParty); + receiverPSITask->setEnableOutputExists(true); receiverPSITask->addParty(_senderParty); receiverPSITask->setSyncResultToPeer(_syncResults); receiverPSITask->setAlgorithm((uint8_t)TaskAlgorithmType::CM_PSI_2PC); diff --git a/cpp/wedpr-computing/ppc-psi/tests/labeled-psi/FakeLabeledPSIFactory.h b/cpp/wedpr-computing/ppc-psi/tests/labeled-psi/FakeLabeledPSIFactory.h index e369b6c1..350d861a 100644 --- a/cpp/wedpr-computing/ppc-psi/tests/labeled-psi/FakeLabeledPSIFactory.h +++ b/cpp/wedpr-computing/ppc-psi/tests/labeled-psi/FakeLabeledPSIFactory.h @@ -43,9 +43,7 @@ class FakeLabeledPSIImpl : public LabeledPSIImpl using Ptr = std::shared_ptr; FakeLabeledPSIImpl(LabeledPSIConfig::Ptr const& _config, unsigned _idleTimeMs = 0) : LabeledPSIImpl(_config, _idleTimeMs) - { - m_enableOutputExists = true; - } + {} ~FakeLabeledPSIImpl() override = default; }; diff --git a/cpp/wedpr-computing/ppc-psi/tests/labeled-psi/TestLabeledPSIImpl.cpp b/cpp/wedpr-computing/ppc-psi/tests/labeled-psi/TestLabeledPSIImpl.cpp index f341b6f2..62f6fd2e 100644 --- a/cpp/wedpr-computing/ppc-psi/tests/labeled-psi/TestLabeledPSIImpl.cpp +++ b/cpp/wedpr-computing/ppc-psi/tests/labeled-psi/TestLabeledPSIImpl.cpp @@ -59,6 +59,7 @@ void runSetup(LabeledPSIImpl::Ptr _sender) auto senderPSITask = std::make_shared(senderAgencyName); senderPSITask->setId("0x00000000"); + senderPSITask->setEnableOutputExists(true); senderPSITask->setParam(R"(["setup_sender_db","32"])"); senderPSITask->setSelf(senderParty); senderPSITask->setAlgorithm((uint8_t)TaskAlgorithmType::LABELED_PSI_2PC); @@ -95,6 +96,7 @@ void saveCache(LabeledPSIImpl::Ptr _sender) auto senderPSITask = std::make_shared(senderAgencyName); senderPSITask->setId("0x00000012"); + senderPSITask->setEnableOutputExists(true); senderPSITask->setParam(R"(["save_sender_cache"])"); senderPSITask->setSelf(senderParty); senderPSITask->setAlgorithm((uint8_t)TaskAlgorithmType::LABELED_PSI_2PC); @@ -234,6 +236,7 @@ void testLabeledPSIImplFunc(const std::string& _taskID, const std::string& _para // trigger the psi task auto senderPSITask = std::make_shared(senderAgencyName); senderPSITask->setId(_taskID); + senderPSITask->setEnableOutputExists(true); senderPSITask->setParam(_params); senderPSITask->setSelf(_senderParty); senderPSITask->setAlgorithm((uint8_t)TaskAlgorithmType::LABELED_PSI_2PC); @@ -241,6 +244,7 @@ void testLabeledPSIImplFunc(const std::string& _taskID, const std::string& _para auto receiverPSITask = std::make_shared(receiverAgencyName); receiverPSITask->setId(_taskID); + receiverPSITask->setEnableOutputExists(true); receiverPSITask->setSelf(_receiverParty); receiverPSITask->setAlgorithm((uint8_t)TaskAlgorithmType::LABELED_PSI_2PC); receiverPSITask->addParty(_senderParty); diff --git a/cpp/wedpr-computing/ppc-psi/tests/ra2018-psi/TestEcdhPSIImpl.cpp b/cpp/wedpr-computing/ppc-psi/tests/ra2018-psi/TestEcdhPSIImpl.cpp index 856eabbc..c0083254 100644 --- a/cpp/wedpr-computing/ppc-psi/tests/ra2018-psi/TestEcdhPSIImpl.cpp +++ b/cpp/wedpr-computing/ppc-psi/tests/ra2018-psi/TestEcdhPSIImpl.cpp @@ -87,6 +87,8 @@ void testEcdhImplFunc(int64_t _dataBatchSize, std::string const& _serverPSIDataS auto clientPSITask = std::make_shared(clientAgencyName); std::string taskID = "runPSI"; clientPSITask->setId(taskID); + clientPSITask->setEnableOutputExists(true); + clientPSITask->setEnableOutputExists(true); clientPSITask->setType((int8_t)TaskType::PSI); clientPSITask->setAlgorithm((int8_t)TaskAlgorithmType::ECDH_PSI_2PC); clientPSITask->setSelf(clientParty); @@ -103,6 +105,8 @@ void testEcdhImplFunc(int64_t _dataBatchSize, std::string const& _serverPSIDataS { serverPSITask->setId(taskID); } + serverPSITask->setEnableOutputExists(true); + serverPSITask->setEnableOutputExists(true); serverPSITask->setType((int8_t)TaskType::PSI); serverPSITask->setAlgorithm((int8_t)TaskAlgorithmType::ECDH_PSI_2PC); serverPSITask->addParty(clientParty); diff --git a/cpp/wedpr-computing/ppc-psi/tests/ra2018-psi/TestRA2018Impl.cpp b/cpp/wedpr-computing/ppc-psi/tests/ra2018-psi/TestRA2018Impl.cpp index 261cd401..2e6c06fa 100644 --- a/cpp/wedpr-computing/ppc-psi/tests/ra2018-psi/TestRA2018Impl.cpp +++ b/cpp/wedpr-computing/ppc-psi/tests/ra2018-psi/TestRA2018Impl.cpp @@ -105,6 +105,7 @@ void testRA2018PSIImplFunc(int _dataBatchSize, CuckoofilterOption::Ptr option, auto offlineFullEvaluateTask = std::make_shared(serverAgencyName); offlineFullEvaluateTask->setId("offlineFullEvaluate"); offlineFullEvaluateTask->setSelf(serverParty); + offlineFullEvaluateTask->setEnableOutputExists(true); // insert operation std::string param = "[\"data_preprocessing\", 0]"; offlineFullEvaluateTask->setParam(param); @@ -127,6 +128,7 @@ void testRA2018PSIImplFunc(int _dataBatchSize, CuckoofilterOption::Ptr option, auto clientPSITask = std::make_shared(clientAgencyName); std::string taskID = "runPSI"; clientPSITask->setId(taskID); + clientPSITask->setEnableOutputExists(true); clientPSITask->setSelf(clientParty); clientPSITask->addParty(serverParty); param = "[\"ra2018_psi\"]"; @@ -135,6 +137,7 @@ void testRA2018PSIImplFunc(int _dataBatchSize, CuckoofilterOption::Ptr option, // the server task auto serverPSITask = std::make_shared(serverAgencyName); serverPSITask->setSelf(serverParty); + serverPSITask->setEnableOutputExists(true); serverPSITask->setId(taskID); serverPSITask->setParam(param); serverPSITask->addParty(clientParty); diff --git a/cpp/wedpr-computing/ppc-psi/tests/ra2018-psi/mock/EcdhPSIFixture.h b/cpp/wedpr-computing/ppc-psi/tests/ra2018-psi/mock/EcdhPSIFixture.h index 37209c9a..4404c8b3 100644 --- a/cpp/wedpr-computing/ppc-psi/tests/ra2018-psi/mock/EcdhPSIFixture.h +++ b/cpp/wedpr-computing/ppc-psi/tests/ra2018-psi/mock/EcdhPSIFixture.h @@ -49,7 +49,6 @@ class FakeEcdhPSIImpl : public EcdhPSIImpl { // set the m_started flag to be true m_started = true; - m_enableOutputExists = true; m_taskSyncTimer->registerTimeoutHandler([this]() { syncTaskInfo(); }); m_taskSyncTimer->start(); } diff --git a/cpp/wedpr-computing/ppc-psi/tests/ra2018-psi/mock/RA2018PSIFixture.h b/cpp/wedpr-computing/ppc-psi/tests/ra2018-psi/mock/RA2018PSIFixture.h index c0bb0d34..83eaffbe 100644 --- a/cpp/wedpr-computing/ppc-psi/tests/ra2018-psi/mock/RA2018PSIFixture.h +++ b/cpp/wedpr-computing/ppc-psi/tests/ra2018-psi/mock/RA2018PSIFixture.h @@ -51,7 +51,6 @@ class FakeRA2018Impl : public RA2018PSIImpl { // set the m_started flag to be true m_started = true; - m_enableOutputExists = true; m_taskSyncTimer->registerTimeoutHandler([this]() { syncTaskInfo(); }); m_taskSyncTimer->start(); } diff --git a/cpp/wedpr-main/air-node/AirNodeInitializer.cpp b/cpp/wedpr-main/air-node/AirNodeInitializer.cpp index 38c6709d..a9636629 100644 --- a/cpp/wedpr-main/air-node/AirNodeInitializer.cpp +++ b/cpp/wedpr-main/air-node/AirNodeInitializer.cpp @@ -62,7 +62,8 @@ void AirNodeInitializer::init(std::string const& _configPath) INIT_LOG(INFO) << LOG_DESC("init the rpc"); // init RpcStatusInterface - RpcStatusInterface::Ptr rpcStatusInterface = std::make_shared(); + RpcStatusInterface::Ptr rpcStatusInterface = + std::make_shared(m_nodeInitializer->ppcFront()); auto rpcFactory = std::make_shared(m_nodeInitializer->config()->agencyID()); diff --git a/cpp/wedpr-main/pro-node/ProNodeInitializer.cpp b/cpp/wedpr-main/pro-node/ProNodeInitializer.cpp index 4d54658e..f92750c8 100644 --- a/cpp/wedpr-main/pro-node/ProNodeInitializer.cpp +++ b/cpp/wedpr-main/pro-node/ProNodeInitializer.cpp @@ -53,7 +53,8 @@ void ProNodeInitializer::init(std::string const& _configPath) INIT_LOG(INFO) << LOG_DESC("init the rpc"); // init RpcStatusInterface - RpcStatusInterface::Ptr rpcStatusInterface = std::make_shared(); + RpcStatusInterface::Ptr rpcStatusInterface = + std::make_shared(m_nodeInitializer->ppcFront()); auto rpcFactory = std::make_shared(m_nodeInitializer->config()->agencyID()); diff --git a/cpp/wedpr-protocol/protocol/src/JsonTaskImpl.cpp b/cpp/wedpr-protocol/protocol/src/JsonTaskImpl.cpp index 6e7af353..a0052b37 100644 --- a/cpp/wedpr-protocol/protocol/src/JsonTaskImpl.cpp +++ b/cpp/wedpr-protocol/protocol/src/JsonTaskImpl.cpp @@ -54,6 +54,10 @@ void JsonTaskImpl::decodeJsonValue(Json::Value const& root) { BOOST_THROW_EXCEPTION(InvalidParameter() << errinfo_comment("Must specify the taskType")); } + if (root.isMember("enableOutputExists")) + { + m_enableOutputExists = root["enableOutputExists"].asBool(); + } checkNull(root["type"], "taskType"); m_type = root["type"].asUInt(); // the taskAlgorithm @@ -309,6 +313,7 @@ std::string JsonTaskImpl::encode() const // sync-result or not taskInfo["syncResult"] = m_syncResultToPeer; taskInfo["lowBandwidth"] = m_lowBandwidth; + taskInfo["enableOutputExists"] = m_enableOutputExists; Json::Value receiverList; for (auto const& it : m_receiverLists) diff --git a/cpp/wedpr-storage/ppc-io/src/DataResourceLoaderImpl.cpp b/cpp/wedpr-storage/ppc-io/src/DataResourceLoaderImpl.cpp index 00d9cfdf..6d951bf4 100644 --- a/cpp/wedpr-storage/ppc-io/src/DataResourceLoaderImpl.cpp +++ b/cpp/wedpr-storage/ppc-io/src/DataResourceLoaderImpl.cpp @@ -213,7 +213,7 @@ void DataResourceLoaderImpl::deleteResource( { BOOST_THROW_EXCEPTION( UnSupportedDataResource() << errinfo_comment( - "checkResourceExists: Only support File/HDFS now, passed in resource type: " + + "deleteResource: Only support File/HDFS now, passed in resource type: " + std::to_string(_desc->type()))); } } @@ -264,7 +264,7 @@ void DataResourceLoaderImpl::renameResource(ppc::protocol::DataResourceDesc::Con { BOOST_THROW_EXCEPTION( UnSupportedDataResource() << errinfo_comment( - "checkResourceExists: Only support File/HDFS now, passed in resource type: " + + "renameResource: Only support File/HDFS now, passed in resource type: " + std::to_string(_desc->type()))); } } diff --git a/cpp/wedpr-transport/ppc-rpc/src/RpcMemory.cpp b/cpp/wedpr-transport/ppc-rpc/src/RpcMemory.cpp index 86074c47..31f923d1 100644 --- a/cpp/wedpr-transport/ppc-rpc/src/RpcMemory.cpp +++ b/cpp/wedpr-transport/ppc-rpc/src/RpcMemory.cpp @@ -72,9 +72,26 @@ void RpcMemory::cleanTask() bcos::Error::Ptr RpcMemory::insertTask(protocol::Task::Ptr _task) { WriteGuard l(x_tasks); - if (m_tasks.find(_task->id()) != m_tasks.end()) + auto it = m_tasks.find(_task->id()); + if (it != m_tasks.end()) { - return std::make_shared(PPCRetCode::WRITE_RPC_STATUS_ERROR, "task exists"); + auto taskResult = it->second.second; + // the task already exists case + if (!taskResult || taskResult->status() == toString(TaskStatus::RUNNING)) + { + return std::make_shared(PPCRetCode::WRITE_RPC_STATUS_ERROR, "task exists"); + } + if (taskResult) + { + RPC_STATUS_LOG(INFO) << LOG_DESC("find the existed not running-task") + << LOG_KV("task", _task->id()) + << LOG_KV("status", taskResult->status()); + if (taskResult->status() != toString(TaskState::COMPLETED)) + { + // erase the task_id + m_front->eraseTaskInfo(_task->id()); + } + } } auto taskResult = std::make_shared(_task->id()); taskResult->setStatus(toString(TaskStatus::RUNNING)); @@ -107,15 +124,4 @@ TaskResult::Ptr RpcMemory::getTaskStatus(const std::string& _taskID) } return m_tasks[_taskID].second; -} - - -bcos::Error::Ptr RpcMemory::deleteGateway(const std::string& _agencyID) -{ - return nullptr; -} - -std::vector RpcMemory::listGateway() -{ - return {}; -} +} \ No newline at end of file diff --git a/cpp/wedpr-transport/ppc-rpc/src/RpcMemory.h b/cpp/wedpr-transport/ppc-rpc/src/RpcMemory.h index 78a447c5..bb2b0a3c 100644 --- a/cpp/wedpr-transport/ppc-rpc/src/RpcMemory.h +++ b/cpp/wedpr-transport/ppc-rpc/src/RpcMemory.h @@ -20,6 +20,7 @@ */ #pragma once +#include "ppc-framework/front/FrontInterface.h" #include "ppc-framework/rpc/RpcStatusInterface.h" #include #include @@ -31,7 +32,10 @@ class RpcMemory : public RpcStatusInterface public: using Ptr = std::shared_ptr; - RpcMemory() : m_taskCleaner(std::make_shared(60 * 60 * 1000, "taskCleaner")) {} + RpcMemory(ppc::front::FrontInterface::Ptr front) + : m_front(std::move(front)), + m_taskCleaner(std::make_shared(60 * 60 * 1000, "taskCleaner")) + {} ~RpcMemory() override = default; void start() override; @@ -40,13 +44,13 @@ class RpcMemory : public RpcStatusInterface bcos::Error::Ptr insertTask(protocol::Task::Ptr _task) override; bcos::Error::Ptr updateTaskStatus(protocol::TaskResult::Ptr _taskResult) override; protocol::TaskResult::Ptr getTaskStatus(const std::string& _taskID) override; - bcos::Error::Ptr deleteGateway(const std::string& _agencyID) override; - std::vector listGateway() override; protected: void cleanTask(); private: + ppc::front::FrontInterface::Ptr m_front; + mutable bcos::SharedMutex x_tasks; std::unordered_map> m_tasks; std::shared_ptr m_taskCleaner;