diff --git a/data_safe_haven/commands/users.py b/data_safe_haven/commands/users.py index fe413fa781..8c8b232ceb 100644 --- a/data_safe_haven/commands/users.py +++ b/data_safe_haven/commands/users.py @@ -5,6 +5,7 @@ import typer +from data_safe_haven import console from data_safe_haven.administration.users import UserHandler from data_safe_haven.config import ContextManager, DSHPulumiConfig, SHMConfig, SREConfig from data_safe_haven.exceptions import DataSafeHavenError @@ -120,9 +121,9 @@ def register( # Load SHMConfig try: shm_config = SHMConfig.from_remote(context) - except DataSafeHavenError: + except DataSafeHavenError as exc: logger.error("Have you deployed the SHM?") - raise + raise typer.Exit(1) from exc # Load Pulumi config pulumi_config = DSHPulumiConfig.from_remote(context) @@ -132,7 +133,7 @@ def register( if sre_config.name not in pulumi_config.project_names: msg = f"Could not load Pulumi settings for '{sre_config.name}'. Have you deployed the SRE?" logger.error(msg) - raise DataSafeHavenError(msg) + raise typer.Exit(1) # Load GraphAPI graph_api = GraphApi.from_scopes( @@ -146,16 +147,29 @@ def register( # List users users = UserHandler(context, graph_api) - available_usernames = users.get_usernames_entra_id() + available_users = users.entra_users.list() + user_dict = { + user.preferred_username.split("@")[0]: user.preferred_username.split("@")[1] + for user in available_users + } usernames_to_register = [] for username in usernames: - if username in available_usernames: - usernames_to_register.append(username) + if user_domain := user_dict.get(username): + if shm_config.shm.fqdn not in user_domain: + console.print( + f"User [green]'{username}[/green]'s principal domain name is [blue]'{user_domain}'[/blue].\n" + f"SRE [yellow]'{sre}'[/yellow] belongs to SHM domain [blue]'{shm_config.shm.fqdn}'[/blue]." + ) + logger.error( + "The user's principal domain name must match the domain of the SRE to be registered." + ) + else: + usernames_to_register.append(username) else: logger.error( f"Username '{username}' does not belong to this Data Safe Haven deployment." - " Please use 'dsh users add' to create it." ) + console.print("Please use 'dsh users add' to create this user.") users.register(sre_config.name, usernames_to_register) except DataSafeHavenError as exc: logger.critical(f"Could not register Data Safe Haven users with SRE '{sre}'.") @@ -259,8 +273,8 @@ def unregister( else: logger.error( f"Username '{username}' does not belong to this Data Safe Haven deployment." - " Please use 'dsh users add' to create it." ) + console.print("Please use 'dsh users add' to create it.") for group_name in ( f"{sre_config.name} Users", f"{sre_config.name} Privileged Users", diff --git a/tests/commands/conftest.py b/tests/commands/conftest.py index d675398bfc..de60eb29d0 100644 --- a/tests/commands/conftest.py +++ b/tests/commands/conftest.py @@ -1,6 +1,8 @@ from pytest import fixture from typer.testing import CliRunner +from data_safe_haven.administration.users.entra_users import EntraUsers +from data_safe_haven.administration.users.research_user import ResearchUser from data_safe_haven.config import ( Context, ContextManager, @@ -260,3 +262,14 @@ def tmp_contexts_none(tmp_path, context_yaml): with open(config_file_path, "w") as f: f.write(context_yaml) return tmp_path + + +@fixture +def mock_entra_user_list(mocker): + test_user = ResearchUser( + given_name="Harry", + surname="Lime", + sam_account_name="harry.lime", + user_principal_name="harry.lime@acme.testing", + ) + mocker.patch.object(EntraUsers, "list", return_value=[test_user]) diff --git a/tests/commands/test_users.py b/tests/commands/test_users.py index c1b183c922..5c11e29cc9 100644 --- a/tests/commands/test_users.py +++ b/tests/commands/test_users.py @@ -52,6 +52,26 @@ def test_invalid_shm( assert result.exit_code == 1 assert "Have you deployed the SHM?" in result.stdout + def test_mismatched_domain( + self, + mock_graphapi_get_credential, # noqa: ARG002 + mock_pulumi_config_no_key_from_remote, # noqa: ARG002 + mock_shm_config_from_remote, # noqa: ARG002 + mock_sre_config_from_remote, # noqa: ARG002 + mock_entra_user_list, # noqa: ARG002 + runner, + tmp_contexts, # noqa: ARG002 + ): + result = runner.invoke( + users_command_group, ["register", "-u", "harry.lime", "sandbox"] + ) + + assert result.exit_code == 0 + assert ( + "principal domain name must match the domain of the SRE to be registered" + in result.stdout + ) + def test_invalid_sre( self, mock_pulumi_config_from_remote, # noqa: ARG002