Skip to content

Commit

Permalink
feat: subtle.encrypt() & subtle.decrypt() support for rsa (marg…
Browse files Browse the repository at this point in the history
  • Loading branch information
boorad authored Jul 5, 2024
1 parent 888c7f7 commit 959fe2e
Show file tree
Hide file tree
Showing 16 changed files with 992 additions and 112 deletions.
125 changes: 125 additions & 0 deletions cpp/Cipher/MGLRsa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,15 @@

#include "MGLRsa.h"
#ifdef ANDROID
#include "Cipher/MGLPublicCipher.h"
#include "JSIUtils/MGLJSIMacros.h"
#include "JSIUtils/MGLJSIUtils.h"
#include "Utils/MGLUtils.h"
#else
#include "MGLPublicCipher.h"
#include "MGLJSIMacros.h"
#include "MGLJSIUtils.h"
#include "MGLUtils.h"
#endif

#include <string>
Expand Down Expand Up @@ -186,6 +192,125 @@ std::pair<jsi::Value, jsi::Value> generateRsaKeyPair(
return {std::move(publicBuffer), std::move(privateBuffer)};
}

template <MGLPublicCipher::EVP_PKEY_cipher_init_t init,
MGLPublicCipher::EVP_PKEY_cipher_t cipher>
WebCryptoCipherStatus RSA_Cipher(const RSACipherConfig& params, ByteSource* out) {
CHECK_NE(params.key->GetKeyType(), kKeyTypeSecret);
ManagedEVPPKey m_pkey = params.key->GetAsymmetricKey();
// Mutex::ScopedLock lock(*m_pkey.mutex());

EVPKeyCtxPointer ctx(EVP_PKEY_CTX_new(m_pkey.get(), nullptr));

if (!ctx || init(ctx.get()) <= 0)
return WebCryptoCipherStatus::FAILED;

if (EVP_PKEY_CTX_set_rsa_padding(ctx.get(), params.padding) <= 0) {
return WebCryptoCipherStatus::FAILED;
}

if (params.digest != nullptr &&
(EVP_PKEY_CTX_set_rsa_oaep_md(ctx.get(), params.digest) <= 0 ||
EVP_PKEY_CTX_set_rsa_mgf1_md(ctx.get(), params.digest) <= 0)) {
return WebCryptoCipherStatus::FAILED;
}

if (!SetRsaOaepLabel(ctx, params.label)) return WebCryptoCipherStatus::FAILED;

size_t out_len = 0;
if (cipher(
ctx.get(),
nullptr,
&out_len,
params.data.data<unsigned char>(),
params.data.size()) <= 0) {
return WebCryptoCipherStatus::FAILED;
}

ByteSource::Builder buf(out_len);

if (cipher(ctx.get(),
buf.data<unsigned char>(),
&out_len,
params.data.data<unsigned char>(),
params.data.size()) <= 0) {
return WebCryptoCipherStatus::FAILED;
}

*out = std::move(buf).release(out_len);
return WebCryptoCipherStatus::OK;
}

RSACipherConfig RSACipher::GetParamsFromJS(jsi::Runtime &rt,
const jsi::Value *args) {
RSACipherConfig params;
unsigned int offset = 0;

// padding
params.padding = RSA_PKCS1_OAEP_PADDING;

// mode (encrypt/decrypt)
params.mode = static_cast<WebCryptoCipherMode>((int)args[offset].getNumber());
offset++;

// key (handle)
if (!args[offset].isObject()) {
throw std::runtime_error("arg is not a KeyObjectHandle: key");
}
std::shared_ptr<KeyObjectHandle> handle =
std::static_pointer_cast<KeyObjectHandle>(
args[offset].asObject(rt).getHostObject(rt));
params.key = handle->Data();
offset++;

// data
params.data = GetByteSourceFromJS(rt, args[offset], "data");
offset++;

// variant
if (CheckIsInt32(args[offset])) {
params.variant = static_cast<RSAKeyVariant>((int)args[offset].getNumber());
}
// offset++; // The below variant-dependent params advance offset themselves

std::string digest;
switch (params.variant) {
case kKeyVariantRSA_OAEP:
// hash (digest)
CHECK(args[offset + 1].isString());
digest = args[offset + 1].asString(rt).utf8(rt);
params.digest = EVP_get_digestbyname(digest.c_str());
if (params.digest == nullptr) {
throw jsi::JSError(rt, "invalid digest: " + digest);
return params;
}

// label
if (args[offset + 2].isUndefined()) {
params.label = ByteSource();
} else {
params.label = GetByteSourceFromJS(rt, args[offset + 2], "label");
}

break;
default:
throw jsi::JSError(rt, "Invalid RSA key variant");
}

return params;
}

WebCryptoCipherStatus RSACipher::DoCipher(const RSACipherConfig &params,
ByteSource *out) {
switch (params.mode) {
case kEncrypt:
CHECK_EQ(params.key->GetKeyType(), kKeyTypePublic);
return RSA_Cipher<EVP_PKEY_encrypt_init, EVP_PKEY_encrypt>(params, out);
case kDecrypt:
CHECK_EQ(params.key->GetKeyType(), kKeyTypePrivate);
return RSA_Cipher<EVP_PKEY_decrypt_init, EVP_PKEY_decrypt>(params, out);
}
}

jsi::Value ExportJWKRsaKey(jsi::Runtime &rt,
std::shared_ptr<KeyObjectData> key,
jsi::Object &target) {
Expand Down
26 changes: 26 additions & 0 deletions cpp/Cipher/MGLRsa.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ namespace margelo {

namespace jsi = facebook::jsi;

// TODO: keep in in sync with JS side (src/rsa.ts)
enum RSAKeyVariant {
kKeyVariantRSA_SSA_PKCS1_v1_5,
kKeyVariantRSA_PSS,
kKeyVariantRSA_OAEP
};

// On node there is a complete madness of structs/classes that encapsulate and
// initialize the data in a generic manner this is to be later be used to
// generate the keys in a thread-safe manner (I think) I'm however too dumb and
Expand Down Expand Up @@ -78,6 +85,25 @@ class RsaKeyExport {
RsaKeyExportConfig params_;
};

struct RSACipherConfig final {
WebCryptoCipherMode mode;
std::shared_ptr<KeyObjectData> key;
ByteSource data;
RSAKeyVariant variant;
ByteSource label;
int padding = 0;
const EVP_MD* digest = nullptr;

RSACipherConfig() = default;
};

class RSACipher {
public:
RSACipher() {}
RSACipherConfig GetParamsFromJS(jsi::Runtime &rt, const jsi::Value *args);
WebCryptoCipherStatus DoCipher(const RSACipherConfig &params, ByteSource *out);
};

} // namespace margelo

#endif /* MGLRsa_hpp */
28 changes: 28 additions & 0 deletions cpp/Utils/MGLUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,19 @@ ByteSource ByteSource::FromBN(const BIGNUM* bn, size_t size) {
return std::move(out).release();
}

ByteSource GetByteSourceFromJS(jsi::Runtime &rt,
const jsi::Value &value,
std::string name) {
if (!value.isObject() || !value.asObject(rt).isArrayBuffer(rt)) {
throw jsi::JSError(rt, "arg is not an array buffer: " + name);
}
ByteSource data = ByteSource::FromStringOrBuffer(rt, value);
if (data.size() > INT_MAX) {
throw jsi::JSError(rt, "arg is too big (> int32): " + name);
}
return data;
}

std::string EncodeBignum(const BIGNUM* bn,
size_t size,
bool url) {
Expand Down Expand Up @@ -264,4 +277,19 @@ MUST_USE_RESULT CSPRNGResult CSPRNG(void* buffer, size_t length) {
return {false};
}

bool SetRsaOaepLabel(const EVPKeyCtxPointer& ctx, const ByteSource& label) {
if (label.size() != 0) {
// OpenSSL takes ownership of the label, so we need to create a copy.
void* label_copy = OPENSSL_memdup(label.data(), label.size());
CHECK_NOT_NULL(label_copy);
int ret = EVP_PKEY_CTX_set0_rsa_oaep_label(
ctx.get(), static_cast<unsigned char*>(label_copy), label.size());
if (ret <= 0) {
OPENSSL_free(label_copy);
return false;
}
}
return true;
}

} // namespace margelo
19 changes: 19 additions & 0 deletions cpp/Utils/MGLUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -304,13 +304,19 @@ inline jsi::Value toJSI(jsi::Runtime& rt, ByteSource value) {
return o;
}

ByteSource GetByteSourceFromJS(jsi::Runtime &rt,
const jsi::Value &value,
std::string name);

std::string EncodeBignum(const BIGNUM* bn,
size_t size,
bool url = false);

std::string EncodeBase64(const std::string data, bool url = false);
std::string DecodeBase64(const std::string &in, bool remove_linebreaks = false);

bool SetRsaOaepLabel(const EVPKeyCtxPointer& ctx, const ByteSource& label);

// TODO: until shared, keep in sync with JS side (src/NativeQuickCrypto/Cipher.ts)
enum KeyVariant {
kvRSA_SSA_PKCS1_v1_5,
Expand All @@ -334,6 +340,19 @@ enum WebCryptoKeyFormat {
kWebCryptoKeyFormatJWK
};

enum WebCryptoCipherMode {
kEncrypt,
kDecrypt,
// kWrapKey,
// kUnwrapKey,
};

enum class WebCryptoCipherStatus {
OK,
INVALID_KEY_TYPE,
FAILED
};

} // namespace margelo

#endif /* MGLUtils_h */
13 changes: 13 additions & 0 deletions cpp/webcrypto/MGLWebCrypto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,18 @@ MGLWebCryptoHostObject::MGLWebCryptoHostObject(
return jsi::Value(std::move(out));
};

auto rsaCipher = JSIF([=]) {
auto rsa = RSACipher();
auto params = rsa.GetParamsFromJS(runtime, arguments);
ByteSource out;
WebCryptoCipherStatus status = rsa.DoCipher(params, &out);
if (status != WebCryptoCipherStatus::OK) {
throw jsi::JSError(runtime, "error in DoCipher, status: " +
std::to_string(static_cast<int>(status)));
}
return toJSI(runtime, std::move(out));
};

auto rsaExportKey = JSIF([=]) {
ByteSource out;
auto rsa = new RsaKeyExport();
Expand All @@ -105,6 +117,7 @@ MGLWebCryptoHostObject::MGLWebCryptoHostObject(
this->fields.push_back(buildPair("ecExportKey", ecExportKey));
this->fields.push_back(GenerateSecretKeyFieldDefinition(jsCallInvoker, workerQueue));
this->fields.push_back(buildPair("generateSecretKeySync", generateSecretKeySync));
this->fields.push_back(buildPair("rsaCipher", rsaCipher));
this->fields.push_back(buildPair("rsaExportKey", rsaExportKey));
this->fields.push_back(buildPair("signVerify", signVerify));
};
Expand Down
13 changes: 0 additions & 13 deletions cpp/webcrypto/crypto_aes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,19 +308,6 @@ WebCryptoCipherStatus AES_CTR_Cipher(
return status;
}

ByteSource GetByteSourceFromJS(jsi::Runtime &rt,
const jsi::Value &value,
std::string name) {
if (!value.isObject() || !value.asObject(rt).isArrayBuffer(rt)) {
throw jsi::JSError(rt, "arg is not an array buffer: " + name);
}
ByteSource data = ByteSource::FromStringOrBuffer(rt, value);
if (data.size() > INT_MAX) {
throw jsi::JSError(rt, "arg is too big (> int32): " + name);
}
return data;
}

bool ValidateIV(
jsi::Runtime &rt,
const jsi::Value &value,
Expand Down
6 changes: 0 additions & 6 deletions cpp/webcrypto/crypto_aes.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,6 @@ enum AESKeyVariant {
#undef V
};

enum class WebCryptoCipherStatus {
OK,
INVALID_KEY_TYPE,
FAILED
};

struct AESCipherConfig final {
enum Mode {
kEncrypt,
Expand Down
Loading

0 comments on commit 959fe2e

Please sign in to comment.