diff --git a/src/auth_server/tests/test_app.py b/src/auth_server/tests/test_app.py index 7e6d218..bc6b49f 100644 --- a/src/auth_server/tests/test_app.py +++ b/src/auth_server/tests/test_app.py @@ -8,6 +8,7 @@ from typing import Any, Dict, Mapping, Optional from unittest import TestCase, mock from unittest.mock import AsyncMock +from urllib.parse import parse_qs, urlparse import yaml from cryptography import x509 @@ -23,7 +24,11 @@ AccessTokenFlags, AccessTokenRequest, Client, + ContinueRequest, + FinishInteraction, + FinishInteractionMethod, GrantRequest, + HashMethod, InteractionRequest, Key, Proof, @@ -37,7 +42,7 @@ from auth_server.tests.utils import create_tls_fed_metadata, tls_fed_metadata_to_jws from auth_server.time_utils import utc_now from auth_server.tls_fed_auth import get_tls_fed_metadata -from auth_server.utils import get_signing_key, hash_with, load_jwks +from auth_server.utils import get_hash_by_name, get_signing_key, hash_with, load_jwks __author__ = "lundberg" @@ -1034,6 +1039,110 @@ def test_transaction_jws_continue(self): assert claims["saml_issuer"] == "https://idp.example.com" assert claims["saml_eppn"] == "test@example.com" + def test_transaction_jws_continue_redirect_finish(self): + self.config["auth_flows"] = json.dumps(["InteractionFlow"]) + self._update_app_config(config=self.config) + + client_key_dict = self.client_jwk.export(as_dict=True) + client_jwk = ECJWK(**client_key_dict) + + client_nonce = "client_nonce" + req = GrantRequest( + client=Client(key=Key(proof=Proof(method=ProofMethod.JWS), jwk=client_jwk)), + access_token=[AccessTokenRequest(flags=[AccessTokenFlags.BEARER])], + interact=InteractionRequest( + start=[StartInteractionMethod.REDIRECT], + finish=FinishInteraction( + method=FinishInteractionMethod.REDIRECT, + uri="https://example.com/redirect", + nonce=client_nonce, + ), + ), + ) + transaction_url = "http://testserver/transaction" + jws_header = { + "typ": SupportedJWSType.JWS, + "alg": SupportedAlgorithms.ES256.value, + "kid": self.client_jwk.key_id, + "htm": SupportedHTTPMethods.POST.value, + "uri": transaction_url, + "created": int(utc_now().timestamp()), + } + _jws = jws.JWS(payload=req.json(exclude_unset=True)) + _jws.add_signature( + key=self.client_jwk, + protected=json.dumps(jws_header), + ) + content = _jws.serialize(compact=True) + + client_header = {"Content-Type": "application/jose+json"} + response = self.client.post("/transaction", content=content, headers=client_header) + assert response.status_code == 200 + + # continue response with no continue reference in uri + assert "continue" in response.json() + continue_response = response.json()["continue"] + assert continue_response["uri"].startswith("http://testserver/continue") is True + assert continue_response["access_token"]["value"] is not None + + # do interaction + interaction_response = response.json()["interact"] + transaction_id = interaction_response["redirect"].split("http://testserver/interaction/redirect/")[1] + as_nonce = interaction_response["finish"] + + # check redirect to SAML SP + response = self.client.get(interaction_response["redirect"], allow_redirects=False) + assert response.status_code == 307 + assert response.headers["location"].startswith("http://testserver/saml2/sp/authn/") + + # fake a completed SAML authentication + self._fake_saml_authentication(transaction_id=transaction_id) + + # complete interaction + response = self.client.get(interaction_response["redirect"], allow_redirects=False) + assert response.status_code == 307 + + # "receive" redirect back to our endpoint and pick out hash and interact_ref + urlparsed_redirect_location = urlparse(response.headers["location"]) + qs = parse_qs(urlparsed_redirect_location.query) + interact_hash = qs["hash"][0] + interact_ref = qs["interact_ref"][0] + + # verify hash + hash_alg = get_hash_by_name(hash_name=HashMethod.SHA_256.value) # defaults to SHA256 + plaintext = f"{client_nonce}\n{as_nonce}\n{interact_ref}\n{transaction_url}".encode(encoding="ascii") + hash_res = hash_with(hash_alg, plaintext) + assert base64.urlsafe_b64encode(hash_res).decode(encoding="ascii") == interact_hash + + # continue request after interaction is completed + jws_header["uri"] = continue_response["uri"] + jws_header["created"] = int(utc_now().timestamp()) + # calculate ath header value + access_token_hash = hash_with(SHA256(), continue_response["access_token"]["value"].encode()) + jws_header["ath"] = base64.urlsafe_b64encode(access_token_hash).decode("utf-8") + # create jws from continue request + _jws = jws.JWS(payload=ContinueRequest(interact_ref=interact_ref).json(exclude_unset=True)) + _jws.add_signature( + key=self.client_jwk, + protected=json.dumps(jws_header), + ) + continue_data = _jws.serialize(compact=True) + authorization_header = f'GNAP {continue_response["access_token"]["value"]}' + client_header["Authorization"] = authorization_header + response = self.client.post(continue_response["uri"], content=continue_data, headers=client_header) + + assert response.status_code == 200 + assert "access_token" in response.json() + access_token = response.json()["access_token"] + assert AccessTokenFlags.BEARER.value in access_token["flags"] + + # Verify token and check claims + claims = self._get_access_token_claims(access_token=access_token, client=self.client) + assert claims["auth_source"] == AuthSource.INTERACTION + assert claims["aud"] == "some_audience" + assert claims["saml_issuer"] == "https://idp.example.com" + assert claims["saml_eppn"] == "test@example.com" + def test_transaction_jwsd_continue(self): self.config["auth_flows"] = json.dumps(["InteractionFlow"]) self._update_app_config(config=self.config)