diff --git a/ChangeLog.rst b/ChangeLog.rst index 88707b1a73..93c1be4cbc 100644 --- a/ChangeLog.rst +++ b/ChangeLog.rst @@ -10,6 +10,9 @@ Note worthy changes - New provider: Miro. + - It is now possible to manage OpenID Connect providers via the Django + admin. Simply add a `SocialApp` for each OpenID Connect provider. + Security notice --------------- @@ -35,6 +38,27 @@ Backwards incompatible changes - The Mozilla Persona provider has been removed. The project was shut down on November 30th 2016. +- A large internal refactor has been performed to be able to add support for + providers oferring one or more subproviders. This refactor has the following + impact: + + - The provider registry methods ``get_list()``, ``by_id()`` have been + removed. The registry now only providers access to the provider classes, not + the instances. + + - ``provider.get_app()`` has been removed -- use ``provider.app`` instead. + + - ``SocialApp.objects.get_current()`` has been removed. + + - The ``SocialApp`` model now has additional fields ``provider_id``, and + ``settings``. + + - The OpenID Connect provider ``SOCIALACCOUNT_PROVIDERS`` settings structure + changed. Instead of the OpenID Connect specific ``SERVERS`` construct, it + now uses the regular ``APPS`` approach. Please refer to the OpenID Connect + documentation for details. + + 0.54.0 (2023-03-31) ******************* diff --git a/allauth/socialaccount/adapter.py b/allauth/socialaccount/adapter.py index 8b357fb839..37a2c78947 100644 --- a/allauth/socialaccount/adapter.py +++ b/allauth/socialaccount/adapter.py @@ -1,6 +1,11 @@ from __future__ import absolute_import -from django.core.exceptions import ValidationError +from django.core.exceptions import ( + ImproperlyConfigured, + MultipleObjectsReturned, + ValidationError, +) +from django.db.models import Q from django.urls import reverse from django.utils.translation import gettext_lazy as _ @@ -188,18 +193,120 @@ def deserialize_instance(self, model, data): def serialize_instance(self, instance): return serialize_instance(instance) - def get_app(self, request, provider, config=None): + def list_providers(self, request): + from allauth.socialaccount.providers import registry + + ret = [] + provider_classes = registry.get_class_list() + apps = self.list_apps(request) + apps_map = {} + for app in apps: + apps_map.setdefault(app.provider, []).append(app) + for provider_class in provider_classes: + provider_apps = apps_map.get(provider_class.id, []) + if not provider_apps: + if provider_class.uses_apps: + continue + provider_apps = [None] + for app in provider_apps: + provider = provider_class(request=request, app=app) + ret.append(provider) + return ret + + def get_provider(self, request, provider): + """Looks up a `provider`, supporting subproviders by looking up by + `provider_id`. + """ + from allauth.socialaccount.providers import registry + + provider_class = registry.get_class(provider) + if provider_class is None or provider_class.uses_apps: + app = self.get_app(request, provider=provider) + if not provider_class: + # In this case, the `provider` argument passed was a + # `provider_id`. + provider_class = registry.get_class(app.provider) + if not provider_class: + raise ImproperlyConfigured(f"unknown provider: {app.provider}") + return provider_class(request, app=app) + elif provider_class: + assert not provider_class.uses_apps + return provider_class(request, app=None) + else: + raise ImproperlyConfigured(f"unknown provider: {app.provider}") + + def list_apps(self, request, provider=None, client_id=None): + """SocialApp's can be setup in the database, or, via + `settings.SOCIALACCOUNT_PROVIDERS`. This methods returns a uniform list + of all known apps matching the specified criteria, and blends both + (db/settings) sources of data. + """ # NOTE: Avoid loading models at top due to registry boot... from allauth.socialaccount.models import SocialApp - config = config or app_settings.PROVIDERS.get(provider, {}).get("APP") - if config: - app = SocialApp(provider=provider) - for field in ["client_id", "secret", "key", "certificate_key"]: - setattr(app, field, config.get(field)) - else: - app = SocialApp.objects.get_current(provider, request) - return app + # Map provider to the list of apps. + provider_to_apps = {} + + # First, populate it with the DB backed apps. + db_apps = SocialApp.objects.on_site(request) + if provider: + db_apps = db_apps.filter( + Q(provider_id="", provider=provider) | Q(provider_id=provider) + ) + if client_id: + db_apps = db_apps.filter(client_id=client_id) + for app in db_apps: + apps = provider_to_apps.setdefault(app.provider, []) + apps.append(app) + + # Then, extend it with the settings backed apps. + for p, pcfg in app_settings.PROVIDERS.items(): + app_configs = pcfg.get("APPS") + if app_configs is None: + app_config = pcfg.get("APP") + if app_config is None: + continue + app_configs = [app_config] + + apps = provider_to_apps.setdefault(p, []) + for config in app_configs: + app = SocialApp(provider=p) + for field in [ + "name", + "provider_id", + "client_id", + "secret", + "key", + "certificate_key", + "settings", + ]: + if field in config: + setattr(app, field, config[field]) + if client_id and app.client_id != client_id: + continue + if ( + provider + and app.provider_id != provider + and app.provider != provider + ): + continue + apps.append(app) + + # Flatten the list of apps. + apps = [] + for provider_apps in provider_to_apps.values(): + apps.extend(provider_apps) + return apps + + def get_app(self, request, provider, client_id=None): + from allauth.socialaccount.models import SocialApp + + apps = self.list_apps(request, provider=provider, client_id=client_id) + if len(apps) > 1: + raise MultipleObjectsReturned + elif len(apps) == 0: + raise SocialApp.DoesNotExist() + return apps[0] def get_adapter(request=None): diff --git a/allauth/socialaccount/app_settings.py b/allauth/socialaccount/app_settings.py index 2740c1e8ba..e61eae59c5 100644 --- a/allauth/socialaccount/app_settings.py +++ b/allauth/socialaccount/app_settings.py @@ -36,7 +36,35 @@ def PROVIDERS(self): """ Provider specific settings """ - return self._setting("PROVIDERS", {}) + ret = self._setting("PROVIDERS", {}) + oidc = ret.get("openid_connect") + if oidc: + ret["openid_connect"] = self._migrate_oidc(oidc) + return ret + + def _migrate_oidc(self, oidc): + servers = oidc.get("SERVERS") + if servers is None: + return oidc + ret = {} + apps = [] + for server in servers: + app = dict(**server["APP"]) + app_settings = {} + if "token_auth_method" in server: + app_settings["token_auth_method"] = server["token_auth_method"] + app_settings["server_url"] = server["server_url"] + app.update( + { + "name": server.get("name", ""), + "provider_id": server["id"], + "settings": app_settings, + } + ) + assert app["provider_id"] + apps.append(app) + ret["APPS"] = apps + return ret @property def EMAIL_REQUIRED(self): diff --git a/allauth/socialaccount/migrations/0004_app_provider_id_settings.py b/allauth/socialaccount/migrations/0004_app_provider_id_settings.py new file mode 100644 index 0000000000..1fde4bd33a --- /dev/null +++ b/allauth/socialaccount/migrations/0004_app_provider_id_settings.py @@ -0,0 +1,29 @@ +# Generated by Django 3.2.19 on 2023-06-30 13:16 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("socialaccount", "0003_extra_data_default_dict"), + ] + + operations = [ + migrations.AddField( + model_name="socialapp", + name="provider_id", + field=models.CharField( + blank=True, max_length=200, verbose_name="provider ID" + ), + ), + migrations.AddField( + model_name="socialapp", + name="settings", + field=models.JSONField(blank=True, default=dict), + ), + migrations.AlterField( + model_name="socialaccount", + name="provider", + field=models.CharField(max_length=200, verbose_name="provider"), + ), + ] diff --git a/allauth/socialaccount/models.py b/allauth/socialaccount/models.py index 66d1e375aa..15c75b2194 100644 --- a/allauth/socialaccount/models.py +++ b/allauth/socialaccount/models.py @@ -19,30 +19,31 @@ class SocialAppManager(models.Manager): - def get_current(self, provider, request=None): - cache = {} - if request: - cache = getattr(request, "_socialapp_cache", {}) - request._socialapp_cache = cache - app = cache.get(provider) - if not app: - if allauth.app_settings.SITES_ENABLED: - site = get_current_site(request) - app = self.get(sites__id=site.id, provider=provider) - else: - app = self.get(provider=provider) - cache[provider] = app - return app + def on_site(self, request): + if allauth.app_settings.SITES_ENABLED: + site = get_current_site(request) + return self.filter(sites__id=site.id) + return self.all() class SocialApp(models.Model): objects = SocialAppManager() + # The provider type, e.g. "google", "telegram", "saml". provider = models.CharField( verbose_name=_("provider"), max_length=30, choices=providers.registry.as_choices(), ) + # For providers that support subproviders, such as OpenID Connect and SAML, + # this ID identifies that instance. SocialAccount's originating from app + # will have their `provider` field set to the `provider_id` if available, + # else `provider`. + provider_id = models.CharField( + verbose_name=_("provider ID"), + max_length=200, + blank=True, + ) name = models.CharField(verbose_name=_("name"), max_length=40) client_id = models.CharField( verbose_name=_("client id"), @@ -58,6 +59,8 @@ class SocialApp(models.Model): key = models.CharField( verbose_name=_("key"), max_length=191, blank=True, help_text=_("Key") ) + settings = models.JSONField(default=dict, blank=True) + if allauth.app_settings.SITES_ENABLED: # Most apps can be used across multiple domains, therefore we use # a ManyToManyField. Note that Facebook requires an app per domain @@ -79,13 +82,18 @@ class Meta: def __str__(self): return self.name + def get_provider(self, request): + provider_class = providers.registry.get_class(self.provider) + return provider_class(request=request, app=self) + class SocialAccount(models.Model): user = models.ForeignKey(allauth.app_settings.USER_MODEL, on_delete=models.CASCADE) + # Given a `SocialApp` from which this account originates, this field equals + # the app's `app.provider_id` if available, `app.provider` otherwise. provider = models.CharField( verbose_name=_("provider"), - max_length=30, - choices=providers.registry.as_choices(), + max_length=200, ) # Just in case you're wondering if an OpenID identity URL is going # to fit in a 'uid': @@ -129,8 +137,15 @@ def get_profile_url(self): def get_avatar_url(self): return self.get_provider_account().get_avatar_url() - def get_provider(self): - return providers.registry.by_id(self.provider) + def get_provider(self, request=None): + provider = getattr(self, "_provider", None) + if provider: + return provider + adapter = get_adapter(request) + provider = self._provider = adapter.get_provider( + request, provider=self.provider + ) + return provider def get_provider_account(self): return self.get_provider().wrap_account(self) diff --git a/allauth/socialaccount/providers/__init__.py b/allauth/socialaccount/providers/__init__.py index f5e74ab69e..f052d5616f 100644 --- a/allauth/socialaccount/providers/__init__.py +++ b/allauth/socialaccount/providers/__init__.py @@ -2,6 +2,9 @@ from collections import OrderedDict from django.apps import apps +from django.conf import settings + +from allauth.utils import import_attribute class ProviderRegistry(object): @@ -9,16 +12,15 @@ def __init__(self): self.provider_map = OrderedDict() self.loaded = False - def get_list(self, request=None): + def get_class_list(self): self.load() - return [provider_cls(request) for provider_cls in self.provider_map.values()] + return list(self.provider_map.values()) def register(self, cls): self.provider_map[cls.id] = cls - def by_id(self, id, request=None): - self.load() - return self.provider_map[id](request=request) + def get_class(self, id): + return self.provider_map.get(id) def as_choices(self): self.load() @@ -41,7 +43,13 @@ def load(self): except ImportError: pass else: + provider_settings = getattr(settings, "SOCIALACCOUNT_PROVIDERS", {}) for cls in getattr(provider_module, "provider_classes", []): + provider_class = provider_settings.get(cls.id, {}).get( + "provider_class" + ) + if provider_class: + cls = import_attribute(provider_class) self.register(cls) self.loaded = True diff --git a/allauth/socialaccount/providers/apple/client.py b/allauth/socialaccount/providers/apple/client.py index 5be90d73c1..77a990429c 100644 --- a/allauth/socialaccount/providers/apple/client.py +++ b/allauth/socialaccount/providers/apple/client.py @@ -36,7 +36,7 @@ class AppleOAuth2Client(OAuth2Client): def generate_client_secret(self): """Create a JWT signed with an apple provided private key""" now = datetime.utcnow() - app = get_adapter().get_app(self.request, "apple") + app = get_adapter(self.request).get_app(self.request, "apple") if not app.key: raise ImproperlyConfigured("Apple 'key' missing") if not app.certificate_key: diff --git a/allauth/socialaccount/providers/asana/tests.py b/allauth/socialaccount/providers/asana/tests.py index c33655658d..2ca56c21de 100644 --- a/allauth/socialaccount/providers/asana/tests.py +++ b/allauth/socialaccount/providers/asana/tests.py @@ -1,11 +1,12 @@ -from allauth.socialaccount.providers import registry -from allauth.socialaccount.tests import create_oauth2_tests -from allauth.tests import MockedResponse +from allauth.socialaccount.tests import OAuth2TestsMixin +from allauth.tests import MockedResponse, TestCase from .provider import AsanaProvider -class AsanaTests(create_oauth2_tests(registry.by_id(AsanaProvider.id))): +class AsanaTests(OAuth2TestsMixin, TestCase): + provider_id = AsanaProvider.id + def get_mocked_response(self): return MockedResponse( 200, diff --git a/allauth/socialaccount/providers/base/provider.py b/allauth/socialaccount/providers/base/provider.py index e3f7647efa..e619847972 100644 --- a/allauth/socialaccount/providers/base/provider.py +++ b/allauth/socialaccount/providers/base/provider.py @@ -1,3 +1,5 @@ +from django.core.exceptions import ImproperlyConfigured + from allauth.account.models import EmailAddress from allauth.socialaccount import app_settings from allauth.socialaccount.adapter import get_adapter @@ -9,9 +11,13 @@ class ProviderException(Exception): class Provider(object): slug = None + uses_apps = True - def __init__(self, request): + def __init__(self, request, app=None): self.request = request + if self.uses_apps and app is None: + raise ValueError("missing: app") + self.app = app @classmethod def get_slug(cls): @@ -24,10 +30,6 @@ def get_login_url(self, request, next=None, **kwargs): """ raise NotImplementedError("get_login_url() for " + self.name) - def get_app(self, request, config=None): - adapter = get_adapter(request) - return adapter.get_app(request, self.id, config=config) - def media_js(self, request): """ Some providers may require extra scripts (e.g. a Facebook connect) @@ -61,9 +63,20 @@ def sociallogin_from_response(self, request, response): adapter = get_adapter(request) uid = self.extract_uid(response) + if not isinstance(uid, str): + raise ValueError(f"uid must be a string: {repr(uid)}") + if len(uid) > app_settings.UID_MAX_LENGTH: + raise ImproperlyConfigured( + f"SOCIALACCOUNT_UID_MAX_LENGTH too small (<{len(uid)})" + ) + extra_data = self.extract_extra_data(response) common_fields = self.extract_common_fields(response) - socialaccount = SocialAccount(extra_data=extra_data, uid=uid, provider=self.id) + socialaccount = SocialAccount( + extra_data=extra_data, + uid=uid, + provider=self.app.provider_id or self.app.provider, + ) email_addresses = self.extract_email_addresses(response) self.cleanup_email_addresses(common_fields.get("email"), email_addresses) sociallogin = SocialLogin( diff --git a/allauth/socialaccount/providers/bitbucket_oauth2/tests.py b/allauth/socialaccount/providers/bitbucket_oauth2/tests.py index 4996714453..41363fda42 100644 --- a/allauth/socialaccount/providers/bitbucket_oauth2/tests.py +++ b/allauth/socialaccount/providers/bitbucket_oauth2/tests.py @@ -6,17 +6,16 @@ from django.test.utils import override_settings from allauth.socialaccount.models import SocialAccount -from allauth.socialaccount.providers import registry -from allauth.socialaccount.tests import create_oauth2_tests -from allauth.tests import MockedResponse, patch +from allauth.socialaccount.tests import OAuth2TestsMixin +from allauth.tests import MockedResponse, TestCase, patch from .provider import BitbucketOAuth2Provider @override_settings(SOCIALACCOUNT_QUERY_EMAIL=True, SOCIALACCOUNT_STORE_TOKENS=True) -class BitbucketOAuth2Tests( - create_oauth2_tests(registry.by_id(BitbucketOAuth2Provider.id)) -): +class BitbucketOAuth2Tests(OAuth2TestsMixin, TestCase): + provider_id = BitbucketOAuth2Provider.id + response_data = """ { "created_on": "2011-12-20T16:34:07.132459+00:00", diff --git a/allauth/socialaccount/providers/clever/provider.py b/allauth/socialaccount/providers/clever/provider.py index 5e93141822..807111866d 100644 --- a/allauth/socialaccount/providers/clever/provider.py +++ b/allauth/socialaccount/providers/clever/provider.py @@ -22,7 +22,7 @@ class CleverProvider(OAuth2Provider): account_class = CleverAccount def extract_uid(self, data): - return data.get("data", {}).get("id") + return data["data"]["id"] def get_user_type(self, data): return list(data.get("data", {}).get("roles", {}).keys())[0] diff --git a/allauth/socialaccount/providers/clever/tests.py b/allauth/socialaccount/providers/clever/tests.py index c0e10c8e2d..8e2441090d 100644 --- a/allauth/socialaccount/providers/clever/tests.py +++ b/allauth/socialaccount/providers/clever/tests.py @@ -39,6 +39,7 @@ def get_mocked_response(self): 200, """{ "data": { + "id": "62027798269867124d10259e", "roles": { "district_admin": {}, "contact": {} diff --git a/allauth/socialaccount/providers/clever/views.py b/allauth/socialaccount/providers/clever/views.py index 928a131c92..1b4ec48fed 100644 --- a/allauth/socialaccount/providers/clever/views.py +++ b/allauth/socialaccount/providers/clever/views.py @@ -32,7 +32,7 @@ def get_data(self, token): if resp.status_code != 200: raise OAuth2Error() resp = resp.json() - user_id = resp.get("data", {}).get("id") + user_id = resp["data"]["id"] user_details = requests.get( "{}/{}".format(self.user_details_url, user_id), headers={"Authorization": "Bearer {}".format(token)}, diff --git a/allauth/socialaccount/providers/doximity/provider.py b/allauth/socialaccount/providers/doximity/provider.py index abf44c35a2..cad6be7db4 100644 --- a/allauth/socialaccount/providers/doximity/provider.py +++ b/allauth/socialaccount/providers/doximity/provider.py @@ -20,7 +20,7 @@ class DoximityProvider(OAuth2Provider): account_class = DoximityAccount def extract_uid(self, data): - return data[str("id")] # the Doximity id is long + return str(data["id"]) # the Doximity id is long def extract_common_fields(self, data): return dict( diff --git a/allauth/socialaccount/providers/draugiem/tests.py b/allauth/socialaccount/providers/draugiem/tests.py index c84ecb0b6b..d033b0a580 100644 --- a/allauth/socialaccount/providers/draugiem/tests.py +++ b/allauth/socialaccount/providers/draugiem/tests.py @@ -1,11 +1,11 @@ from hashlib import md5 from django.contrib.auth.models import User +from django.test import RequestFactory from django.urls import reverse from django.utils.http import urlencode from allauth import app_settings -from allauth.socialaccount import providers from allauth.socialaccount.models import SocialApp, SocialToken from allauth.tests import Mock, TestCase, patch @@ -22,14 +22,15 @@ def setUp(self): ) self.client.login(username="anakin", password="s1thrul3s") - self.provider = providers.registry.by_id(DraugiemProvider.id) app = SocialApp.objects.create( - provider=self.provider.id, - name=self.provider.id, + provider=DraugiemProvider.id, + name=DraugiemProvider.id, client_id="app123id", - key=self.provider.id, + key=DraugiemProvider.id, secret="dummy", ) + request = RequestFactory().get("/") + self.provider = app.get_provider(request) if app_settings.SITES_ENABLED: from django.contrib.sites.models import Site diff --git a/allauth/socialaccount/providers/draugiem/views.py b/allauth/socialaccount/providers/draugiem/views.py index 9f5e612b89..55c6924e46 100644 --- a/allauth/socialaccount/providers/draugiem/views.py +++ b/allauth/socialaccount/providers/draugiem/views.py @@ -6,7 +6,7 @@ from django.utils.http import urlencode from django.views.decorators.csrf import csrf_exempt -from allauth.socialaccount import providers +from allauth.socialaccount.adapter import get_adapter from allauth.socialaccount.helpers import ( complete_social_login, render_authentication_error, @@ -26,7 +26,7 @@ class DraugiemApiError(Exception): def login(request): - app = providers.registry.by_id(DraugiemProvider.id, request).get_app(request) + app = get_adapter(request).get_app(request, DraugiemProvider.id) redirect_url = request.build_absolute_uri(reverse(callback)) redirect_url_hash = md5((app.secret + redirect_url).encode("utf-8")).hexdigest() params = { @@ -58,7 +58,7 @@ def callback(request): ret = None auth_exception = None try: - app = providers.registry.by_id(DraugiemProvider.id, request).get_app(request) + app = get_adapter(request).get_app(request, DraugiemProvider.id) login = draugiem_complete_login(request, app, request.GET["dr_auth_code"]) login.state = SocialLogin.unstash_state(request) @@ -75,7 +75,7 @@ def callback(request): def draugiem_complete_login(request, app, code): - provider = providers.registry.by_id(DraugiemProvider.id, request) + provider = get_adapter(request).get_provider(request, DraugiemProvider.id) response = requests.get( ACCESS_TOKEN_URL, {"action": "authorize", "app": app.secret, "code": code}, diff --git a/allauth/socialaccount/providers/facebook/provider.py b/allauth/socialaccount/providers/facebook/provider.py index e8cb343f67..ffc55a8365 100644 --- a/allauth/socialaccount/providers/facebook/provider.py +++ b/allauth/socialaccount/providers/facebook/provider.py @@ -61,9 +61,9 @@ class FacebookProvider(OAuth2Provider): name = "Facebook" account_class = FacebookAccount - def __init__(self, request): + def __init__(self, *args, **kwargs): self._locale_callable_cache = None - super(FacebookProvider, self).__init__(request) + super().__init__(*args, **kwargs) def get_method(self): return self.get_settings().get("METHOD", "oauth2") @@ -154,25 +154,14 @@ def get_sdk_url(self, request): return sdk_url def media_js(self, request): - # NOTE: Avoid loading models at top due to registry boot... - from allauth.socialaccount.models import SocialApp - - try: - app = self.get_app(request) - except SocialApp.DoesNotExist: - # It's a problem that Facebook isn't configured; but don't raise - # an error. Other providers don't raise errors when they're missing - # SocialApps in media_js(). - return "" - def abs_uri(name): return request.build_absolute_uri(reverse(name)) fb_data = { - "appId": app.client_id, + "appId": self.app.client_id, "version": GRAPH_API_VERSION, "sdkUrl": self.get_sdk_url(request), - "initParams": self.get_init_params(request, app), + "initParams": self.get_init_params(request, self.app), "loginOptions": self.get_fb_login_options(request), "loginByTokenUrl": abs_uri("facebook_login_by_token"), "cancelUrl": abs_uri("socialaccount_login_cancelled"), diff --git a/allauth/socialaccount/providers/facebook/tests.py b/allauth/socialaccount/providers/facebook/tests.py index 8635682fd4..bd010b188a 100644 --- a/allauth/socialaccount/providers/facebook/tests.py +++ b/allauth/socialaccount/providers/facebook/tests.py @@ -6,7 +6,6 @@ from allauth.account import app_settings as account_settings from allauth.account.models import EmailAddress -from allauth.socialaccount import providers from allauth.socialaccount.models import SocialAccount from allauth.socialaccount.tests import OAuth2TestsMixin from allauth.tests import MockedResponse, TestCase, patch @@ -73,20 +72,11 @@ def test_username_based_on_provider_with_simple_name(self): self.assertEqual(socialaccount.user.username, "harvey") def test_media_js(self): - provider = providers.registry.by_id(FacebookProvider.id) request = RequestFactory().get(reverse("account_login")) request.session = {} - script = provider.media_js(request) + script = self.provider.media_js(request) self.assertTrue('"appId": "app123id"' in script) - def test_media_js_when_not_configured(self): - provider = providers.registry.by_id(FacebookProvider.id) - provider.get_app(None).delete() - request = RequestFactory().get(reverse("account_login")) - request.session = {} - script = provider.media_js(request) - self.assertEqual(script, "") - def test_login_by_token(self): resp = self.client.get(reverse("account_login")) with patch( diff --git a/allauth/socialaccount/providers/facebook/views.py b/allauth/socialaccount/providers/facebook/views.py index 002fd3cdc4..de8018941d 100644 --- a/allauth/socialaccount/providers/facebook/views.py +++ b/allauth/socialaccount/providers/facebook/views.py @@ -6,7 +6,8 @@ from django.utils import timezone -from allauth.socialaccount import app_settings, providers +from allauth.socialaccount import app_settings +from allauth.socialaccount.adapter import get_adapter from allauth.socialaccount.helpers import ( complete_social_login, render_authentication_error, @@ -35,7 +36,7 @@ def compute_appsecret_proof(app, token): def fb_complete_login(request, app, token): - provider = providers.registry.by_id(app.provider, request) + provider = app.get_provider(request) resp = requests.get( GRAPH_API_URL + "/me", params={ @@ -77,9 +78,10 @@ def login_by_token(request): form = FacebookConnectForm(request.POST) if form.is_valid(): try: - provider = providers.registry.by_id(FacebookProvider.id, request) + adapter = get_adapter(request) + provider = adapter.get_provider(request, FacebookProvider.id) login_options = provider.get_fb_login_options(request) - app = provider.get_app(request) + app = provider.app access_token = form.cleaned_data["access_token"] expires_at = None if login_options.get("auth_type") == "reauthenticate": diff --git a/allauth/socialaccount/providers/feishu/tests.py b/allauth/socialaccount/providers/feishu/tests.py index 57582a2ea7..cbb7ddf44b 100644 --- a/allauth/socialaccount/providers/feishu/tests.py +++ b/allauth/socialaccount/providers/feishu/tests.py @@ -2,14 +2,15 @@ from __future__ import unicode_literals -from allauth.socialaccount.providers import registry -from allauth.socialaccount.tests import create_oauth2_tests -from allauth.tests import MockedResponse +from allauth.socialaccount.tests import OAuth2TestsMixin +from allauth.tests import MockedResponse, TestCase from .provider import FeishuProvider -class FeishuTests(create_oauth2_tests(registry.by_id(FeishuProvider.id))): +class FeishuTests(OAuth2TestsMixin, TestCase): + provider_id = FeishuProvider.id + def get_mocked_response(self): return [ MockedResponse( diff --git a/allauth/socialaccount/providers/fivehundredpx/provider.py b/allauth/socialaccount/providers/fivehundredpx/provider.py index f125c1a5f6..891881d6f0 100644 --- a/allauth/socialaccount/providers/fivehundredpx/provider.py +++ b/allauth/socialaccount/providers/fivehundredpx/provider.py @@ -25,7 +25,7 @@ def get_default_scope(self): return [] def extract_uid(self, data): - return data["id"] + return str(data["id"]) def extract_common_fields(self, data): return dict( diff --git a/allauth/socialaccount/providers/keycloak/provider.py b/allauth/socialaccount/providers/keycloak/provider.py index 06e82a2603..f6f1e699ca 100644 --- a/allauth/socialaccount/providers/keycloak/provider.py +++ b/allauth/socialaccount/providers/keycloak/provider.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from django.conf import settings from allauth.socialaccount import app_settings @@ -25,12 +24,23 @@ class KeycloakProvider(OpenIDConnectProvider): name = OVERRIDE_NAME account_class = KeycloakAccount + def get_login_url(self, request, **kwargs): + return super(OpenIDConnectProvider, self).get_login_url(request, **kwargs) + + def get_callback_url(self): + return super(OpenIDConnectProvider, self).get_callback_url() + + @property + def server_url(self): + return self.wk_server_url(self.base_server_url) + @property - def _server_url(self): + def base_server_url(self): other_url = self.settings.get("KEYCLOAK_URL_ALT") if other_url is None: other_url = self.settings.get("KEYCLOAK_URL") - return "{0}/realms/{1}".format(other_url, self.settings.get("KEYCLOAK_REALM")) + url = "{0}/realms/{1}".format(other_url, self.settings.get("KEYCLOAK_REALM")) + return url @property def provider_base_url(self): diff --git a/allauth/socialaccount/providers/keycloak/views.py b/allauth/socialaccount/providers/keycloak/views.py index e92c768366..16e45e66ad 100644 --- a/allauth/socialaccount/providers/keycloak/views.py +++ b/allauth/socialaccount/providers/keycloak/views.py @@ -22,15 +22,25 @@ def authorize_url(self): @property def access_token_url(self): return "{0}/protocol/openid-connect/token".format( - self.get_provider()._server_url + self.get_provider().base_server_url ) @property def profile_url(self): return "{0}/protocol/openid-connect/userinfo".format( - self.get_provider()._server_url + self.get_provider().base_server_url ) -oauth2_login = OAuth2LoginView.adapter_view(KeycloakOAuth2Adapter) -oauth2_callback = OAuth2CallbackView.adapter_view(KeycloakOAuth2Adapter) +def oauth2_login(request): + view = OAuth2LoginView.adapter_view( + KeycloakOAuth2Adapter(request, KeycloakProvider.id) + ) + return view(request) + + +def oauth2_callback(request): + view = OAuth2CallbackView.adapter_view( + KeycloakOAuth2Adapter(request, KeycloakProvider.id) + ) + return view(request) diff --git a/allauth/socialaccount/providers/linkedin/views.py b/allauth/socialaccount/providers/linkedin/views.py index 4136836517..55736c9bf2 100644 --- a/allauth/socialaccount/providers/linkedin/views.py +++ b/allauth/socialaccount/providers/linkedin/views.py @@ -1,7 +1,7 @@ from xml.etree import ElementTree from xml.parsers.expat import ExpatError -from allauth.socialaccount import providers +from allauth.socialaccount.adapter import get_adapter from allauth.socialaccount.providers.oauth.client import OAuth from allauth.socialaccount.providers.oauth.views import ( OAuthAdapter, @@ -16,9 +16,9 @@ class LinkedInAPI(OAuth): url = "https://api.linkedin.com/v1/people/~" def get_user_info(self): - fields = providers.registry.by_id( - LinkedInProvider.id, self.request - ).get_profile_fields() + adapter = get_adapter(self.request) + provider = adapter.get_provider(self.request, LinkedInProvider.id) + fields = provider.get_profile_fields() url = self.url + ":(%s)" % ",".join(fields) raw_xml = self.query(url) try: diff --git a/allauth/socialaccount/providers/linkedin_oauth2/tests.py b/allauth/socialaccount/providers/linkedin_oauth2/tests.py index dfd57282a2..d272c7d734 100644 --- a/allauth/socialaccount/providers/linkedin_oauth2/tests.py +++ b/allauth/socialaccount/providers/linkedin_oauth2/tests.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- from json import loads -from django.test.client import RequestFactory from django.test.utils import override_settings from allauth.socialaccount.models import SocialAccount @@ -536,5 +535,6 @@ def test_id_missing(self): "Id": "1234567" } """ - provider = LinkedInOAuth2Provider(RequestFactory().get("/login")) - self.assertRaises(ProviderException, provider.extract_uid, loads(extra_data)) + self.assertRaises( + ProviderException, self.provider.extract_uid, loads(extra_data) + ) diff --git a/allauth/socialaccount/providers/meetup/tests.py b/allauth/socialaccount/providers/meetup/tests.py index b0119019b8..c6a2002f82 100644 --- a/allauth/socialaccount/providers/meetup/tests.py +++ b/allauth/socialaccount/providers/meetup/tests.py @@ -1,11 +1,12 @@ -from allauth.socialaccount.providers import registry -from allauth.socialaccount.tests import create_oauth2_tests -from allauth.tests import MockedResponse +from allauth.socialaccount.tests import OAuth2TestsMixin +from allauth.tests import MockedResponse, TestCase from .provider import MeetupProvider -class MeetupTests(create_oauth2_tests(registry.by_id(MeetupProvider.id))): +class MeetupTests(OAuth2TestsMixin, TestCase): + provider_id = MeetupProvider.id + def get_mocked_response(self): return MockedResponse( 200, diff --git a/allauth/socialaccount/providers/oauth/views.py b/allauth/socialaccount/providers/oauth/views.py index dc711460c1..9ac68b63cc 100644 --- a/allauth/socialaccount/providers/oauth/views.py +++ b/allauth/socialaccount/providers/oauth/views.py @@ -2,7 +2,7 @@ from django.urls import reverse -from allauth.socialaccount import providers +from allauth.socialaccount.adapter import get_adapter from allauth.socialaccount.helpers import ( complete_social_login, render_authentication_error, @@ -30,7 +30,9 @@ def complete_login(self, request, app): raise NotImplementedError def get_provider(self): - return providers.registry.by_id(self.provider_id, self.request) + adapter = get_adapter(self.request) + app = adapter.get_app(self.request, provider=self.provider_id) + return app.get_provider(self.request) class OAuthView(object): @@ -46,7 +48,7 @@ def view(request, *args, **kwargs): def _get_client(self, request, callback_url): provider = self.adapter.get_provider() - app = provider.get_app(request) + app = provider.app scope = " ".join(provider.get_scope(request)) parameters = {} if scope: @@ -101,7 +103,7 @@ def dispatch(self, request): error=error, extra_context=extra_context, ) - app = self.adapter.get_provider().get_app(request) + app = self.adapter.get_provider().app try: access_token = client.get_access_token() token = SocialToken( diff --git a/allauth/socialaccount/providers/oauth2/provider.py b/allauth/socialaccount/providers/oauth2/provider.py index 18b0912da0..46d9352789 100644 --- a/allauth/socialaccount/providers/oauth2/provider.py +++ b/allauth/socialaccount/providers/oauth2/provider.py @@ -17,6 +17,9 @@ def get_login_url(self, request, **kwargs): url = url + "?" + urlencode(kwargs) return url + def get_callback_url(self): + return reverse(self.id + "_callback") + def get_pkce_params(self): settings = self.get_settings() if settings.get("OAUTH_PKCE_ENABLED", self.pkce_enabled_default): diff --git a/allauth/socialaccount/providers/oauth2/views.py b/allauth/socialaccount/providers/oauth2/views.py index 6ce9fe25c7..46a1465487 100644 --- a/allauth/socialaccount/providers/oauth2/views.py +++ b/allauth/socialaccount/providers/oauth2/views.py @@ -9,7 +9,7 @@ from django.utils import timezone from allauth.exceptions import ImmediateHttpResponse -from allauth.socialaccount import providers +from allauth.socialaccount.adapter import get_adapter from allauth.socialaccount.helpers import ( complete_social_login, render_authentication_error, @@ -43,7 +43,9 @@ def __init__(self, request): self.request = request def get_provider(self): - return providers.registry.by_id(self.provider_id, self.request) + return get_adapter(self.request).get_provider( + self.request, provider=self.provider_id + ) def complete_login(self, request, app, access_token, **kwargs): """ @@ -76,7 +78,10 @@ def adapter_view(cls, adapter): def view(request, *args, **kwargs): self = cls() self.request = request - self.adapter = adapter(request) + if not isinstance(adapter, OAuth2Adapter): + self.adapter = adapter(request) + else: + self.adapter = adapter try: return self.dispatch(request, *args, **kwargs) except ImmediateHttpResponse as e: @@ -106,7 +111,7 @@ def get_client(self, request, app): class OAuth2LoginView(OAuthLoginMixin, OAuth2View): def login(self, request, *args, **kwargs): provider = self.adapter.get_provider() - app = provider.get_app(self.request) + app = provider.app client = self.get_client(request, app) action = request.GET.get("action", AuthAction.AUTHENTICATE) auth_url = self.adapter.authorize_url @@ -137,7 +142,7 @@ def dispatch(self, request, *args, **kwargs): return render_authentication_error( request, self.adapter.provider_id, error=error ) - app = self.adapter.get_provider().get_app(self.request) + app = self.adapter.get_provider().app client = self.get_client(self.request, app) try: diff --git a/allauth/socialaccount/providers/openid/provider.py b/allauth/socialaccount/providers/openid/provider.py index b7c47f2515..f39eaaaf47 100644 --- a/allauth/socialaccount/providers/openid/provider.py +++ b/allauth/socialaccount/providers/openid/provider.py @@ -39,6 +39,7 @@ class OpenIDProvider(Provider): id = "openid" name = "OpenID" account_class = OpenIDAccount + uses_apps = False def get_login_url(self, request, **kwargs): url = reverse("openid_login") diff --git a/allauth/socialaccount/providers/openid/views.py b/allauth/socialaccount/providers/openid/views.py index 802495c096..8b4fa2e659 100644 --- a/allauth/socialaccount/providers/openid/views.py +++ b/allauth/socialaccount/providers/openid/views.py @@ -9,7 +9,6 @@ from openid.extensions.ax import AttrInfo, FetchRequest from openid.extensions.sreg import SRegRequest -from allauth.socialaccount import providers from allauth.socialaccount.app_settings import QUERY_EMAIL from allauth.socialaccount.helpers import ( complete_social_login, @@ -131,9 +130,7 @@ def get(self, request): response = self.get_openid_response(client) if response.status == consumer.SUCCESS: - login = providers.registry.by_id( - self.provider.id, request - ).sociallogin_from_response(request, response) + login = provider.sociallogin_from_response(request, response) login.state = SocialLogin.unstash_state(request) return self.complete_login(login) else: diff --git a/allauth/socialaccount/providers/openid_connect/provider.py b/allauth/socialaccount/providers/openid_connect/provider.py index c94ad717be..a3cee245f2 100644 --- a/allauth/socialaccount/providers/openid_connect/provider.py +++ b/allauth/socialaccount/providers/openid_connect/provider.py @@ -1,6 +1,7 @@ -# -*- coding: utf-8 -*- +from django.urls import reverse +from django.utils.http import urlencode + from allauth.account.models import EmailAddress -from allauth.socialaccount import app_settings from allauth.socialaccount.providers.base import ProviderAccount from allauth.socialaccount.providers.oauth2.provider import OAuth2Provider @@ -14,25 +15,40 @@ def to_str(self): class OpenIDConnectProvider(OAuth2Provider): id = "openid_connect" name = "OpenID Connect" - _server_id = None - _server_url = None account_class = OpenIDConnectProviderAccount + @property + def name(self): + return self.app.name + @property def server_url(self): + url = self.app.settings["server_url"] + return self.wk_server_url(url) + + def wk_server_url(self, url): well_known_uri = "/.well-known/openid-configuration" - url = self._server_url if not url.endswith(well_known_uri): url += well_known_uri return url + def get_login_url(self, request, **kwargs): + url = reverse( + self.app.provider + "_login", kwargs={"provider_id": self.app.provider_id} + ) + if kwargs: + url = url + "?" + urlencode(kwargs) + return url + + def get_callback_url(self): + return reverse( + self.app.provider + "_callback", + kwargs={"provider_id": self.app.provider_id}, + ) + @property def token_auth_method(self): - return app_settings.PROVIDERS.get(self.id, {}).get("token_auth_method") - - @classmethod - def get_slug(cls): - return cls._server_id or super().get_slug() + return self.app.settings.get("token_auth_method") def get_default_scope(self): return ["openid", "profile", "email"] @@ -63,27 +79,4 @@ def extract_email_addresses(self, data): return addresses -def _provider_factory(server_settings): - class OpenIDConnectProviderServer(OpenIDConnectProvider): - name = server_settings.get("name", OpenIDConnectProvider.name) - id = server_settings["id"] - _server_id = server_settings["id"] - _server_url = server_settings["server_url"] - - def get_app(self, request, config=None): - return super().get_app(request, config=server_settings.get("APP")) - - OpenIDConnectProviderServer.__name__ = ( - "OpenIDConnectProviderServer_" + server_settings["id"] - ) - app_settings.PROVIDERS.setdefault(OpenIDConnectProviderServer.id, {}) - app_settings.PROVIDERS[OpenIDConnectProviderServer.id].update(server_settings) - return OpenIDConnectProviderServer - - -provider_classes = [ - _provider_factory(server_settings) - for server_settings in app_settings.PROVIDERS.get(OpenIDConnectProvider.id, {}).get( - "SERVERS", [] - ) -] +provider_classes = [OpenIDConnectProvider] diff --git a/allauth/socialaccount/providers/openid_connect/tests.py b/allauth/socialaccount/providers/openid_connect/tests.py index df193fc10b..83122ab8b1 100644 --- a/allauth/socialaccount/providers/openid_connect/tests.py +++ b/allauth/socialaccount/providers/openid_connect/tests.py @@ -1,56 +1,10 @@ -# -*- coding: utf-8 -*- -from unittest import TestSuite - -from allauth.socialaccount.providers.openid_connect.provider import ( - OpenIDConnectProvider, - provider_classes, -) from allauth.socialaccount.tests import OpenIDConnectTests from allauth.tests import TestCase -from ... import app_settings - - -class OpenIDConnectTestsBase(OpenIDConnectTests): - provider_id = None - - def test_oidc_base_and_provider_settings_sync(self): - # Retrieve settings via this OpenID Connect server's specific provider ID - provider_settings = self.provider.get_settings() - # Retrieve settings via the base OpenID Connect provider ID - oidc_server_settings = app_settings.PROVIDERS[OpenIDConnectProvider.id][ - "SERVERS" - ] - # Find the matching entry in the base OpenID Connect provider's servers list - matching_servers = list( - filter( - lambda server_settings: server_settings["id"] - == provider_settings["id"], - oidc_server_settings, - ) - ) - # Make sure there's only one matching entry and that it's identical - self.assertEqual(len(matching_servers), 1) - self.assertDictEqual(matching_servers[0], provider_settings) - - -def _test_class_factory(provider_class): - class Provider_OpenIDConnectTests(OpenIDConnectTestsBase, TestCase): - provider_id = provider_class.id - Provider_OpenIDConnectTests.__name__ = ( - "Provider_OpenIDConnectTests_" + provider_class.id - ) - return Provider_OpenIDConnectTests +class OpenIDConnectTests(OpenIDConnectTests, TestCase): + provider_id = "unittest-server" -def load_tests(loader, tests, pattern): - suite = TestSuite() - assert len( - provider_classes - ), "No OpenID Connect servers are configured in test_settings.py" - for provider_class in provider_classes: - suite.addTests( - loader.loadTestsFromTestCase(_test_class_factory(provider_class)) - ) - return suite +class OtherOpenIDConnectTests(OpenIDConnectTests, TestCase): + provider_id = "ther-server" diff --git a/allauth/socialaccount/providers/openid_connect/urls.py b/allauth/socialaccount/providers/openid_connect/urls.py index 4227ae7d44..458f855e86 100644 --- a/allauth/socialaccount/providers/openid_connect/urls.py +++ b/allauth/socialaccount/providers/openid_connect/urls.py @@ -1,45 +1,24 @@ -# -*- coding: utf-8 -*- -import itertools - -from django.urls import include, path - -from allauth.socialaccount.providers.oauth2.views import ( - OAuth2CallbackView, - OAuth2LoginView, -) -from allauth.socialaccount.providers.openid_connect.provider import ( - provider_classes, -) -from allauth.socialaccount.providers.openid_connect.views import ( - OpenIDConnectAdapter, -) - - -def _factory(_provider_id): - class OpenIDConnectAdapterServer(OpenIDConnectAdapter): - provider_id = _provider_id - - return OpenIDConnectAdapterServer - - -def default_urlpatterns(provider): - adapter_class = _factory(provider.id) - urlpatterns = [ - path( - "login/", - OAuth2LoginView.adapter_view(adapter_class), - name=provider.id + "_login", - ), - path( - "login/callback/", - OAuth2CallbackView.adapter_view(adapter_class), - name=provider.id + "_callback", +from django.urls import include, path, re_path + +from . import views + + +urlpatterns = [ + re_path( + r"^(?P[^/]+)/", + include( + [ + path( + "login/", + views.login, + name="openid_connect_login", + ), + path( + "login/callback/", + views.callback, + name="openid_connect_callback", + ), + ] ), - ] - - return [path(provider.get_slug() + "/", include(urlpatterns))] - - -urlpatterns = itertools.chain.from_iterable( - [default_urlpatterns(provider_class) for provider_class in provider_classes] -) + ) +] diff --git a/allauth/socialaccount/providers/openid_connect/views.py b/allauth/socialaccount/providers/openid_connect/views.py index 05b7a70eb8..5b55769748 100644 --- a/allauth/socialaccount/providers/openid_connect/views.py +++ b/allauth/socialaccount/providers/openid_connect/views.py @@ -1,24 +1,34 @@ -# -*- coding: utf-8 -*- import requests -from allauth.socialaccount.providers.oauth2.views import OAuth2Adapter +from django.urls import reverse + +from allauth.socialaccount.providers.oauth2.views import ( + OAuth2Adapter, + OAuth2CallbackView, + OAuth2LoginView, +) +from allauth.utils import build_absolute_uri class OpenIDConnectAdapter(OAuth2Adapter): supports_state = True - provider_id = "openid_connect" + + def __init__(self, request, provider_id): + self.provider_id = provider_id + super().__init__(request) @property def openid_config(self): if not hasattr(self, "_openid_config"): - resp = requests.get(self.get_provider().server_url) + server_url = self.get_provider().server_url + resp = requests.get(server_url) resp.raise_for_status() self._openid_config = resp.json() return self._openid_config @property def basic_auth(self): - token_auth_method = self.get_provider().token_auth_method + token_auth_method = self.get_provider().app.settings.get("token_auth_method") if token_auth_method: return token_auth_method == "client_secret_basic" return "client_secret_basic" in self.openid_config.get( @@ -44,3 +54,20 @@ def complete_login(self, request, app, token, response): response.raise_for_status() extra_data = response.json() return self.get_provider().sociallogin_from_response(request, extra_data) + + def get_callback_url(self, request, app): + callback_url = reverse( + "openid_connect_callback", kwargs={"provider_id": self.provider_id} + ) + protocol = self.redirect_uri_protocol + return build_absolute_uri(request, callback_url, protocol) + + +def login(request, provider_id): + view = OAuth2LoginView.adapter_view(OpenIDConnectAdapter(request, provider_id)) + return view(request) + + +def callback(request, provider_id): + view = OAuth2CallbackView.adapter_view(OpenIDConnectAdapter(request, provider_id)) + return view(request) diff --git a/allauth/socialaccount/providers/pocket/views.py b/allauth/socialaccount/providers/pocket/views.py index 053fba14d5..551435dae6 100644 --- a/allauth/socialaccount/providers/pocket/views.py +++ b/allauth/socialaccount/providers/pocket/views.py @@ -16,7 +16,7 @@ def complete_login(self, request, app, token, response): class PocketOAuthLoginView(OAuthLoginView): def _get_client(self, request, callback_url): provider = self.adapter.get_provider() - app = provider.get_app(request) + app = provider.app scope = " ".join(provider.get_scope(request)) parameters = {} if scope: @@ -37,7 +37,7 @@ def _get_client(self, request, callback_url): class PocketOAuthCallbackView(OAuthCallbackView): def _get_client(self, request, callback_url): provider = self.adapter.get_provider() - app = provider.get_app(request) + app = provider.app scope = " ".join(provider.get_scope(request)) parameters = {} if scope: diff --git a/allauth/socialaccount/providers/reddit/tests.py b/allauth/socialaccount/providers/reddit/tests.py index 0a46e38b3d..5dbae2c889 100755 --- a/allauth/socialaccount/providers/reddit/tests.py +++ b/allauth/socialaccount/providers/reddit/tests.py @@ -1,14 +1,12 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals - -from allauth.socialaccount.providers import registry -from allauth.socialaccount.tests import create_oauth2_tests -from allauth.tests import MockedResponse +from allauth.socialaccount.tests import OAuth2TestsMixin +from allauth.tests import MockedResponse, TestCase from .provider import RedditProvider -class RedditTests(create_oauth2_tests(registry.by_id(RedditProvider.id))): +class RedditTests(OAuth2TestsMixin, TestCase): + provider_id = RedditProvider.id + def get_mocked_response(self): return [ MockedResponse( diff --git a/allauth/socialaccount/providers/robinhood/tests.py b/allauth/socialaccount/providers/robinhood/tests.py index ddf580bca5..9110c340cb 100644 --- a/allauth/socialaccount/providers/robinhood/tests.py +++ b/allauth/socialaccount/providers/robinhood/tests.py @@ -1,11 +1,12 @@ -from allauth.socialaccount.providers import registry -from allauth.socialaccount.tests import create_oauth2_tests -from allauth.tests import MockedResponse +from allauth.socialaccount.tests import OAuth2TestsMixin +from allauth.tests import MockedResponse, TestCase from .provider import RobinhoodProvider -class RobinhoodTests(create_oauth2_tests(registry.by_id(RobinhoodProvider.id))): +class RobinhoodTests(OAuth2TestsMixin, TestCase): + provider_id = RobinhoodProvider.id + def get_mocked_response(self): return MockedResponse( 200, diff --git a/allauth/socialaccount/providers/salesforce/tests.py b/allauth/socialaccount/providers/salesforce/tests.py index 498ddca82d..6ea8d81647 100644 --- a/allauth/socialaccount/providers/salesforce/tests.py +++ b/allauth/socialaccount/providers/salesforce/tests.py @@ -1,14 +1,15 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, unicode_literals -from allauth.socialaccount.providers import registry -from allauth.socialaccount.tests import create_oauth2_tests -from allauth.tests import MockedResponse +from allauth.socialaccount.tests import OAuth2TestsMixin +from allauth.tests import MockedResponse, TestCase from .provider import SalesforceProvider -class SalesforceTests(create_oauth2_tests(registry.by_id(SalesforceProvider.id))): +class SalesforceTests(OAuth2TestsMixin, TestCase): + provider_id = SalesforceProvider.id + def get_mocked_response( self, last_name="Penners", diff --git a/allauth/socialaccount/providers/salesforce/views.py b/allauth/socialaccount/providers/salesforce/views.py index 978a44ebae..63a30a07d4 100644 --- a/allauth/socialaccount/providers/salesforce/views.py +++ b/allauth/socialaccount/providers/salesforce/views.py @@ -14,7 +14,7 @@ class SalesforceOAuth2Adapter(OAuth2Adapter): @property def base_url(self): - return self.get_provider().get_app(self.request).key + return self.get_provider().app.key @property def authorize_url(self): diff --git a/allauth/socialaccount/providers/shopify/tests.py b/allauth/socialaccount/providers/shopify/tests.py index 8cc616402d..415f7cdc1c 100644 --- a/allauth/socialaccount/providers/shopify/tests.py +++ b/allauth/socialaccount/providers/shopify/tests.py @@ -6,14 +6,15 @@ from django.utils.http import urlencode from allauth.socialaccount.models import SocialAccount -from allauth.socialaccount.providers import registry -from allauth.socialaccount.tests import create_oauth2_tests -from allauth.tests import MockedResponse, mocked_response +from allauth.socialaccount.tests import OAuth2TestsMixin +from allauth.tests import MockedResponse, TestCase, mocked_response from .provider import ShopifyProvider -class ShopifyTests(create_oauth2_tests(registry.by_id(ShopifyProvider.id))): +class ShopifyTests(OAuth2TestsMixin, TestCase): + provider_id = ShopifyProvider.id + def _complete_shopify_login(self, q, resp, resp_mock, with_refresh_token): complete_url = reverse(self.provider.id + "_callback") self.assertGreater(q["redirect_uri"][0].find(complete_url), 0) diff --git a/allauth/socialaccount/providers/steam/provider.py b/allauth/socialaccount/providers/steam/provider.py index 85b3fa978c..8314dd1757 100644 --- a/allauth/socialaccount/providers/steam/provider.py +++ b/allauth/socialaccount/providers/steam/provider.py @@ -55,7 +55,7 @@ def get_login_url(self, request, **kwargs): def sociallogin_from_response(self, request, response): steam_id = extract_steam_id(response.identity_url) - steam_api_key = self.get_app(request).secret + steam_api_key = self.app.secret response._extra = request_steam_account_summary(steam_api_key, steam_id) return super(SteamOpenIDProvider, self).sociallogin_from_response( request, response diff --git a/allauth/socialaccount/providers/stocktwits/provider.py b/allauth/socialaccount/providers/stocktwits/provider.py index 35c9eb1dc0..4d69279ebe 100644 --- a/allauth/socialaccount/providers/stocktwits/provider.py +++ b/allauth/socialaccount/providers/stocktwits/provider.py @@ -17,7 +17,7 @@ class StocktwitsProvider(OAuth2Provider): account_class = StocktwitsAccount def extract_uid(self, data): - return data.get("user", {}).get("id") + return str(data["user"]["id"]) def extract_common_fields(self, data): return dict( diff --git a/allauth/socialaccount/providers/strava/provider.py b/allauth/socialaccount/providers/strava/provider.py index a97c75f332..2a2f48a269 100644 --- a/allauth/socialaccount/providers/strava/provider.py +++ b/allauth/socialaccount/providers/strava/provider.py @@ -26,7 +26,7 @@ class StravaProvider(OAuth2Provider): account_class = StravaAccount def extract_uid(self, data): - return data.get("id") + return str(data["id"]) def extract_common_fields(self, data): extra_common = super(StravaProvider, self).extract_common_fields(data) diff --git a/allauth/socialaccount/providers/telegram/provider.py b/allauth/socialaccount/providers/telegram/provider.py index 1d5523a741..d90b033d92 100644 --- a/allauth/socialaccount/providers/telegram/provider.py +++ b/allauth/socialaccount/providers/telegram/provider.py @@ -9,6 +9,7 @@ class TelegramProvider(Provider): id = "telegram" name = "Telegram" account_class = TelegramAccount + uses_apps = False def get_login_url(self, request, **kwargs): # TODO: Find a way to better wrap the iframed button diff --git a/allauth/socialaccount/providers/telegram/views.py b/allauth/socialaccount/providers/telegram/views.py index 9cd83c507e..d52f339d9b 100644 --- a/allauth/socialaccount/providers/telegram/views.py +++ b/allauth/socialaccount/providers/telegram/views.py @@ -2,7 +2,7 @@ import hmac import time -from allauth.socialaccount import providers +from allauth.socialaccount.adapter import get_adapter from allauth.socialaccount.helpers import ( complete_social_login, render_authentication_error, @@ -12,7 +12,8 @@ def telegram_login(request): - provider = providers.registry.by_id(TelegramProvider.id, request) + request = get_adapter(request) + provider = request.get_provider(request, TelegramProvider.id) data = dict(request.GET.items()) hash = data.pop("hash") payload = "\n".join(sorted(["{}={}".format(k, v) for k, v in data.items()])) diff --git a/allauth/socialaccount/providers/trainingpeaks/provider.py b/allauth/socialaccount/providers/trainingpeaks/provider.py index 56417d4bc9..2de1f45b15 100644 --- a/allauth/socialaccount/providers/trainingpeaks/provider.py +++ b/allauth/socialaccount/providers/trainingpeaks/provider.py @@ -26,7 +26,7 @@ class TrainingPeaksProvider(OAuth2Provider): account_class = TrainingPeaksAccount def extract_uid(self, data): - return data.get("Id") + return str(data["Id"]) def extract_common_fields(self, data): extra_common = super(TrainingPeaksProvider, self).extract_common_fields(data) diff --git a/allauth/socialaccount/providers/trello/provider.py b/allauth/socialaccount/providers/trello/provider.py index 8287407027..2eceb3d734 100644 --- a/allauth/socialaccount/providers/trello/provider.py +++ b/allauth/socialaccount/providers/trello/provider.py @@ -30,9 +30,8 @@ def extract_common_fields(self, data): def get_auth_params(self, request, action): data = super(TrelloProvider, self).get_auth_params(request, action) - app = self.get_app(request) data["type"] = "web_server" - data["name"] = app.name + data["name"] = self.app.name data["scope"] = self.get_scope(request) # define here for how long it will be, this can be configured on the # social app diff --git a/allauth/socialaccount/providers/twitch/tests.py b/allauth/socialaccount/providers/twitch/tests.py index 1d43bdce7e..fe7c4f864b 100644 --- a/allauth/socialaccount/providers/twitch/tests.py +++ b/allauth/socialaccount/providers/twitch/tests.py @@ -75,7 +75,7 @@ def _run_just_complete_login(self, resp_mock): {"process": "login"}, ) adapter = TwitchOAuth2Adapter(request) - app = adapter.get_provider().get_app(request) + app = adapter.get_provider().app token = SocialToken(token="this-is-my-fake-token") with mocked_response(resp_mock): diff --git a/allauth/socialaccount/providers/twitter/provider.py b/allauth/socialaccount/providers/twitter/provider.py index 7d2c883a13..72adab1099 100644 --- a/allauth/socialaccount/providers/twitter/provider.py +++ b/allauth/socialaccount/providers/twitter/provider.py @@ -40,7 +40,7 @@ def get_auth_url(self, request, action): return url def extract_uid(self, data): - return data["id"] + return str(data["id"]) def extract_common_fields(self, data): return dict( diff --git a/allauth/socialaccount/providers/untappd/tests.py b/allauth/socialaccount/providers/untappd/tests.py index 87bad781bb..abb450d309 100644 --- a/allauth/socialaccount/providers/untappd/tests.py +++ b/allauth/socialaccount/providers/untappd/tests.py @@ -1,11 +1,12 @@ -from allauth.socialaccount.providers import registry -from allauth.socialaccount.tests import create_oauth2_tests -from allauth.tests import MockedResponse +from allauth.socialaccount.tests import OAuth2TestsMixin +from allauth.tests import MockedResponse, TestCase from .provider import UntappdProvider -class UntappdTests(create_oauth2_tests(registry.by_id(UntappdProvider.id))): +class UntappdTests(OAuth2TestsMixin, TestCase): + provider_id = UntappdProvider.id + def get_login_response_json(self, with_refresh_token=True): return """ { diff --git a/allauth/socialaccount/providers/wahoo/provider.py b/allauth/socialaccount/providers/wahoo/provider.py index ad66683a1f..f37fac74a5 100644 --- a/allauth/socialaccount/providers/wahoo/provider.py +++ b/allauth/socialaccount/providers/wahoo/provider.py @@ -15,7 +15,7 @@ class WahooProvider(OAuth2Provider): account_class = WahooAccount def extract_uid(self, data): - return data.get("id") + return str(data["id"]) def extract_common_fields(self, data): extra_common = super(WahooProvider, self).extract_common_fields(data) diff --git a/allauth/socialaccount/providers/weixin/tests.py b/allauth/socialaccount/providers/weixin/tests.py index 4763cb7454..04a0b9aa04 100644 --- a/allauth/socialaccount/providers/weixin/tests.py +++ b/allauth/socialaccount/providers/weixin/tests.py @@ -1,14 +1,15 @@ # -*- coding: utf-8 -*- from __future__ import unicode_literals -from allauth.socialaccount.providers import registry -from allauth.socialaccount.tests import create_oauth2_tests -from allauth.tests import MockedResponse +from allauth.socialaccount.tests import OAuth2TestsMixin +from allauth.tests import MockedResponse, TestCase from .provider import WeixinProvider -class WeixinTests(create_oauth2_tests(registry.by_id(WeixinProvider.id))): +class WeixinTests(OAuth2TestsMixin, TestCase): + provider_id = WeixinProvider.id + def get_mocked_response(self): return MockedResponse( 200, diff --git a/allauth/socialaccount/providers/ynab/tests.py b/allauth/socialaccount/providers/ynab/tests.py index bbe09c443d..f62e4c5ac6 100644 --- a/allauth/socialaccount/providers/ynab/tests.py +++ b/allauth/socialaccount/providers/ynab/tests.py @@ -51,7 +51,7 @@ def raise_for_status(self): ) adapter = YNABOAuth2Adapter(request) - app = adapter.get_provider().get_app(request) + app = adapter.get_provider().app token = SocialToken(token="some_token") response_with_401 = LessMockedResponse( 401, diff --git a/allauth/socialaccount/providers/zoho/provider.py b/allauth/socialaccount/providers/zoho/provider.py index 2eafbbbe53..9321450f6d 100644 --- a/allauth/socialaccount/providers/zoho/provider.py +++ b/allauth/socialaccount/providers/zoho/provider.py @@ -18,7 +18,7 @@ def get_default_scope(self): return ["aaaserver.profile.READ"] def extract_uid(self, data): - return data["ZUID"] + return str(data["ZUID"]) def extract_common_fields(self, data): return dict( diff --git a/allauth/socialaccount/templatetags/socialaccount.py b/allauth/socialaccount/templatetags/socialaccount.py index 37033c1e70..76bf337c24 100644 --- a/allauth/socialaccount/templatetags/socialaccount.py +++ b/allauth/socialaccount/templatetags/socialaccount.py @@ -1,7 +1,7 @@ from django import template from django.template.defaulttags import token_kwargs -from allauth.socialaccount import providers +from allauth.socialaccount.adapter import get_adapter from allauth.utils import get_request_param @@ -14,9 +14,11 @@ def __init__(self, provider_id, params): self.params = params def render(self, context): - provider_id = self.provider_id_var.resolve(context) + provider = self.provider_id_var.resolve(context) request = context.get("request") - provider = providers.registry.by_id(provider_id, request) + if isinstance(provider, str): + adapter = get_adapter(request) + provider = adapter.get_provider(request, provider) query = dict( [(str(name), var.resolve(context)) for name, var in self.params.items()] ) @@ -55,9 +57,8 @@ def provider_login_url(parser, token): class ProvidersMediaJSNode(template.Node): def render(self, context): request = context["request"] - ret = "\n".join( - p.media_js(request) for p in providers.registry.get_list(request) - ) + providers = get_adapter(request).list_providers(request) + ret = "\n".join(p.media_js(request) for p in providers) return ret @@ -83,8 +84,8 @@ def get_social_accounts(user): return accounts -@register.simple_tag -def get_providers(): +@register.simple_tag(takes_context=True) +def get_providers(context): """ Returns a list of social authentication providers. @@ -93,4 +94,7 @@ def get_providers(): Then within the template context, `socialaccount_providers` will hold a list of social providers configured for the current site. """ - return providers.registry.get_list() + request = context["request"] + adapter = get_adapter(request) + providers = adapter.list_providers(request) + return sorted(providers, key=lambda p: p.name) diff --git a/allauth/socialaccount/tests/__init__.py b/allauth/socialaccount/tests/__init__.py index 34a82c6d27..c927597391 100644 --- a/allauth/socialaccount/tests/__init__.py +++ b/allauth/socialaccount/tests/__init__.py @@ -7,6 +7,7 @@ from urllib.parse import parse_qs, urlparse from django.conf import settings +from django.test import RequestFactory from django.test.utils import override_settings from django.urls import reverse from django.utils.http import urlencode @@ -14,7 +15,8 @@ import allauth.app_settings from allauth.account.models import EmailAddress from allauth.account.utils import user_email, user_username -from allauth.socialaccount import app_settings, providers +from allauth.socialaccount import app_settings +from allauth.socialaccount.adapter import get_adapter from allauth.socialaccount.models import SocialAccount, SocialApp from allauth.tests import ( Mock, @@ -26,20 +28,23 @@ from allauth.utils import get_user_model -def setup_app(provider): - app = None - if not app_settings.PROVIDERS.get(provider.id, {}).get("APP"): - app = SocialApp.objects.create( - provider=provider.id, - name=provider.id, - client_id="app123id", - key=provider.id, - secret="dummy", - ) - if allauth.app_settings.SITES_ENABLED: - from django.contrib.sites.models import Site +def setup_app(provider_id): + request = RequestFactory().get("/") + apps = get_adapter(request).list_apps(request, provider_id) + if apps: + return apps[0] + + app = SocialApp.objects.create( + provider=provider_id, + name=provider_id, + client_id="app123id", + key=provider_id, + secret="dummy", + ) + if allauth.app_settings.SITES_ENABLED: + from django.contrib.sites.models import Site - app.sites.add(Site.objects.get_current()) + app.sites.add(Site.objects.get_current()) return app @@ -51,8 +56,9 @@ def get_mocked_response(self): def setUp(self): super(OAuthTestsMixin, self).setUp() - self.provider = providers.registry.by_id(self.provider_id) - self.app = setup_app(self.provider) + self.app = setup_app(self.provider_id) + request = RequestFactory().get("/") + self.provider = self.app.get_provider(request) @override_settings(SOCIALACCOUNT_AUTO_SIGNUP=False) def test_login(self): @@ -163,26 +169,30 @@ def get_login_response_json(self, with_refresh_token=True): def setUp(self): super(OAuth2TestsMixin, self).setUp() - self.provider = providers.registry.by_id(self.provider_id) - self.app = setup_app(self.provider) + self.setup_provider() + + def setup_provider(self): + self.app = setup_app(self.provider_id) + self.request = RequestFactory().get("/") + self.provider = self.app.get_provider(self.request) def test_provider_has_no_pkce_params(self): - provider_settings = app_settings.PROVIDERS.get(self.provider_id, {}) + provider_settings = app_settings.PROVIDERS.get(self.app.provider, {}) provider_settings_with_pkce_set = provider_settings.copy() provider_settings_with_pkce_set["OAUTH_PKCE_ENABLED"] = False with self.settings( - SOCIALACCOUNT_PROVIDERS={self.provider_id: provider_settings_with_pkce_set} + SOCIALACCOUNT_PROVIDERS={self.app.provider: provider_settings_with_pkce_set} ): self.assertEqual(self.provider.get_pkce_params(), {}) def test_provider_has_pkce_params(self): - provider_settings = app_settings.PROVIDERS.get(self.provider_id, {}) + provider_settings = app_settings.PROVIDERS.get(self.app.provider, {}) provider_settings_with_pkce_set = provider_settings.copy() provider_settings_with_pkce_set["OAUTH_PKCE_ENABLED"] = True with self.settings( - SOCIALACCOUNT_PROVIDERS={self.provider_id: provider_settings_with_pkce_set} + SOCIALACCOUNT_PROVIDERS={self.app.provider: provider_settings_with_pkce_set} ): pkce_params = self.provider.get_pkce_params() self.assertEqual( @@ -209,13 +219,13 @@ def test_login(self): @override_settings(SOCIALACCOUNT_AUTO_SIGNUP=False) def test_login_with_pkce_disabled(self): - provider_settings = app_settings.PROVIDERS.get(self.provider_id, {}) + provider_settings = app_settings.PROVIDERS.get(self.app.provider, {}) provider_settings_with_pkce_disabled = provider_settings.copy() provider_settings_with_pkce_disabled["OAUTH_PKCE_ENABLED"] = False with self.settings( SOCIALACCOUNT_PROVIDERS={ - self.provider_id: provider_settings_with_pkce_disabled + self.app.provider: provider_settings_with_pkce_disabled } ): resp_mock = self.get_mocked_response() @@ -231,12 +241,12 @@ def test_login_with_pkce_disabled(self): @override_settings(SOCIALACCOUNT_AUTO_SIGNUP=False) def test_login_with_pkce_enabled(self): - provider_settings = app_settings.PROVIDERS.get(self.provider_id, {}) + provider_settings = app_settings.PROVIDERS.get(self.app.provider, {}) provider_settings_with_pkce_enabled = provider_settings.copy() provider_settings_with_pkce_enabled["OAUTH_PKCE_ENABLED"] = True with self.settings( SOCIALACCOUNT_PROVIDERS={ - self.provider_id: provider_settings_with_pkce_enabled + self.app.provider: provider_settings_with_pkce_enabled } ): resp_mock = self.get_mocked_response() @@ -300,15 +310,12 @@ def test_account_refresh_token_saved_next_login(self): def login(self, resp_mock=None, process="login", with_refresh_token=True): resp = self.client.post( - reverse(self.provider.id + "_login") - + "?" - + urlencode(dict(process=process)) + self.provider.get_login_url(self.request, process=process) ) - p = urlparse(resp["location"]) q = parse_qs(p.query) - pkce_enabled = app_settings.PROVIDERS.get(self.provider_id, {}).get( + pkce_enabled = app_settings.PROVIDERS.get(self.app.provider, {}).get( "OAUTH_PKCE_ENABLED", self.provider.pkce_enabled_default ) @@ -318,7 +325,7 @@ def login(self, resp_mock=None, process="login", with_refresh_token=True): code_challenge = q["code_challenge"][0] self.assertEqual(q["code_challenge_method"][0], "S256") - complete_url = reverse(self.provider.id + "_callback") + complete_url = self.provider.get_callback_url() self.assertGreater(q["redirect_uri"][0].find(complete_url), 0) response_json = self.get_login_response_json( with_refresh_token=with_refresh_token @@ -366,7 +373,7 @@ def get_complete_parameters(self, q): return {"code": "test", "state": q["state"][0]} def test_authentication_error(self): - resp = self.client.get(reverse(self.provider.id + "_callback")) + resp = self.client.get(self.provider.get_callback_url()) self.assertTemplateUsed( resp, "socialaccount/authentication_error.%s" @@ -374,16 +381,6 @@ def test_authentication_error(self): ) -# For backward-compatibility with third-party provider tests that call -# create_oauth2_tests() rather than using the mixin directly. -def create_oauth2_tests(provider): - class Class(OAuth2TestsMixin, TestCase): - provider_id = provider.id - - Class.__name__ = "OAuth2Tests_" + provider.id - return Class - - class OpenIDConnectTests(OAuth2TestsMixin): oidc_info_content = { "authorization_endpoint": "/login", @@ -414,6 +411,18 @@ def setUp(self): self.mock_requests = patcher.start() self.addCleanup(patcher.stop) + def setup_provider(self): + self.app = setup_app(self.provider_id) + if self.provider_id not in ["keycloak"]: + self.app.provider_id = self.provider_id + self.app.provider = "openid_connect" + self.app.settings = { + "server_url": "https://unittest.example.com", + } + self.app.save() + self.request = RequestFactory().get("/") + self.provider = self.app.get_provider(self.request) + def get_mocked_response(self): # Enable test_login in OAuth2TestsMixin, but this response mock is unused return True @@ -428,5 +437,9 @@ def _mocked_responses(self, url, *args, **kwargs): def test_login_auto_signup(self): resp = self.login() self.assertRedirects(resp, "/accounts/profile/", fetch_redirect_response=False) - sa = SocialAccount.objects.get(provider=self.provider.id) + sa = SocialAccount.objects.get( + # For Keycloak, `provider_id` is empty. + provider=self.app.provider_id + or self.app.provider + ) self.assertDictEqual(sa.extra_data, self.extra_data) diff --git a/allauth/socialaccount/tests/test_registry.py b/allauth/socialaccount/tests/test_registry.py index b4e6672181..fac460ac0a 100644 --- a/allauth/socialaccount/tests/test_registry.py +++ b/allauth/socialaccount/tests/test_registry.py @@ -18,13 +18,15 @@ class ProviderRegistryTests(TestCase): ) def test_load_provider_with_default_app_config(self): registry = providers.ProviderRegistry() - provider_list = registry.get_list() + provider_list = registry.get_class_list() self.assertTrue(registry.loaded) self.assertEqual(1, len(provider_list)) - self.assertIsInstance( - provider_list[0], - providers.facebook.provider.FacebookProvider, + self.assertTrue( + issubclass( + provider_list[0], + providers.facebook.provider.FacebookProvider, + ) ) app_config_list = list(apps.get_app_configs()) @@ -40,13 +42,15 @@ def test_load_provider_with_default_app_config(self): ) def test_load_provider_with_custom_app_config(self): registry = providers.ProviderRegistry() - provider_list = registry.get_list() + provider_list = registry.get_class_list() self.assertTrue(registry.loaded) self.assertEqual(1, len(provider_list)) - self.assertIsInstance( - provider_list[0], - providers.facebook.provider.FacebookProvider, + self.assertTrue( + issubclass( + provider_list[0], + providers.facebook.provider.FacebookProvider, + ) ) app_config_list = list(apps.get_app_configs()) diff --git a/allauth/socialaccount/tests/test_signup.py b/allauth/socialaccount/tests/test_signup.py index 0a3112b62e..c6480f23a8 100644 --- a/allauth/socialaccount/tests/test_signup.py +++ b/allauth/socialaccount/tests/test_signup.py @@ -21,7 +21,9 @@ class SignupTests(TestCase): def setUp(self): super().setUp() - for provider in providers.registry.get_list(): + for provider in providers.registry.get_class_list(): + if provider.id == "openid_connect": + continue app = SocialApp.objects.create( provider=provider.id, name=provider.id, diff --git a/allauth/templates/socialaccount/snippets/provider_list.html b/allauth/templates/socialaccount/snippets/provider_list.html index e76a29695a..a7d53fe033 100644 --- a/allauth/templates/socialaccount/snippets/provider_list.html +++ b/allauth/templates/socialaccount/snippets/provider_list.html @@ -6,15 +6,15 @@ {% if provider.id == "openid" %} {% for brand in provider.get_brands %}
  • - {{brand.name}}
  • {% endfor %} {% endif %}
  • - {{provider.name}} + {{provider.name}}
  • {% endfor %} diff --git a/allauth/tests.py b/allauth/tests.py index dbb018180e..1ed6700a89 100644 --- a/allauth/tests.py +++ b/allauth/tests.py @@ -199,6 +199,14 @@ def test_int_to_base36(self): def test_templatetag_with_csrf_failure(self): # Generate a fictitious GET request + from allauth.socialaccount.models import SocialApp + + app = SocialApp.objects.create(provider="google") + if app_settings.SITES_ENABLED: + from django.contrib.sites.models import Site + + app.sites.add(Site.objects.get_current()) + request = self.factory.get("/tests/test_403_csrf.html") # Simulate a CSRF failure by calling the View directly # This template is using the `provider_login_url` templatetag diff --git a/allauth/urls.py b/allauth/urls.py index 0c54012c4c..aa572a308e 100644 --- a/allauth/urls.py +++ b/allauth/urls.py @@ -14,12 +14,20 @@ # Provider urlpatterns, as separate attribute (for reusability). provider_urlpatterns = [] -for provider in providers.registry.get_list(): +provider_classes = providers.registry.get_class_list() + +# We need to move the OpenID Connect provider to the end. The reason is that +# matches URLs that the builtin providers also match. +provider_classes = [cls for cls in provider_classes if cls.id != "openid_connect"] + [ + cls for cls in provider_classes if cls.id == "openid_connect" +] +for provider_class in provider_classes: try: - prov_mod = import_module(provider.get_package() + ".urls") + prov_mod = import_module(provider_class.get_package() + ".urls") except ImportError: continue prov_urlpatterns = getattr(prov_mod, "urlpatterns", None) if prov_urlpatterns: provider_urlpatterns += prov_urlpatterns + urlpatterns += provider_urlpatterns diff --git a/docs/providers.rst b/docs/providers.rst index 23a7700a44..d59b2f46b4 100644 --- a/docs/providers.rst +++ b/docs/providers.rst @@ -1523,36 +1523,37 @@ The OpenID provider can be forced to operate in stateless mode as follows:: OpenID Connect -------------- -The OpenID Connect provider provides a dynamic instance for each configured -server. To expose an OpenID Connect server as an authentication method, -configuration of one or more servers is required: +The OpenID Connect provider provides access to multiple independent OpenID +Connect (sub)providers. You configure these (sub)providers by adding apps to the +configuration of the overall OpenID connect provider. Each app represents a +standalone OpenID Connect provider: .. code-block:: python SOCIALACCOUNT_PROVIDERS = { "openid_connect": { - "SERVERS": [ + "APPS": [ { - "id": "my-server", # 30 characters or less + "provider_id": "my-server", "name": "My Login Server", - "server_url": "https://my.server.example.com", - # Optional token endpoint authentication method. - # May be one of "client_secret_basic", "client_secret_post" - # If omitted, a method from the the server's - # token auth methods list is used - "token_auth_method": "client_secret_basic", - "APP": { - "client_id": "your.service.id", - "secret": "your.service.secret", + "client_id": "your.service.id", + "secret": "your.service.secret", + "settings": { + "server_url": "https://my.server.example.com", + # Optional token endpoint authentication method. + # May be one of "client_secret_basic", "client_secret_post" + # If omitted, a method from the the server's + # token auth methods list is used + "token_auth_method": "client_secret_basic", }, }, { - "id": "other-server", # 30 characters or less + "provider_id": "other-server", "name": "Other Login Server", - "server_url": "https://other.server.example.com", - "APP": { - "client_id": "your.other.service.id", - "secret": "your.other.service.secret", + "client_id": "your.other.service.id", + "secret": "your.other.service.secret", + "settings": { + "server_url": "https://other.server.example.com", }, }, ] @@ -1563,8 +1564,8 @@ This configuration example will create two independent provider instances, ``My Login Server`` and ``Other Login Server``. The OpenID Connect callback URL for each configured server is at -``/accounts/{id}/login/callback/`` where ``{id}`` is the configured -server's ``id`` value (``my-server`` or ``other-server`` in the above example). +``/accounts/{id}/login/callback/`` where ``{id}`` is the configured app's +``provider_id`` value (``my-server`` or ``other-server`` in the above example). OpenStreetMap diff --git a/example/example/demo/apps.py b/example/example/demo/apps.py index f336efcae7..a9045a8946 100644 --- a/example/example/demo/apps.py +++ b/example/example/demo/apps.py @@ -16,22 +16,22 @@ def setup_dummy_social_apps(sender, **kwargs): from allauth.socialaccount.providers.oauth.provider import OAuthProvider site = Site.objects.get_current() - for provider in registry.get_list(): - if isinstance(provider, OAuth2Provider) or isinstance(provider, OAuthProvider): + for provider_class in registry.get_class_list(): + if issubclass(provider_class, (OAuthProvider,OAuth2Provider)): try: - SocialApp.objects.get(provider=provider.id, sites=site) + SocialApp.objects.get(provider=provider_class.id, sites=site) except SocialApp.DoesNotExist: print( "Installing dummy application credentials for %s." " Authentication via this provider will not work" " until you configure proper credentials via the" - " Django admin (`SocialApp` models)" % provider.id + " Django admin (`SocialApp` models)" % provider_class.id ) app = SocialApp.objects.create( - provider=provider.id, + provider=provider_class.id, secret="secret", client_id="client-id", - name="Dummy %s app" % provider.id, + name="Dummy %s app" % provider_class.id, ) app.sites.add(site) @@ -39,7 +39,7 @@ def setup_dummy_social_apps(sender, **kwargs): class DemoConfig(AppConfig): name = "example.demo" verbose_name = _("Demo") - default_auto_field = 'django.db.models.AutoField' + default_auto_field = "django.db.models.AutoField" def ready(self): post_migrate.connect(setup_dummy_social_apps, sender=self) diff --git a/example/example/settings.py b/example/example/settings.py index 810b3fc1d3..862068cd6c 100644 --- a/example/example/settings.py +++ b/example/example/settings.py @@ -144,6 +144,7 @@ "allauth.socialaccount.providers.linkedin", "allauth.socialaccount.providers.mediawiki", "allauth.socialaccount.providers.openid", + "allauth.socialaccount.providers.openid_connect", "allauth.socialaccount.providers.pinterest", "allauth.socialaccount.providers.pocket", "allauth.socialaccount.providers.reddit", @@ -174,7 +175,7 @@ ] ALLOWED_HOSTS = ["127.0.0.1", "localhost"] - +SECURE_PROXY_SSL_HEADER = ("HTTP_X_FORWARDED_PROTO", "https") try: from .local_settings import * # noqa except ImportError: diff --git a/test_settings.py b/test_settings.py index 4684f6a6c5..d33a141b78 100644 --- a/test_settings.py +++ b/test_settings.py @@ -1,5 +1,7 @@ import os +from django.contrib.auth.hashers import PBKDF2PasswordHasher + SECRET_KEY = "psst" SITE_ID = 1 @@ -23,7 +25,9 @@ TEMPLATES = [ { "BACKEND": "django.template.backends.django.DjangoTemplates", - "DIRS": [os.path.join(os.path.dirname(__file__), "example", "example", "templates")], + "DIRS": [ + os.path.join(os.path.dirname(__file__), "example", "example", "templates") + ], "APP_DIRS": True, "OPTIONS": { "context_processors": [ @@ -184,8 +188,6 @@ STATIC_ROOT = "/tmp/" # Dummy STATIC_URL = "/static/" -from django.contrib.auth.hashers import PBKDF2PasswordHasher - class MyPBKDF2PasswordHasher(PBKDF2PasswordHasher): """ @@ -212,16 +214,24 @@ class MyPBKDF2PasswordHasher(PBKDF2PasswordHasher): SOCIALACCOUNT_PROVIDERS = { "openid_connect": { - "SERVERS": [ + "APPS": [ { - "id": "unittest-server", + "provider_id": "unittest-server", "name": "Unittest Server", - "server_url": "https://unittest.example.com", + "client_id": "Unittest client_id", + "client_secret": "Unittest client_secret", + "settings": { + "server_url": "https://unittest.example.com", + }, }, { - "id": "other-server", + "provider_id": "other-server", "name": "Other Example Server", - "server_url": "https://other.example.com", + "client_id": "other client_id", + "client_secret": "other client_secret", + "settings": { + "server_url": "https://other.example.com", + }, }, ], }