diff --git a/eox_core/pipeline.py b/eox_core/pipeline.py index 5de15911..138e05b0 100644 --- a/eox_core/pipeline.py +++ b/eox_core/pipeline.py @@ -7,6 +7,11 @@ from eox_core.edxapp_wrapper.users import get_user_profile +try: + from social_core.exceptions import NotAllowedToDisconnect +except ImportError: + NotAllowedToDisconnect = object + LOG = logging.getLogger(__name__) @@ -64,3 +69,22 @@ def force_user_post_save_callback(auth_entry, is_new, user=None, *args, **kwargs instance=user, created=False ) + + +def check_disconnect_pipeline_enabled(backend, *args, **kwargs): + """ + This pipeline function checks whether disconnection from the auth provider is enabled or not. That's + done checking for `disableDisconnectPipeline` setting defined in the provider configuration. + + For example: + To avoid disconnection from SAML, add the following to `Other config str` in your SAMLConfiguration: + + "BACKEND_OPTIONS": { "disableDisconnectPipeline":true }, + + Now, to avoid disconnection from an Oauth2Provider, add the same setting to `Other settings` in your + Oauth2Provider. + + It's recommended to place this function at the beginning of the pipeline. + """ + if backend and backend.setting("BACKEND_OPTIONS", {}).get("disableDisconnectPipeline"): + raise NotAllowedToDisconnect() # pylint: disable=raising-non-exception diff --git a/eox_core/tests/test_pipeline.py b/eox_core/tests/test_pipeline.py index ff6fed57..3295d397 100644 --- a/eox_core/tests/test_pipeline.py +++ b/eox_core/tests/test_pipeline.py @@ -5,7 +5,7 @@ from django.test import TestCase from mock import MagicMock, PropertyMock, patch -from eox_core.pipeline import ensure_user_has_profile +from eox_core.pipeline import check_disconnect_pipeline_enabled, ensure_user_has_profile class EnsureUserProfileTest(TestCase): @@ -39,3 +39,36 @@ def test_user_without_profile_works(self, import_mock): ensure_user_has_profile(self.backend_mock, {}, user=self.user_mock) backend().get_user_profile().objects.create.assert_called() + + +class TestDisconnectionPipeline(TestCase): + """Test disconnection from TPA provider.""" + + def setUp(self): + self.backend_mock = MagicMock() + + def test_disable_disconnect_pipeline(self): + """ + Test disabling disconnection pipeline through TPA provider settings. + """ + self.backend_mock.setting.return_value.get.return_value = True + + with self.assertRaises(Exception): + check_disconnect_pipeline_enabled(self.backend_mock) + + def test_disconnect_pipeline_enable_explicit(self): + """ + Test explicitly enable disconnection pipeline through TPA provider settings. + """ + self.backend_mock.setting.return_value.get.return_value = False + + self.assertIsNone(check_disconnect_pipeline_enabled(self.backend_mock)) + + def test_disconnect_pipeline_enable_implicit(self): + """ + Test enable disconnection pipeline by not defining disable setting through TPA + provider settings. + """ + self.backend_mock.setting.return_value.get.return_value = None + + self.assertIsNone(check_disconnect_pipeline_enabled(self.backend_mock))