diff --git a/libs/mbedtls/mbedtls.ml b/libs/mbedtls/mbedtls.ml index dac738dde37..513cca45052 100644 --- a/libs/mbedtls/mbedtls.ml +++ b/libs/mbedtls/mbedtls.ml @@ -43,8 +43,8 @@ external mbedtls_ssl_setup : mbedtls_ssl_context -> mbedtls_ssl_config -> mbedtl external mbedtls_ssl_write : mbedtls_ssl_context -> bytes -> int -> int -> mbedtls_result = "ml_mbedtls_ssl_write" external mbedtls_pk_init : unit -> mbedtls_pk_context = "ml_mbedtls_pk_init" -external mbedtls_pk_parse_key : mbedtls_pk_context -> bytes -> string option -> mbedtls_result = "ml_mbedtls_pk_parse_key" -external mbedtls_pk_parse_keyfile : mbedtls_pk_context -> string -> string option -> mbedtls_result = "ml_mbedtls_pk_parse_keyfile" +external mbedtls_pk_parse_key : mbedtls_pk_context -> bytes -> string option -> mbedtls_ctr_drbg_context -> mbedtls_result = "ml_mbedtls_pk_parse_key" +external mbedtls_pk_parse_keyfile : mbedtls_pk_context -> string -> string option -> mbedtls_ctr_drbg_context -> mbedtls_result = "ml_mbedtls_pk_parse_keyfile" external mbedtls_pk_parse_public_keyfile : mbedtls_pk_context -> string -> mbedtls_result = "ml_mbedtls_pk_parse_public_keyfile" external mbedtls_pk_parse_public_key : mbedtls_pk_context -> bytes -> mbedtls_result = "ml_mbedtls_pk_parse_public_key" diff --git a/libs/mbedtls/mbedtls_stubs.c b/libs/mbedtls/mbedtls_stubs.c index f675e63213f..9d8132d9e07 100644 --- a/libs/mbedtls/mbedtls_stubs.c +++ b/libs/mbedtls/mbedtls_stubs.c @@ -1,4 +1,3 @@ -#include #include #include @@ -18,13 +17,10 @@ #include #include -#include "mbedtls/debug.h" #include "mbedtls/error.h" -#include "mbedtls/config.h" #include "mbedtls/ssl.h" #include "mbedtls/entropy.h" #include "mbedtls/ctr_drbg.h" -#include "mbedtls/certs.h" #include "mbedtls/oid.h" #define PVoid_val(v) (*((void**) Data_custom_val(v))) @@ -84,7 +80,7 @@ CAMLprim value ml_mbedtls_ctr_drbg_init(void) { CAMLprim value ml_mbedtls_ctr_drbg_random(value p_rng, value output, value output_len) { CAMLparam3(p_rng, output, output_len); - CAMLreturn(Val_int(mbedtls_ctr_drbg_random(CtrDrbg_val(p_rng), String_val(output), Int_val(output_len)))); + CAMLreturn(Val_int(mbedtls_ctr_drbg_random(CtrDrbg_val(p_rng), Bytes_val(output), Int_val(output_len)))); } CAMLprim value ml_mbedtls_ctr_drbg_seed(value ctx, value p_entropy, value custom) { @@ -124,7 +120,7 @@ CAMLprim value ml_mbedtls_entropy_init(void) { CAMLprim value ml_mbedtls_entropy_func(value data, value output, value len) { CAMLparam3(data, output, len); - CAMLreturn(Val_int(mbedtls_entropy_func(PVoid_val(data), String_val(output), Int_val(len)))); + CAMLreturn(Val_int(mbedtls_entropy_func(PVoid_val(data), Bytes_val(output), Int_val(len)))); } // Certificate @@ -171,7 +167,7 @@ CAMLprim value ml_mbedtls_x509_next(value chain) { CAMLprim value ml_mbedtls_x509_crt_parse(value chain, value bytes) { CAMLparam2(chain, bytes); - const char* buf = String_val(bytes); + const unsigned char* buf = Bytes_val(bytes); int len = caml_string_length(bytes); CAMLreturn(Val_int(mbedtls_x509_crt_parse(X509Crt_val(chain), buf, len + 1))); } @@ -191,8 +187,7 @@ CAMLprim value ml_mbedtls_x509_crt_parse_path(value chain, value path) { value caml_string_of_asn1_buf(mbedtls_asn1_buf* dat) { CAMLparam0(); CAMLlocal1(s); - s = caml_alloc_string(dat->len); - memcpy(String_val(s), dat->p, dat->len); + s = caml_alloc_initialized_string(dat->len, (const char *)dat->p); CAMLreturn(s); } @@ -200,7 +195,11 @@ CAMLprim value hx_cert_get_alt_names(value chain) { CAMLparam1(chain); CAMLlocal1(obj); mbedtls_x509_crt* cert = X509Crt_val(chain); - if (cert->ext_types & MBEDTLS_X509_EXT_SUBJECT_ALT_NAME == 0 || &cert->subject_alt_names == NULL) { +#if MBEDTLS_VERSION_MAJOR >= 3 + if (!mbedtls_x509_crt_has_ext_type(cert, MBEDTLS_X509_EXT_SUBJECT_ALT_NAME)) { +#else + if ((cert->ext_types & MBEDTLS_X509_EXT_SUBJECT_ALT_NAME) == 0) { +#endif obj = Atom(0); } else { mbedtls_asn1_sequence* cur = &cert->subject_alt_names; @@ -366,29 +365,39 @@ CAMLprim value ml_mbedtls_pk_init(void) { CAMLreturn(obj); } -CAMLprim value ml_mbedtls_pk_parse_key(value ctx, value key, value password) { - CAMLparam3(ctx, key, password); - const char* pwd = NULL; +CAMLprim value ml_mbedtls_pk_parse_key(value ctx, value key, value password, value rng) { + CAMLparam4(ctx, key, password, rng); + const unsigned char* pwd = NULL; size_t pwdlen = 0; if (password != Val_none) { - pwd = String_val(Field(password, 0)); + pwd = Bytes_val(Field(password, 0)); pwdlen = caml_string_length(Field(password, 0)); } - CAMLreturn(mbedtls_pk_parse_key(PkContext_val(ctx), String_val(key), caml_string_length(key) + 1, pwd, pwdlen)); + #if MBEDTLS_VERSION_MAJOR >= 3 + mbedtls_ctr_drbg_context *ctr_drbg = CtrDrbg_val(rng); + CAMLreturn(mbedtls_pk_parse_key(PkContext_val(ctx), Bytes_val(key), caml_string_length(key) + 1, pwd, pwdlen, mbedtls_ctr_drbg_random, NULL)); + #else + CAMLreturn(mbedtls_pk_parse_key(PkContext_val(ctx), Bytes_val(key), caml_string_length(key) + 1, pwd, pwdlen)); + #endif } -CAMLprim value ml_mbedtls_pk_parse_keyfile(value ctx, value path, value password) { - CAMLparam3(ctx, path, password); +CAMLprim value ml_mbedtls_pk_parse_keyfile(value ctx, value path, value password, value rng) { + CAMLparam4(ctx, path, password, rng); const char* pwd = NULL; if (password != Val_none) { pwd = String_val(Field(password, 0)); } + #if MBEDTLS_VERSION_MAJOR >= 3 + mbedtls_ctr_drbg_context *ctr_drbg = CtrDrbg_val(rng); + CAMLreturn(mbedtls_pk_parse_keyfile(PkContext_val(ctx), String_val(path), pwd, mbedtls_ctr_drbg_random, ctr_drbg)); + #else CAMLreturn(mbedtls_pk_parse_keyfile(PkContext_val(ctx), String_val(path), pwd)); + #endif } CAMLprim value ml_mbedtls_pk_parse_public_key(value ctx, value key) { CAMLparam2(ctx, key); - CAMLreturn(mbedtls_pk_parse_public_key(PkContext_val(ctx), String_val(key), caml_string_length(key) + 1)); + CAMLreturn(mbedtls_pk_parse_public_key(PkContext_val(ctx), Bytes_val(key), caml_string_length(key) + 1)); } CAMLprim value ml_mbedtls_pk_parse_public_keyfile(value ctx, value path) { @@ -446,15 +455,14 @@ CAMLprim value ml_mbedtls_ssl_handshake(value ssl) { CAMLprim value ml_mbedtls_ssl_read(value ssl, value buf, value pos, value len) { CAMLparam4(ssl, buf, pos, len); - CAMLreturn(Val_int(mbedtls_ssl_read(SslContext_val(ssl), String_val(buf) + Int_val(pos), Int_val(len)))); + CAMLreturn(Val_int(mbedtls_ssl_read(SslContext_val(ssl), Bytes_val(buf) + Int_val(pos), Int_val(len)))); } static int bio_write_cb(void* ctx, const unsigned char* buf, size_t len) { CAMLparam0(); CAMLlocal3(r, s, vctx); - vctx = (value)ctx; - s = caml_alloc_string(len); - memcpy(String_val(s), buf, len); + vctx = *(value*)ctx; + s = caml_alloc_initialized_string(len, (const char*)buf); r = caml_callback2(Field(vctx, 1), Field(vctx, 0), s); CAMLreturn(Int_val(r)); } @@ -462,7 +470,7 @@ static int bio_write_cb(void* ctx, const unsigned char* buf, size_t len) { static int bio_read_cb(void* ctx, unsigned char* buf, size_t len) { CAMLparam0(); CAMLlocal3(r, s, vctx); - vctx = (value)ctx; + vctx = *(value*)ctx; s = caml_alloc_string(len); r = caml_callback2(Field(vctx, 2), Field(vctx, 0), s); memcpy(buf, String_val(s), len); @@ -476,7 +484,11 @@ CAMLprim value ml_mbedtls_ssl_set_bio(value ssl, value p_bio, value f_send, valu Store_field(ctx, 0, p_bio); Store_field(ctx, 1, f_send); Store_field(ctx, 2, f_recv); - mbedtls_ssl_set_bio(SslContext_val(ssl), (void*)ctx, bio_write_cb, bio_read_cb, NULL); + // TODO: this allocation is leaked + value *location = malloc(sizeof(value)); + *location = ctx; + caml_register_generational_global_root(location); + mbedtls_ssl_set_bio(SslContext_val(ssl), (void*)location, bio_write_cb, bio_read_cb, NULL); CAMLreturn(Val_unit); } @@ -492,7 +504,7 @@ CAMLprim value ml_mbedtls_ssl_setup(value ssl, value conf) { CAMLprim value ml_mbedtls_ssl_write(value ssl, value buf, value pos, value len) { CAMLparam4(ssl, buf, pos, len); - CAMLreturn(Val_int(mbedtls_ssl_write(SslContext_val(ssl), String_val(buf) + Int_val(pos), Int_val(len)))); + CAMLreturn(Val_int(mbedtls_ssl_write(SslContext_val(ssl), Bytes_val(buf) + Int_val(pos), Int_val(len)))); } // glue @@ -520,36 +532,23 @@ CAMLprim value hx_cert_load_defaults(value certificate) { #endif #ifdef __APPLE__ - CFMutableDictionaryRef search; - CFArrayRef result; - SecKeychainRef keychain; - SecCertificateRef item; - CFDataRef dat; - // Load keychain - if (SecKeychainOpen("/System/Library/Keychains/SystemRootCertificates.keychain", &keychain) == errSecSuccess) { - // Search for certificates - search = CFDictionaryCreateMutable(NULL, 0, NULL, NULL); - CFDictionarySetValue(search, kSecClass, kSecClassCertificate); - CFDictionarySetValue(search, kSecMatchLimit, kSecMatchLimitAll); - CFDictionarySetValue(search, kSecReturnRef, kCFBooleanTrue); - CFDictionarySetValue(search, kSecMatchSearchList, CFArrayCreate(NULL, (const void **)&keychain, 1, NULL)); - if (SecItemCopyMatching(search, (CFTypeRef *)&result) == errSecSuccess) { - CFIndex n = CFArrayGetCount(result); - for (CFIndex i = 0; i < n; i++) { - item = (SecCertificateRef)CFArrayGetValueAtIndex(result, i); - - // Get certificate in DER format - dat = SecCertificateCopyData(item); - if (dat) { - r = mbedtls_x509_crt_parse_der(chain, (unsigned char *)CFDataGetBytePtr(dat), CFDataGetLength(dat)); - CFRelease(dat); - if (r != 0) { - CAMLreturn(Val_int(r)); - } + CFArrayRef certs; + if (SecTrustCopyAnchorCertificates(&certs) == errSecSuccess) { + CFIndex count = CFArrayGetCount(certs); + for(CFIndex i = 0; i < count; i++) { + SecCertificateRef item = (SecCertificateRef)CFArrayGetValueAtIndex(certs, i); + + // Get certificate in DER format + CFDataRef data = SecCertificateCopyData(item); + if(data) { + r = mbedtls_x509_crt_parse_der(chain, (unsigned char *)CFDataGetBytePtr(data), CFDataGetLength(data)); + CFRelease(data); + if (r != 0) { + CAMLreturn(Val_int(r)); } } } - CFRelease(keychain); + CFRelease(certs); } #endif diff --git a/src/macro/eval/evalSsl.ml b/src/macro/eval/evalSsl.ml index 3e22ec4dfd1..4da4b0355d0 100644 --- a/src/macro/eval/evalSsl.ml +++ b/src/macro/eval/evalSsl.ml @@ -160,11 +160,11 @@ let init_fields init_fields builtins = "strerror",vfun1 (fun code -> encode_string (mbedtls_strerror (decode_int code))); ] []; init_fields builtins (["mbedtls"],"PkContext") [] [ - "parse_key",vifun2 (fun this key password -> - vint (mbedtls_pk_parse_key (as_pk_context this) (decode_bytes key) (match password with VNull -> None | _ -> Some (decode_string password))); + "parse_key",vifun3 (fun this key password rng -> + vint (mbedtls_pk_parse_key (as_pk_context this) (decode_bytes key) (match password with VNull -> None | _ -> Some (decode_string password)) (as_ctr_drbg rng)); ); - "parse_keyfile",vifun2 (fun this path password -> - vint (mbedtls_pk_parse_keyfile (as_pk_context this) (decode_string path) (match password with VNull -> None | _ -> Some (decode_string password))); + "parse_keyfile",vifun3 (fun this path password rng -> + vint (mbedtls_pk_parse_keyfile (as_pk_context this) (decode_string path) (match password with VNull -> None | _ -> Some (decode_string password)) (as_ctr_drbg rng)); ); "parse_public_key",vifun1 (fun this key -> vint (mbedtls_pk_parse_public_key (as_pk_context this) (decode_bytes key)); diff --git a/std/eval/_std/mbedtls/PkContext.hx b/std/eval/_std/mbedtls/PkContext.hx index 0c83a4a47f4..6c0da6e935f 100644 --- a/std/eval/_std/mbedtls/PkContext.hx +++ b/std/eval/_std/mbedtls/PkContext.hx @@ -5,8 +5,8 @@ import haxe.io.Bytes; extern class PkContext { function new():Void; - function parse_key(key:Bytes, ?pwd:String):Int; - function parse_keyfile(path:String, ?password:String):Int; + function parse_key(key:Bytes, ?pwd:String, ctr_dbg: CtrDrbg):Int; + function parse_keyfile(path:String, ?password:String, ctr_dbg: CtrDrbg):Int; function parse_public_key(key:Bytes):Int; function parse_public_keyfile(path:String):Int; } diff --git a/std/eval/_std/sys/ssl/Key.hx b/std/eval/_std/sys/ssl/Key.hx index 67ea51a5cf8..b756a3dacf3 100644 --- a/std/eval/_std/sys/ssl/Key.hx +++ b/std/eval/_std/sys/ssl/Key.hx @@ -38,7 +38,7 @@ class Key { var code = if (isPublic) { key.native.parse_public_keyfile(file); } else { - key.native.parse_keyfile(file, pass); + key.native.parse_keyfile(file, pass, Mbedtls.getDefaultCtrDrbg()); } if (code != 0) { throw(mbedtls.Error.strerror(code)); @@ -51,7 +51,7 @@ class Key { var code = if (isPublic) { key.native.parse_public_key(data); } else { - key.native.parse_key(data); + key.native.parse_key(data, null, Mbedtls.getDefaultCtrDrbg()); } if (code != 0) { throw(mbedtls.Error.strerror(code));