diff --git a/Cargo.toml b/Cargo.toml index 52244e8..ae5472b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,9 @@ log = { version = "0.4", features = ["std"] } once_cell = "1.9.0" rand = "0.8" +[dev-dependencies] +futures-executor = "0.3" + [features] failpoints = [] diff --git a/src/lib.rs b/src/lib.rs index f23cc44..c72184c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -843,6 +843,27 @@ macro_rules! fail_point { }}; } +/// Like [fail_point], but accept a future. +#[macro_export] +#[cfg(feature = "failpoints")] +macro_rules! async_fail_point { + ($name:expr) => {{ + $crate::eval($name, |_| { + panic!("Return is not supported for the fail point \"{}\"", $name); + }); + }}; + ($name:expr, $e:expr) => {{ + if let Some(res) = $crate::eval($name, $e) { + return res.await; + } + }}; + ($name:expr, $cond:expr, $e:expr) => {{ + if $cond { + async_fail_point!($name, $e); + } + }}; +} + /// Define a fail point (disabled, see `failpoints` feature). #[macro_export] #[cfg(not(feature = "failpoints"))] @@ -852,6 +873,15 @@ macro_rules! fail_point { ($name:expr, $cond:expr, $e:expr) => {{}}; } +/// Define an async fail point (disabled, see `failpoints` feature). +#[macro_export] +#[cfg(not(feature = "failpoints"))] +macro_rules! async_fail_point { + ($name:expr, $e:expr) => {{}}; + ($name:expr) => {{}}; + ($name:expr, $cond:expr, $e:expr) => {{}}; +} + #[cfg(test)] mod tests { use super::*; diff --git a/tests/tests.rs b/tests/tests.rs index 778ab8d..676236e 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -5,7 +5,7 @@ use std::sync::*; use std::time::*; use std::*; -use fail::fail_point; +use fail::{async_fail_point, fail_point}; #[test] fn test_off() { @@ -36,6 +36,46 @@ fn test_return() { assert_eq!(f(), 2); } +#[cfg_attr(not(feature = "failpoints"), ignore)] +#[test] +fn test_async_return() { + async fn async_fn() -> usize { + async_fail_point!("async_return", move |s: Option| async { + (async {}).await; + s.map_or(2, |s| s.parse().unwrap()) + }); + 0 + } + + futures_executor::block_on(async move { + fail::cfg("async_return", "return(1000)").unwrap(); + assert_eq!(async_fn().await, 1000); + + fail::cfg("async_return", "return").unwrap(); + assert_eq!(async_fn().await, 2); + }) +} + +#[cfg_attr(not(feature = "failpoints"), ignore)] +#[test] +fn test_async_move_return() { + async fn async_fn() -> usize { + async_fail_point!("async_return", |s: Option| async move { + (async {}).await; + s.map_or(2, |s| s.parse().unwrap()) + }); + 0 + } + + futures_executor::block_on(async move { + fail::cfg("async_return", "return(1000)").unwrap(); + assert_eq!(async_fn().await, 1000); + + fail::cfg("async_return", "return").unwrap(); + assert_eq!(async_fn().await, 2); + }) +} + #[test] #[cfg_attr(not(feature = "failpoints"), ignore)] fn test_sleep() {