Skip to content

Commit

Permalink
Fix UB: Non-static callbacks are unsound
Browse files Browse the repository at this point in the history
  • Loading branch information
ivmarkov committed Nov 11, 2023
1 parent cdfc601 commit 1c1ffe3
Show file tree
Hide file tree
Showing 18 changed files with 153 additions and 175 deletions.
4 changes: 2 additions & 2 deletions .github/configs/sdkconfig.defaults
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
CONFIG_MBEDTLS_CERTIFICATE_BUNDLE=y
CONFIG_MBEDTLS_CERTIFICATE_BUNDLE_DEFAULT_FULL=y
CONFIG_MBEDTLS_CERTIFICATE_BUNDLE=n
#CONFIG_MBEDTLS_CERTIFICATE_BUNDLE_DEFAULT_FULL=y

# Examples often require a larger than the default stack size for the main thread.
CONFIG_ESP_MAIN_TASK_STACK_SIZE=10000
Expand Down
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ build = "build.rs"
documentation = "https://esp-rs.github.io/esp-idf-svc/"
rust-version = "1.71"

[patch.crates-io]
embedded-svc = { git = "https://github.com/esp-rs/embedded-svc" }

[features]
default = ["std", "native", "binstart"]

Expand Down
9 changes: 4 additions & 5 deletions src/bt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ impl From<esp_bt_uuid_t> for BtUuid {
#[allow(clippy::type_complexity)]
pub(crate) struct BtCallback<A, R> {
initialized: AtomicBool,
callback: UnsafeCell<Option<alloc::boxed::Box<dyn Fn(A) -> R>>>,
callback: UnsafeCell<Option<alloc::boxed::Box<dyn Fn(A) -> R + Send + 'static>>>,
default_result: R,
}

Expand All @@ -155,16 +155,15 @@ where
}
}

