Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix cm2020 check hostResource #101

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cpp/ppc-framework/protocol/DataResource.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ class DataResource

inline std::string printDataResourceInfo(DataResource::ConstPtr _dataResource)
{
if (!_dataResource)
{
return "empty";
}
std::ostringstream stringstream;
stringstream << LOG_KV("dataResource", _dataResource->resourceID());
if (_dataResource->desc())
Expand Down
4 changes: 4 additions & 0 deletions cpp/ppc-framework/protocol/PartyResource.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ using ConstParties = std::vector<PartyResource::ConstPtr>;

inline std::string printPartyInfo(PartyResource::ConstPtr _party)
{
if (!_party)
{
return "empty";
}
std::ostringstream stringstream;
stringstream << LOG_KV("partyId", _party->id()) << LOG_KV("partyIndex", _party->partyIndex())
<< LOG_KV("desc", _party->desc());
Expand Down
14 changes: 14 additions & 0 deletions cpp/ppc-framework/protocol/Task.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,15 @@ class Task
// decode the task
virtual void decode(std::string_view _taskData) = 0;
virtual std::string encode() const = 0;

virtual bool enableOutputExists() const { return m_enableOutputExists; }
virtual void setEnableOutputExists(bool enableOutputExists)
{
m_enableOutputExists = enableOutputExists;
}

protected:
bool m_enableOutputExists = false;
};

class TaskFactory
Expand All @@ -160,9 +169,14 @@ class TaskFactory
inline std::string printTaskInfo(Task::ConstPtr _task)
{
std::ostringstream stringstream;
if (!_task)
{
return "empty";
}
stringstream << LOG_KV("id", _task->id())
<< LOG_KV("type", (ppc::protocol::TaskType)_task->type())
<< LOG_KV("algorithm", (ppc::protocol::TaskAlgorithmType)_task->algorithm())
<< LOG_KV("enableOutputExists", _task->enableOutputExists())
<< LOG_KV("taskPtr", _task);
if (_task->selfParty())
{
Expand Down
2 changes: 0 additions & 2 deletions cpp/ppc-framework/rpc/RpcStatusInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,5 @@ class RpcStatusInterface
virtual bcos::Error::Ptr insertTask(protocol::Task::Ptr _task) = 0;
virtual bcos::Error::Ptr updateTaskStatus(protocol::TaskResult::Ptr _taskResult) = 0;
virtual protocol::TaskResult::Ptr getTaskStatus(const std::string& _taskID) = 0;
virtual bcos::Error::Ptr deleteGateway(const std::string& _agencyID) = 0;
virtual std::vector<protocol::GatewayInfo> listGateway() = 0;
};
} // namespace ppc::rpc
10 changes: 6 additions & 4 deletions cpp/wedpr-computing/ppc-pir/src/OtPIRImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,11 @@ void OtPIRImpl::onReceiveMessage(ppc::front::PPCMessageFace::Ptr _msg)
}
}

void OtPIRImpl::onReceivedErrorNotification(const std::string& _taskID)
void OtPIRImpl::onReceivedErrorNotification(ppc::front::PPCMessageFace::Ptr const& _message)
{
PIR_LOG(WARNING) << LOG_DESC("onReceivedErrorNotification") << printPPCMsg(_message);
// finish the task while the peer is failed
auto taskState = findPendingTask(_taskID);
auto taskState = findPendingTask(_message->taskID());
if (taskState)
{
taskState->onPeerNotifyFinish();
Expand Down Expand Up @@ -242,7 +243,7 @@ void OtPIRImpl::handleReceivedMessage(const ppc::front::PPCMessageFace::Ptr& _me
{
case int(CommonMessageType::ErrorNotification):
{
pir->onReceivedErrorNotification(_message->taskID());
pir->onReceivedErrorNotification(_message);
break;
}
case int(CommonMessageType::PingPeer):
Expand Down Expand Up @@ -444,7 +445,8 @@ void OtPIRImpl::asyncRunTask()
<< LOG_KV("requestAgencyDataset", taskMessage.requestAgencyDataset)
<< LOG_KV("prefixLength", taskMessage.prefixLength)
<< LOG_KV("searchId", taskMessage.searchId);
auto writer = loadWriter(task->id(), dataResource, m_enableOutputExists);
auto writer =
loadWriter(task->id(), dataResource, m_taskState->task()->enableOutputExists());
m_taskState->setWriter(writer);
runSenderGenerateCipher(taskMessage);
}
Expand Down
4 changes: 1 addition & 3 deletions cpp/wedpr-computing/ppc-pir/src/OtPIRImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class OtPIRImpl : public std::enable_shared_from_this<OtPIRImpl>,
// register to the front to get the message related to ot-pir
void onReceiveMessage(ppc::front::PPCMessageFace::Ptr _message) override;

void onReceivedErrorNotification(const std::string& _taskID) override;
void onReceivedErrorNotification(ppc::front::PPCMessageFace::Ptr const& _message) override;
void onSelfError(
const std::string& _taskID, bcos::Error::Ptr _error, bool _noticePeer) override;

Expand Down Expand Up @@ -150,8 +150,6 @@ class OtPIRImpl : public std::enable_shared_from_this<OtPIRImpl>,
m_senders.erase(it);
}
}
// allow the output-path exists, for ut
bool m_enableOutputExists = false;
// 为true时启动时会从配置中加载文件作为匹配源
bool m_enableMemoryFile = false;
ppc::protocol::DataResource m_resource;
Expand Down
4 changes: 1 addition & 3 deletions cpp/wedpr-computing/ppc-pir/tests/FakeOtPIRFactory.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ class FakeOtPIRImpl : public OtPIRImpl
using Ptr = std::shared_ptr<FakeOtPIRImpl>;
FakeOtPIRImpl(OtPIRConfig::Ptr const& _config, unsigned _idleTimeMs = 0)
: OtPIRImpl(_config, _idleTimeMs)
{
m_enableOutputExists = true;
}
{}
~FakeOtPIRImpl() override = default;
};

