Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Properly initialize RNG context #33

Merged
merged 1 commit into from
Jul 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions esp-mbedtls/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ impl<'a> Certificates<'a> {
min_version: TlsVersion,
) -> Result<
(
*mut mbedtls_ctr_drbg_context,
*mut mbedtls_ssl_context,
*mut mbedtls_ssl_config,
*mut mbedtls_x509_crt,
Expand All @@ -252,9 +253,16 @@ impl<'a> Certificates<'a> {
unsafe {
error_checked!(psa_crypto_init())?;

let drbg_context = calloc(1, size_of::<mbedtls_ctr_drbg_context>() as u32)
as *mut mbedtls_ctr_drbg_context;
if drbg_context.is_null() {
return Err(TlsError::OutOfMemory);
}

let ssl_context =
calloc(1, size_of::<mbedtls_ssl_context>() as u32) as *mut mbedtls_ssl_context;
if ssl_context.is_null() {
free(drbg_context as *const _);
return Err(TlsError::OutOfMemory);
}

Expand All @@ -267,6 +275,7 @@ impl<'a> Certificates<'a> {

let crt = calloc(1, size_of::<mbedtls_x509_crt>() as u32) as *mut mbedtls_x509_crt;
if crt.is_null() {
free(drbg_context as *const _);
free(ssl_context as *const _);
free(ssl_config as *const _);
return Err(TlsError::OutOfMemory);
Expand All @@ -275,6 +284,7 @@ impl<'a> Certificates<'a> {
let certificate =
calloc(1, size_of::<mbedtls_x509_crt>() as u32) as *mut mbedtls_x509_crt;
if certificate.is_null() {
free(drbg_context as *const _);
free(ssl_context as *const _);
free(ssl_config as *const _);
free(crt as *const _);
Expand All @@ -284,6 +294,7 @@ impl<'a> Certificates<'a> {
let private_key =
calloc(1, size_of::<mbedtls_pk_context>() as u32) as *mut mbedtls_pk_context;
if private_key.is_null() {
free(drbg_context as *const _);
free(ssl_context as *const _);
free(ssl_config as *const _);
free(crt as *const _);
Expand All @@ -300,7 +311,9 @@ impl<'a> Certificates<'a> {
// Initialize private key
mbedtls_pk_init(private_key);
(*ssl_config).private_f_dbg = Some(dbg_print);
(*ssl_config).private_f_rng = Some(rng);
// Init RNG
mbedtls_ctr_drbg_init(drbg_context);
mbedtls_ssl_conf_rng(ssl_config, Some(rng), drbg_context as *mut c_void);

error_checked!(mbedtls_ssl_config_defaults(
ssl_config,
Expand Down Expand Up @@ -373,13 +386,21 @@ impl<'a> Certificates<'a> {

mbedtls_ssl_conf_ca_chain(ssl_config, crt, core::ptr::null_mut());
error_checked!(mbedtls_ssl_setup(ssl_context, ssl_config))?;
Ok((ssl_context, ssl_config, crt, certificate, private_key))
Ok((
drbg_context,
ssl_context,
ssl_config,
crt,
certificate,
private_key,
))
}
}
}

pub struct Session<T> {
stream: T,
drbg_context: *mut mbedtls_ctr_drbg_context,
ssl_context: *mut mbedtls_ssl_context,
ssl_config: *mut mbedtls_ssl_config,
crt: *mut mbedtls_x509_crt,
Expand Down Expand Up @@ -415,11 +436,12 @@ impl<T> Session<T> {
certificates: Certificates,
rsa: Option<impl Peripheral<P = RSA>>,
) -> Result<Self, TlsError> {
let (ssl_context, ssl_config, crt, client_crt, private_key) =
let (drbg_context, ssl_context, ssl_config, crt, client_crt, private_key) =
certificates.init_ssl(servername, mode, min_version)?;
unsafe { RSA_REF = core::mem::transmute(rsa.map(|inner| Rsa::new(inner, None))) }
return Ok(Self {
stream,
drbg_context,
ssl_context,
ssl_config,
crt,
Expand Down Expand Up @@ -537,11 +559,13 @@ impl<T> Drop for Session<T> {
log::debug!("session dropped - freeing memory");
unsafe {
mbedtls_ssl_close_notify(self.ssl_context);
mbedtls_ctr_drbg_free(self.drbg_context);
mbedtls_ssl_config_free(self.ssl_config);
mbedtls_ssl_free(self.ssl_context);
mbedtls_x509_crt_free(self.crt);
mbedtls_x509_crt_free(self.client_crt);
mbedtls_pk_free(self.private_key);
free(self.drbg_context as *const _);
free(self.ssl_config as *const _);
free(self.ssl_context as *const _);
free(self.crt as *const _);
Expand Down Expand Up @@ -603,6 +627,7 @@ pub mod asynch {

pub struct Session<T, const BUFFER_SIZE: usize = 4096> {
stream: T,
drbg_context: *mut mbedtls_ctr_drbg_context,
ssl_context: *mut mbedtls_ssl_context,
ssl_config: *mut mbedtls_ssl_config,
crt: *mut mbedtls_x509_crt,
Expand Down Expand Up @@ -641,11 +666,12 @@ pub mod asynch {
certificates: Certificates,
rsa: Option<impl Peripheral<P = RSA>>,
) -> Result<Self, TlsError> {
let (ssl_context, ssl_config, crt, client_crt, private_key) =
let (drbg_context, ssl_context, ssl_config, crt, client_crt, private_key) =
certificates.init_ssl(servername, mode, min_version)?;
unsafe { RSA_REF = core::mem::transmute(rsa.map(|inner| Rsa::new(inner, None))) }
return Ok(Self {
stream,
drbg_context,
ssl_context,
ssl_config,
crt,
Expand All @@ -663,11 +689,13 @@ pub mod asynch {
log::debug!("session dropped - freeing memory");
unsafe {
mbedtls_ssl_close_notify(self.ssl_context);
mbedtls_ctr_drbg_free(self.drbg_context);
mbedtls_ssl_config_free(self.ssl_config);
mbedtls_ssl_free(self.ssl_context);
mbedtls_x509_crt_free(self.crt);
mbedtls_x509_crt_free(self.client_crt);
mbedtls_pk_free(self.private_key);
free(self.drbg_context as *const _);
free(self.ssl_config as *const _);
free(self.ssl_context as *const _);
free(self.crt as *const _);
Expand Down
Loading