Skip to content

Commit

Permalink
Improve event handlers and remove remainder assertion
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex Coats committed Aug 2, 2023
1 parent 36733c0 commit db98c29
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 22 deletions.
2 changes: 1 addition & 1 deletion sdk/src/wallet/core/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ impl<S: SecretManage> WalletInner<S> {
pub async fn listen<F, I: IntoIterator<Item = WalletEventType> + Send>(&self, events: I, handler: F)
where
I::IntoIter: Send,
F: Fn(&Event) + 'static + Clone + Send + Sync,
F: Fn(&Event) + 'static + Send + Sync,
{
let mut emitter = self.event_emitter.write().await;
emitter.on(events, handler);
Expand Down
13 changes: 6 additions & 7 deletions sdk/src/wallet/events/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@

pub mod types;

use alloc::sync::Arc;
use std::{
collections::HashMap,
fmt::{Debug, Formatter, Result},
};

pub use self::types::{Event, WalletEvent, WalletEventType};

type Handler<T> = Box<dyn Fn(&T) + Send + Sync + 'static>;
type Handler<T> = Arc<dyn Fn(&T) + Send + Sync + 'static>;

pub struct EventEmitter {
handlers: HashMap<WalletEventType, Vec<Handler<Event>>>,
Expand All @@ -28,9 +29,10 @@ impl EventEmitter {
/// multiple listeners for a single event.
pub fn on<F>(&mut self, events: impl IntoIterator<Item = WalletEventType>, handler: F)
where
F: Fn(&Event) + 'static + Clone + Send + Sync,
F: Fn(&Event) + 'static + Send + Sync,
{
let mut events = events.into_iter().peekable();
let handler = Arc::new(handler);
// if no event is provided the handler is registered for all event types
if events.peek().is_none() {
// we could use a crate like strum or a macro to iterate over all values, but not sure if it's worth it
Expand All @@ -43,14 +45,11 @@ impl EventEmitter {
#[cfg(feature = "ledger_nano")]
WalletEventType::LedgerAddressGeneration,
] {
self.handlers
.entry(event_type)
.or_default()
.push(Box::new(handler.clone()));
self.handlers.entry(event_type).or_default().push(handler.clone());
}
}
for event in events {
self.handlers.entry(event).or_default().push(Box::new(handler.clone()));
self.handlers.entry(event).or_default().push(handler.clone());
}
}

Expand Down
23 changes: 10 additions & 13 deletions sdk/tests/wallet/address_generation.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
// Copyright 2023 IOTA Stiftung
// SPDX-License-Identifier: Apache-2.0

#[cfg(feature = "ledger_nano")]
use std::sync::{Arc, Mutex};

#[cfg(feature = "stronghold")]
use crypto::keys::bip39::Mnemonic;
#[cfg(feature = "stronghold")]
Expand Down Expand Up @@ -94,7 +91,7 @@ async fn wallet_address_generation_stronghold() -> Result<()> {
}

#[tokio::test]
#[cfg(feature = "ledger_nano")]
#[cfg(all(feature = "ledger_nano", feature = "events"))]
#[ignore = "requires ledger nano instance"]
async fn wallet_address_generation_ledger() -> Result<()> {
let storage_path = "test-storage/wallet_address_generation_ledger";
Expand Down Expand Up @@ -127,16 +124,16 @@ async fn wallet_address_generation_ledger() -> Result<()> {
"smr1qqdnv60ryxynaeyu8paq3lp9rkll7d7d92vpumz88fdj4l0pn5mruy3qdpm"
);

let address_event = Arc::new(Mutex::new(None));
let address_event_clone = address_event.clone();
let (sender, mut receiver) = tokio::sync::mpsc::channel(1);

#[cfg(feature = "events")]
wallet
.listen([WalletEventType::LedgerAddressGeneration], move |event| {
if let WalletEvent::LedgerAddressGeneration(address) = &event.event {
*address_event_clone.lock().unwrap() = Some(address.address);
sender
.try_send(address.address)
.expect("too many LedgerAddressGeneration events");
} else {
panic!("expected LedgerAddressGeneration")
panic!("expected LedgerAddressGeneration event")
}
})
.await;
Expand All @@ -162,10 +159,10 @@ async fn wallet_address_generation_ledger() -> Result<()> {
);

assert_eq!(
address_event
.lock()
.unwrap()
.unwrap()
receiver
.recv()
.await
.expect("never received event")
.inner()
.to_bech32_unchecked("smr"),
// Address generated with bip32 path: [44, 4218, 0, 0, 0].
Expand Down
1 change: 0 additions & 1 deletion sdk/tests/wallet/transactions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,6 @@ async fn prepare_transaction_ledger() -> Result<()> {
assert_eq!(sign.output, input.output);
assert_eq!(sign.output_metadata, input.metadata);
}
assert!(data.remainder.is_none());

tear_down(storage_path)
}

0 comments on commit db98c29

Please sign in to comment.