diff --git a/src/auth_server/tests/test_saml_sp.py b/src/auth_server/tests/test_saml_sp.py index eb10f24..0be7d35 100644 --- a/src/auth_server/tests/test_saml_sp.py +++ b/src/auth_server/tests/test_saml_sp.py @@ -3,7 +3,7 @@ from datetime import datetime, timedelta from os import environ from pathlib import Path -from typing import Any, Dict, Mapping +from typing import Any, Dict, Mapping, Optional from unittest import TestCase from urllib.parse import parse_qs, urlparse @@ -191,6 +191,14 @@ def setUp(self) -> None: """ + def _update_app_config(self, config: Optional[Dict] = None): + if config is not None: + environ.clear() + environ.update(config) + self._clear_lru_cache() + self.app = init_auth_server_api() + self.client = TestClient(self.app) + @staticmethod def _clear_lru_cache(): # Clear lru_cache to allow config update @@ -248,8 +256,10 @@ def get_current_saml_request_id(self) -> str: return ids[0] def test_authn_request(self): + self.config["saml2_single_idp"] = self.test_idp + self._update_app_config(config=self.config) authn_url = saml2_router.url_path_for("authenticate", transaction_id=self.test_transaction_state.transaction_id) - response = self.client.get(f"{authn_url}?idp={self.test_idp}", allow_redirects=False) + response = self.client.get(authn_url, allow_redirects=False) assert response.status_code == 303 assert ( response.headers["location"].startswith( @@ -259,6 +269,8 @@ def test_authn_request(self): ) def test_saml_acs(self): + self.config["saml2_single_idp"] = self.test_idp + self._update_app_config(config=self.config) # do authn request authn_url = saml2_router.url_path_for("authenticate", transaction_id=self.test_transaction_state.transaction_id) self.client.get(f"{authn_url}?idp={self.test_idp}", allow_redirects=False) @@ -272,7 +284,7 @@ def test_saml_acs(self): saml2_router.url_path_for("assertion_consumer_service"), data=data, follow_redirects=False ) assert response.status_code == 303 - assert response.headers["location"].startswith("/interaction/redirect/") is True + assert response.headers["location"].startswith("http://testserver/interaction/redirect/") is True # check authentication result transaction_state = self._get_transaction_state_by_id(self.test_transaction_state.transaction_id)