Skip to content

Commit

Permalink
Add some tests and comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jsdt committed Dec 31, 2024
1 parent dc1e02f commit 62a30d9
Showing 1 changed file with 120 additions and 0 deletions.
120 changes: 120 additions & 0 deletions crates/core/src/subscription/module_subscription_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ impl ClientInfo {
#[derive(Debug)]
struct QueryState {
query: Query,
// For legacy clients that subscribe to a set of queries, we track them here.
legacy_subscribers: HashSet<ClientId>,
// For clients that subscribe to a single query, we track them here.
subscriptions: HashSet<SubscriptionId>,
}

Expand Down Expand Up @@ -146,6 +148,8 @@ impl SubscriptionManager {
}
}

/// Remove a single subscription for a client.
/// This will return an error if the client does not have a subscription with the given query id.
pub fn remove_subscription(&mut self, client_id: ClientId, query_id: ClientQueryId) -> Result<Query, DBError> {
let subscription_id = (client_id, query_id);
let Some(ci) = self.clients.get_mut(&client_id) else {
Expand Down Expand Up @@ -613,6 +617,122 @@ mod tests {
Ok(())
}

#[test]
fn test_unsubscribe_doesnt_remove_other_clients() -> ResultTest<()> {
let db = TestDB::durable()?;

let table_id = create_table(&db, "T")?;
let sql = "select * from T";
let plan = compile_plan(&db, sql)?;
let hash = plan.hash();

let clients = (0..3).map(|i| Arc::new(client(i))).collect::<Vec<_>>();

// All of the clients are using the same query id.
let query_id: ClientQueryId = QueryId::new(1);
let mut subscriptions = SubscriptionManager::default();
subscriptions.add_subscription(clients[0].clone(), plan.clone(), query_id)?;
subscriptions.add_subscription(clients[1].clone(), plan.clone(), query_id)?;
subscriptions.add_subscription(clients[2].clone(), plan.clone(), query_id)?;

assert!(subscriptions.query_reads_from_table(&hash, &table_id));

let client_ids = clients
.iter()
.map(|client| (client.id.identity, client.id.address))
.collect::<Vec<_>>();
subscriptions.remove_subscription(client_ids[0], query_id)?;
// There are still two left.
assert!(subscriptions.query_reads_from_table(&hash, &table_id));
subscriptions.remove_subscription(client_ids[1], query_id)?;
// There is still one left.
assert!(subscriptions.query_reads_from_table(&hash, &table_id));
subscriptions.remove_subscription(client_ids[2], query_id)?;
// Now there are no subscribers.
assert!(!subscriptions.query_reads_from_table(&hash, &table_id));

Ok(())
}

#[test]
fn test_unsubscribe_all_doesnt_remove_other_clients() -> ResultTest<()> {
let db = TestDB::durable()?;

let table_id = create_table(&db, "T")?;
let sql = "select * from T";
let plan = compile_plan(&db, sql)?;
let hash = plan.hash();

let clients = (0..3).map(|i| Arc::new(client(i))).collect::<Vec<_>>();

// All of the clients are using the same query id.
let query_id: ClientQueryId = QueryId::new(1);
let mut subscriptions = SubscriptionManager::default();
subscriptions.add_subscription(clients[0].clone(), plan.clone(), query_id)?;
subscriptions.add_subscription(clients[1].clone(), plan.clone(), query_id)?;
subscriptions.add_subscription(clients[2].clone(), plan.clone(), query_id)?;

assert!(subscriptions.query_reads_from_table(&hash, &table_id));

let client_ids = clients
.iter()
.map(|client| (client.id.identity, client.id.address))
.collect::<Vec<_>>();
subscriptions.remove_all_subscriptions(&client_ids[0]);
// There are still two left.
assert!(subscriptions.query_reads_from_table(&hash, &table_id));
subscriptions.remove_all_subscriptions(&client_ids[1]);
// There is still one left.
assert!(subscriptions.query_reads_from_table(&hash, &table_id));
subscriptions.remove_all_subscriptions(&client_ids[2]);
// Now there are no subscribers.
assert!(!subscriptions.query_reads_from_table(&hash, &table_id));

Ok(())
}

// This test has a single client with 3 queries of different tables, and tests removing them.
#[test]
fn test_multiple_queries() -> ResultTest<()> {
let db = TestDB::durable()?;

let table_names = ["T", "S", "U"];
let table_ids = table_names
.iter()
.map(|name| create_table(&db, name))
.collect::<ResultTest<Vec<_>>>()?;
let queries = table_names
.iter()
.map(|name| format!("select * from {}", name))
.map(|sql| compile_plan(&db, &sql))
.collect::<ResultTest<Vec<_>>>()?;

let client = Arc::new(client(0));
let mut subscriptions = SubscriptionManager::default();
subscriptions.add_subscription(client.clone(), queries[0].clone(), QueryId::new(1))?;
subscriptions.add_subscription(client.clone(), queries[1].clone(), QueryId::new(2))?;
subscriptions.add_subscription(client.clone(), queries[2].clone(), QueryId::new(3))?;
for i in 0..3 {
assert!(subscriptions.query_reads_from_table(&queries[i].hash(), &table_ids[i]));
}

let client_id = (client.id.identity, client.id.address);
subscriptions.remove_subscription(client_id, QueryId::new(1))?;
assert!(!subscriptions.query_reads_from_table(&queries[0].hash(), &table_ids[0]));
// Assert that the rest are there.
for i in 1..3 {
assert!(subscriptions.query_reads_from_table(&queries[i].hash(), &table_ids[i]));
}

// Now remove the final two at once.
subscriptions.remove_all_subscriptions(&client_id);
for i in 0..3 {
assert!(!subscriptions.query_reads_from_table(&queries[i].hash(), &table_ids[i]));
}

Ok(())
}

#[test]
fn test_subscribe_fails_with_duplicate_request_id() -> ResultTest<()> {
let db = TestDB::durable()?;
Expand Down

0 comments on commit 62a30d9

Please sign in to comment.