|  | 
|  | 1 | +from unittest.mock import patch | 
|  | 2 | +from urllib.parse import parse_qs, urlparse | 
|  | 3 | + | 
| 1 | 4 | import pytest | 
| 2 | 5 | from django.contrib.auth import get_user | 
| 3 | 6 | from django.contrib.auth.models import AnonymousUser | 
|  | 
| 12 | 15 |     InvalidOIDCClientError, | 
| 13 | 16 |     InvalidOIDCRedirectURIError, | 
| 14 | 17 | ) | 
| 15 |  | -from oauth2_provider.models import get_access_token_model, get_id_token_model, get_refresh_token_model | 
|  | 18 | +from oauth2_provider.models import ( | 
|  | 19 | +    get_access_token_model, | 
|  | 20 | +    get_application_model, | 
|  | 21 | +    get_id_token_model, | 
|  | 22 | +    get_refresh_token_model, | 
|  | 23 | +) | 
| 16 | 24 | from oauth2_provider.oauth2_validators import OAuth2Validator | 
| 17 | 25 | from oauth2_provider.settings import oauth2_settings | 
| 18 | 26 | from oauth2_provider.views.oidc import RPInitiatedLogoutView, _load_id_token, _validate_claims | 
| @@ -47,6 +55,7 @@ def test_get_connect_discovery_info(self): | 
| 47 | 55 |             "token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"], | 
| 48 | 56 |             "code_challenge_methods_supported": ["plain", "S256"], | 
| 49 | 57 |             "claims_supported": ["sub"], | 
|  | 58 | +            "prompt_values_supported": ["none", "login"], | 
| 50 | 59 |         } | 
| 51 | 60 |         response = self.client.get("/o/.well-known/openid-configuration") | 
| 52 | 61 |         self.assertEqual(response.status_code, 200) | 
| @@ -74,6 +83,7 @@ def test_get_connect_discovery_info_deprecated(self): | 
| 74 | 83 |             "token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"], | 
| 75 | 84 |             "code_challenge_methods_supported": ["plain", "S256"], | 
| 76 | 85 |             "claims_supported": ["sub"], | 
|  | 86 | +            "prompt_values_supported": ["none", "login"], | 
| 77 | 87 |         } | 
| 78 | 88 |         response = self.client.get("/o/.well-known/openid-configuration/") | 
| 79 | 89 |         self.assertEqual(response.status_code, 200) | 
| @@ -101,6 +111,7 @@ def expect_json_response_with_rp_logout(self, base): | 
| 101 | 111 |             "token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"], | 
| 102 | 112 |             "code_challenge_methods_supported": ["plain", "S256"], | 
| 103 | 113 |             "claims_supported": ["sub"], | 
|  | 114 | +            "prompt_values_supported": ["none", "login"], | 
| 104 | 115 |             "end_session_endpoint": f"{base}/logout/", | 
| 105 | 116 |         } | 
| 106 | 117 |         response = self.client.get(reverse("oauth2_provider:oidc-connect-discovery-info")) | 
| @@ -135,6 +146,7 @@ def test_get_connect_discovery_info_without_issuer_url(self): | 
| 135 | 146 |             "token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"], | 
| 136 | 147 |             "code_challenge_methods_supported": ["plain", "S256"], | 
| 137 | 148 |             "claims_supported": ["sub"], | 
|  | 149 | +            "prompt_values_supported": ["none", "login"], | 
| 138 | 150 |         } | 
| 139 | 151 |         response = self.client.get(reverse("oauth2_provider:oidc-connect-discovery-info")) | 
| 140 | 152 |         self.assertEqual(response.status_code, 200) | 
| @@ -206,6 +218,79 @@ def test_get_jwks_info_multiple_rsa_keys(self): | 
| 206 | 218 |         assert response.json() == expected_response | 
| 207 | 219 | 
 | 
| 208 | 220 | 
 | 
