diff --git a/ext/openssl/ossl_pkey.c b/ext/openssl/ossl_pkey.c index 013412c27..e3a71b238 100644 --- a/ext/openssl/ossl_pkey.c +++ b/ext/openssl/ossl_pkey.c @@ -446,6 +446,174 @@ pkey_generate(int argc, VALUE *argv, VALUE self, int genparam) return ossl_pkey_new(gen_arg.pkey); } +#if OSSL_OPENSSL_PREREQ(3, 0, 0) +#include +#include + +struct pkey_from_parameters_alias { + char alias[10]; + char param_name[20]; +}; + +static const struct pkey_from_parameters_alias rsa_aliases[] = { + { "p", OSSL_PKEY_PARAM_RSA_FACTOR1 }, + { "q", OSSL_PKEY_PARAM_RSA_FACTOR2 }, + { "dmp1", OSSL_PKEY_PARAM_RSA_EXPONENT1 }, + { "dmq1", OSSL_PKEY_PARAM_RSA_EXPONENT2 }, + { "iqmp", OSSL_PKEY_PARAM_RSA_COEFFICIENT1 }, + { "", "" } +}; + +static const struct pkey_from_parameters_alias fcc_aliases[] = { + { "pub_key", OSSL_PKEY_PARAM_PUB_KEY }, + { "priv_key", OSSL_PKEY_PARAM_PRIV_KEY }, + { "", "" } +}; + +struct pkey_from_parameters_arg { + OSSL_PARAM_BLD *param_bld; + const OSSL_PARAM *settable_params; + const struct pkey_from_parameters_alias *aliases; +}; + +static int +add_parameter_to_builder(VALUE key, VALUE value, VALUE arg) { + if(NIL_P(value)) + return ST_CONTINUE; + + if (SYMBOL_P(key)) + key = rb_sym2str(key); + + const char *key_ptr = StringValueCStr(key); + const struct pkey_from_parameters_arg *params = (const struct pkey_from_parameters_arg *) arg; + + for(int i = 0; strlen(params->aliases[i].alias) > 0; i++) { + if(strcmp(params->aliases[i].alias, key_ptr) == 0) { + key_ptr = params->aliases[i].param_name; + break; + } + } + + for (const OSSL_PARAM *settable_params = params->settable_params; settable_params->key != NULL; settable_params++) { + if(strcmp(settable_params->key, key_ptr) == 0) { + switch (settable_params->data_type) { + case OSSL_PARAM_INTEGER: + case OSSL_PARAM_UNSIGNED_INTEGER: + if(!OSSL_PARAM_BLD_push_BN(params->param_bld, key_ptr, GetBNPtr(value))) { + OSSL_PARAM_BLD_free(params->param_bld); + ossl_raise(ePKeyError, "OSSL_PARAM_BLD_push_BN"); + } + break; + case OSSL_PARAM_UTF8_STRING: + StringValue(value); + if(!OSSL_PARAM_BLD_push_utf8_string(params->param_bld, key_ptr, RSTRING_PTR(value), RSTRING_LENINT(value))) { + OSSL_PARAM_BLD_free(params->param_bld); + ossl_raise(ePKeyError, "OSSL_PARAM_BLD_push_utf8_string"); + } + break; + + case OSSL_PARAM_OCTET_STRING: + StringValue(value); + if(!OSSL_PARAM_BLD_push_octet_string(params->param_bld, key_ptr, RSTRING_PTR(value), RSTRING_LENINT(value))) { + OSSL_PARAM_BLD_free(params->param_bld); + ossl_raise(ePKeyError, "OSSL_PARAM_BLD_push_octet_string"); + } + break; + case OSSL_PARAM_UTF8_PTR: + case OSSL_PARAM_OCTET_PTR: + OSSL_PARAM_BLD_free(params->param_bld); + ossl_raise(ePKeyError, "Unsupported parameter \"%s\", handling of OSSL_PARAM_UTF8_PTR and OSSL_PARAM_OCTET_PTR not implemented", key_ptr); + break; + } + + return ST_CONTINUE; + } + } + OSSL_PARAM_BLD_free(params->param_bld); + + char message_buffer[512] = { 0 }; + char *cur = message_buffer; + char *end = message_buffer + sizeof(message_buffer); + for (const OSSL_PARAM *settable_params = params->settable_params; settable_params->key != NULL; settable_params++) { + const char *fmt = cur == message_buffer ? "%s" : ", %s"; + if (cur > end) + break; + cur += snprintf(cur, end-cur, fmt, settable_params->key); + } + + for(int i = 0; strlen(params->aliases[i].alias) > 0; i++) { + const char *fmt = cur == message_buffer ? "%s" : ", %s"; + if (cur > end) + break; + cur += snprintf(cur, end-cur, fmt, params->aliases[i].alias); + } + + ossl_raise(ePKeyError, "Invalid parameter \"%s\". Supported parameters: \"%s\"", key_ptr, message_buffer); +} + +static VALUE +pkey_from_parameters(int argc, VALUE *argv, VALUE self) +{ + VALUE alg, options; + rb_scan_args(argc, argv, "11", &alg, &options); + + const char* algorithm = StringValueCStr(alg); + + EVP_PKEY_CTX *ctx = EVP_PKEY_CTX_new_from_name(NULL, algorithm, NULL); + + if (ctx == NULL) + ossl_raise(ePKeyError, "EVP_PKEY_CTX_new_from_name"); + + struct pkey_from_parameters_arg from_params_args = { 0 }; + + from_params_args.param_bld = OSSL_PARAM_BLD_new(); + + if (from_params_args.param_bld == NULL) { + EVP_PKEY_CTX_free(ctx); + ossl_raise(ePKeyError, "OSSL_PARAM_BLD_new"); + } + + from_params_args.settable_params = EVP_PKEY_fromdata_settable(ctx, EVP_PKEY_KEYPAIR); + + if (from_params_args.settable_params == NULL) { + EVP_PKEY_CTX_free(ctx); + ossl_raise(ePKeyError, "EVP_PKEY_fromdata_settable"); + } + + if (strcmp("RSA", algorithm) == 0) + from_params_args.aliases = rsa_aliases; + else + from_params_args.aliases = fcc_aliases; + + rb_hash_foreach(options, &add_parameter_to_builder, (VALUE) &from_params_args); + + OSSL_PARAM *params = OSSL_PARAM_BLD_to_param(from_params_args.param_bld); + OSSL_PARAM_BLD_free(from_params_args.param_bld); + + if (params == NULL) { + EVP_PKEY_CTX_free(ctx); + ossl_raise(ePKeyError, "OSSL_PARAM_BLD_to_param"); + } + + EVP_PKEY *pkey = NULL; + + if (EVP_PKEY_fromdata_init(ctx) <= 0) { + EVP_PKEY_CTX_free(ctx); + ossl_raise(ePKeyError, "EVP_PKEY_fromdata_init"); + } + + if (EVP_PKEY_fromdata(ctx, &pkey, EVP_PKEY_KEYPAIR, params) <= 0) { + EVP_PKEY_CTX_free(ctx); + EVP_PKEY_free(pkey); + ossl_raise(ePKeyError, "EVP_PKEY_fromdata"); + } + + EVP_PKEY_CTX_free(ctx); + + return ossl_pkey_new(pkey); +} +#endif + /* * call-seq: * OpenSSL::PKey.generate_parameters(algo_name [, options]) -> pkey @@ -498,6 +666,33 @@ ossl_pkey_s_generate_key(int argc, VALUE *argv, VALUE self) return pkey_generate(argc, argv, self, 0); } +/* + * call-seq: + * OpenSSL::PKey.from_parameters(algo_name, parameters) -> pkey + * + * Generates a new key based on given key parameters. + * NOTE: Requires OpenSSL 3.0 or later. + * + * The first parameter is the type of the key to create, given as a String, for example RSA, DSA, EC etc. + * Second parameter is the parameters to be used for the key. + * + * For details algorithms and parameters see https://www.openssl.org/docs/man3.0/man3/EVP_PKEY_fromdata.html + * + * == Example + * pkey = OpenSSL::PKey.from_parameters("RSA", n: 3161751493, e: 65537, d: 2064855961) + * pkey.private? #=> true + * pkey.public_key #=> # OpenSSL::BN.new(3161751493), + "e" => OpenSSL::BN.new(65537), + "d" => OpenSSL::BN.new(2064855961)) + + assert_instance_of OpenSSL::PKey::RSA, new_key + assert_equal true, new_key.private? + assert_equal OpenSSL::BN.new(3161751493), new_key.n + assert_equal OpenSSL::BN.new(65537), new_key.e + assert_equal OpenSSL::BN.new(2064855961), new_key.d + end + + def test_s_from_parameters_rsa_with_n_and_e_given + new_key = OpenSSL::PKey.from_parameters("RSA", n: OpenSSL::BN.new(3161751493), + e: OpenSSL::BN.new(65537)) + + assert_instance_of OpenSSL::PKey::RSA, new_key + assert_equal false, new_key.private? + assert_equal OpenSSL::BN.new(3161751493), new_key.n + assert_equal OpenSSL::BN.new(65537), new_key.e + assert_equal nil, new_key.d + end + + def test_s_from_parameters_rsa_with_openssl_internal_names + source = OpenSSL::PKey::RSA.generate(2048) + new_key = OpenSSL::PKey.from_parameters("RSA", n: source.n, + e: source.e, + d: source.d, + "rsa-factor1" => source.p, + "rsa-factor2" => source.q, + "rsa-exponent1" => source.dmp1, + "rsa-exponent2" => source.dmq1, + "rsa-coefficient1" => source.iqmp + ) + + assert_equal source.n, new_key.n + assert_equal source.e, new_key.e + assert_equal source.d, new_key.d + assert_equal source.p, new_key.p + assert_equal source.q, new_key.q + assert_equal source.dmp1, new_key.dmp1 + assert_equal source.dmq1, new_key.dmq1 + assert_equal source.iqmp, new_key.iqmp + + assert_equal source.to_pem, new_key.to_pem + end + + def test_s_from_parameters_rsa_with_simple_names + source = OpenSSL::PKey::RSA.generate(2048) + new_key = OpenSSL::PKey.from_parameters("RSA", n: source.n, + e: source.e, + d: source.d, + p: source.p, + q: source.q, + dmp1: source.dmp1, + dmq1: source.dmq1, + iqmp: source.iqmp + ) + + assert_equal source.n, new_key.n + assert_equal source.e, new_key.e + assert_equal source.d, new_key.d + assert_equal source.p, new_key.p + assert_equal source.q, new_key.q + assert_equal source.dmp1, new_key.dmp1 + assert_equal source.dmq1, new_key.dmq1 + assert_equal source.iqmp, new_key.iqmp + + assert_equal source.to_pem, new_key.to_pem + end + + def test_s_from_parameters_rsa_with_invalid_parameter + e = assert_raise(OpenSSL::PKey::PKeyError) { OpenSSL::PKey.from_parameters("RSA", invalid: 1234) } + assert_match(/Invalid parameter "invalid"/, e.message) + end + + def test_s_from_parameters_ec_pub_given_as_string + source = OpenSSL::PKey::EC.generate("prime256v1") + new_key = OpenSSL::PKey.from_parameters("EC", group: source.group.curve_name, + pub: source.public_key.to_bn.to_s(2)) + assert_instance_of OpenSSL::PKey::EC, new_key + assert_equal source.group.curve_name, new_key.group.curve_name + assert_equal source.public_key, new_key.public_key + assert_equal nil, new_key.private_key + end + + def test_s_from_parameters_ec_priv_given_as_bn + source = OpenSSL::PKey::EC.generate("prime256v1") + new_key = OpenSSL::PKey.from_parameters("EC", group: source.group.curve_name, + priv: source.private_key.to_bn) + assert_instance_of OpenSSL::PKey::EC, new_key + assert_equal source.group.curve_name, new_key.group.curve_name + assert_equal source.private_key, new_key.private_key + assert_equal nil, new_key.public_key + end + + def test_s_from_parameters_ec_priv_given_as_integer + source = OpenSSL::PKey::EC.generate("prime256v1") + new_key = OpenSSL::PKey.from_parameters("EC", group: source.group.curve_name, + priv: source.private_key.to_i) + assert_instance_of OpenSSL::PKey::EC, new_key + assert_equal source.group.curve_name, new_key.group.curve_name + assert_equal source.private_key, new_key.private_key + assert_equal nil, new_key.public_key + end + + def test_s_from_parameters_ec_priv_and_pub_given_for_different_curves + ["prime256v1", "secp256k1", "secp384r1", "secp521r1"].each do |curve| + source = OpenSSL::PKey::EC.generate(curve) + new_key = OpenSSL::PKey.from_parameters("EC", group: source.group.curve_name, + pub: source.public_key.to_bn.to_s(2), + priv: source.private_key.to_i) + assert_instance_of OpenSSL::PKey::EC, new_key + assert_equal source.group.curve_name, new_key.group.curve_name + assert_equal source.private_key, new_key.private_key + assert_equal source.public_key, new_key.public_key + end + end + + def test_s_from_parameters_ec_pub_given_as_integer + e = assert_raise(TypeError) { OpenSSL::PKey.from_parameters("EC", { group: "prime256v1", pub: 12345 }) } + assert_equal "no implicit conversion of Integer into String", e.message + end + + def test_s_from_parameters_ec_with_invalid_parameter + e = assert_raise(OpenSSL::PKey::PKeyError) { OpenSSL::PKey.from_parameters("EC", invalid: 1234) } + assert_match(/Invalid parameter "invalid"/, e.message) + end + + def test_s_from_parameters_scrypt + e = assert_raise(OpenSSL::PKey::PKeyError) { OpenSSL::PKey.from_parameters("SCRYPT", {}) } + assert_match("EVP_PKEY_fromdata_settable", e.message) + end + + def test_s_from_parameters_dsa_with_all_supported_parameters + source = OpenSSL::PKey::DSA.generate(2048) + + new_key = OpenSSL::PKey.from_parameters("DSA", pub: source.params["pub_key"], + priv: source.params["priv_key"], + p: source.params["p"], + q: source.params["q"], + g: source.params["g"]) + + assert_instance_of OpenSSL::PKey::DSA, new_key + assert_equal source.params, new_key.params + end + + def test_s_from_parameters_dsa_with_gem_specific_keys + source = OpenSSL::PKey::DSA.generate(2048) + + new_key = OpenSSL::PKey.from_parameters("DSA", source.params) + + assert_equal source.params, new_key.params + end + + def test_s_from_parameters_dsa_with_invalid_parameter + e = assert_raise(OpenSSL::PKey::PKeyError) { OpenSSL::PKey.from_parameters("DSA", invalid: 1234) } + assert_match(/Invalid parameter "invalid"/, e.message) + end + + def test_s_from_parameters_dh_with_all_supported_parameters + source = OpenSSL::PKey::DH.generate(512) + + new_key = OpenSSL::PKey.from_parameters("DH", source.params) + + assert_instance_of OpenSSL::PKey::DH, new_key + assert_equal source.params, new_key.params + end + + def test_s_from_parameters_dh_with_invalid_parameter + e = assert_raise(OpenSSL::PKey::PKeyError) { OpenSSL::PKey.from_parameters("DH", invalid: 1234) } + assert_match(/Invalid parameter "invalid"/, e.message) + end + + def test_s_from_parameters_ed25519 + key = OpenSSL::PKey.from_parameters("ED25519", pub: "\xD0\x8E\xA8\x96\xB6Fbi{$k\xAC\xB8\xA2V\xF4n\xC3\xD06}R\x8A\xE6I\xA7r\xF6D{W\x84") + assert_instance_of OpenSSL::PKey::PKey, key + assert_equal "-----BEGIN PUBLIC KEY-----\nMCowBQYDK2VwAyEA0I6olrZGYml7JGusuKJW9G7D0DZ9UormSady9kR7V4Q=\n-----END PUBLIC KEY-----\n", key.public_to_pem + end + + def test_s_from_parameters_ed25519_with_invalid_parameters + e = assert_raise(OpenSSL::PKey::PKeyError) { OpenSSL::PKey.from_parameters("ED25519", invalid: 12345) } + assert_equal 'Invalid parameter "invalid". Supported parameters: "pub, priv, pub_key, priv_key"', e.message + end + else + def test_from_parameter_raises_on_pre_3_openssl + e = assert_raise(OpenSSL::PKey::PKeyError) { + OpenSSL::PKey.from_parameters("RSA", {}) + } + assert_equal e.message, "OpenSSL::PKey.from_parameters requires OpenSSL 3.0 or later" + end + end end