Expand Down
2 changes: 2 additions & 0 deletions cpp/wedpr-computing/ppc-pir/tests/TestBaseOT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ void testOTPIRImplFunc(const std::string& _taskID, const std::string& _params, b
auto senderPIRTask = std::make_shared<JsonTaskImpl>(senderAgencyName);
senderPIRTask->setId(_taskID);
senderPIRTask->setParam(_params);
senderPIRTask->setEnableOutputExists(true);
senderPIRTask->setSelf(_senderParty);
senderPIRTask->addParty(_receiverParty);
senderPIRTask->setSyncResultToPeer(_syncResults);
Expand All @@ -203,6 +204,7 @@ void testOTPIRImplFunc(const std::string& _taskID, const std::string& _params, b
auto receiverPIRTask = std::make_shared<JsonTaskImpl>(receiverAgencyName);
receiverPIRTask->setId(_taskID);
receiverPIRTask->setParam(_params);
receiverPIRTask->setEnableOutputExists(true);
receiverPIRTask->setSelf(_receiverParty);
receiverPIRTask->addParty(_senderParty);
receiverPIRTask->setSyncResultToPeer(_syncResults);
Expand Down
18 changes: 9 additions & 9 deletions cpp/wedpr-computing/ppc-psi/src/cm2020-psi/CM2020PSIImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,14 @@ void CM2020PSIImpl::asyncRunTask()
{
return;
}

CM2020_PSI_LOG(INFO) << LOG_DESC("noticePeerToFinish") << printTaskInfo(task);
psi->noticePeerToFinish(task);
});
// check the memory
checkHostResource(m_config->minNeededMemoryGB());
addPendingTask(taskState);

try
{
addPendingTask(taskState);
// check the memory
checkHostResource(m_config->minNeededMemoryGB());
// prepare reader and writer
auto dataResource = task->selfParty()->dataResource();
auto reader = loadReader(task->id(), dataResource, DataSchema::Bytes);
Expand All @@ -169,7 +168,7 @@ void CM2020PSIImpl::asyncRunTask()
auto role = task->selfParty()->partyIndex();
if (role == uint16_t(PartyType::Client) || task->syncResultToPeer())
{
auto writer = loadWriter(task->id(), dataResource, m_enableOutputExists);
auto writer = loadWriter(task->id(), dataResource, task->enableOutputExists());
taskState->setWriter(writer);
}

Expand Down Expand Up @@ -319,10 +318,11 @@ void CM2020PSIImpl::stop()
CM2020_PSI_LOG(INFO) << LOG_DESC("CM2020-PSI stopped");
}