pub fn set<'d, F>(&self, callback: F) -> Result<(), EspError>
pub fn set<F>(&self, callback: F) -> Result<(), EspError>
where
F: Fn(A) -> R + Send + 'd,
F: Fn(A) -> R + Send + 'static,
{
self.initialized
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
.map_err(|_| EspError::from_infallible::<ESP_ERR_INVALID_STATE>())?;

let b: alloc::boxed::Box<dyn Fn(A) -> R + 'd> = alloc::boxed::Box::new(callback);
let b: alloc::boxed::Box<dyn Fn(A) -> R + 'static> = unsafe { core::mem::transmute(b) };
let b: alloc::boxed::Box<dyn Fn(A) -> R + 'static> = alloc::boxed::Box::new(callback);
*unsafe { self.callback.get().as_mut() }.unwrap() = Some(b);

Ok(())
Expand Down
8 changes: 4 additions & 4 deletions src/bt/a2dp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ where

pub fn initialize<F>(&self, events_cb: F) -> Result<(), EspError>
where
F: Fn(A2dpEvent) + Send + 'd,
F: Fn(A2dpEvent) + Send + 'static,
{
self.internal_initialize(move |event| {
events_cb(event);
Expand Down Expand Up @@ -343,7 +343,7 @@ where

pub fn initialize<F>(&self, events_cb: F) -> Result<(), EspError>
where
F: Fn(A2dpEvent) -> usize + Send + 'd,
F: Fn(A2dpEvent) -> usize + Send + 'static,
{
self.internal_initialize(events_cb)
}
Expand Down Expand Up @@ -374,7 +374,7 @@ where

pub fn initialize<F>(&self, events_cb: F) -> Result<(), EspError>
where
F: Fn(A2dpEvent) -> usize + Send + 'd,
F: Fn(A2dpEvent) -> usize + Send + 'static,
{
self.internal_initialize(events_cb)
}
Expand All @@ -388,7 +388,7 @@ where
{
fn internal_initialize<F>(&self, events_cb: F) -> Result<(), EspError>
where
F: Fn(A2dpEvent) -> usize + Send + 'd,
F: Fn(A2dpEvent) -> usize + Send + 'static,
{
CALLBACK.set(events_cb)?;

Expand Down
2 changes: 1 addition & 1 deletion src/bt/avrc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ pub mod controller {

pub fn initialize<F>(&self, events_cb: F) -> Result<(), EspError>
where
F: Fn(AvrccEvent) + Send + 'd,
F: Fn(AvrccEvent) + Send + 'static,
{
CALLBACK.set(events_cb)?;

Expand Down
2 changes: 1 addition & 1 deletion src/bt/ble/gap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ where

pub fn initialize<F>(&self, events_cb: F) -> Result<(), EspError>
where
F: Fn(GapEvent) + Send + 'd,
F: Fn(GapEvent) + Send + 'static,
{
CALLBACK.set(events_cb)?;

Expand Down
2 changes: 1 addition & 1 deletion src/bt/gap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ where

pub fn initialize<F>(&self, events_cb: F) -> Result<(), EspError>
where
F: Fn(GapEvent) + Send + 'd,
F: Fn(GapEvent) + Send + 'static,
{
CALLBACK.set(events_cb)?;

Expand Down
2 changes: 1 addition & 1 deletion src/bt/hfp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ pub mod client {

pub fn initialize<F>(&self, events_cb: F) -> Result<(), EspError>
where
F: Fn(HfpcEvent) -> usize + Send + 'd,
F: Fn(HfpcEvent) -> usize + Send + 'static,
{
CALLBACK.set(events_cb)?;

Expand Down
22 changes: 8 additions & 14 deletions src/espnow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ impl From<u32> for SendStatus {

pub type PeerInfo = esp_now_peer_info_t;

pub struct EspNow<'a>(PhantomData<&'a ()>);
pub struct EspNow(());

impl<'a> EspNow<'a> {
impl EspNow {
pub fn take() -> Result<Self, EspError> {
let mut taken = TAKEN.lock();

Expand All @@ -67,7 +67,7 @@ impl<'a> EspNow<'a> {

*taken = true;

Ok(Self(PhantomData))
Ok(Self(()))
}

pub fn send(&self, peer_addr: [u8; 6], data: &[u8]) -> Result<(), EspError> {
Expand Down Expand Up @@ -130,13 +130,10 @@ impl<'a> EspNow<'a> {

pub fn register_recv_cb<F>(&self, callback: F) -> Result<(), EspError>
where
F: FnMut(&[u8], &[u8]) + Send + 'a,
F: FnMut(&[u8], &[u8]) + Send + 'static,
{
#[allow(clippy::type_complexity)]
let callback: Box<dyn FnMut(&[u8], &[u8]) + Send + 'a> = Box::new(callback);
#[allow(clippy::type_complexity)]
let callback: Box<dyn FnMut(&[u8], &[u8]) + Send + 'static> =
unsafe { core::mem::transmute(callback) };
let callback: Box<dyn FnMut(&[u8], &[u8]) + Send + 'static> = Box::new(callback);

*RECV_CALLBACK.lock() = Some(Box::new(callback));
esp!(unsafe { esp_now_register_recv_cb(Some(Self::recv_callback)) })?;
Expand All @@ -153,13 +150,10 @@ impl<'a> EspNow<'a> {

pub fn register_send_cb<F>(&self, callback: F) -> Result<(), EspError>
where
F: FnMut(&[u8], SendStatus) + Send + 'a,
F: FnMut(&[u8], SendStatus) + Send + 'static,
{
#[allow(clippy::type_complexity)]
let callback: Box<dyn FnMut(&[u8], SendStatus) + Send + 'a> = Box::new(callback);
#[allow(clippy::type_complexity)]
let callback: Box<dyn FnMut(&[u8], SendStatus) + Send + 'static> =
unsafe { core::mem::transmute(callback) };
let callback: Box<dyn FnMut(&[u8], SendStatus) + Send + 'static> = Box::new(callback);

*SEND_CALLBACK.lock() = Some(Box::new(callback));
esp!(unsafe { esp_now_register_send_cb(Some(Self::send_callback)) })?;
Expand Down Expand Up @@ -203,7 +197,7 @@ impl<'a> EspNow<'a> {
}
}

impl<'a> Drop for EspNow<'a> {
impl Drop for EspNow {
fn drop(&mut self) {
let mut taken = TAKEN.lock();

Expand Down
16 changes: 8 additions & 8 deletions src/eth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,13 @@ struct RawHandleImpl(esp_eth_handle_t);

unsafe impl Send for RawHandleImpl {}

type RawCallback<'a> = Box<dyn FnMut(&[u8]) + Send + 'a>;
type RawCallback = Box<dyn FnMut(&[u8]) + Send + 'static>;

struct UnsafeCallback<'a>(*mut RawCallback<'a>);
struct UnsafeCallback(*mut RawCallback);

impl<'a> UnsafeCallback<'a> {
impl UnsafeCallback {
#[allow(clippy::type_complexity)]
fn from(boxed: &mut Box<RawCallback<'a>>) -> Self {
fn from(boxed: &mut Box<RawCallback>) -> Self {
Self(boxed.as_mut())
}

Expand Down Expand Up @@ -178,8 +178,8 @@ pub struct EthDriver<'d, T> {
_flavor: T,
handle: esp_eth_handle_t,
status: Arc<mutex::Mutex<Status>>,
_subscription: EspSubscription<'static, System>,
callback: Option<Box<RawCallback<'d>>>,
_subscription: EspSubscription<System>,
callback: Option<Box<RawCallback>>,
_p: PhantomData<&'d mut ()>,
}

Expand Down Expand Up @@ -665,9 +665,9 @@ impl<'d, T> EthDriver<'d, T> {
Ok(())
}

pub fn set_rx_callback<C>(&mut self, mut callback: C) -> Result<(), EspError>
pub fn set_rx_callback<F>(&mut self, mut callback: F) -> Result<(), EspError>
where
C: for<'a> FnMut(&[u8]) + Send + 'd,
F: FnMut(&[u8]) + Send + 'static,
{
let _ = self.stop();

Expand Down
50 changes: 25 additions & 25 deletions src/eventloop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ pub use asyncify::*;
#[cfg(all(feature = "alloc", esp_idf_comp_esp_timer_enabled))]
pub use async_wait::*;

pub type EspSystemSubscription<'a> = EspSubscription<'a, System>;
pub type EspBackgroundSubscription<'a> = EspSubscription<'a, User<Background>>;
pub type EspExplicitSubscription<'a> = EspSubscription<'a, User<Explicit>>;
pub type EspSystemSubscription = EspSubscription<System>;
pub type EspBackgroundSubscription = EspSubscription<User<Background>>;
pub type EspExplicitSubscription = EspSubscription<User<Explicit>>;

pub type EspSystemEventLoop = EspEventLoop<System>;
pub type EspBackgroundEventLoop = EspEventLoop<User<Background>>;
Expand Down Expand Up @@ -211,11 +211,11 @@ impl EspEventFetchData {
}
}

struct UnsafeCallback<'a>(*mut Box<dyn FnMut(&EspEventFetchData) + Send + 'a>);
struct UnsafeCallback(*mut Box<dyn FnMut(&EspEventFetchData) + Send + 'static>);

impl<'a> UnsafeCallback<'a> {
impl UnsafeCallback {
#[allow(clippy::type_complexity)]
fn from(boxed: &mut Box<Box<dyn FnMut(&EspEventFetchData) + Send + 'a>>) -> Self {
fn from(boxed: &mut Box<Box<dyn FnMut(&EspEventFetchData) + Send + 'stastic>>) -> Self {
Self(boxed.as_mut())
}

Expand All @@ -234,7 +234,7 @@ impl<'a> UnsafeCallback<'a> {
}
}

pub struct EspSubscription<'a, T>
pub struct EspSubscription<T>
where
T: EspEventLoopType,
{
Expand All @@ -243,10 +243,10 @@ where
source: *const ffi::c_char,
event_id: i32,
#[allow(clippy::type_complexity)]
_callback: Box<Box<dyn FnMut(&EspEventFetchData) + Send + 'a>>,
_callback: Box<Box<dyn FnMut(&EspEventFetchData) + Send + 'static>>,
}

impl<'a, T> EspSubscription<'a, T>
impl<T> EspSubscription<T>
where
T: EspEventLoopType,
{
Expand All @@ -268,9 +268,9 @@ where
}
}

unsafe impl<'a, T> Send for EspSubscription<'a, T> where T: EspEventLoopType {}
unsafe impl<T> Send for EspSubscription<T> where T: EspEventLoopType {}

impl<'a, T> Drop for EspSubscription<'a, T>
impl<T> Drop for EspSubscription<T>
where
T: EspEventLoopType,
{
Expand Down Expand Up @@ -301,7 +301,7 @@ where
}
}

impl<'a, T> RawHandle for EspSubscription<'a, User<T>>
impl<T> RawHandle for EspSubscription<User<T>>
where
T: EspEventLoopType,
{
Expand Down Expand Up @@ -395,18 +395,18 @@ where
T: EspEventLoopType,
{
#[allow(clippy::not_unsafe_ptr_arg_deref)]
pub fn subscribe_raw<'a, F>(
pub fn subscribe_raw<F>(
&self,
source: *const ffi::c_char,
event_id: i32,
mut callback: F,
) -> Result<EspSubscription<'a, T>, EspError>
) -> Result<EspSubscription<T>, EspError>
where
F: FnMut(&EspEventFetchData) + Send + 'a,
F: FnMut(&EspEventFetchData) + Send + 'static,
{
let mut handler_instance: esp_event_handler_instance_t = ptr::null_mut();

let callback: Box<dyn FnMut(&EspEventFetchData) + Send + 'a> =
let callback: Box<dyn FnMut(&EspEventFetchData) + Send + 'static> =
Box::new(move |data| callback(data));
let mut callback = Box::new(callback);

Expand Down Expand Up @@ -528,10 +528,10 @@ where
}
}

pub fn subscribe<'a, P, F>(&self, mut callback: F) -> Result<EspSubscription<'a, T>, EspError>
pub fn subscribe<P, F>(&self, mut callback: F) -> Result<EspSubscription<T>, EspError>
where
P: EspTypedEventDeserializer<P>,
F: FnMut(&P) + Send + 'a,
F: FnMut(&P) + Send + 'static,
{
self.subscribe_raw(
P::source(),
Expand Down Expand Up @@ -656,11 +656,11 @@ where
P: EspTypedEventDeserializer<P>,
T: EspEventLoopType,
{
type Subscription<'a> = EspSubscription<'a, T> where Self: 'a;
type Subscription<'a> = EspSubscription<T> where Self: 'a;

fn subscribe<'a, F>(&'a self, callback: F) -> Result<Self::Subscription<'a>, Self::Error>
where
F: FnMut(&P) + Send + 'a,
F: FnMut(&P) + Send + 'static,
{
EspEventLoop::subscribe(self, callback)
}
Expand Down Expand Up @@ -789,11 +789,11 @@ where
M: EspTypedEventDeserializer<P>,
T: EspEventLoopType,
{
type Subscription<'a> = EspSubscription<'a, T> where Self: 'a;
type Subscription<'a> = EspSubscription<T> where Self: 'a;

fn subscribe<'a, F>(&'a self, mut callback: F) -> Result<Self::Subscription<'a>, EspError>
where
F: FnMut(&P) + Send + 'a,
F: FnMut(&P) + Send + 'static,
{
self.untyped_event_loop.subscribe_raw(
M::source(),
Expand All @@ -808,11 +808,11 @@ where
M: EspTypedEventDeserializer<P>,
T: EspEventLoopType,
{
type Subscription<'a> = EspSubscription<'a, T> where Self: 'a;
type Subscription<'a> = EspSubscription<T> where Self: 'a;

fn subscribe<'a, F>(&'a self, mut callback: F) -> Result<Self::Subscription<'a>, EspError>
where
F: FnMut(&P) + Send + 'a,
F: FnMut(&P) + Send + 'static,
{
self.untyped_event_loop.subscribe_raw(
M::source(),
Expand Down Expand Up @@ -924,7 +924,7 @@ where
E: EspTypedEventDeserializer<E> + Debug,
T: EspEventLoopType,
{
pub fn new<F: FnMut(&E) -> bool + Send + 'a>(
pub fn new<F: FnMut(&E) -> bool + Send + 'static>(
event_loop: &EspEventLoop<T>,
mut waiter: F,
) -> Result<Self, EspError> {
Expand Down
Loading

0 comments on commit 1c1ffe3

Please sign in to comment.