From cbe429f3fa72fa70994721ddcd0ea30bf4753ee4 Mon Sep 17 00:00:00 2001 From: "gcp-cherry-pick-bot[bot]" <98988430+gcp-cherry-pick-bot[bot]@users.noreply.github.com> Date: Thu, 9 Jan 2025 16:20:49 +0100 Subject: [PATCH] providers/saml: fix invalid SAML Response when assertion and response are signed (cherry-pick #12611) (#12613) providers/saml: fix invalid SAML Response when assertion and response are signed (#12611) * providers/saml: fix invalid SAML Response when assertion and response are signed * validate against schema too --------- Signed-off-by: Jens Langhammer Co-authored-by: Jens L. --- .../providers/saml/processors/assertion.py | 14 +++++++++++++- .../providers/saml/tests/test_auth_n_request.py | 17 +++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/authentik/providers/saml/processors/assertion.py b/authentik/providers/saml/processors/assertion.py index cdce8586616d..dd618f2800ce 100644 --- a/authentik/providers/saml/processors/assertion.py +++ b/authentik/providers/saml/processors/assertion.py @@ -256,7 +256,7 @@ def get_assertion(self) -> Element: assertion.attrib["IssueInstant"] = self._issue_instant assertion.append(self.get_issuer()) - if self.provider.signing_kp: + if self.provider.signing_kp and self.provider.sign_assertion: sign_algorithm_transform = SIGN_ALGORITHM_TRANSFORM_MAP.get( self.provider.signature_algorithm, xmlsec.constants.TransformRsaSha1 ) @@ -295,6 +295,18 @@ def get_response(self) -> Element: response.append(self.get_issuer()) + if self.provider.signing_kp and self.provider.sign_response: + sign_algorithm_transform = SIGN_ALGORITHM_TRANSFORM_MAP.get( + self.provider.signature_algorithm, xmlsec.constants.TransformRsaSha1 + ) + signature = xmlsec.template.create( + response, + xmlsec.constants.TransformExclC14N, + sign_algorithm_transform, + ns=xmlsec.constants.DSigNs, + ) + response.append(signature) + status = SubElement(response, f"{{{NS_SAML_PROTOCOL}}}Status") status_code = SubElement(status, f"{{{NS_SAML_PROTOCOL}}}StatusCode") status_code.attrib["Value"] = "urn:oasis:names:tc:SAML:2.0:status:Success" diff --git a/authentik/providers/saml/tests/test_auth_n_request.py b/authentik/providers/saml/tests/test_auth_n_request.py index 1bd58f04d523..48d6d713b686 100644 --- a/authentik/providers/saml/tests/test_auth_n_request.py +++ b/authentik/providers/saml/tests/test_auth_n_request.py @@ -2,8 +2,10 @@ from base64 import b64encode +from defusedxml.lxml import fromstring from django.http.request import QueryDict from django.test import TestCase +from lxml import etree # nosec from authentik.blueprints.tests import apply_blueprint from authentik.core.tests.utils import create_test_admin_user, create_test_cert, create_test_flow @@ -11,12 +13,14 @@ from authentik.events.models import Event, EventAction from authentik.lib.generators import generate_id from authentik.lib.tests.utils import get_request +from authentik.lib.xml import lxml_from_string from authentik.providers.saml.models import SAMLPropertyMapping, SAMLProvider from authentik.providers.saml.processors.assertion import AssertionProcessor from authentik.providers.saml.processors.authn_request_parser import AuthNRequestParser from authentik.sources.saml.exceptions import MismatchedRequestID from authentik.sources.saml.models import SAMLSource from authentik.sources.saml.processors.constants import ( + NS_MAP, SAML_BINDING_REDIRECT, SAML_NAME_ID_FORMAT_EMAIL, SAML_NAME_ID_FORMAT_UNSPECIFIED, @@ -185,6 +189,19 @@ def test_request_signed_both(self): self.assertEqual(response.count(response_proc._assertion_id), 2) self.assertEqual(response.count(response_proc._response_id), 2) + schema = etree.XMLSchema( + etree.parse("schemas/saml-schema-protocol-2.0.xsd", parser=etree.XMLParser()) # nosec + ) + self.assertTrue(schema.validate(lxml_from_string(response))) + + response_xml = fromstring(response) + self.assertEqual( + len(response_xml.xpath("//saml:Assertion/ds:Signature", namespaces=NS_MAP)), 1 + ) + self.assertEqual( + len(response_xml.xpath("//samlp:Response/ds:Signature", namespaces=NS_MAP)), 1 + ) + # Now parse the response (source) http_request.POST = QueryDict(mutable=True) http_request.POST["SAMLResponse"] = b64encode(response.encode()).decode()