Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 163 additions & 49 deletions nexus/db-queries/src/db/datastore/scim_provider_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,26 +298,47 @@ impl<'a> CrdbScimProviderStore<'a> {
Ok(convert_to_scim_user(new_user, None))
}

async fn list_users_in_txn(
async fn list_users_with_groups(
&self,
conn: &async_bb8_diesel::Connection<DbConnection>,
err: OptionalError<ProviderStoreError>,
filter: Option<FilterOp>,
) -> Result<Vec<StoredParts<User>>, diesel::result::Error> {
use nexus_db_schema::schema::silo_user::dsl;

let mut query = dsl::silo_user
.filter(dsl::silo_id.eq(self.authz_silo.id()))
.filter(dsl::user_provision_type.eq(model::UserProvisionType::Scim))
.filter(dsl::time_deleted.is_null())
use nexus_db_schema::schema::silo_group::dsl as group_dsl;
use nexus_db_schema::schema::silo_group_membership::dsl as membership_dsl;
use nexus_db_schema::schema::silo_user::dsl as user_dsl;
use std::collections::HashMap;

let mut query = user_dsl::silo_user
.left_join(
membership_dsl::silo_group_membership
.on(user_dsl::id.eq(membership_dsl::silo_user_id)),
)
.left_join(
group_dsl::silo_group.on(membership_dsl::silo_group_id
.eq(group_dsl::id)
.and(group_dsl::silo_id.eq(self.authz_silo.id()))
.and(
group_dsl::user_provision_type
.eq(model::UserProvisionType::Scim),
)
.and(group_dsl::time_deleted.is_null())),
)
.filter(user_dsl::silo_id.eq(self.authz_silo.id()))
.filter(
user_dsl::user_provision_type
.eq(model::UserProvisionType::Scim),
)
.filter(user_dsl::time_deleted.is_null())
.into_boxed();

match filter {
Some(FilterOp::UserNameEq(username)) => {
// userName is defined as `"caseExact" : false` in RFC 7643,
// section 8.7.1
query = query
.filter(lower(dsl::user_name).eq(lower(username.clone())));
query = query.filter(
lower(user_dsl::user_name).eq(lower(username.clone())),
);
}

None => {
Expand All @@ -334,17 +355,74 @@ impl<'a> CrdbScimProviderStore<'a> {
}
}

let users = query
.select(model::SiloUser::as_returning())
// Select user fields and optional group fields
// Note: We need to explicitly select the group columns to work with the
// boxed query and LEFT JOIN
type UserRow = (model::SiloUser, Option<(Uuid, Option<String>)>);
type UsersMap = HashMap<
Uuid,
(Option<model::SiloUser>, Vec<(SiloGroupUuid, String)>),
>;

let rows: Vec<UserRow> = query
.select((
model::SiloUser::as_select(),
(group_dsl::id, group_dsl::display_name).nullable(),
))
.load_async(conn)
.await?;

let mut returned_users = Vec::with_capacity(users.len());
// Group the results by user_id
let mut users_map: UsersMap = HashMap::new();

for user in users {
let groups = self
.get_user_groups_for_user_in_txn(conn, user.identity.id.into())
.await?;
for (user, maybe_group_info) in rows {
let user_id = user.identity.id.into_untyped_uuid();

let entry =
users_map.entry(user_id).or_insert_with(|| (None, Vec::new()));

// Store the user on first occurrence
if entry.0.is_none() {
entry.0 = Some(user);
}

// If this row has a group, add it to the user's groups
if let Some((group_id, maybe_display_name)) = maybe_group_info {
let display_name = maybe_display_name.expect(
"the constraint `display_name_consistency` prevents a \
group with provision type 'scim' from having a null \
display_name",
);

entry.1.push((
SiloGroupUuid::from_untyped_uuid(group_id),
display_name,
));
}
}

// Convert to the expected return type
let mut returned_users = Vec::with_capacity(users_map.len());

for (_user_id, (maybe_user, groups)) in users_map {
let user = maybe_user.expect("user should always be present");

let groups = if groups.is_empty() {
None
} else {
Some(
groups
.into_iter()
.map(|(group_id, display_name)| UserGroup {
// Note neither the scim2-rs crate or Nexus supports
// nested groups
member_type: Some(UserGroupType::Direct),
value: Some(group_id.to_string()),
display: Some(display_name),
})
.collect(),
)
};

let SiloUser::Scim(user) = user.into() else {
// With the user provision type filter, this should never be
Expand Down Expand Up @@ -839,26 +917,36 @@ impl<'a> CrdbScimProviderStore<'a> {
Ok(convert_to_scim_group(new_group, members))
}

async fn list_groups_in_txn(
async fn list_groups_with_members(
&self,
conn: &async_bb8_diesel::Connection<DbConnection>,
err: OptionalError<ProviderStoreError>,
filter: Option<FilterOp>,
) -> Result<Vec<StoredParts<Group>>, diesel::result::Error> {
use nexus_db_schema::schema::silo_group::dsl;
use nexus_db_schema::schema::silo_group::dsl as group_dsl;
use nexus_db_schema::schema::silo_group_membership::dsl as membership_dsl;
use std::collections::HashMap;

let mut query = dsl::silo_group
.filter(dsl::silo_id.eq(self.authz_silo.id()))
.filter(dsl::user_provision_type.eq(model::UserProvisionType::Scim))
.filter(dsl::time_deleted.is_null())
let mut query = group_dsl::silo_group
.left_join(
membership_dsl::silo_group_membership
.on(group_dsl::id.eq(membership_dsl::silo_group_id)),
)
.filter(group_dsl::silo_id.eq(self.authz_silo.id()))
.filter(
group_dsl::user_provision_type
.eq(model::UserProvisionType::Scim),
)
.filter(group_dsl::time_deleted.is_null())
.into_boxed();

match filter {
Some(FilterOp::DisplayNameEq(display_name)) => {
// displayName is defined as `"caseExact" : false` in RFC 7643,
// section 8.7.1
query = query.filter(
lower(dsl::display_name).eq(lower(display_name.clone())),
lower(group_dsl::display_name)
.eq(lower(display_name.clone())),
);
}

Expand All @@ -876,20 +964,60 @@ impl<'a> CrdbScimProviderStore<'a> {
}
}

let groups = query
.select(model::SiloGroup::as_returning())
// Select group fields and optional member user_id
type GroupRow = (model::SiloGroup, Option<Uuid>);
type GroupsMap =
HashMap<Uuid, (Option<model::SiloGroup>, Vec<SiloUserUuid>)>;

let rows: Vec<GroupRow> = query
.select((
model::SiloGroup::as_select(),
membership_dsl::silo_user_id.nullable(),
))
.load_async(conn)
.await?;

let mut returned_groups = Vec::with_capacity(groups.len());
// Group the results by group_id
let mut groups_map: GroupsMap = HashMap::new();

for group in groups {
let members = self
.get_group_members_for_group_in_txn(
conn,
group.identity.id.into(),
)
.await?;
for (group, maybe_user_id) in rows {
let group_id = group.identity.id.into_untyped_uuid();

let entry = groups_map
.entry(group_id)
.or_insert_with(|| (None, Vec::new()));

// Store the group on first occurrence
if entry.0.is_none() {
entry.0 = Some(group);
}

// If this row has a member, add it to the group's members
if let Some(user_id) = maybe_user_id {
entry.1.push(SiloUserUuid::from_untyped_uuid(user_id));
}
}

// Convert to the expected return type
let mut returned_groups = Vec::with_capacity(groups_map.len());

for (_group_id, (maybe_group, members)) in groups_map {
let group = maybe_group.expect("group should always be present");

let members = if members.is_empty() {
None
} else {
let mut id_ord_map = IdOrdMap::with_capacity(members.len());
for user_id in members {
id_ord_map
.insert_unique(GroupMember {
resource_type: Some(ResourceType::User.to_string()),
value: Some(user_id.to_string()),
})
.expect("user_id should be unique");
}
Some(id_ord_map)
};

let SiloGroup::Scim(group) = group.into() else {
// With the user provision type filter, this should never be
Expand Down Expand Up @@ -1387,14 +1515,7 @@ impl<'a> ProviderStore for CrdbScimProviderStore<'a> {
let err: OptionalError<ProviderStoreError> = OptionalError::new();

let users = self
.datastore
.transaction_retry_wrapper("scim_list_users")
.transaction(&conn, |conn| {
let err = err.clone();
let filter = filter.clone();

async move { self.list_users_in_txn(&conn, err, filter).await }
})
.list_users_with_groups(&conn, err.clone(), filter)
Copy link
Contributor

@papertigers papertigers Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm I think doing this in a txn is load bearing, is that right @jmpesp ?

Copy link
Contributor Author

@david-crespo david-crespo Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it was only because there were 50 queries in there. It can easily be put back.

.await
.map_err(|e| {
if let Some(e) = err.take() {
Expand Down Expand Up @@ -1652,14 +1773,7 @@ impl<'a> ProviderStore for CrdbScimProviderStore<'a> {
let err: OptionalError<ProviderStoreError> = OptionalError::new();

let groups = self
.datastore
.transaction_retry_wrapper("scim_list_groups")
.transaction(&conn, |conn| {
let err = err.clone();
let filter = filter.clone();

async move { self.list_groups_in_txn(&conn, err, filter).await }
})
.list_groups_with_members(&conn, err.clone(), filter)
.await
.map_err(|e| {
if let Some(e) = err.take() {
Expand Down
Loading
Loading