Skip to content

Commit

Permalink
psi support retry
Browse files Browse the repository at this point in the history
  • Loading branch information
cyjseagull committed Nov 25, 2024
1 parent 860fcc0 commit 2f81eaa
Show file tree
Hide file tree
Showing 30 changed files with 66 additions and 44 deletions.
4 changes: 4 additions & 0 deletions cpp/ppc-framework/protocol/DataResource.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ class DataResource

inline std::string printDataResourceInfo(DataResource::ConstPtr _dataResource)
{
if (!_dataResource)
{
return "empty";
}
std::ostringstream stringstream;
stringstream << LOG_KV("dataResource", _dataResource->resourceID());
if (_dataResource->desc())
Expand Down
4 changes: 4 additions & 0 deletions cpp/ppc-framework/protocol/PartyResource.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ using ConstParties = std::vector<PartyResource::ConstPtr>;

inline std::string printPartyInfo(PartyResource::ConstPtr _party)
{
if (!_party)
{
return "empty";
}
std::ostringstream stringstream;
stringstream << LOG_KV("partyId", _party->id()) << LOG_KV("partyIndex", _party->partyIndex())
<< LOG_KV("desc", _party->desc());
Expand Down
14 changes: 14 additions & 0 deletions cpp/ppc-framework/protocol/Task.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,15 @@ class Task
// decode the task
virtual void decode(std::string_view _taskData) = 0;
virtual std::string encode() const = 0;

virtual bool enableOutputExists() const { return m_enableOutputExists; }
virtual void setEnableOutputExists(bool enableOutputExists)
{
m_enableOutputExists = enableOutputExists;
}

protected:
bool m_enableOutputExists = false;
};

class TaskFactory
Expand All @@ -160,9 +169,14 @@ class TaskFactory
inline std::string printTaskInfo(Task::ConstPtr _task)
{
std::ostringstream stringstream;
if (!task)
{
return "empty";
}
stringstream << LOG_KV("id", _task->id())
<< LOG_KV("type", (ppc::protocol::TaskType)_task->type())
<< LOG_KV("algorithm", (ppc::protocol::TaskAlgorithmType)_task->algorithm())
<< LOG_KV("enableOutputExists", _task->enableOutputExists())
<< LOG_KV("taskPtr", _task);
if (_task->selfParty())
{
Expand Down
3 changes: 2 additions & 1 deletion cpp/wedpr-computing/ppc-pir/src/OtPIRImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,8 @@ void OtPIRImpl::asyncRunTask()
<< LOG_KV("requestAgencyDataset", taskMessage.requestAgencyDataset)
<< LOG_KV("prefixLength", taskMessage.prefixLength)
<< LOG_KV("searchId", taskMessage.searchId);
auto writer = loadWriter(task->id(), dataResource, m_enableOutputExists);
auto writer =
loadWriter(task->id(), dataResource, m_taskState->task()->enableOutputExists());
m_taskState->setWriter(writer);
runSenderGenerateCipher(taskMessage);
}
Expand Down
2 changes: 0 additions & 2 deletions cpp/wedpr-computing/ppc-pir/src/OtPIRImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,6 @@ class OtPIRImpl : public std::enable_shared_from_this<OtPIRImpl>,
m_senders.erase(it);
}
}
// allow the output-path exists, for ut
bool m_enableOutputExists = false;
// 为true时启动时会从配置中加载文件作为匹配源
bool m_enableMemoryFile = false;
ppc::protocol::DataResource m_resource;
Expand Down
4 changes: 1 addition & 3 deletions cpp/wedpr-computing/ppc-pir/tests/FakeOtPIRFactory.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ class FakeOtPIRImpl : public OtPIRImpl
using Ptr = std::shared_ptr<FakeOtPIRImpl>;
FakeOtPIRImpl(OtPIRConfig::Ptr const& _config, unsigned _idleTimeMs = 0)
: OtPIRImpl(_config, _idleTimeMs)
{
m_enableOutputExists = true;
}
{}
~FakeOtPIRImpl() override = default;
};

