From 16efec911a361946a10ddfcdc8d2ef8706544350 Mon Sep 17 00:00:00 2001 From: Sam Clark <3758302+goatgoose@users.noreply.github.com> Date: Sat, 25 May 2024 01:11:46 -0400 Subject: [PATCH] feat(bindings): Associate an application context with a Connection (#4563) --- bindings/rust/s2n-tls/src/connection.rs | 110 ++++++++++++++++++- bindings/rust/s2n-tls/src/testing/s2n_tls.rs | 62 ++++++++++- 2 files changed, 170 insertions(+), 2 deletions(-) diff --git a/bindings/rust/s2n-tls/src/connection.rs b/bindings/rust/s2n-tls/src/connection.rs index 442a51f0c39..f8c4ccb008a 100644 --- a/bindings/rust/s2n-tls/src/connection.rs +++ b/bindings/rust/s2n-tls/src/connection.rs @@ -23,7 +23,7 @@ use core::{ }; use libc::c_void; use s2n_tls_sys::*; -use std::ffi::CStr; +use std::{any::Any, ffi::CStr}; mod builder; pub use builder::*; @@ -1049,6 +1049,45 @@ impl Connection { pub fn resumed(&self) -> bool { unsafe { s2n_connection_is_session_resumed(self.connection.as_ptr()) == 1 } } + + /// Associates an arbitrary application context with the Connection to be later retrieved via + /// the [`Self::application_context()`] and [`Self::application_context_mut()`] APIs. + /// + /// This API will override an existing application context set on the Connection. + pub fn set_application_context(&mut self, app_context: T) { + self.context_mut().app_context = Some(Box::new(app_context)); + } + + /// Retrieves a reference to the application context associated with the Connection. + /// + /// If an application context hasn't already been set on the Connection, or if the set + /// application context isn't of type T, None will be returned. + /// + /// To set a context on the connection, use [`Self::set_application_context()`]. To retrieve a + /// mutable reference to the context, use [`Self::application_context_mut()`]. + pub fn application_context(&self) -> Option<&T> { + match self.context().app_context.as_ref() { + None => None, + // The Any trait keeps track of the application context's type. downcast_ref() returns + // Some only if the correct type is provided: + // https://doc.rust-lang.org/std/any/trait.Any.html#method.downcast_ref + Some(app_context) => app_context.downcast_ref::(), + } + } + + /// Retrieves a mutable reference to the application context associated with the Connection. + /// + /// If an application context hasn't already been set on the Connection, or if the set + /// application context isn't of type T, None will be returned. + /// + /// To set a context on the connection, use [`Self::set_application_context()`]. To retrieve an + /// immutable reference to the context, use [`Self::application_context()`]. + pub fn application_context_mut(&mut self) -> Option<&mut T> { + match self.context_mut().app_context.as_mut() { + None => None, + Some(app_context) => app_context.downcast_mut::(), + } + } } struct Context { @@ -1057,6 +1096,7 @@ struct Context { async_callback: Option, verify_host_callback: Option>, connection_initialized: bool, + app_context: Option>, } impl Context { @@ -1067,6 +1107,7 @@ impl Context { async_callback: None, verify_host_callback: None, connection_initialized: false, + app_context: None, } } } @@ -1181,4 +1222,71 @@ mod tests { fn assert_sync() {} assert_sync::(); } + + /// Test that an application context can be set and retrieved. + #[test] + fn test_app_context_set_and_retrieve() { + let mut connection = Connection::new_server(); + + // Before a context is set, None is returned. + assert!(connection.application_context::().is_none()); + + let test_value: u32 = 1142; + connection.set_application_context(test_value); + + // After a context is set, the application data is returned. + assert_eq!(*connection.application_context::().unwrap(), 1142); + } + + /// Test that an application context can be modified. + #[test] + fn test_app_context_modify() { + let test_value: u64 = 0; + + let mut connection = Connection::new_server(); + connection.set_application_context(test_value); + + let context_value = connection.application_context_mut::().unwrap(); + *context_value += 1; + + assert_eq!(*connection.application_context::().unwrap(), 1); + } + + /// Test that an application context can be overridden. + #[test] + fn test_app_context_override() { + let mut connection = Connection::new_server(); + + let test_value: u16 = 1142; + connection.set_application_context(test_value); + + assert_eq!(*connection.application_context::().unwrap(), 1142); + + // Override the context with a new value. + let test_value: u16 = 10; + connection.set_application_context(test_value); + + assert_eq!(*connection.application_context::().unwrap(), 10); + + // Override the context with a new type. + let test_value: i16 = -20; + connection.set_application_context(test_value); + + assert_eq!(*connection.application_context::().unwrap(), -20); + } + + /// Test that a context of another type can't be retrieved. + #[test] + fn test_app_context_invalid_type() { + let mut connection = Connection::new_server(); + + let test_value: u32 = 0; + connection.set_application_context(test_value); + + // A context type that wasn't set shouldn't be returned. + assert!(connection.application_context::().is_none()); + + // Retrieving the correct type succeeds. + assert!(connection.application_context::().is_some()); + } } diff --git a/bindings/rust/s2n-tls/src/testing/s2n_tls.rs b/bindings/rust/s2n-tls/src/testing/s2n_tls.rs index 4f186c3c020..eb18e7d16bd 100644 --- a/bindings/rust/s2n-tls/src/testing/s2n_tls.rs +++ b/bindings/rust/s2n-tls/src/testing/s2n_tls.rs @@ -233,7 +233,7 @@ impl<'a, T: 'a + Context> Callback<'a, T> { #[cfg(test)] mod tests { use crate::{ - callbacks::{ClientHelloCallback, ConnectionFuture}, + callbacks::{ClientHelloCallback, ConnectionFuture, ConnectionFutureResult}, enums::ClientAuthType, error::ErrorType, testing::{client_hello::*, s2n_tls::*, *}, @@ -970,4 +970,64 @@ mod tests { init::init(); assert!(init::fips_mode().unwrap().is_enabled()); } + + /// Test that a context can be used from within a callback. + #[test] + fn test_app_context_callback() { + struct TestApplicationContext { + invoked_count: u32, + } + + struct TestClientHelloHandler {} + + impl ClientHelloCallback for TestClientHelloHandler { + fn on_client_hello( + &self, + connection: &mut connection::Connection, + ) -> ConnectionFutureResult { + let app_context = connection + .application_context_mut::() + .unwrap(); + app_context.invoked_count += 1; + Ok(None) + } + } + + let config = { + let keypair = CertKeyPair::default(); + let mut builder = Builder::new(); + builder + .set_verify_host_callback(InsecureAcceptAllCertificatesHandler {}) + .unwrap(); + builder + .set_client_hello_callback(TestClientHelloHandler {}) + .unwrap(); + builder.load_pem(keypair.cert, keypair.key).unwrap(); + builder.trust_pem(keypair.cert).unwrap(); + builder.build().unwrap() + }; + + let mut pair = tls_pair(config); + pair.server + .0 + .connection_mut() + .set_waker(Some(&noop_waker())) + .unwrap(); + + let context = TestApplicationContext { invoked_count: 0 }; + pair.server + .0 + .connection_mut() + .set_application_context(context); + + assert!(poll_tls_pair_result(&mut pair).is_ok()); + + let context = pair + .server + .0 + .connection() + .application_context::() + .unwrap(); + assert_eq!(context.invoked_count, 1); + } }