void CM2020PSIImpl::onReceivedErrorNotification(const std::string& _taskID)
void CM2020PSIImpl::onReceivedErrorNotification(ppc::front::PPCMessageFace::Ptr const& _message)
{
CM2020_PSI_LOG(INFO) << LOG_DESC("onReceivedErrorNotification") << printPPCMsg(_message);
// finish the task while the peer is failed
auto taskState = findPendingTask(_taskID);
auto taskState = findPendingTask(_message->taskID());
if (taskState)
{
taskState->onPeerNotifyFinish();
Expand Down Expand Up @@ -410,7 +410,7 @@ void CM2020PSIImpl::handleReceivedMessage(const ppc::front::PPCMessageFace::Ptr&
{
case int(CommonMessageType::ErrorNotification):
{
psi->onReceivedErrorNotification(_message->taskID());
psi->onReceivedErrorNotification(_message);
break;
}
case int(CommonMessageType::PingPeer):
Expand Down
6 changes: 1 addition & 5 deletions cpp/wedpr-computing/ppc-psi/src/cm2020-psi/CM2020PSIImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class CM2020PSIImpl : public bcos::Worker,

void start() override;
void stop() override;
void onReceivedErrorNotification(const std::string& _taskID) override;
void onReceivedErrorNotification(ppc::front::PPCMessageFace::Ptr const& _message) override;
void onSelfError(
const std::string& _taskID, bcos::Error::Ptr _error, bool _noticePeer) override;

Expand Down Expand Up @@ -148,10 +148,6 @@ class CM2020PSIImpl : public bcos::Worker,
}
}

protected:
// allow the output-path exists, for ut
bool m_enableOutputExists = false;

private:
void waitSignal()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ void EcdhConnPSIImpl::asyncRunTask(
if (role == uint16_t(PartyType::Client))
{
ECDH_CONN_LOG(INFO) << LOG_DESC("Client do the Task") << LOG_KV("taskID", _task->id());
auto writer = loadWriter(_task->id(), dataResource, m_enableOutputExists);
auto writer = loadWriter(_task->id(), dataResource, _task->enableOutputExists());
ecdhTaskState->setWriter(writer);
auto client = std::make_shared<EcdhConnPSIClient>(m_config, ecdhTaskState);
addClient(client);
Expand All @@ -66,7 +66,7 @@ void EcdhConnPSIImpl::asyncRunTask(
ECDH_CONN_LOG(INFO) << LOG_DESC("Server do the Task") << LOG_KV("taskID", _task->id());
if (_task->syncResultToPeer())
{
auto writer = loadWriter(_task->id(), dataResource, m_enableOutputExists);
auto writer = loadWriter(_task->id(), dataResource, _task->enableOutputExists());
ecdhTaskState->setWriter(writer);
}
auto server = std::make_shared<EcdhConnPSIServer>(m_config, ecdhTaskState);
Expand Down Expand Up @@ -131,7 +131,7 @@ void EcdhConnPSIImpl::stop()
}
}

void EcdhConnPSIImpl::onReceivedErrorNotification(const std::string& _taskID) {}
void EcdhConnPSIImpl::onReceivedErrorNotification(ppc::front::PPCMessageFace::Ptr const&) {}

void EcdhConnPSIImpl::onSelfError(
const std::string& _taskID, bcos::Error::Ptr _error, bool _noticePeer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,11 @@ class EcdhConnPSIImpl : public std::enable_shared_from_this<EcdhConnPSIImpl>,
void start() override;
void stop() override;

void onReceivedErrorNotification(const std::string& _taskID) override;
void onReceivedErrorNotification(ppc::front::PPCMessageFace::Ptr const& _message) override;
void onSelfError(
const std::string& _taskID, bcos::Error::Ptr _error, bool _noticePeer) override;
void executeWorker() override;

protected:
bool m_enableOutputExists = false;

void handlerPSIReceiveMessage(PSIConnMessage::Ptr _msg);
void onHandShakeRequestHandler(const std::string& _taskId, const bcos::bytes& _msg);
void onHandShakeResponseHandler(const std::string& _taskId, const bcos::bytes& _msg);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,9 @@ void EcdhMultiPSIImpl::asyncRunTask(
psi->removePartner(taskID);
psi->removePendingTask(taskID);
});
addPendingTask(taskState);
// check the memory
checkHostResource(m_config->minNeededMemoryGB());
addPendingTask(taskState);
// over the peer limit
if (_task->getAllPeerParties().size() > c_max_peer_size)
{
Expand All @@ -155,7 +155,7 @@ void EcdhMultiPSIImpl::asyncRunTask(
<< LOG_KV("roleId", role);
if (role == uint16_t(PartiesType::Calculator))
{
auto writer = loadWriter(_task->id(), dataResource, m_enableOutputExists);
auto writer = loadWriter(_task->id(), dataResource, _task->enableOutputExists());
taskState->setWriter(writer);
ECDH_MULTI_LOG(INFO) << LOG_DESC("Calculator do the Task")
<< LOG_KV("taskID", _task->id());
Expand All @@ -170,7 +170,7 @@ void EcdhMultiPSIImpl::asyncRunTask(
if (_task->syncResultToPeer() && std::find(receivers.begin(), receivers.end(),
m_config->selfParty()) != receivers.end())
{
auto writer = loadWriter(_task->id(), dataResource, m_enableOutputExists);
auto writer = loadWriter(_task->id(), dataResource, _task->enableOutputExists());
taskState->setWriter(writer);
}
auto partner = std::make_shared<EcdhMultiPSIPartner>(m_config, taskState);
Expand All @@ -183,7 +183,7 @@ void EcdhMultiPSIImpl::asyncRunTask(
if (_task->syncResultToPeer() && std::find(receivers.begin(), receivers.end(),
m_config->selfParty()) != receivers.end())
{
auto writer = loadWriter(_task->id(), dataResource, m_enableOutputExists);
auto writer = loadWriter(_task->id(), dataResource, _task->enableOutputExists());
taskState->setWriter(writer);
}
auto master = std::make_shared<EcdhMultiPSIMaster>(m_config, taskState);
Expand Down Expand Up @@ -266,11 +266,11 @@ void EcdhMultiPSIImpl::checkFinishedTask()
}
}

void EcdhMultiPSIImpl::onReceivedErrorNotification(const std::string& _taskID)
void EcdhMultiPSIImpl::onReceivedErrorNotification(ppc::front::PPCMessageFace::Ptr const& _message)
{
ECDH_MULTI_LOG(INFO) << LOG_DESC("onReceivedErrorNotification") << LOG_KV("taskID", _taskID);
ECDH_MULTI_LOG(INFO) << LOG_DESC("onReceivedErrorNotification") << printPPCMsg(_message);
// finish the task while the peer is failed
auto taskState = findPendingTask(_taskID);
auto taskState = findPendingTask(_message->taskID());
if (taskState)
{
taskState->onPeerNotifyFinish();
Expand Down Expand Up @@ -308,7 +308,7 @@ void EcdhMultiPSIImpl::executeWorker()
auto pop_msg = _msg.second;
if (pop_msg->messageType() == uint8_t(CommonMessageType::ErrorNotification))
{
onReceivedErrorNotification(pop_msg->taskID());
onReceivedErrorNotification(pop_msg);
return;
}
else if (pop_msg->messageType() == uint8_t(CommonMessageType::PingPeer))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,13 @@ class EcdhMultiPSIImpl : public std::enable_shared_from_this<EcdhMultiPSIImpl>,
void stop() override;

void checkFinishedTask();
void onReceivedErrorNotification(const std::string& _taskID) override;
void onReceivedErrorNotification(ppc::front::PPCMessageFace::Ptr const& _message) override;
void onSelfError(
const std::string& _taskID, bcos::Error::Ptr _error, bool _noticePeer) override;
void executeWorker() override;


protected:
bool m_enableOutputExists = false;
virtual void onReceiveRandomA(PSIMessageInterface::Ptr _msg);
virtual void onReceiveCalCipher(PSIMessageInterface::Ptr _msg);
virtual void handlerPSIReceiveMessage(PSIMessageInterface::Ptr _msg);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ void EcdhMultiPSICalculator::blindData(std::string _taskID, bcos::bytes _randA)
uint64_t dataOffset = 0;
do
{
if (m_taskState->loadFinished())
if (m_taskState->loadFinished() || m_taskState->taskDone())
{
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ void EcdhMultiPSIMaster::blindData()
auto reader = m_taskState->reader();
do
{
if (m_taskState->loadFinished())
if (m_taskState->loadFinished() || m_taskState->taskDone())
{
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ void EcdhMultiPSIPartner::onReceiveRandomA(bcos::bytesPointer _randA)
uint64_t dataOffset = 0;
do
{
if (m_taskState->loadFinished())
if (m_taskState->loadFinished() || m_taskState->taskDone())
{
break;
}
Expand Down
2 changes: 1 addition & 1 deletion cpp/wedpr-computing/ppc-psi/src/ecdh-psi/EcdhPSIImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ bool EcdhPSIImpl::initTaskState(TaskState::Ptr const& _taskState)
if (!server)
{
m_cache->insertServerCipherCache(task->id(), _taskState);
if (!m_enableOutputExists)
if (!_taskState->task()->enableOutputExists())
{
// Note: if the output-resource already exists, will throw exception
m_config->dataResourceLoader()->checkResourceExists(dataResource->outputDesc());
Expand Down
3 changes: 0 additions & 3 deletions cpp/wedpr-computing/ppc-psi/src/ecdh-psi/EcdhPSIImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,5 @@ class EcdhPSIImpl : public PSIFramework, public std::enable_shared_from_this<Ecd
protected:
EcdhPSIConfig::Ptr m_config;
EcdhCache::Ptr m_cache;

// allow the output-path exists, for ut
bool m_enableOutputExists = false;
};
} // namespace ppc::psi
Loading
Loading