Expand Down
2 changes: 2 additions & 0 deletions cpp/wedpr-computing/ppc-pir/tests/TestBaseOT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ void testOTPIRImplFunc(const std::string& _taskID, const std::string& _params, b
auto senderPIRTask = std::make_shared<JsonTaskImpl>(senderAgencyName);
senderPIRTask->setId(_taskID);
senderPIRTask->setParam(_params);
senderPIRTask->setEnableOutputExists(true);
senderPIRTask->setSelf(_senderParty);
senderPIRTask->addParty(_receiverParty);
senderPIRTask->setSyncResultToPeer(_syncResults);
Expand All @@ -203,6 +204,7 @@ void testOTPIRImplFunc(const std::string& _taskID, const std::string& _params, b
auto receiverPIRTask = std::make_shared<JsonTaskImpl>(receiverAgencyName);
receiverPIRTask->setId(_taskID);
receiverPIRTask->setParam(_params);
receiverPIRTask->setEnableOutputExists(true);
receiverPIRTask->setSelf(_receiverParty);
receiverPIRTask->addParty(_senderParty);
receiverPIRTask->setSyncResultToPeer(_syncResults);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ void CM2020PSIImpl::asyncRunTask()
auto role = task->selfParty()->partyIndex();
if (role == uint16_t(PartyType::Client) || task->syncResultToPeer())
{
auto writer = loadWriter(task->id(), dataResource, m_enableOutputExists);
auto writer = loadWriter(task->id(), dataResource, task->enableOutputExists());
taskState->setWriter(writer);
}

Expand Down
4 changes: 0 additions & 4 deletions cpp/wedpr-computing/ppc-psi/src/cm2020-psi/CM2020PSIImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,6 @@ class CM2020PSIImpl : public bcos::Worker,
}
}

protected:
// allow the output-path exists, for ut
bool m_enableOutputExists = false;

