diff --git a/openwisp_radius/saml/utils.py b/openwisp_radius/saml/utils.py index c5da5b05..b87f7fff 100644 --- a/openwisp_radius/saml/utils.py +++ b/openwisp_radius/saml/utils.py @@ -9,3 +9,16 @@ def get_url_or_path(url): if parsed_url.netloc: return f'{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path}' return parsed_url.path + + +def get_email_from_ava(ava): + email_keys = ( + 'email', + 'mail', + 'uid', + ) + for key in email_keys: + email = ava.get(key, None) + if email is not None: + return email[0] + return None diff --git a/openwisp_radius/saml/views.py b/openwisp_radius/saml/views.py index 3ab0b9d2..133731e2 100644 --- a/openwisp_radius/saml/views.py +++ b/openwisp_radius/saml/views.py @@ -2,6 +2,8 @@ from urllib.parse import parse_qs, urlparse import swapper +from allauth.account.models import EmailAddress +from allauth.utils import ValidationError from django.conf import settings from django.contrib.auth import logout from django.core.exceptions import ObjectDoesNotExist, PermissionDenied @@ -16,7 +18,7 @@ from .. import settings as app_settings from ..api.views import RadiusTokenMixin from ..utils import get_organization_radius_settings, load_model -from .utils import get_url_or_path +from .utils import get_email_from_ava, get_url_or_path logger = logging.getLogger(__name__) @@ -66,6 +68,37 @@ def post_login_hook(self, request, user, session_info): try: user.registered_user except ObjectDoesNotExist: + email = None + uid_is_email = 'email' in getattr( + settings, 'SAML_ATTRIBUTE_MAPPING', {} + ).get('uid', ()) + if uid_is_email: + email = session_info['name_id'].text + if email is None: + email = get_email_from_ava(session_info['ava']) + if email: + user.email = email + try: + user.full_clean() + user.save() + EmailAddress.objects.create( + user=user, email=email, verified=True, primary=True + ) + except ValidationError: + assertion_email = get_email_from_ava(session_info['ava']) + if assertion_email and assertion_email != email: + user.email = assertion_email + try: + user.full_clean() + user.save() + EmailAddress.objects.create( + user=user, + email=assertion_email, + verified=True, + primary=True, + ) + except ValidationError: + raise ValidationError('Email Verification Failed') registered_user = RegisteredUser( user=user, method='saml', is_verified=app_settings.SAML_IS_VERIFIED ) diff --git a/openwisp_radius/tests/test_saml/test_views.py b/openwisp_radius/tests/test_saml/test_views.py index 4ef1868d..a0234838 100644 --- a/openwisp_radius/tests/test_saml/test_views.py +++ b/openwisp_radius/tests/test_saml/test_views.py @@ -3,7 +3,9 @@ from urllib.parse import parse_qs, urlparse import swapper +from allauth.account.models import EmailAddress from django.contrib.auth import SESSION_KEY, get_user_model +from django.core.validators import ValidationError from django.test import TestCase, override_settings from django.urls import reverse from djangosaml2.tests import auth_response, conf @@ -58,17 +60,19 @@ class TestAssertionConsumerServiceView(TestSamlMixin, TestCase): def _get_relay_state(self, redirect_url, org_slug): return f'{redirect_url}?org={org_slug}' - def _get_saml_response_for_acs_view(self, relay_state): + def _get_saml_response_for_acs_view(self, relay_state, uid='org_user@example.com'): response = self.client.get(self.login_url, {'RelayState': relay_state}) saml2_req = saml2_from_httpredirect_request(response.url) session_id = get_session_id_from_saml2(saml2_req) self.add_outstanding_query(session_id, relay_state) - return auth_response(session_id, 'org_user@example.com'), relay_state + return auth_response(session_id, uid), relay_state def _post_successful_auth_assertions(self, query_params, org_slug): self.assertEqual(User.objects.count(), 1) user_id = self.client.session[SESSION_KEY] user = User.objects.get(id=user_id) + email = EmailAddress.objects.filter(user=user) + self.assertEqual(email.count(), 1) self.assertEqual(user.username, 'org_user@example.com') self.assertEqual(OrganizationUser.objects.count(), 1) org_user = OrganizationUser.objects.get(user_id=user_id) @@ -100,6 +104,24 @@ def test_organization_slug_present(self): query_params = parse_qs(urlparse(response.url).query) self._post_successful_auth_assertions(query_params, org_slug) + @capture_any_output() + def test_invalid_email_raise_validation_error(self): + invalid_email = 'invalid_email@example' + relay_state = self._get_relay_state( + redirect_url='https://captive-portal.example.com', org_slug='default' + ) + saml_response, relay_state = self._get_saml_response_for_acs_view( + relay_state, uid=invalid_email + ) + with self.assertRaises(ValidationError): + self.client.post( + reverse('radius:saml2_acs'), + { + 'SAMLResponse': self.b64_for_post(saml_response), + 'RelayState': relay_state, + }, + ) + @capture_any_output() def test_relay_state_relative_path(self): expected_redirect_path = '/captive/portal/page'