From db98c29cc68cb2fab63d9a9011624235547815be Mon Sep 17 00:00:00 2001 From: Alex Coats Date: Wed, 2 Aug 2023 13:24:07 -0400 Subject: [PATCH] Improve event handlers and remove remainder assertion --- sdk/src/wallet/core/mod.rs | 2 +- sdk/src/wallet/events/mod.rs | 13 ++++++------- sdk/tests/wallet/address_generation.rs | 23 ++++++++++------------- sdk/tests/wallet/transactions.rs | 1 - 4 files changed, 17 insertions(+), 22 deletions(-) diff --git a/sdk/src/wallet/core/mod.rs b/sdk/src/wallet/core/mod.rs index f007ac896a..e48de98f89 100644 --- a/sdk/src/wallet/core/mod.rs +++ b/sdk/src/wallet/core/mod.rs @@ -178,7 +178,7 @@ impl WalletInner { pub async fn listen + Send>(&self, events: I, handler: F) where I::IntoIter: Send, - F: Fn(&Event) + 'static + Clone + Send + Sync, + F: Fn(&Event) + 'static + Send + Sync, { let mut emitter = self.event_emitter.write().await; emitter.on(events, handler); diff --git a/sdk/src/wallet/events/mod.rs b/sdk/src/wallet/events/mod.rs index 010804ed9f..140159d0a6 100644 --- a/sdk/src/wallet/events/mod.rs +++ b/sdk/src/wallet/events/mod.rs @@ -3,6 +3,7 @@ pub mod types; +use alloc::sync::Arc; use std::{ collections::HashMap, fmt::{Debug, Formatter, Result}, @@ -10,7 +11,7 @@ use std::{ pub use self::types::{Event, WalletEvent, WalletEventType}; -type Handler = Box; +type Handler = Arc; pub struct EventEmitter { handlers: HashMap>>, @@ -28,9 +29,10 @@ impl EventEmitter { /// multiple listeners for a single event. pub fn on(&mut self, events: impl IntoIterator, handler: F) where - F: Fn(&Event) + 'static + Clone + Send + Sync, + F: Fn(&Event) + 'static + Send + Sync, { let mut events = events.into_iter().peekable(); + let handler = Arc::new(handler); // if no event is provided the handler is registered for all event types if events.peek().is_none() { // we could use a crate like strum or a macro to iterate over all values, but not sure if it's worth it @@ -43,14 +45,11 @@ impl EventEmitter { #[cfg(feature = "ledger_nano")] WalletEventType::LedgerAddressGeneration, ] { - self.handlers - .entry(event_type) - .or_default() - .push(Box::new(handler.clone())); + self.handlers.entry(event_type).or_default().push(handler.clone()); } } for event in events { - self.handlers.entry(event).or_default().push(Box::new(handler.clone())); + self.handlers.entry(event).or_default().push(handler.clone()); } } diff --git a/sdk/tests/wallet/address_generation.rs b/sdk/tests/wallet/address_generation.rs index 3309a63a60..7677855da0 100644 --- a/sdk/tests/wallet/address_generation.rs +++ b/sdk/tests/wallet/address_generation.rs @@ -1,9 +1,6 @@ // Copyright 2023 IOTA Stiftung // SPDX-License-Identifier: Apache-2.0 -#[cfg(feature = "ledger_nano")] -use std::sync::{Arc, Mutex}; - #[cfg(feature = "stronghold")] use crypto::keys::bip39::Mnemonic; #[cfg(feature = "stronghold")] @@ -94,7 +91,7 @@ async fn wallet_address_generation_stronghold() -> Result<()> { } #[tokio::test] -#[cfg(feature = "ledger_nano")] +#[cfg(all(feature = "ledger_nano", feature = "events"))] #[ignore = "requires ledger nano instance"] async fn wallet_address_generation_ledger() -> Result<()> { let storage_path = "test-storage/wallet_address_generation_ledger"; @@ -127,16 +124,16 @@ async fn wallet_address_generation_ledger() -> Result<()> { "smr1qqdnv60ryxynaeyu8paq3lp9rkll7d7d92vpumz88fdj4l0pn5mruy3qdpm" ); - let address_event = Arc::new(Mutex::new(None)); - let address_event_clone = address_event.clone(); + let (sender, mut receiver) = tokio::sync::mpsc::channel(1); - #[cfg(feature = "events")] wallet .listen([WalletEventType::LedgerAddressGeneration], move |event| { if let WalletEvent::LedgerAddressGeneration(address) = &event.event { - *address_event_clone.lock().unwrap() = Some(address.address); + sender + .try_send(address.address) + .expect("too many LedgerAddressGeneration events"); } else { - panic!("expected LedgerAddressGeneration") + panic!("expected LedgerAddressGeneration event") } }) .await; @@ -162,10 +159,10 @@ async fn wallet_address_generation_ledger() -> Result<()> { ); assert_eq!( - address_event - .lock() - .unwrap() - .unwrap() + receiver + .recv() + .await + .expect("never received event") .inner() .to_bech32_unchecked("smr"), // Address generated with bip32 path: [44, 4218, 0, 0, 0]. diff --git a/sdk/tests/wallet/transactions.rs b/sdk/tests/wallet/transactions.rs index c76a349376..14952185fa 100644 --- a/sdk/tests/wallet/transactions.rs +++ b/sdk/tests/wallet/transactions.rs @@ -301,7 +301,6 @@ async fn prepare_transaction_ledger() -> Result<()> { assert_eq!(sign.output, input.output); assert_eq!(sign.output_metadata, input.metadata); } - assert!(data.remainder.is_none()); tear_down(storage_path) }