Skip to content

Commit

Permalink
Added pipeline step to avoid disconnection from tpa provider
Browse files Browse the repository at this point in the history
  • Loading branch information
mariajgrimaldi committed Jan 20, 2021
1 parent 51ae3fc commit c66b3d5
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 1 deletion.
24 changes: 24 additions & 0 deletions eox_core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@

from eox_core.edxapp_wrapper.users import get_user_profile

try:
from social_core.exceptions import NotAllowedToDisconnect
except ImportError:
NotAllowedToDisconnect = object

LOG = logging.getLogger(__name__)


Expand Down Expand Up @@ -64,3 +69,22 @@ def force_user_post_save_callback(auth_entry, is_new, user=None, *args, **kwargs
instance=user,
created=False
)


def check_disconnect_pipeline_enabled(backend, *args, **kwargs):
"""
This pipeline function checks whether disconnection from the auth provider is enabled or not. That's
done checking for `disableDisconnectPipeline` setting defined in the provider configuration.
For example:
To avoid disconnection from SAML, add the following to `Other config str` in your SAMLConfiguration:
"BACKEND_OPTIONS": { "disableDisconnectPipeline":true },
Now, to avoid disconnection from an Oauth2Provider, add the same setting to `Other settings` in your
Oauth2Provider.
It's recommended to place this function at the beginning of the pipeline.
"""
if backend and backend.setting("BACKEND_OPTIONS", {}).get("disableDisconnectPipeline"):
raise NotAllowedToDisconnect() # pylint: disable=raising-non-exception
35 changes: 34 additions & 1 deletion eox_core/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from django.test import TestCase
from mock import MagicMock, PropertyMock, patch

from eox_core.pipeline import ensure_user_has_profile
from eox_core.pipeline import check_disconnect_pipeline_enabled, ensure_user_has_profile


class EnsureUserProfileTest(TestCase):
Expand Down Expand Up @@ -39,3 +39,36 @@ def test_user_without_profile_works(self, import_mock):

ensure_user_has_profile(self.backend_mock, {}, user=self.user_mock)
backend().get_user_profile().objects.create.assert_called()


class TestDisconnectionPipeline(TestCase):
"""Test disconnection from TPA provider."""

def setUp(self):
self.backend_mock = MagicMock()

def test_disable_disconnect_pipeline(self):
"""
Test disabling disconnection pipeline through TPA provider settings.
"""
self.backend_mock.setting.return_value.get.return_value = True

with self.assertRaises(Exception):
check_disconnect_pipeline_enabled(self.backend_mock)

def test_disconnect_pipeline_enable_explicit(self):
"""
Test explicitly enable disconnection pipeline through TPA provider settings.
"""
self.backend_mock.setting.return_value.get.return_value = False

self.assertIsNone(check_disconnect_pipeline_enabled(self.backend_mock))

def test_disconnect_pipeline_enable_implicit(self):
"""
Test enable disconnection pipeline by not defining disable setting through TPA
provider settings.
"""
self.backend_mock.setting.return_value.get.return_value = None

self.assertIsNone(check_disconnect_pipeline_enabled(self.backend_mock))

0 comments on commit c66b3d5

Please sign in to comment.