Skip to content

Commit

Permalink
Move logic to its trait.
Browse files Browse the repository at this point in the history
I moved the logic to send the initial values to the subscriptions onto a
generic trait implemented in nut17. The main goal is to have the same behavior
regardless of whether the subscriptions come from web sockets or internally
from other parts of the systems or other crates.
  • Loading branch information
crodas committed Nov 9, 2024
1 parent 5a199bb commit 50d7afc
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 91 deletions.
79 changes: 1 addition & 78 deletions crates/cdk-axum/src/ws/subscribe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@ use super::{
WsContext, WsError, JSON_RPC_VERSION,
};
use cdk::{
nuts::{
nut17::{Kind, NotificationPayload, Params},
MeltQuoteBolt11Response, MintQuoteBolt11Response, ProofState, PublicKey,
},
nuts::nut17::{NotificationPayload, Params},
pub_sub::SubId,
};

Expand Down Expand Up @@ -67,80 +64,6 @@ impl WsHandle for Method {
.subscribe(self.0.clone())
.await;
let publisher = context.publisher.clone();

let current_notification_to_send: Vec<NotificationPayload> = match self.0.kind {
Kind::Bolt11MeltQuote => {
let queries = self
.0
.filters
.iter()
.map(|id| context.state.mint.localstore.get_melt_quote(id))
.collect::<Vec<_>>();

futures::future::try_join_all(queries)
.await
.map(|quotes| {
quotes
.into_iter()
.filter_map(|quote| quote.map(|x| x.into()))
.map(|x: MeltQuoteBolt11Response| x.into())
.collect::<Vec<_>>()
})
.unwrap_or_default()
}
Kind::Bolt11MintQuote => {
let queries = self
.0
.filters
.iter()
.map(|id| context.state.mint.localstore.get_mint_quote(id))
.collect::<Vec<_>>();

futures::future::try_join_all(queries)
.await
.map(|quotes| {
quotes
.into_iter()
.filter_map(|quote| quote.map(|x| x.into()))
.map(|x: MintQuoteBolt11Response| x.into())
.collect::<Vec<_>>()
})
.unwrap_or_default()
}
Kind::ProofState => {
if let Ok(public_keys) = self
.0
.filters
.iter()
.map(PublicKey::from_hex)
.collect::<Result<Vec<PublicKey>, _>>()
{
context
.state
.mint
.localstore
.get_proofs_states(&public_keys)
.await
.map(|x| {
x.into_iter()
.enumerate()
.filter_map(|(idx, state)| {
state.map(|state| (public_keys[idx], state).into())
})
.map(|x: ProofState| x.into())
.collect::<Vec<_>>()
})
.unwrap_or_default()
} else {
vec![]
}
}
};

for notification in current_notification_to_send.into_iter() {
let _ = publisher.send((sub_id.clone(), notification)).await;
}

context.subscriptions.insert(
sub_id.clone(),
tokio::spawn(async move {
Expand Down
2 changes: 1 addition & 1 deletion crates/cdk-integration-tests/tests/regtest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ async fn get_notification<T: StreamExt<Item = Result<Message, E>> + Unpin, E: De
.unwrap();

let mut response: serde_json::Value =
serde_json::from_str(&msg.to_text().unwrap()).expect("valid json");
serde_json::from_str(msg.to_text().unwrap()).expect("valid json");

let mut params_raw = response
.as_object_mut()
Expand Down
2 changes: 1 addition & 1 deletion crates/cdk/src/mint/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ impl Mint {
Ok(Self {
mint_url: MintUrl::from_str(mint_url)?,
keysets: Arc::new(RwLock::new(active_keysets)),
pubsub_manager: Default::default(),
pubsub_manager: Arc::new(localstore.clone().into()),
secp_ctx,
quote_ttl,
xpriv,
Expand Down
122 changes: 116 additions & 6 deletions crates/cdk/src/nuts/nut17.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
//! Specific Subscription for the cdk crate

use crate::{
cdk_database::{self, MintDatabase},
nuts::{
MeltQuoteBolt11Response, MeltQuoteState, MintQuoteBolt11Response, MintQuoteState,
ProofState,
},
pub_sub::{self, Index, Indexable, SubscriptionGlobalId},
pub_sub::{self, Index, Indexable, OnNewSubscription, SubscriptionGlobalId},
};
use serde::{Deserialize, Serialize};
use std::ops::Deref;
use std::{collections::HashMap, ops::Deref, sync::Arc};

/// Subscription Parameter according to the standard
#[derive(Debug, Clone, Serialize, Deserialize)]
Expand Down Expand Up @@ -59,7 +60,7 @@ impl Default for SupportedMethods {

pub use crate::pub_sub::SubId;

use super::{BlindSignature, CurrencyUnit, PaymentMethod};
use super::{BlindSignature, CurrencyUnit, PaymentMethod, PublicKey};

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(untagged)]
Expand Down Expand Up @@ -144,16 +145,125 @@ impl From<Params> for Vec<Index<(String, Kind)>> {
}
}

/// Manager
#[derive(Default)]
/// Subscription Init
///
/// This struct triggers code when a new subscription is created.
///
/// It is used to send the initial state of the subscription to the client.
pub struct SubscriptionInit(Option<Arc<dyn MintDatabase<Err = cdk_database::Error> + Send + Sync>>);

#[async_trait::async_trait]
impl OnNewSubscription for SubscriptionInit {
type Event = NotificationPayload;
type Index = (String, Kind);

async fn on_new_subscription(
&self,
request: &[&Self::Index],
) -> Result<Vec<Self::Event>, String> {
let datastore = if let Some(localstore) = self.0.as_ref() {
localstore
} else {
return Ok(vec![]);
};

let mut to_return = vec![];

for (kind, values) in request.iter().fold(
HashMap::new(),
|mut acc: HashMap<&Kind, Vec<&String>>, (data, kind)| {
acc.entry(kind).or_default().push(data);
acc
},
) {
match kind {
Kind::Bolt11MeltQuote => {
let queries = values
.iter()
.map(|id| datastore.get_melt_quote(id))
.collect::<Vec<_>>();

to_return.extend(
futures::future::try_join_all(queries)
.await
.map(|quotes| {
quotes
.into_iter()
.filter_map(|quote| quote.map(|x| x.into()))
.map(|x: MeltQuoteBolt11Response| x.into())
.collect::<Vec<_>>()
})
.map_err(|e| e.to_string())?,
);
}
Kind::Bolt11MintQuote => {
let queries = values
.iter()
.map(|id| datastore.get_mint_quote(id))
.collect::<Vec<_>>();

to_return.extend(
futures::future::try_join_all(queries)
.await
.map(|quotes| {
quotes
.into_iter()
.filter_map(|quote| quote.map(|x| x.into()))
.map(|x: MintQuoteBolt11Response| x.into())
.collect::<Vec<_>>()
})
.map_err(|e| e.to_string())?,
);
}
Kind::ProofState => {
let public_keys = values
.iter()
.map(PublicKey::from_hex)
.collect::<Result<Vec<PublicKey>, _>>()
.map_err(|e| e.to_string())?;

to_return.extend(
datastore
.get_proofs_states(&public_keys)
.await
.map_err(|e| e.to_string())?
.into_iter()
.enumerate()
.filter_map(|(idx, state)| {
state.map(|state| (public_keys[idx], state).into())
})
.map(|state: ProofState| state.into()),
);
}
}
}

Ok(to_return)
}
}

/// Manager
/// Publish–subscribe manager
///
/// Nut-17 implementation is system-wide and not only through the WebSocket, so
/// it is possible for another part of the system to subscribe to events.
pub struct PubSubManager(pub_sub::Manager<NotificationPayload, (String, Kind)>);
pub struct PubSubManager(pub_sub::Manager<NotificationPayload, (String, Kind), SubscriptionInit>);

impl Default for PubSubManager {
fn default() -> Self {
PubSubManager(SubscriptionInit::default().into())
}
}

impl From<Arc<dyn MintDatabase<Err = cdk_database::Error> + Send + Sync>> for PubSubManager {
fn from(val: Arc<dyn MintDatabase<Err = cdk_database::Error> + Send + Sync>) -> Self {
PubSubManager(SubscriptionInit(Some(val)).into())
}
}

impl Deref for PubSubManager {
type Target = pub_sub::Manager<NotificationPayload, (String, Kind)>;
type Target = pub_sub::Manager<NotificationPayload, (String, Kind), SubscriptionInit>;

fn deref(&self) -> &Self::Target {
&self.0
Expand Down
69 changes: 64 additions & 5 deletions crates/cdk/src/pub_sub/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,25 @@ pub const DEFAULT_REMOVE_SIZE: usize = 10_000;
/// Default channel size for subscription buffering
pub const DEFAULT_CHANNEL_SIZE: usize = 10;

#[async_trait::async_trait]
/// On New Subscription trait
///
/// This trait is optional and it is used to notify the application when a new
/// subscription is created. This is useful when the application needs to send
/// the initial state to the subscriber upon subscription
pub trait OnNewSubscription {
/// Index type
type Index;
/// Subscription event type
type Event;

/// Called when a new subscription is created
async fn on_new_subscription(
&self,
request: &[&Self::Index],
) -> Result<Vec<Self::Event>, String>;
}

/// Subscription manager
///
/// This object keep track of all subscription listener and it is also
Expand All @@ -45,21 +64,24 @@ pub const DEFAULT_CHANNEL_SIZE: usize = 10;
/// The content of the notification is not relevant to this scope and it is up
/// to the application, therefore the generic T is used instead of a specific
/// type
pub struct Manager<T, I>
pub struct Manager<T, I, F>
where
T: Indexable<Type = I> + Clone + Send + Sync + 'static,
I: PartialOrd + Clone + Debug + Ord + Send + Sync + 'static,
F: OnNewSubscription<Index = I, Event = T> + 'static,
{
indexes: IndexTree<T, I>,
on_new_subscription: Option<F>,
unsubscription_sender: mpsc::Sender<(SubId, Vec<Index<I>>)>,
active_subscriptions: Arc<AtomicUsize>,
background_subscription_remover: Option<JoinHandle<()>>,
}

impl<T, I> Default for Manager<T, I>
impl<T, I, F> Default for Manager<T, I, F>
where
T: Indexable<Type = I> + Clone + Send + Sync + 'static,
I: PartialOrd + Clone + Debug + Ord + Send + Sync + 'static,
F: OnNewSubscription<Index = I, Event = T> + 'static,
{
fn default() -> Self {
let (sender, receiver) = mpsc::channel(DEFAULT_REMOVE_SIZE);
Expand All @@ -72,17 +94,32 @@ where
storage.clone(),
active_subscriptions.clone(),
))),
on_new_subscription: None,
unsubscription_sender: sender,
active_subscriptions,
indexes: storage,
}
}
}

impl<T, I> Manager<T, I>
impl<T, I, F> From<F> for Manager<T, I, F>
where
T: Indexable<Type = I> + Clone + Send + Sync + 'static,
I: Clone + Debug + PartialOrd + Ord + Send + Sync + 'static,
I: PartialOrd + Clone + Debug + Ord + Send + Sync + 'static,
F: OnNewSubscription<Index = I, Event = T> + 'static,
{
fn from(value: F) -> Self {
let mut manager: Self = Default::default();
manager.on_new_subscription = Some(value);
manager
}
}

impl<T, I, F> Manager<T, I, F>
where
T: Indexable<Type = I> + Clone + Send + Sync + 'static,
I: PartialOrd + Clone + Debug + Ord + Send + Sync + 'static,
F: OnNewSubscription<Index = I, Event = T> + 'static,
{
#[inline]
/// Broadcast an event to all listeners
Expand Down Expand Up @@ -132,8 +169,29 @@ where
) -> ActiveSubscription<T, I> {
let (sender, receiver) = mpsc::channel(10);
let sub_id: SubId = params.as_ref().clone();

let indexes: Vec<Index<I>> = params.into();

if let Some(on_new_subscription) = self.on_new_subscription.as_ref() {
match on_new_subscription
.on_new_subscription(&indexes.iter().map(|x| x.deref()).collect::<Vec<_>>())
.await
{
Ok(events) => {
for event in events {
let _ = sender.try_send((sub_id.clone(), event));
}
}
Err(err) => {
tracing::info!(
"Failed to get initial state for subscription: {:?}, {}",
sub_id,
err
);
}
}
}

let mut index_storage = self.indexes.write().await;
for index in indexes.clone() {
index_storage.insert(index, sender.clone());
Expand Down Expand Up @@ -180,10 +238,11 @@ where
}

/// Manager goes out of scope, stop all background tasks
impl<T, I> Drop for Manager<T, I>
impl<T, I, F> Drop for Manager<T, I, F>
where
T: Indexable<Type = I> + Clone + Send + Sync + 'static,
I: Clone + Debug + PartialOrd + Ord + Send + Sync + 'static,
F: OnNewSubscription<Index = I, Event = T> + 'static,
{
fn drop(&mut self) {
if let Some(handler) = self.background_subscription_remover.take() {
Expand Down

0 comments on commit 50d7afc

Please sign in to comment.