diff --git a/nexus/db-queries/src/db/datastore/scim_provider_store.rs b/nexus/db-queries/src/db/datastore/scim_provider_store.rs index b9414c4fbf..2585869438 100644 --- a/nexus/db-queries/src/db/datastore/scim_provider_store.rs +++ b/nexus/db-queries/src/db/datastore/scim_provider_store.rs @@ -298,26 +298,47 @@ impl<'a> CrdbScimProviderStore<'a> { Ok(convert_to_scim_user(new_user, None)) } - async fn list_users_in_txn( + async fn list_users_with_groups( &self, conn: &async_bb8_diesel::Connection, err: OptionalError, filter: Option, ) -> Result>, diesel::result::Error> { - use nexus_db_schema::schema::silo_user::dsl; - - let mut query = dsl::silo_user - .filter(dsl::silo_id.eq(self.authz_silo.id())) - .filter(dsl::user_provision_type.eq(model::UserProvisionType::Scim)) - .filter(dsl::time_deleted.is_null()) + use nexus_db_schema::schema::silo_group::dsl as group_dsl; + use nexus_db_schema::schema::silo_group_membership::dsl as membership_dsl; + use nexus_db_schema::schema::silo_user::dsl as user_dsl; + use std::collections::HashMap; + + let mut query = user_dsl::silo_user + .left_join( + membership_dsl::silo_group_membership + .on(user_dsl::id.eq(membership_dsl::silo_user_id)), + ) + .left_join( + group_dsl::silo_group.on(membership_dsl::silo_group_id + .eq(group_dsl::id) + .and(group_dsl::silo_id.eq(self.authz_silo.id())) + .and( + group_dsl::user_provision_type + .eq(model::UserProvisionType::Scim), + ) + .and(group_dsl::time_deleted.is_null())), + ) + .filter(user_dsl::silo_id.eq(self.authz_silo.id())) + .filter( + user_dsl::user_provision_type + .eq(model::UserProvisionType::Scim), + ) + .filter(user_dsl::time_deleted.is_null()) .into_boxed(); match filter { Some(FilterOp::UserNameEq(username)) => { // userName is defined as `"caseExact" : false` in RFC 7643, // section 8.7.1 - query = query - .filter(lower(dsl::user_name).eq(lower(username.clone()))); + query = query.filter( + lower(user_dsl::user_name).eq(lower(username.clone())), + ); } None => { @@ -334,17 +355,74 @@ impl<'a> CrdbScimProviderStore<'a> { } } - let users = query - .select(model::SiloUser::as_returning()) + // Select user fields and optional group fields + // Note: We need to explicitly select the group columns to work with the + // boxed query and LEFT JOIN + type UserRow = (model::SiloUser, Option<(Uuid, Option)>); + type UsersMap = HashMap< + Uuid, + (Option, Vec<(SiloGroupUuid, String)>), + >; + + let rows: Vec = query + .select(( + model::SiloUser::as_select(), + (group_dsl::id, group_dsl::display_name).nullable(), + )) .load_async(conn) .await?; - let mut returned_users = Vec::with_capacity(users.len()); + // Group the results by user_id + let mut users_map: UsersMap = HashMap::new(); - for user in users { - let groups = self - .get_user_groups_for_user_in_txn(conn, user.identity.id.into()) - .await?; + for (user, maybe_group_info) in rows { + let user_id = user.identity.id.into_untyped_uuid(); + + let entry = + users_map.entry(user_id).or_insert_with(|| (None, Vec::new())); + + // Store the user on first occurrence + if entry.0.is_none() { + entry.0 = Some(user); + } + + // If this row has a group, add it to the user's groups + if let Some((group_id, maybe_display_name)) = maybe_group_info { + let display_name = maybe_display_name.expect( + "the constraint `display_name_consistency` prevents a \ + group with provision type 'scim' from having a null \ + display_name", + ); + + entry.1.push(( + SiloGroupUuid::from_untyped_uuid(group_id), + display_name, + )); + } + } + + // Convert to the expected return type + let mut returned_users = Vec::with_capacity(users_map.len()); + + for (_user_id, (maybe_user, groups)) in users_map { + let user = maybe_user.expect("user should always be present"); + + let groups = if groups.is_empty() { + None + } else { + Some( + groups + .into_iter() + .map(|(group_id, display_name)| UserGroup { + // Note neither the scim2-rs crate or Nexus supports + // nested groups + member_type: Some(UserGroupType::Direct), + value: Some(group_id.to_string()), + display: Some(display_name), + }) + .collect(), + ) + }; let SiloUser::Scim(user) = user.into() else { // With the user provision type filter, this should never be @@ -839,18 +917,27 @@ impl<'a> CrdbScimProviderStore<'a> { Ok(convert_to_scim_group(new_group, members)) } - async fn list_groups_in_txn( + async fn list_groups_with_members( &self, conn: &async_bb8_diesel::Connection, err: OptionalError, filter: Option, ) -> Result>, diesel::result::Error> { - use nexus_db_schema::schema::silo_group::dsl; + use nexus_db_schema::schema::silo_group::dsl as group_dsl; + use nexus_db_schema::schema::silo_group_membership::dsl as membership_dsl; + use std::collections::HashMap; - let mut query = dsl::silo_group - .filter(dsl::silo_id.eq(self.authz_silo.id())) - .filter(dsl::user_provision_type.eq(model::UserProvisionType::Scim)) - .filter(dsl::time_deleted.is_null()) + let mut query = group_dsl::silo_group + .left_join( + membership_dsl::silo_group_membership + .on(group_dsl::id.eq(membership_dsl::silo_group_id)), + ) + .filter(group_dsl::silo_id.eq(self.authz_silo.id())) + .filter( + group_dsl::user_provision_type + .eq(model::UserProvisionType::Scim), + ) + .filter(group_dsl::time_deleted.is_null()) .into_boxed(); match filter { @@ -858,7 +945,8 @@ impl<'a> CrdbScimProviderStore<'a> { // displayName is defined as `"caseExact" : false` in RFC 7643, // section 8.7.1 query = query.filter( - lower(dsl::display_name).eq(lower(display_name.clone())), + lower(group_dsl::display_name) + .eq(lower(display_name.clone())), ); } @@ -876,20 +964,60 @@ impl<'a> CrdbScimProviderStore<'a> { } } - let groups = query - .select(model::SiloGroup::as_returning()) + // Select group fields and optional member user_id + type GroupRow = (model::SiloGroup, Option); + type GroupsMap = + HashMap, Vec)>; + + let rows: Vec = query + .select(( + model::SiloGroup::as_select(), + membership_dsl::silo_user_id.nullable(), + )) .load_async(conn) .await?; - let mut returned_groups = Vec::with_capacity(groups.len()); + // Group the results by group_id + let mut groups_map: GroupsMap = HashMap::new(); - for group in groups { - let members = self - .get_group_members_for_group_in_txn( - conn, - group.identity.id.into(), - ) - .await?; + for (group, maybe_user_id) in rows { + let group_id = group.identity.id.into_untyped_uuid(); + + let entry = groups_map + .entry(group_id) + .or_insert_with(|| (None, Vec::new())); + + // Store the group on first occurrence + if entry.0.is_none() { + entry.0 = Some(group); + } + + // If this row has a member, add it to the group's members + if let Some(user_id) = maybe_user_id { + entry.1.push(SiloUserUuid::from_untyped_uuid(user_id)); + } + } + + // Convert to the expected return type + let mut returned_groups = Vec::with_capacity(groups_map.len()); + + for (_group_id, (maybe_group, members)) in groups_map { + let group = maybe_group.expect("group should always be present"); + + let members = if members.is_empty() { + None + } else { + let mut id_ord_map = IdOrdMap::with_capacity(members.len()); + for user_id in members { + id_ord_map + .insert_unique(GroupMember { + resource_type: Some(ResourceType::User.to_string()), + value: Some(user_id.to_string()), + }) + .expect("user_id should be unique"); + } + Some(id_ord_map) + }; let SiloGroup::Scim(group) = group.into() else { // With the user provision type filter, this should never be @@ -1387,14 +1515,7 @@ impl<'a> ProviderStore for CrdbScimProviderStore<'a> { let err: OptionalError = OptionalError::new(); let users = self - .datastore - .transaction_retry_wrapper("scim_list_users") - .transaction(&conn, |conn| { - let err = err.clone(); - let filter = filter.clone(); - - async move { self.list_users_in_txn(&conn, err, filter).await } - }) + .list_users_with_groups(&conn, err.clone(), filter) .await .map_err(|e| { if let Some(e) = err.take() { @@ -1652,14 +1773,7 @@ impl<'a> ProviderStore for CrdbScimProviderStore<'a> { let err: OptionalError = OptionalError::new(); let groups = self - .datastore - .transaction_retry_wrapper("scim_list_groups") - .transaction(&conn, |conn| { - let err = err.clone(); - let filter = filter.clone(); - - async move { self.list_groups_in_txn(&conn, err, filter).await } - }) + .list_groups_with_members(&conn, err.clone(), filter) .await .map_err(|e| { if let Some(e) = err.take() { diff --git a/nexus/tests/integration_tests/scim.rs b/nexus/tests/integration_tests/scim.rs index 425a7e1055..7f31e5444d 100644 --- a/nexus/tests/integration_tests/scim.rs +++ b/nexus/tests/integration_tests/scim.rs @@ -1990,3 +1990,251 @@ async fn test_scim_user_admin_group_priv_conflict( .await .expect("expected 200"); } + +#[nexus_test] +async fn test_scim_list_users_and_groups(cptestctx: &ControlPlaneTestContext) { + let client = &cptestctx.external_client; + let nexus = &cptestctx.server.server_context().nexus; + let opctx = OpContext::for_tests( + cptestctx.logctx.log.new(o!()), + nexus.datastore().clone(), + ); + + const SILO_NAME: &str = "saml-scim-silo"; + create_silo(&client, SILO_NAME, true, shared::SiloIdentityMode::SamlScim) + .await; + + grant_iam( + client, + &format!("/v1/system/silos/{SILO_NAME}"), + shared::SiloRole::Admin, + opctx.authn.actor().unwrap().silo_user_id().unwrap(), + AuthnMode::PrivilegedUser, + ) + .await; + + let created_token: views::ScimClientBearerTokenValue = + object_create_no_body( + client, + &format!("/v1/system/scim/tokens?silo={}", SILO_NAME), + ) + .await; + + // Create 5 users + let mut users = Vec::new(); + for i in 1..=5 { + let user: scim2_rs::User = NexusRequest::new( + RequestBuilder::new(client, Method::POST, "/scim/v2/Users") + .header(http::header::CONTENT_TYPE, "application/scim+json") + .header( + http::header::AUTHORIZATION, + format!("Bearer oxide-scim-{}", created_token.bearer_token), + ) + .allow_non_dropshot_errors() + .raw_body(Some( + serde_json::to_string(&serde_json::json!({ + "userName": format!("user{}", i), + "externalId": format!("user{}@example.com", i), + })) + .unwrap(), + )) + .expect_status(Some(StatusCode::CREATED)), + ) + .execute_and_parse_unwrap() + .await; + users.push(user); + } + + // Create 3 groups with various membership patterns: + // - group1: user1, user2, user3 + // - group2: user1, user4 + // - group3: no members + let group1: scim2_rs::Group = NexusRequest::new( + RequestBuilder::new(client, Method::POST, "/scim/v2/Groups") + .header(http::header::CONTENT_TYPE, "application/scim+json") + .header( + http::header::AUTHORIZATION, + format!("Bearer oxide-scim-{}", created_token.bearer_token), + ) + .allow_non_dropshot_errors() + .raw_body(Some( + serde_json::to_string(&serde_json::json!({ + "displayName": "group1", + "externalId": "group1@example.com", + "members": [ + {"value": users[0].id}, + {"value": users[1].id}, + {"value": users[2].id}, + ], + })) + .unwrap(), + )) + .expect_status(Some(StatusCode::CREATED)), + ) + .execute_and_parse_unwrap() + .await; + + let group2: scim2_rs::Group = NexusRequest::new( + RequestBuilder::new(client, Method::POST, "/scim/v2/Groups") + .header(http::header::CONTENT_TYPE, "application/scim+json") + .header( + http::header::AUTHORIZATION, + format!("Bearer oxide-scim-{}", created_token.bearer_token), + ) + .allow_non_dropshot_errors() + .raw_body(Some( + serde_json::to_string(&serde_json::json!({ + "displayName": "group2", + "externalId": "group2@example.com", + "members": [ + {"value": users[0].id}, + {"value": users[3].id}, + ], + })) + .unwrap(), + )) + .expect_status(Some(StatusCode::CREATED)), + ) + .execute_and_parse_unwrap() + .await; + + let group3: scim2_rs::Group = NexusRequest::new( + RequestBuilder::new(client, Method::POST, "/scim/v2/Groups") + .header(http::header::CONTENT_TYPE, "application/scim+json") + .header( + http::header::AUTHORIZATION, + format!("Bearer oxide-scim-{}", created_token.bearer_token), + ) + .allow_non_dropshot_errors() + .raw_body(Some( + serde_json::to_string(&serde_json::json!({ + "displayName": "group3", + "externalId": "group3@example.com", + })) + .unwrap(), + )) + .expect_status(Some(StatusCode::CREATED)), + ) + .execute_and_parse_unwrap() + .await; + + // List all users and verify group memberships + let response: scim2_rs::ListResponse = NexusRequest::new( + RequestBuilder::new(client, Method::GET, "/scim/v2/Users") + .header(http::header::CONTENT_TYPE, "application/scim+json") + .header( + http::header::AUTHORIZATION, + format!("Bearer oxide-scim-{}", created_token.bearer_token), + ) + .allow_non_dropshot_errors() + .expect_status(Some(StatusCode::OK)), + ) + .execute_and_parse_unwrap() + .await; + + let returned_users: Vec = serde_json::from_value( + serde_json::to_value(&response.resources).unwrap(), + ) + .unwrap(); + + // Find our created users in the response + let find_user = |user_id: &str| { + returned_users + .iter() + .find(|u| u.id == user_id) + .expect("user should be in list") + }; + + // user1 should be in group1 and group2 + let user1 = find_user(&users[0].id); + assert!(user1.groups.is_some()); + let user1_groups = user1.groups.as_ref().unwrap(); + assert_eq!(user1_groups.len(), 2); + let user1_group_ids: std::collections::HashSet<_> = user1_groups + .iter() + .map(|g| g.value.as_ref().unwrap().as_str()) + .collect(); + assert!(user1_group_ids.contains(group1.id.as_str())); + assert!(user1_group_ids.contains(group2.id.as_str())); + + // user2 should be in group1 only + let user2 = find_user(&users[1].id); + assert!(user2.groups.is_some()); + let user2_groups = user2.groups.as_ref().unwrap(); + assert_eq!(user2_groups.len(), 1); + assert_eq!(user2_groups[0].value.as_ref().unwrap(), &group1.id); + + // user3 should be in group1 only + let user3 = find_user(&users[2].id); + assert!(user3.groups.is_some()); + let user3_groups = user3.groups.as_ref().unwrap(); + assert_eq!(user3_groups.len(), 1); + assert_eq!(user3_groups[0].value.as_ref().unwrap(), &group1.id); + + // user4 should be in group2 only + let user4 = find_user(&users[3].id); + assert!(user4.groups.is_some()); + let user4_groups = user4.groups.as_ref().unwrap(); + assert_eq!(user4_groups.len(), 1); + assert_eq!(user4_groups[0].value.as_ref().unwrap(), &group2.id); + + // user5 should have no groups + let user5 = find_user(&users[4].id); + assert!(user5.groups.is_none()); + + // List all groups and verify members + let response: scim2_rs::ListResponse = NexusRequest::new( + RequestBuilder::new(client, Method::GET, "/scim/v2/Groups") + .header(http::header::CONTENT_TYPE, "application/scim+json") + .header( + http::header::AUTHORIZATION, + format!("Bearer oxide-scim-{}", created_token.bearer_token), + ) + .allow_non_dropshot_errors() + .expect_status(Some(StatusCode::OK)), + ) + .execute_and_parse_unwrap() + .await; + + let returned_groups: Vec = serde_json::from_value( + serde_json::to_value(&response.resources).unwrap(), + ) + .unwrap(); + + // Find our created groups in the response + let find_group = |group_id: &str| { + returned_groups + .iter() + .find(|g| g.id == group_id) + .expect("group should be in list") + }; + + // group1 should have 3 members + let returned_group1 = find_group(&group1.id); + assert!(returned_group1.members.is_some()); + let group1_members = returned_group1.members.as_ref().unwrap(); + assert_eq!(group1_members.len(), 3); + let group1_member_ids: std::collections::HashSet<_> = group1_members + .iter() + .map(|m| m.value.as_ref().unwrap().as_str()) + .collect(); + assert!(group1_member_ids.contains(users[0].id.as_str())); + assert!(group1_member_ids.contains(users[1].id.as_str())); + assert!(group1_member_ids.contains(users[2].id.as_str())); + + // group2 should have 2 members + let returned_group2 = find_group(&group2.id); + assert!(returned_group2.members.is_some()); + let group2_members = returned_group2.members.as_ref().unwrap(); + assert_eq!(group2_members.len(), 2); + let group2_member_ids: std::collections::HashSet<_> = group2_members + .iter() + .map(|m| m.value.as_ref().unwrap().as_str()) + .collect(); + assert!(group2_member_ids.contains(users[0].id.as_str())); + assert!(group2_member_ids.contains(users[3].id.as_str())); + + // group3 should have no members + let returned_group3 = find_group(&group3.id); + assert!(returned_group3.members.is_none()); +}