private:
void waitSignal()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ void EcdhConnPSIImpl::asyncRunTask(
if (role == uint16_t(PartyType::Client))
{
ECDH_CONN_LOG(INFO) << LOG_DESC("Client do the Task") << LOG_KV("taskID", _task->id());
auto writer = loadWriter(_task->id(), dataResource, m_enableOutputExists);
auto writer = loadWriter(_task->id(), dataResource, _task->enableOutputExists());
ecdhTaskState->setWriter(writer);
auto client = std::make_shared<EcdhConnPSIClient>(m_config, ecdhTaskState);
addClient(client);
Expand All @@ -66,7 +66,7 @@ void EcdhConnPSIImpl::asyncRunTask(
ECDH_CONN_LOG(INFO) << LOG_DESC("Server do the Task") << LOG_KV("taskID", _task->id());
if (_task->syncResultToPeer())
{
auto writer = loadWriter(_task->id(), dataResource, m_enableOutputExists);
auto writer = loadWriter(_task->id(), dataResource, _task->enableOutputExists());
ecdhTaskState->setWriter(writer);
}
auto server = std::make_shared<EcdhConnPSIServer>(m_config, ecdhTaskState);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,6 @@ class EcdhConnPSIImpl : public std::enable_shared_from_this<EcdhConnPSIImpl>,
const std::string& _taskID, bcos::Error::Ptr _error, bool _noticePeer) override;
void executeWorker() override;

protected:
bool m_enableOutputExists = false;

void handlerPSIReceiveMessage(PSIConnMessage::Ptr _msg);
void onHandShakeRequestHandler(const std::string& _taskId, const bcos::bytes& _msg);
void onHandShakeResponseHandler(const std::string& _taskId, const bcos::bytes& _msg);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ void EcdhMultiPSIImpl::asyncRunTask(
<< LOG_KV("roleId", role);
if (role == uint16_t(PartiesType::Calculator))
{
auto writer = loadWriter(_task->id(), dataResource, m_enableOutputExists);
auto writer = loadWriter(_task->id(), dataResource, _task->enableOutputExists());
taskState->setWriter(writer);
ECDH_MULTI_LOG(INFO) << LOG_DESC("Calculator do the Task")
<< LOG_KV("taskID", _task->id());
Expand All @@ -170,7 +170,7 @@ void EcdhMultiPSIImpl::asyncRunTask(
if (_task->syncResultToPeer() && std::find(receivers.begin(), receivers.end(),
m_config->selfParty()) != receivers.end())
{
auto writer = loadWriter(_task->id(), dataResource, m_enableOutputExists);
auto writer = loadWriter(_task->id(), dataResource, _task->enableOutputExists());
taskState->setWriter(writer);
}
auto partner = std::make_shared<EcdhMultiPSIPartner>(m_config, taskState);
Expand All @@ -183,7 +183,7 @@ void EcdhMultiPSIImpl::asyncRunTask(
if (_task->syncResultToPeer() && std::find(receivers.begin(), receivers.end(),
m_config->selfParty()) != receivers.end())
{
auto writer = loadWriter(_task->id(), dataResource, m_enableOutputExists);
auto writer = loadWriter(_task->id(), dataResource, _task->enableOutputExists());
taskState->setWriter(writer);
}
auto master = std::make_shared<EcdhMultiPSIMaster>(m_config, taskState);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ class EcdhMultiPSIImpl : public std::enable_shared_from_this<EcdhMultiPSIImpl>,


protected:
bool m_enableOutputExists = false;
virtual void onReceiveRandomA(PSIMessageInterface::Ptr _msg);
virtual void onReceiveCalCipher(PSIMessageInterface::Ptr _msg);
virtual void handlerPSIReceiveMessage(PSIMessageInterface::Ptr _msg);
Expand Down
2 changes: 1 addition & 1 deletion cpp/wedpr-computing/ppc-psi/src/ecdh-psi/EcdhPSIImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ bool EcdhPSIImpl::initTaskState(TaskState::Ptr const& _taskState)
if (!server)
{
m_cache->insertServerCipherCache(task->id(), _taskState);
if (!m_enableOutputExists)
if (!_taskState->task()->enableOutputExists())
{
// Note: if the output-resource already exists, will throw exception
m_config->dataResourceLoader()->checkResourceExists(dataResource->outputDesc());
Expand Down
3 changes: 0 additions & 3 deletions cpp/wedpr-computing/ppc-psi/src/ecdh-psi/EcdhPSIImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,5 @@ class EcdhPSIImpl : public PSIFramework, public std::enable_shared_from_this<Ecd
protected:
EcdhPSIConfig::Ptr m_config;
EcdhCache::Ptr m_cache;

// allow the output-path exists, for ut
bool m_enableOutputExists = false;
};
} // namespace ppc::psi
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ void LabeledPSIImpl::saveSenderCache(const ppc::protocol::Task::ConstPtr& _task)
auto dataResource = _task->selfParty()->dataResource();

LineWriter::Ptr writer;
if (!m_enableOutputExists)
if (!_task->enableOutputExists())
{
// Note: if the output-resource already exists, will throw exception
m_config->dataResourceLoader()->checkResourceExists(dataResource->outputDesc());
Expand Down
3 changes: 0 additions & 3 deletions cpp/wedpr-computing/ppc-psi/src/labeled-psi/LabeledPSIImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,6 @@ class LabeledPSIImpl : public bcos::Worker,
}
}

protected:
// allow the output-path exists, for ut
bool m_enableOutputExists = false;

private:
void waitSignal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ void RA2018PSIImpl::runClientPSI(TaskState::Ptr const& _taskState)
{
return;
}
if (!m_enableOutputExists)
if (!_taskState->task()->enableOutputExists())
{
// Note: if the output-resource already exists, will throw exception
m_config->dataResourceLoader()->checkResourceExists(dataResource->outputDesc());
Expand Down
3 changes: 0 additions & 3 deletions cpp/wedpr-computing/ppc-psi/src/ra2018-psi/RA2018PSIImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,5 @@ class RA2018PSIImpl : public PSIFramework, public std::enable_shared_from_this<R
bcos::ThreadPool::Ptr m_worker;
// the flag means that response to the sdk once handling the task or after the task completed
bool m_waitResult;

// allow the output-path exists, for ut
bool m_enableOutputExists = false;
};
} // namespace ppc::psi
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@ class FakeCM2020PSIImpl : public CM2020PSIImpl
using Ptr = std::shared_ptr<FakeCM2020PSIImpl>;
FakeCM2020PSIImpl(CM2020PSIConfig::Ptr const& _config, unsigned _idleTimeMs = 0)
: CM2020PSIImpl(_config, _idleTimeMs)
{
m_enableOutputExists = true;
}
{}
~FakeCM2020PSIImpl() override = default;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ void testCM2020PSIImplFunc(const std::string& _taskID, const std::string& _param
senderPSITask->setId(_taskID);
senderPSITask->setParam(_params);
senderPSITask->setSelf(_senderParty);
senderPSITask->setEnableOutputExists(true);
senderPSITask->addParty(_receiverParty);
senderPSITask->setSyncResultToPeer(_syncResults);
senderPSITask->setAlgorithm((uint8_t)TaskAlgorithmType::CM_PSI_2PC);
Expand All @@ -161,6 +162,7 @@ void testCM2020PSIImplFunc(const std::string& _taskID, const std::string& _param
receiverPSITask->setId(_taskID);
receiverPSITask->setParam(_params);
receiverPSITask->setSelf(_receiverParty);
receiverPSITask->setEnableOutputExists(true);
receiverPSITask->addParty(_senderParty);
receiverPSITask->setSyncResultToPeer(_syncResults);
receiverPSITask->setAlgorithm((uint8_t)TaskAlgorithmType::CM_PSI_2PC);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@ class FakeLabeledPSIImpl : public LabeledPSIImpl
using Ptr = std::shared_ptr<FakeLabeledPSIImpl>;
FakeLabeledPSIImpl(LabeledPSIConfig::Ptr const& _config, unsigned _idleTimeMs = 0)
: LabeledPSIImpl(_config, _idleTimeMs)
{
m_enableOutputExists = true;
}
{}
~FakeLabeledPSIImpl() override = default;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ void runSetup(LabeledPSIImpl::Ptr _sender)

auto senderPSITask = std::make_shared<JsonTaskImpl>(senderAgencyName);
senderPSITask->setId("0x00000000");
senderPSITask->setEnableOutputExists(true);
senderPSITask->setParam(R"(["setup_sender_db","32"])");
senderPSITask->setSelf(senderParty);
senderPSITask->setAlgorithm((uint8_t)TaskAlgorithmType::LABELED_PSI_2PC);
Expand Down Expand Up @@ -95,6 +96,7 @@ void saveCache(LabeledPSIImpl::Ptr _sender)

auto senderPSITask = std::make_shared<JsonTaskImpl>(senderAgencyName);
senderPSITask->setId("0x00000012");
senderPSITask->setEnableOutputExists(true);
senderPSITask->setParam(R"(["save_sender_cache"])");
senderPSITask->setSelf(senderParty);
senderPSITask->setAlgorithm((uint8_t)TaskAlgorithmType::LABELED_PSI_2PC);
Expand Down Expand Up @@ -234,13 +236,15 @@ void testLabeledPSIImplFunc(const std::string& _taskID, const std::string& _para
// trigger the psi task
auto senderPSITask = std::make_shared<JsonTaskImpl>(senderAgencyName);
senderPSITask->setId(_taskID);
senderPSITask->setEnableOutputExists(true);
senderPSITask->setParam(_params);
senderPSITask->setSelf(_senderParty);
senderPSITask->setAlgorithm((uint8_t)TaskAlgorithmType::LABELED_PSI_2PC);
senderPSITask->addParty(_receiverParty);

auto receiverPSITask = std::make_shared<JsonTaskImpl>(receiverAgencyName);
receiverPSITask->setId(_taskID);
receiverPSITask->setEnableOutputExists(true);
receiverPSITask->setSelf(_receiverParty);
receiverPSITask->setAlgorithm((uint8_t)TaskAlgorithmType::LABELED_PSI_2PC);
receiverPSITask->addParty(_senderParty);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ void testEcdhImplFunc(int64_t _dataBatchSize, std::string const& _serverPSIDataS
auto clientPSITask = std::make_shared<JsonTaskImpl>(clientAgencyName);
std::string taskID = "runPSI";
clientPSITask->setId(taskID);
clientPSITask->setEnableOutputExists(true);
clientPSITask->setEnableOutputExists(true);
clientPSITask->setType((int8_t)TaskType::PSI);
clientPSITask->setAlgorithm((int8_t)TaskAlgorithmType::ECDH_PSI_2PC);
clientPSITask->setSelf(clientParty);
Expand All @@ -103,6 +105,8 @@ void testEcdhImplFunc(int64_t _dataBatchSize, std::string const& _serverPSIDataS
{
serverPSITask->setId(taskID);
}
serverPSITask->setEnableOutputExists(true);
serverPSITask->setEnableOutputExists(true);
serverPSITask->setType((int8_t)TaskType::PSI);
serverPSITask->setAlgorithm((int8_t)TaskAlgorithmType::ECDH_PSI_2PC);
serverPSITask->addParty(clientParty);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ void testRA2018PSIImplFunc(int _dataBatchSize, CuckoofilterOption::Ptr option,
auto offlineFullEvaluateTask = std::make_shared<JsonTaskImpl>(serverAgencyName);
offlineFullEvaluateTask->setId("offlineFullEvaluate");
offlineFullEvaluateTask->setSelf(serverParty);
offlineFullEvaluateTask->setEnableOutputExists(true);
// insert operation
std::string param = "[\"data_preprocessing\", 0]";
offlineFullEvaluateTask->setParam(param);
Expand All @@ -127,6 +128,7 @@ void testRA2018PSIImplFunc(int _dataBatchSize, CuckoofilterOption::Ptr option,
auto clientPSITask = std::make_shared<JsonTaskImpl>(clientAgencyName);
std::string taskID = "runPSI";
clientPSITask->setId(taskID);
clientPSITask->setEnableOutputExists(true);
clientPSITask->setSelf(clientParty);
clientPSITask->addParty(serverParty);
param = "[\"ra2018_psi\"]";
Expand All @@ -135,6 +137,7 @@ void testRA2018PSIImplFunc(int _dataBatchSize, CuckoofilterOption::Ptr option,
// the server task
auto serverPSITask = std::make_shared<JsonTaskImpl>(serverAgencyName);
serverPSITask->setSelf(serverParty);
serverPSITask->setEnableOutputExists(true);
serverPSITask->setId(taskID);
serverPSITask->setParam(param);
serverPSITask->addParty(clientParty);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ class FakeEcdhPSIImpl : public EcdhPSIImpl
{
// set the m_started flag to be true
m_started = true;
m_enableOutputExists = true;
m_taskSyncTimer->registerTimeoutHandler([this]() { syncTaskInfo(); });
m_taskSyncTimer->start();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ class FakeRA2018Impl : public RA2018PSIImpl
{
// set the m_started flag to be true
m_started = true;
m_enableOutputExists = true;
m_taskSyncTimer->registerTimeoutHandler([this]() { syncTaskInfo(); });
m_taskSyncTimer->start();
}
Expand Down
5 changes: 5 additions & 0 deletions cpp/wedpr-protocol/protocol/src/JsonTaskImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ void JsonTaskImpl::decodeJsonValue(Json::Value const& root)
{
BOOST_THROW_EXCEPTION(InvalidParameter() << errinfo_comment("Must specify the taskType"));
}
if (root.isMember("enableOutputExists"))
{
m_enableOutputExists = root["enableOutputExists"].asBool();
}
checkNull(root["type"], "taskType");
m_type = root["type"].asUInt();
// the taskAlgorithm
Expand Down Expand Up @@ -309,6 +313,7 @@ std::string JsonTaskImpl::encode() const
// sync-result or not
taskInfo["syncResult"] = m_syncResultToPeer;
taskInfo["lowBandwidth"] = m_lowBandwidth;
taskInfo["enableOutputExists"] = m_enableOutputExists;

Json::Value receiverList;
for (auto const& it : m_receiverLists)
Expand Down
Loading

0 comments on commit 2f81eaa

Please sign in to comment.