Skip to content

Commit

Permalink
Merge pull request #2 from fairmath/adding_decryption
Browse files Browse the repository at this point in the history
Adding decryption part
  • Loading branch information
g-arakelov authored Nov 25, 2024
2 parents 9a96ed5 + 09d899c commit 9aae4b8
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 131 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,5 @@ jobs:
- name: 'Upload Artifact'
uses: actions/upload-artifact@v4
with:
name: fairmath-keygen-${{matrix.platform}}
name: fairmath-cli-${{matrix.platform}}
path: .build/install
12 changes: 6 additions & 6 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
cmake_minimum_required(VERSION 3.22)

project(fairmath-keygen CXX)
project(fairmath-cli CXX)

set(CMAKE_CXX_STANDARD 20)
set(CMAKE_SKIP_BUILD_RPATH FALSE)
Expand Down Expand Up @@ -40,7 +40,7 @@ set(SOURCES
src/fairmathCli.cpp
)

add_executable(fairmath-keygen ${SOURCES})
add_executable(fairmath-cli ${SOURCES})

if (OpenMP_CXX_FOUND)
message(STATUS "FOUND OpenMP: ${OpenMP_CXX_LIBRARIES}")
Expand All @@ -50,11 +50,11 @@ if (OpenMP_CXX_FOUND)
set(lib_openmp_path "${OpenMP_libomp_LIBRARY}")
cmake_path(GET lib_openmp_path PARENT_PATH OMP_LINK_DIR)

target_link_directories(fairmath-keygen PRIVATE ${OMP_LINK_DIR})
target_link_directories(fairmath-cli PRIVATE ${OMP_LINK_DIR})
endif()

target_include_directories(fairmath-keygen PRIVATE ${NLOHMANN_JSON_INCLUDE_DIRS})
target_link_libraries(fairmath-keygen PRIVATE ${OpenFHE_SHARED_LIBRARIES} Boost::program_options)
target_include_directories(fairmath-cli PRIVATE ${NLOHMANN_JSON_INCLUDE_DIRS})
target_link_libraries(fairmath-cli PRIVATE ${OpenFHE_SHARED_LIBRARIES} Boost::program_options)

install(TARGETS fairmath-keygen DESTINATION install)
install(TARGETS fairmath-cli DESTINATION install)
install(DIRECTORY ${OpenFHE_LIBDIR} DESTINATION install)
9 changes: 9 additions & 0 deletions example/run_ciphertext_decryption.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#!/usr/bin/env bash

../.build/install/fairmath-cli \
--working_mode="ciphertext_decryption" \
--output_decryption_location="./output_decryption" \
--decryption_cryptocontext_location="./cryptocontext_name" \
--ciphertext_location="./ciphertext_1" \
--decryption_key_location="./private_key_name" \
--plaintext_length="10"
8 changes: 8 additions & 0 deletions example/run_config_processing.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/usr/bin/env bash

../.build/install/fairmath-cli \
--working_mode="config_processing" \
--input_config_location="./input.json" \
--output_crypto_objects_directory="." \
--output_config_location="./config.json" \
--output_config_json_indent="4"
144 changes: 47 additions & 97 deletions src/configProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

#include <nlohmann/json.hpp>

#include "openfhe/pke/cryptocontext-ser.h"
#include <openfhe/pke/gen-cryptocontext.h>
#include <openfhe/pke/scheme/bfvrns/gen-cryptocontext-bfvrns.h>
#include <openfhe/pke/scheme/bgvrns/gen-cryptocontext-bgvrns.h>
Expand Down Expand Up @@ -75,15 +74,6 @@ class ConfigProcessor final
bool (*serializeFunc)(std::ostream&, const lbcrypto::SerType::SERBINARY&, std::string),
KeyGenFuncParamsTypes&&... keyGenFuncParams);

template <typename CryptoObjectType>
void serialize(const std::string& filename,
const std::shared_ptr<CryptoObjectType>& cryptoObject);
void serialize(const std::string& filename,
bool (*serializeFunc)(std::ostream&, const lbcrypto::SerType::SERBINARY&, std::string)) const;
template <typename CryptoObjectType>
static void deserialize(const std::string& path,
std::shared_ptr<CryptoObjectType>& cryptoObject);

