diff --git a/esp-mbedtls/src/lib.rs b/esp-mbedtls/src/lib.rs index d29e23c..b9e6799 100644 --- a/esp-mbedtls/src/lib.rs +++ b/esp-mbedtls/src/lib.rs @@ -240,6 +240,7 @@ impl<'a> Certificates<'a> { min_version: TlsVersion, ) -> Result< ( + *mut mbedtls_ctr_drbg_context, *mut mbedtls_ssl_context, *mut mbedtls_ssl_config, *mut mbedtls_x509_crt, @@ -258,9 +259,16 @@ impl<'a> Certificates<'a> { unsafe { error_checked!(psa_crypto_init())?; + let drbg_context = calloc(1, size_of::() as u32) + as *mut mbedtls_ctr_drbg_context; + if drbg_context.is_null() { + return Err(TlsError::OutOfMemory); + } + let ssl_context = calloc(1, size_of::() as u32) as *mut mbedtls_ssl_context; if ssl_context.is_null() { + free(drbg_context as *const _); return Err(TlsError::OutOfMemory); } @@ -273,6 +281,7 @@ impl<'a> Certificates<'a> { let crt = calloc(1, size_of::() as u32) as *mut mbedtls_x509_crt; if crt.is_null() { + free(drbg_context as *const _); free(ssl_context as *const _); free(ssl_config as *const _); return Err(TlsError::OutOfMemory); @@ -281,6 +290,7 @@ impl<'a> Certificates<'a> { let certificate = calloc(1, size_of::() as u32) as *mut mbedtls_x509_crt; if certificate.is_null() { + free(drbg_context as *const _); free(ssl_context as *const _); free(ssl_config as *const _); free(crt as *const _); @@ -290,6 +300,7 @@ impl<'a> Certificates<'a> { let private_key = calloc(1, size_of::() as u32) as *mut mbedtls_pk_context; if private_key.is_null() { + free(drbg_context as *const _); free(ssl_context as *const _); free(ssl_config as *const _); free(crt as *const _); @@ -306,7 +317,9 @@ impl<'a> Certificates<'a> { // Initialize private key mbedtls_pk_init(private_key); (*ssl_config).private_f_dbg = Some(dbg_print); - (*ssl_config).private_f_rng = Some(rng); + // Init RNG + mbedtls_ctr_drbg_init(drbg_context); + mbedtls_ssl_conf_rng(ssl_config, Some(rng), drbg_context as *mut c_void); error_checked!(mbedtls_ssl_config_defaults( ssl_config, @@ -379,13 +392,21 @@ impl<'a> Certificates<'a> { mbedtls_ssl_conf_ca_chain(ssl_config, crt, core::ptr::null_mut()); error_checked!(mbedtls_ssl_setup(ssl_context, ssl_config))?; - Ok((ssl_context, ssl_config, crt, certificate, private_key)) + Ok(( + drbg_context, + ssl_context, + ssl_config, + crt, + certificate, + private_key, + )) } } } pub struct Session { stream: T, + drbg_context: *mut mbedtls_ctr_drbg_context, ssl_context: *mut mbedtls_ssl_context, ssl_config: *mut mbedtls_ssl_config, crt: *mut mbedtls_x509_crt, @@ -420,10 +441,11 @@ impl Session { min_version: TlsVersion, certificates: Certificates, ) -> Result { - let (ssl_context, ssl_config, crt, client_crt, private_key) = + let (drbg_context, ssl_context, ssl_config, crt, client_crt, private_key) = certificates.init_ssl(servername, mode, min_version)?; return Ok(Self { stream, + drbg_context, ssl_context, ssl_config, crt, @@ -561,11 +583,13 @@ impl Drop for Session { RSA_REF = core::mem::transmute(None::); } mbedtls_ssl_close_notify(self.ssl_context); + mbedtls_ctr_drbg_free(self.drbg_context); mbedtls_ssl_config_free(self.ssl_config); mbedtls_ssl_free(self.ssl_context); mbedtls_x509_crt_free(self.crt); mbedtls_x509_crt_free(self.client_crt); mbedtls_pk_free(self.private_key); + free(self.drbg_context as *const _); free(self.ssl_config as *const _); free(self.ssl_context as *const _); free(self.crt as *const _); @@ -627,6 +651,7 @@ pub mod asynch { pub struct Session { stream: T, + drbg_context: *mut mbedtls_ctr_drbg_context, ssl_context: *mut mbedtls_ssl_context, ssl_config: *mut mbedtls_ssl_config, crt: *mut mbedtls_x509_crt, @@ -663,10 +688,11 @@ pub mod asynch { min_version: TlsVersion, certificates: Certificates, ) -> Result { - let (ssl_context, ssl_config, crt, client_crt, private_key) = + let (drbg_context, ssl_context, ssl_config, crt, client_crt, private_key) = certificates.init_ssl(servername, mode, min_version)?; return Ok(Self { stream, + drbg_context, ssl_context, ssl_config, crt, @@ -704,11 +730,13 @@ pub mod asynch { RSA_REF = core::mem::transmute(None::); } mbedtls_ssl_close_notify(self.ssl_context); + mbedtls_ctr_drbg_free(self.drbg_context); mbedtls_ssl_config_free(self.ssl_config); mbedtls_ssl_free(self.ssl_context); mbedtls_x509_crt_free(self.crt); mbedtls_x509_crt_free(self.client_crt); mbedtls_pk_free(self.private_key); + free(self.drbg_context as *const _); free(self.ssl_config as *const _); free(self.ssl_context as *const _); free(self.crt as *const _);