From 9c8110934e12ce341205f3c5bedd64cc30b3282f Mon Sep 17 00:00:00 2001 From: Eliza Weisman Date: Thu, 5 Sep 2024 16:46:01 -0700 Subject: [PATCH] [db-queries] Allow join expressions in `paginated-multicolumn` (#6530) Currently, the `paginated_multicolumn` utility in `nexus_db_queries::pagination` only works when the select expression to paginate is a table, and both columns to order by come from that table. This means that it cannot easily be used to fix the bug in `instance_and_vmm_list_by_sled_agent` that @davepacheco describes in [this comment][1], which would require using `paginated_multicolumn` to paginate on two columns in an inner join expression. This commit changes the giant wad of Diesel type ceremony on `paginated_multicolumn` in order to ~~make it even worse~~ allow expressions which are not tables to be paginated. I've added a test demonstrating that this does, in fact, work. Figuring out how to do this was...certainly an experience which I have had. I think I need to lie down now. [1]: https://github.com/oxidecomputer/omicron/pull/6519#pullrequestreview-2283593739 --- nexus/db-queries/src/db/pagination.rs | 266 ++++++++++++++++++++++---- 1 file changed, 228 insertions(+), 38 deletions(-) diff --git a/nexus/db-queries/src/db/pagination.rs b/nexus/db-queries/src/db/pagination.rs index 9920440ade..4872e18136 100644 --- a/nexus/db-queries/src/db/pagination.rs +++ b/nexus/db-queries/src/db/pagination.rs @@ -11,6 +11,7 @@ use diesel::helper_types::*; use diesel::pg::Pg; use diesel::query_builder::AsQuery; use diesel::query_dsl::methods as query_methods; +use diesel::query_source::QuerySource; use diesel::sql_types::{Bool, SqlType}; use diesel::AppearsOnTable; use diesel::Column; @@ -70,7 +71,7 @@ where } } -/// Uses `pagparams` to list a subset of rows in `table`, ordered by `c1, and +/// Uses `pagparams` to list a subset of rows in `query`, ordered by `c1, and /// then by `c2. /// /// This is a two-column variation of the [`paginated`] function. @@ -79,40 +80,56 @@ where // columns" implement a subset of ExpressionMethods) or making a macro to generate // all the necessary bounds we need. pub fn paginated_multicolumn( - table: T, + query: T, (c1, c2): (C1, C2), pagparams: &DataPageParams<'_, (M1, M2)>, -) -> BoxedQuery +) -> >::Output where - // T is a table which can create a BoxedQuery. - T: diesel::Table, - T: query_methods::BoxedDsl< - 'static, - Pg, - Output = diesel::internal::table_macro::BoxedSelectStatement< - 'static, - TableSqlType, - diesel::internal::table_macro::FromClause, - Pg, - >, - >, + // T is a table^H^H^H^H^Hquery source which can create a BoxedQuery. + T: QuerySource, + T: AsQuery, + ::DefaultSelection: + Expression::SqlType>, + T::Query: query_methods::BoxedDsl<'static, Pg>, + // Required for...everything. + >::Output: QueryDsl, // C1 & C2 are columns which appear in T. - C1: 'static + Column + Copy + ExpressionMethods + AppearsOnTable, - C2: 'static + Column + Copy + ExpressionMethods + AppearsOnTable, + C1: 'static + Column + Copy + ExpressionMethods, + C2: 'static + Column + Copy + ExpressionMethods, // Required to compare the columns with the marker types. C1::SqlType: SqlType, C2::SqlType: SqlType, M1: Clone + AsExpression, M2: Clone + AsExpression, + // Necessary for `query.limit(...)` + >::Output: + query_methods::LimitDsl< + Output = >::Output, + >, // Necessary for "query.order(c1.desc())" - BoxedQuery: query_methods::OrderDsl, Output = BoxedQuery>, + >::Output: + query_methods::OrderDsl< + Desc, + Output = >::Output, + >, // Necessary for "query.order(...).then_order_by(c2.desc())" - BoxedQuery: - query_methods::ThenOrderDsl, Output = BoxedQuery>, + >::Output: + query_methods::ThenOrderDsl< + Desc, + Output = >::Output, + >, // Necessary for "query.order(c1.asc())" - BoxedQuery: query_methods::OrderDsl, Output = BoxedQuery>, + >::Output: + query_methods::OrderDsl< + Asc, + Output = >::Output, + >, // Necessary for "query.order(...).then_order_by(c2.asc())" - BoxedQuery: query_methods::ThenOrderDsl, Output = BoxedQuery>, + >::Output: + query_methods::ThenOrderDsl< + Asc, + Output = >::Output, + >, // We'd like to be able to call: // @@ -126,10 +143,11 @@ where // The RHS (c2.gt(v2)) must be a boolean expression: Gt: Expression, // Putting it together, we should be able to filter by LHS.and(RHS): - BoxedQuery: query_methods::FilterDsl< - And, Gt>, - Output = BoxedQuery, - >, + >::Output: + query_methods::FilterDsl< + And, Gt>, + Output = >::Output, + >, // We'd also like to be able to call: // @@ -138,19 +156,30 @@ where // We've already defined the bound on the LHS, so we add the equivalent // bounds on the RHS for the "Less than" variant. Lt: Expression, - BoxedQuery: query_methods::FilterDsl< - And, Lt>, - Output = BoxedQuery, - >, + >::Output: + query_methods::FilterDsl< + And, Lt>, + Output = >::Output, + >, // Necessary for "query.or_filter(c1.gt(v1))" - BoxedQuery: - query_methods::OrFilterDsl, Output = BoxedQuery>, + >::Output: + query_methods::OrFilterDsl< + Gt, + Output = >::Output, + >, // Necessary for "query.or_filter(c1.lt(v1))" - BoxedQuery: - query_methods::OrFilterDsl, Output = BoxedQuery>, + >::Output: + query_methods::OrFilterDsl< + Lt, + Output = >::Output, + >, { - let mut query = table.into_boxed().limit(pagparams.limit.get().into()); + use query_methods::BoxedDsl; + let mut query = query + .as_query() + .internal_into_boxed() + .limit(pagparams.limit.get().into()); let marker = pagparams.marker.map(|m| m.clone()); match pagparams.direction { dropshot::PaginationOrder::Ascending => { @@ -315,6 +344,7 @@ mod test { use crate::db; use async_bb8_diesel::{AsyncRunQueryDsl, AsyncSimpleConnection}; + use diesel::JoinOnDsl; use diesel::SelectableHelper; use dropshot::PaginationOrder; use nexus_test_utils::db::test_setup_database; @@ -333,9 +363,18 @@ mod test { height -> Int8, } } + + table! { + test_phone_numbers (user_id, phone_number) { + user_id -> Uuid, + phone_number -> Int8, + } + } + + allow_tables_to_appear_in_same_query!(test_users, test_phone_numbers,); } - use schema::test_users; + use schema::{test_phone_numbers, test_users}; #[derive(Clone, Debug, Queryable, Insertable, PartialEq, Selectable)] #[diesel(table_name = test_users)] @@ -345,13 +384,39 @@ mod test { height: i64, } + #[derive(Clone, Debug, Queryable, Insertable, PartialEq, Selectable)] + #[diesel(table_name = test_phone_numbers)] + struct PhoneNumber { + user_id: Uuid, + phone_number: i64, + } + + #[derive(Debug)] + struct UserAndPhoneNumber { + user: User, + phone_number: PhoneNumber, + } + + impl PartialEq<((i64, i64), i64)> for UserAndPhoneNumber { + fn eq(&self, &(user, phone): &((i64, i64), i64)) -> bool { + self.user == user && self.phone_number == phone + } + } + impl PartialEq<(i64, i64)> for User { fn eq(&self, other: &(i64, i64)) -> bool { self.age == other.0 && self.height == other.1 } } + impl PartialEq for PhoneNumber { + fn eq(&self, &other: &i64) -> bool { + self.phone_number == other + } + } + async fn populate_users(pool: &db::Pool, values: &Vec<(i64, i64)>) { + use schema::test_phone_numbers::dsl as phone_numbers_dsl; use schema::test_users::dsl; let conn = pool.claim().await.unwrap(); @@ -365,8 +430,17 @@ mod test { height INT NOT NULL ); + CREATE TABLE test_phone_numbers ( + user_id UUID NOT NULL, + -- This is definitely the correct way to store a + -- phone number in the database. :) + phone_number INT NOT NULL, + PRIMARY KEY (user_id, phone_number) + ); + CREATE INDEX ON test_users (age, height); - CREATE INDEX ON test_users (height, age);", + CREATE INDEX ON test_users (height, age); + CREATE INDEX ON test_phone_numbers (user_id);", ) .await .unwrap(); @@ -381,7 +455,22 @@ mod test { .collect(); diesel::insert_into(dsl::test_users) - .values(users) + .values(users.clone()) + .execute_async(&*conn) + .await + .unwrap(); + + let mut phone_numbers = Vec::new(); + for (i, user) in users.iter().enumerate() { + for j in 0..3 { + phone_numbers.push(PhoneNumber { + user_id: user.id, + phone_number: (i as i64 + 1) * 10 + j, + }); + } + } + diesel::insert_into(phone_numbers_dsl::test_phone_numbers) + .values(phone_numbers) .execute_async(&*conn) .await .unwrap(); @@ -574,6 +663,107 @@ mod test { logctx.cleanup_successful(); } + #[tokio::test] + async fn test_paginated_multicolumn_works_with_joins() { + use async_bb8_diesel::AsyncConnection; + + let logctx = + dev::test_setup_log("test_paginated_multicolumn_works_with_joins"); + let mut db = test_setup_database(&logctx.log).await; + let cfg = db::Config { url: db.pg_config().clone() }; + let pool = db::Pool::new_single_host(&logctx.log, &cfg); + + use schema::test_phone_numbers::dsl as phone_numbers_dsl; + use schema::test_users::dsl; + + populate_users(&pool, &vec![(1, 1), (1, 2), (2, 1), (2, 3), (3, 1)]) + .await; + + async fn get_page( + pool: &db::Pool, + pagparams: &DataPageParams<'_, (i64, i64)>, + ) -> Vec { + let conn = pool.claim().await.unwrap(); + conn.transaction_async(|conn| async move { + // I couldn't figure out how to make this work without requiring a full + // table scan, and I just want the test to work so that I can get on + // with my life... + conn.batch_execute_async( + crate::db::queries::ALLOW_FULL_TABLE_SCAN_SQL, + ) + .await + .unwrap(); + + paginated_multicolumn( + dsl::test_users.inner_join( + phone_numbers_dsl::test_phone_numbers + .on(phone_numbers_dsl::user_id.eq(dsl::id)), + ), + (dsl::age, phone_numbers_dsl::phone_number), + &pagparams, + ) + .select((User::as_select(), PhoneNumber::as_select())) + .load_async(&conn) + .await + }) + .await + .unwrap() + .into_iter() + .map(|(user, phone_number)| UserAndPhoneNumber { + user, + phone_number, + }) + .collect::>() + } + + // Get the first paginated result. + let mut pagparams = DataPageParams::<(i64, i64)> { + marker: None, + direction: PaginationOrder::Ascending, + limit: NonZeroU32::new(1).unwrap(), + }; + let observed = get_page(&pool, &pagparams).await; + assert_eq!(dbg!(&observed), &[((1, 1), 10)]); + + // Get the next paginated results, check that they arrived in the order + // we expected. + let marker = + (observed[0].user.age, observed[0].phone_number.phone_number); + pagparams.marker = Some(&marker); + pagparams.limit = NonZeroU32::new(10).unwrap(); + let observed = get_page(&pool, &pagparams).await; + assert_eq!( + dbg!(&observed), + &[ + ((1, 1), 11), + ((1, 1), 12), + ((1, 2), 20), + ((1, 2), 21), + ((1, 2), 22), + ((2, 1), 30), + ((2, 1), 31), + ((2, 1), 32), + ((2, 3), 40), + ((2, 3), 41), + ] + ); + + // Get the next paginated results, check that they arrived in the order + // we expected. + let marker = + (observed[9].user.age, observed[9].phone_number.phone_number); + pagparams.marker = Some(&marker); + pagparams.limit = NonZeroU32::new(10).unwrap(); + let observed = get_page(&pool, &pagparams).await; + assert_eq!( + dbg!(&observed), + &[((2, 3), 42), ((3, 1), 50), ((3, 1), 51), ((3, 1), 52)] + ); + + let _ = db.cleanup().await; + logctx.cleanup_successful(); + } + #[test] fn test_paginator() { // The doctest exercises a basic case for Paginator. Here we test some