template <typename KeyType>
[[nodiscard]] std::shared_ptr<KeyType> aquireKey(
const std::string_view keyName);
Expand Down Expand Up @@ -129,34 +119,34 @@ class ConfigProcessor final
{
auto itKey = getKeyMap<lbcrypto::PublicKeyImpl<lbcrypto::DCRTPoly>>().emplace(
linkedKeyName, std::move(getKeyFromKeyPair<lbcrypto::PublicKeyImpl<lbcrypto::DCRTPoly>>(keyPair))).first;
serialize<lbcrypto::PublicKeyImpl<lbcrypto::DCRTPoly>>(linkedKeyName, itKey->second);
utils::serialize<lbcrypto::PublicKeyImpl<lbcrypto::DCRTPoly>>(m_outputCryptoObjectsDirectory + linkedKeyName, itKey->second);
}
else
{
auto itKey = getKeyMap<lbcrypto::PrivateKeyImpl<lbcrypto::DCRTPoly>>().emplace(
linkedKeyName, std::move(getKeyFromKeyPair<lbcrypto::PrivateKeyImpl<lbcrypto::DCRTPoly>>(keyPair))).first;
serialize<lbcrypto::PrivateKeyImpl<lbcrypto::DCRTPoly>>(linkedKeyName, itKey->second);
utils::serialize<lbcrypto::PrivateKeyImpl<lbcrypto::DCRTPoly>>(m_outputCryptoObjectsDirectory + linkedKeyName, itKey->second);
}
updateSource(linkedKeyName, m_configJson[linkedKeyName]);
}
else
{
if constexpr (std::is_same_v<KeyType, lbcrypto::PrivateKeyImpl<lbcrypto::DCRTPoly>>)
{
serialize<lbcrypto::PublicKeyImpl<lbcrypto::DCRTPoly>>(std::string(keyName) + "." +
m_configJson[linkedKeyName]["type"].get<std::string>(),
utils::serialize<lbcrypto::PublicKeyImpl<lbcrypto::DCRTPoly>>(m_outputCryptoObjectsDirectory +
std::string(keyName) + "." + m_configJson[linkedKeyName]["type"].get<std::string>(),
getKeyFromKeyPair<lbcrypto::PublicKeyImpl<lbcrypto::DCRTPoly>>(keyPair));
}
else
{
serialize<lbcrypto::PrivateKeyImpl<lbcrypto::DCRTPoly>>(std::string(keyName) + "." +
m_configJson[linkedKeyName]["type"].get<std::string>(),
utils::serialize<lbcrypto::PrivateKeyImpl<lbcrypto::DCRTPoly>>(m_outputCryptoObjectsDirectory +
std::string(keyName) + "." + m_configJson[linkedKeyName]["type"].get<std::string>(),
getKeyFromKeyPair<lbcrypto::PrivateKeyImpl<lbcrypto::DCRTPoly>>(keyPair));
}
}
auto itKey = getKeyMap<KeyType>().emplace(keyName, std::move(getKeyFromKeyPair<KeyType>(keyPair))).first;
const std::string keyNameStr(itKey->first);
serialize<KeyType>(keyNameStr, itKey->second);
utils::serialize<KeyType>(m_outputCryptoObjectsDirectory + keyNameStr, itKey->second);
updateSource(keyNameStr, keyContent);
}

Expand Down Expand Up @@ -212,45 +202,45 @@ void ConfigProcessor::generateOutputConfig()
{
for (auto& [argName, argContent] : m_configJson.items())
{
switch (strHash(argContent["type"].get<std::string_view>()))
switch (utils::strHash(argContent["type"].get<std::string_view>()))
{
case strHash("cryptocontext"): generateCCIfNotExist(
case utils::strHash("cryptocontext"): generateCCIfNotExist(
argName, argContent); break;
case strHash("private_key"): generateKeyAndSerializeIfNotExist<
case utils::strHash("private_key"): generateKeyAndSerializeIfNotExist<
lbcrypto::PrivateKeyImpl<lbcrypto::DCRTPoly>>(argName, argContent); break;
case strHash("public_key"): generateKeyAndSerializeIfNotExist<
case utils::strHash("public_key"): generateKeyAndSerializeIfNotExist<
lbcrypto::PublicKeyImpl<lbcrypto::DCRTPoly>>(argName, argContent); break;
case strHash("ciphertext"): generateCiphertextAndSerializeIfNotExist(
case utils::strHash("ciphertext"): generateCiphertextAndSerializeIfNotExist(
argName, argContent); break;
case strHash("sum_key"): generateEvalKeyAndSerializeIfNotExist(
case utils::strHash("sum_key"): generateEvalKeyAndSerializeIfNotExist(
argName, argContent,
&lbcrypto::CryptoContextImpl<lbcrypto::DCRTPoly>::EvalSumKeyGen,
&lbcrypto::CryptoContextImpl<lbcrypto::DCRTPoly>::SerializeEvalSumKey<
lbcrypto::SerType::SERBINARY>,
nullptr); break;
case strHash("mult_key"): generateEvalKeyAndSerializeIfNotExist(
case utils::strHash("mult_key"): generateEvalKeyAndSerializeIfNotExist(
argName, argContent,
&lbcrypto::CryptoContextImpl<lbcrypto::DCRTPoly>::EvalMultKeyGen,
&lbcrypto::CryptoContextImpl<lbcrypto::DCRTPoly>::SerializeEvalMultKey<
lbcrypto::SerType::SERBINARY>); break;
case strHash("rotation_key"): generateEvalKeyAndSerializeIfNotExist(
case utils::strHash("rotation_key"): generateEvalKeyAndSerializeIfNotExist(
argName, argContent,
&lbcrypto::CryptoContextImpl<lbcrypto::DCRTPoly>::EvalRotateKeyGen,
&lbcrypto::CryptoContextImpl<lbcrypto::DCRTPoly>::SerializeEvalAutomorphismKey<
lbcrypto::SerType::SERBINARY>,
argContent["indexes"].get<std::vector<int32_t>>(),
nullptr); break;
case strHash("i8"): [[fallthrough]];
case strHash("i16"): [[fallthrough]];
case strHash("i32"): [[fallthrough]];
case strHash("i64"): [[fallthrough]];
case strHash("u8"): [[fallthrough]];
case strHash("u16"): [[fallthrough]];
case strHash("u32"): [[fallthrough]];
case strHash("u64"): [[fallthrough]];
case strHash("f32"): [[fallthrough]];
case strHash("f64"): [[fallthrough]];
case strHash("bool"): break;
case utils::strHash("i8"): [[fallthrough]];
case utils::strHash("i16"): [[fallthrough]];
case utils::strHash("i32"): [[fallthrough]];
case utils::strHash("i64"): [[fallthrough]];
case utils::strHash("u8"): [[fallthrough]];
case utils::strHash("u16"): [[fallthrough]];
case utils::strHash("u32"): [[fallthrough]];
case utils::strHash("u64"): [[fallthrough]];
case utils::strHash("f32"): [[fallthrough]];
case utils::strHash("f64"): [[fallthrough]];
case utils::strHash("bool"): break;
default: throw std::runtime_error("Argument type is not supported");
}
}
Expand All @@ -260,7 +250,8 @@ void ConfigProcessor::generateOutputConfig()
for (const std::string_view ccName : m_generatedCCVec)
{
const std::string ccNameStr(ccName);
serialize<lbcrypto::CryptoContextImpl<lbcrypto::DCRTPoly>>(ccNameStr, m_ccMap[ccName]);
utils::serialize<lbcrypto::CryptoContextImpl<lbcrypto::DCRTPoly>>(m_outputCryptoObjectsDirectory +
ccNameStr, m_ccMap[ccName]);
updateSource(ccNameStr, m_configJson[ccName]);
}

Expand Down Expand Up @@ -309,7 +300,7 @@ void ConfigProcessor::generateCiphertextAndSerializeIfNotExist(
aquireKey<lbcrypto::PublicKeyImpl<lbcrypto::DCRTPoly>>(ciphertextContent["public_key"].get<std::string_view>());

const lbcrypto::Ciphertext<lbcrypto::DCRTPoly> ciphertext = cc->Encrypt(publicKey, plaintext);
serialize<lbcrypto::CiphertextImpl<lbcrypto::DCRTPoly>>(ciphertextName, ciphertext);
utils::serialize<lbcrypto::CiphertextImpl<lbcrypto::DCRTPoly>>(m_outputCryptoObjectsDirectory + ciphertextName, ciphertext);
updateSource(ciphertextName, ciphertextContent);
}

Expand Down Expand Up @@ -348,51 +339,10 @@ void ConfigProcessor::generateEvalKeyAndSerializeIfNotExist(

std::invoke(keyGenFunc, cc, privateKey, std::forward<KeyGenFuncParamsTypes>(keyGenFuncParams)...);

serialize(keyName, serializeFunc);
utils::serialize(m_outputCryptoObjectsDirectory + keyName, serializeFunc);
updateSource(keyName, keyContent);
}

template <typename CryptoObjectType>
void ConfigProcessor::serialize(
const std::string& filename, const std::shared_ptr<CryptoObjectType>& cryptoObject)
{
if (!lbcrypto::Serial::SerializeToFile(
m_outputCryptoObjectsDirectory + filename, cryptoObject, lbcrypto::SerType::BINARY))
{
throw std::runtime_error("Unable to serialize " + filename);
}
}

void ConfigProcessor::serialize(
const std::string& filename,
bool (*serializeFunc)(std::ostream&, const lbcrypto::SerType::SERBINARY&, std::string)) const
{
std::ofstream ofs(m_outputCryptoObjectsDirectory + filename, std::ios::out | std::ios::binary);
if (ofs.is_open())
{
if (!serializeFunc(ofs, lbcrypto::SerType::BINARY, ""))
{
ofs.close();
throw std::runtime_error("Unable to serialize " + filename);
}
ofs.close();
}
else
{
throw std::runtime_error("Unable to open " + m_outputCryptoObjectsDirectory +
filename + " file for writing serialization");
}
}

template <typename CryptoObjectType>
void ConfigProcessor::deserialize(const std::string& path, std::shared_ptr<CryptoObjectType>& cryptoObject)
{
if (!lbcrypto::Serial::DeserializeFromFile(path, cryptoObject, lbcrypto::SerType::BINARY))
{
throw std::runtime_error("Unable to deserialize " + path);
}
}

template <typename KeyType>
[[nodiscard]] std::shared_ptr<KeyType> ConfigProcessor::aquireKey(const std::string_view keyName)
{
Expand All @@ -409,7 +359,7 @@ template <typename KeyType>
{
std::shared_ptr<KeyType> key;
static constexpr size_t mainPathStartingIndex = 8;
deserialize<KeyType>(
utils::deserialize<KeyType>(
m_configJson[keyName]["source"].get_ref<const std::string&>().substr(mainPathStartingIndex), key);
itKey = keyMap.emplace(keyName, std::move(key)).first;
}
Expand All @@ -432,7 +382,7 @@ template <typename KeyType>
{
std::shared_ptr<lbcrypto::CryptoContextImpl<lbcrypto::DCRTPoly>> cc;
static constexpr size_t mainPathStartingIndex = 8;
deserialize<lbcrypto::CryptoContextImpl<lbcrypto::DCRTPoly>>(
utils::deserialize<lbcrypto::CryptoContextImpl<lbcrypto::DCRTPoly>>(
m_configJson[ccName]["source"].get_ref<const std::string&>().substr(mainPathStartingIndex), cc);
itCC = m_ccMap.emplace(ccName, std::move(cc)).first;
}
Expand All @@ -445,13 +395,13 @@ template <typename KeyType>
const std::string_view ccName, const nlohmann::json& ccContent)
{
std::shared_ptr<lbcrypto::CryptoContextImpl<lbcrypto::DCRTPoly>> cc;
switch (strHash(ccContent["scheme"].get<std::string_view>()))
switch (utils::strHash(ccContent["scheme"].get<std::string_view>()))
{
case strHash("BFVRNS_SCHEME"): cc = lbcrypto::GenCryptoContext(
case utils::strHash("BFVRNS_SCHEME"): cc = lbcrypto::GenCryptoContext(
getCCParams<lbcrypto::CryptoContextBFVRNS>(ccContent)); break;
case strHash("BGVRNS_SCHEME"): cc = lbcrypto::GenCryptoContext(
case utils::strHash("BGVRNS_SCHEME"): cc = lbcrypto::GenCryptoContext(
getCCParams<lbcrypto::CryptoContextBGVRNS>(ccContent)); break;
case strHash("CKKSRNS_SCHEME"): cc = lbcrypto::GenCryptoContext(
case utils::strHash("CKKSRNS_SCHEME"): cc = lbcrypto::GenCryptoContext(
getCCParams<lbcrypto::CryptoContextCKKSRNS>(ccContent)); break;
default: throw std::runtime_error("Scheme type is not supported");
}
Expand Down Expand Up @@ -501,31 +451,31 @@ template <typename SchemeType>
if (ccContent.contains("thresholdNumOfParties")) {
params.SetThresholdNumOfParties(ccContent["thresholdNumOfParties"].get<uint32_t>()); }
if (ccContent.contains("secretKeyDist")) {
params.SetSecretKeyDist(getSecretKeyDist(ccContent["secretKeyDist"].get<std::string_view>())); }
params.SetSecretKeyDist(utils::getSecretKeyDist(ccContent["secretKeyDist"].get<std::string_view>())); }
if (ccContent.contains("ksTech")) {
params.SetKeySwitchTechnique(getKeySwitchTechnique(ccContent["ksTech"].get<std::string_view>())); }
params.SetKeySwitchTechnique(utils::getKeySwitchTechnique(ccContent["ksTech"].get<std::string_view>())); }
if (ccContent.contains("scalTech")) {
params.SetScalingTechnique(getScalingTechnique(ccContent["scalTech"].get<std::string_view>())); }
params.SetScalingTechnique(utils::getScalingTechnique(ccContent["scalTech"].get<std::string_view>())); }
if (ccContent.contains("securityLevel")) {
params.SetSecurityLevel(getSecurityLevel(ccContent["securityLevel"].get<std::string_view>())); }
params.SetSecurityLevel(utils::getSecurityLevel(ccContent["securityLevel"].get<std::string_view>())); }
if (ccContent.contains("encryptionTechnique")) {
params.SetEncryptionTechnique(
getEncryptionTechnique(ccContent["encryptionTechnique"].get<std::string_view>())); }
utils::getEncryptionTechnique(ccContent["encryptionTechnique"].get<std::string_view>())); }
if (ccContent.contains("multiplicationTechnique")) {
params.SetMultiplicationTechnique(
getMultiplicationTechnique(ccContent["multiplicationTechnique"].get<std::string_view>())); }
utils::getMultiplicationTechnique(ccContent["multiplicationTechnique"].get<std::string_view>())); }
if (ccContent.contains("PREMode")) {
params.SetPREMode(getProxyReEncryptionMode(ccContent["PREMode"].get<std::string_view>())); }
params.SetPREMode(utils::getProxyReEncryptionMode(ccContent["PREMode"].get<std::string_view>())); }
if (ccContent.contains("multipartyMode")) {
params.SetMultipartyMode(getMultipartyMode(ccContent["multipartyMode"].get<std::string_view>())); }
params.SetMultipartyMode(utils::getMultipartyMode(ccContent["multipartyMode"].get<std::string_view>())); }
if (ccContent.contains("executionMode")) {
params.SetExecutionMode(getExecutionMode(ccContent["executionMode"].get<std::string_view>())); }
params.SetExecutionMode(utils::getExecutionMode(ccContent["executionMode"].get<std::string_view>())); }
if (ccContent.contains("decryptionNoiseMode")) {
params.SetDecryptionNoiseMode(
getDecryptionNoiseMode(ccContent["decryptionNoiseMode"].get<std::string_view>())); }
utils::getDecryptionNoiseMode(ccContent["decryptionNoiseMode"].get<std::string_view>())); }
if (ccContent.contains("interactiveBootCompressionLevel")) {
params.SetInteractiveBootCompressionLevel(
getCompressionLevel(ccContent["interactiveBootCompressionLevel"].get<std::string_view>())); }
utils::getCompressionLevel(ccContent["interactiveBootCompressionLevel"].get<std::string_view>())); }

return params;
}
Expand Down
Loading

0 comments on commit 9aae4b8

Please sign in to comment.