From 18d0cc2321172a4a4c170dff12e7ed626ef6335a Mon Sep 17 00:00:00 2001 From: Anthony Grondin <104731965+AnthonyGrondin@users.noreply.github.com> Date: Thu, 11 Jul 2024 19:36:06 -0400 Subject: [PATCH] feat: Use builder pattern to enable rsa acceleration. --- esp-mbedtls/src/lib.rs | 57 ++++++++++++++++++++++++++++++----- examples/async_client.rs | 4 +-- examples/async_client_mTLS.rs | 4 +-- examples/async_server.rs | 4 +-- examples/async_server_mTLS.rs | 4 +-- examples/sync_client.rs | 4 +-- examples/sync_client_mTLS.rs | 4 +-- examples/sync_server.rs | 5 +-- examples/sync_server_mTLS.rs | 5 +-- 9 files changed, 67 insertions(+), 24 deletions(-) diff --git a/esp-mbedtls/src/lib.rs b/esp-mbedtls/src/lib.rs index 2b8bb98..0267b53 100644 --- a/esp-mbedtls/src/lib.rs +++ b/esp-mbedtls/src/lib.rs @@ -29,6 +29,12 @@ pub use esp_mbedtls_sys::bindings::{ use esp_mbedtls_sys::c_types::*; /// Hold the RSA peripheral for cryptographic operations. +/// +/// This is initialized when `with_hardware_rsa()` is called on a [Session] and is set back to None +/// when the session that called `with_hardware_rsa()` is dropped. +/// +/// Note: Due to implementation constraints, this session and every other session will use the +/// hardware accelerated RSA driver until the session called with this function is dropped. static mut RSA_REF: Option> = None; // these will come from esp-wifi (i.e. this can only be used together with esp-wifi) @@ -385,6 +391,8 @@ pub struct Session { crt: *mut mbedtls_x509_crt, client_crt: *mut mbedtls_x509_crt, private_key: *mut mbedtls_pk_context, + // Indicate if this session is the one holding the RSA ref + owns_rsa: bool, } impl Session { @@ -399,8 +407,6 @@ impl Session { /// * `min_version` - The minimum TLS version for the connection, that will be accepted. /// * `certificates` - Certificate chain for the connection. Will play a different role /// depending on if running as client or server. See [Certificates] for more information. - /// * `rsa` - Optionally take an RSA driver instance. This session will use the hardware rsa crypto - /// accelerators for the session. Passing None will use the software implementation of RSA which is slower. /// /// # Errors /// @@ -413,11 +419,9 @@ impl Session { mode: Mode, min_version: TlsVersion, certificates: Certificates, - rsa: Option>, ) -> Result { let (ssl_context, ssl_config, crt, client_crt, private_key) = certificates.init_ssl(servername, mode, min_version)?; - unsafe { RSA_REF = core::mem::transmute(rsa.map(|inner| Rsa::new(inner, None))) } return Ok(Self { stream, ssl_context, @@ -425,8 +429,23 @@ impl Session { crt, client_crt, private_key, + owns_rsa: false, }); } + + /// Enable the use of the hardware accelerated RSA peripheral for the [Session]. + /// + /// Note: Due to implementation constraints, this session and every other session will use the + /// hardware accelerated RSA driver until the sesssion called with this session is dropped. + /// + /// # Arguments + /// + /// * `rsa` - The RSA peripheral from the HAL + pub fn with_hardware_rsa(mut self, rsa: impl Peripheral

) -> Self { + unsafe { RSA_REF = core::mem::transmute(Some(Rsa::new(rsa, None))) } + self.owns_rsa = true; + self + } } impl Session @@ -536,6 +555,11 @@ impl Drop for Session { fn drop(&mut self) { log::debug!("session dropped - freeing memory"); unsafe { + // If the struct that owns the RSA reference is dropped + // we remove RSA in static for safety + if self.owns_rsa { + RSA_REF = core::mem::transmute(None::); + } mbedtls_ssl_close_notify(self.ssl_context); mbedtls_ssl_config_free(self.ssl_config); mbedtls_ssl_free(self.ssl_context); @@ -611,6 +635,7 @@ pub mod asynch { eof: bool, tx_buffer: BufferedBytes, rx_buffer: BufferedBytes, + owns_rsa: bool, } impl Session { @@ -625,8 +650,6 @@ pub mod asynch { /// * `min_version` - The minimum TLS version for the connection, that will be accepted. /// * `certificates` - Certificate chain for the connection. Will play a different role /// depending on if running as client or server. See [Certificates] for more information. - /// * `rsa` - Optionally take an RSA driver instance. This session will use the hardware rsa crypto - /// accelerators for the session. Passing None will use the software implementation of RSA which is slower. /// /// # Errors /// @@ -639,11 +662,9 @@ pub mod asynch { mode: Mode, min_version: TlsVersion, certificates: Certificates, - rsa: Option>, ) -> Result { let (ssl_context, ssl_config, crt, client_crt, private_key) = certificates.init_ssl(servername, mode, min_version)?; - unsafe { RSA_REF = core::mem::transmute(rsa.map(|inner| Rsa::new(inner, None))) } return Ok(Self { stream, ssl_context, @@ -654,14 +675,34 @@ pub mod asynch { eof: false, tx_buffer: Default::default(), rx_buffer: Default::default(), + owns_rsa: false, }); } + + /// Enable the use of the hardware accelerated RSA peripheral for the [Session]. + /// + /// Note: Due to implementation constraints, this session and every other session will use the + /// hardware accelerated RSA driver until the sesssion called with this session is dropped. + /// + /// # Arguments + /// + /// * `rsa` - The RSA peripheral from the HAL + pub fn with_hardware_rsa(mut self, rsa: impl Peripheral

) -> Self { + unsafe { RSA_REF = core::mem::transmute(Some(Rsa::new(rsa, None))) } + self.owns_rsa = true; + self + } } impl Drop for Session { fn drop(&mut self) { log::debug!("session dropped - freeing memory"); unsafe { + // If the struct that owns the RSA reference is dropped + // we remove RSA in static for safety + if self.owns_rsa { + RSA_REF = core::mem::transmute(None::); + } mbedtls_ssl_close_notify(self.ssl_context); mbedtls_ssl_config_free(self.ssl_config); mbedtls_ssl_free(self.ssl_context); diff --git a/examples/async_client.rs b/examples/async_client.rs index 2214bc9..128472d 100644 --- a/examples/async_client.rs +++ b/examples/async_client.rs @@ -119,9 +119,9 @@ async fn main(spawner: Spawner) -> ! { .ok(), ..Default::default() }, - Some(peripherals.RSA), ) - .unwrap(); + .unwrap() + .with_hardware_rsa(peripherals.RSA); println!("Start tls connect"); let mut tls = tls.connect().await.unwrap(); diff --git a/examples/async_client_mTLS.rs b/examples/async_client_mTLS.rs index cffacf9..73ab1d4 100644 --- a/examples/async_client_mTLS.rs +++ b/examples/async_client_mTLS.rs @@ -125,9 +125,9 @@ async fn main(spawner: Spawner) -> ! { Mode::Client, TlsVersion::Tls1_3, certificates, - Some(peripherals.RSA), ) - .unwrap(); + .unwrap() + .with_hardware_rsa(peripherals.RSA); println!("Start tls connect"); let mut tls = tls.connect().await.unwrap(); diff --git a/examples/async_server.rs b/examples/async_server.rs index d09639f..ca0938c 100644 --- a/examples/async_server.rs +++ b/examples/async_server.rs @@ -140,9 +140,9 @@ async fn main(spawner: Spawner) -> ! { .ok(), ..Default::default() }, - Some(&mut peripherals.RSA), ) - .unwrap(); + .unwrap() + .with_hardware_rsa(&mut peripherals.RSA); println!("Start tls connect"); match tls.connect().await { diff --git a/examples/async_server_mTLS.rs b/examples/async_server_mTLS.rs index e21a7ed..928fde7 100644 --- a/examples/async_server_mTLS.rs +++ b/examples/async_server_mTLS.rs @@ -159,9 +159,9 @@ async fn main(spawner: Spawner) -> ! { .ok(), ..Default::default() }, - Some(&mut peripherals.RSA), ) - .unwrap(); + .unwrap() + .with_hardware_rsa(&mut peripherals.RSA); println!("Start tls connect"); match tls.connect().await { diff --git a/examples/sync_client.rs b/examples/sync_client.rs index 536f5ed..ddf42e1 100644 --- a/examples/sync_client.rs +++ b/examples/sync_client.rs @@ -116,9 +116,9 @@ fn main() -> ! { .ok(), ..Default::default() }, - Some(peripherals.RSA), ) - .unwrap(); + .unwrap() + .with_hardware_rsa(peripherals.RSA); println!("Start tls connect"); let mut tls = tls.connect().unwrap(); diff --git a/examples/sync_client_mTLS.rs b/examples/sync_client_mTLS.rs index 43a3871..97c5544 100644 --- a/examples/sync_client_mTLS.rs +++ b/examples/sync_client_mTLS.rs @@ -122,9 +122,9 @@ fn main() -> ! { Mode::Client, TlsVersion::Tls1_3, certificates, - Some(peripherals.RSA), ) - .unwrap(); + .unwrap() + .with_hardware_rsa(peripherals.RSA); println!("Start tls connect"); let mut tls = tls.connect().unwrap(); diff --git a/examples/sync_server.rs b/examples/sync_server.rs index 5041e11..2c3093a 100644 --- a/examples/sync_server.rs +++ b/examples/sync_server.rs @@ -137,9 +137,10 @@ fn main() -> ! { .ok(), ..Default::default() }, - Some(&mut peripherals.RSA), ) - .unwrap(); + .unwrap() + .with_hardware_rsa(&mut peripherals.RSA); + match tls.connect() { Ok(mut connected_session) => { loop { diff --git a/examples/sync_server_mTLS.rs b/examples/sync_server_mTLS.rs index 5309a4a..02373b1 100644 --- a/examples/sync_server_mTLS.rs +++ b/examples/sync_server_mTLS.rs @@ -158,9 +158,10 @@ fn main() -> ! { .ok(), ..Default::default() }, - Some(&mut peripherals.RSA), ) - .unwrap(); + .unwrap() + .with_hardware_rsa(&mut peripherals.RSA); + match tls.connect() { Ok(mut connected_session) => { loop {