|  | 221 | +@pytest.mark.usefixtures("oauth2_settings") | 
|  | 222 | +@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_REGISTRATION) | 
|  | 223 | +class TestRPInitiatedRegistration(TestCase): | 
|  | 224 | +    def test_connect_discovery_info_has_create(self): | 
|  | 225 | +        expected_response = { | 
|  | 226 | +            "issuer": "http://localhost/o", | 
|  | 227 | +            "authorization_endpoint": "http://localhost/o/authorize/", | 
|  | 228 | +            "token_endpoint": "http://localhost/o/token/", | 
|  | 229 | +            "userinfo_endpoint": "http://localhost/o/userinfo/", | 
|  | 230 | +            "jwks_uri": "http://localhost/o/.well-known/jwks.json", | 
|  | 231 | +            "scopes_supported": ["read", "write", "openid"], | 
|  | 232 | +            "response_types_supported": [ | 
|  | 233 | +                "code", | 
|  | 234 | +                "token", | 
|  | 235 | +                "id_token", | 
|  | 236 | +                "id_token token", | 
|  | 237 | +                "code token", | 
|  | 238 | +                "code id_token", | 
|  | 239 | +                "code id_token token", | 
|  | 240 | +            ], | 
|  | 241 | +            "subject_types_supported": ["public"], | 
|  | 242 | +            "id_token_signing_alg_values_supported": ["RS256", "HS256"], | 
|  | 243 | +            "token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"], | 
|  | 244 | +            "code_challenge_methods_supported": ["plain", "S256"], | 
|  | 245 | +            "claims_supported": ["sub"], | 
|  | 246 | +            "prompt_values_supported": ["none", "login", "create"], | 
|  | 247 | +        } | 
|  | 248 | +        response = self.client.get("/o/.well-known/openid-configuration") | 
|  | 249 | +        self.assertEqual(response.status_code, 200) | 
|  | 250 | +        assert response.json() == expected_response | 
|  | 251 | + | 
|  | 252 | +    def test_prompt_create_redirects_to_registration_view(self): | 
|  | 253 | +        Application = get_application_model() | 
|  | 254 | +        application = Application.objects.create( | 
|  | 255 | +            name="Test Application", | 
|  | 256 | +            redirect_uris="http://localhost http://example.com", | 
|  | 257 | +            client_type=Application.CLIENT_CONFIDENTIAL, | 
|  | 258 | +            authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, | 
|  | 259 | +        ) | 
|  | 260 | + | 
|  | 261 | +        auth_url = reverse("oauth2_provider:authorize") | 
|  | 262 | +        query_params = { | 
|  | 263 | +            "response_type": "code", | 
|  | 264 | +            "client_id": application.client_id, | 
|  | 265 | +            "redirect_uri": "http://localhost", | 
|  | 266 | +            "scope": "openid", | 
|  | 267 | +            "prompt": "create", | 
|  | 268 | +        } | 
|  | 269 | + | 
|  | 270 | +        with patch("oauth2_provider.views.base.reverse") as patched_reverse: | 
|  | 271 | +            patched_reverse.return_value = "/register-test/" | 
|  | 272 | +            response = self.client.get(f"{auth_url}?{'&'.join(f'{k}={v}' for k, v in query_params.items())}") | 
|  | 273 | + | 
|  | 274 | +        self.assertEqual(response.status_code, 302) | 
|  | 275 | +        redirect_url = response.url | 
|  | 276 | +        parsed_url = urlparse(redirect_url) | 
|  | 277 | + | 
|  | 278 | +        # Verify it's the registration URL | 
|  | 279 | +        self.assertEqual(parsed_url.path, "/register-test/") | 
|  | 280 | + | 
|  | 281 | +        # Verify the query parameters | 
|  | 282 | +        query = parse_qs(parsed_url.query) | 
|  | 283 | +        self.assertIn("next", query) | 
|  | 284 | + | 
|  | 285 | +        # Verify the next parameter doesn't contain prompt=create | 
|  | 286 | +        next_url = query["next"][0] | 
|  | 287 | +        self.assertNotIn("prompt=create", next_url) | 
|  | 288 | + | 
|  | 289 | +        # But it should contain the other original parameters | 
|  | 290 | +        self.assertIn("response_type=code", next_url) | 
|  | 291 | +        self.assertIn(f"client_id={application.client_id}", next_url) | 
|  | 292 | + | 
|  | 293 | + | 
| 209 | 294 | def mock_request(): | 
| 210 | 295 |     """ | 
| 211 | 296 |     Dummy request with an AnonymousUser attached. | 
|  | 
0 commit comments