Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CORS by passing 'Origin' header to OAuthLib #1229

Merged
merged 12 commits into from
Oct 19, 2023
11 changes: 11 additions & 0 deletions docs/settings.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,17 @@ assigned ports.
Note that you may override ``Application.get_allowed_schemes()`` to set this on
a per-application basis.

ALLOWED_SCHEMES
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Default: ``["https"]``

A list of schemes that the ``allowed_origins`` field will be validated against.
Setting this to ``["https"]`` only in production is strongly recommended.
Adding ``"http"`` to the list is considered to be safe only for local development and testing.
Note that `OAUTHLIB_INSECURE_TRANSPORT <https://oauthlib.readthedocs.io/en/latest/oauth2/security.html#envvar-OAUTHLIB_INSECURE_TRANSPORT>`_
environment variable should be also set to allow http origins.


APPLICATION_MODEL
~~~~~~~~~~~~~~~~~
Expand Down
8 changes: 6 additions & 2 deletions oauth2_provider/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .scopes import get_scopes_backend
from .settings import oauth2_settings
from .utils import jwk_from_pem
from .validators import RedirectURIValidator, URIValidator, WildcardSet
from .validators import AllowedURIValidator, RedirectURIValidator, WildcardSet


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -218,7 +218,7 @@ def clean(self):
allowed_origins = self.allowed_origins.strip().split()
if allowed_origins:
# oauthlib allows only https scheme for CORS
validator = URIValidator({"https"})
validator = AllowedURIValidator(oauth2_settings.ALLOWED_SCHEMES, "Origin")
for uri in allowed_origins:
validator(uri)

Expand Down Expand Up @@ -808,6 +808,10 @@ def is_origin_allowed(origin, allowed_origins):
"""

parsed_origin = urlparse(origin)

if parsed_origin.scheme not in oauth2_settings.ALLOWED_SCHEMES:
return False

for allowed_origin in allowed_origins:
parsed_allowed_origin = urlparse(allowed_origin)
if (
Expand Down
1 change: 1 addition & 0 deletions oauth2_provider/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
"REFRESH_TOKEN_ADMIN_CLASS": "oauth2_provider.admin.RefreshTokenAdmin",
"REQUEST_APPROVAL_PROMPT": "force",
"ALLOWED_REDIRECT_URI_SCHEMES": ["http", "https"],
"ALLOWED_SCHEMES": ["https"],
"OIDC_ENABLED": False,
"OIDC_ISS_ENDPOINT": "",
"OIDC_USERINFO_ENDPOINT": "",
Expand Down
27 changes: 27 additions & 0 deletions oauth2_provider/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,33 @@ def __call__(self, value):
raise ValidationError("Redirect URIs must not contain fragments")


class AllowedURIValidator(URIValidator):
def __init__(self, schemes, name, allow_path=False, allow_query=False, allow_fragments=False):
"""
:param schemes: List of allowed schemes. E.g.: ["https"]
:param name: Name of the validated URI. It is required for validation message. E.g.: "Origin"
:param allow_path: If URI can contain path part
:param allow_query: If URI can contain query part
:param allow_fragments: If URI can contain fragments part
"""
super().__init__(schemes=schemes)
self.name = name
self.allow_path = allow_path
self.allow_query = allow_query
self.allow_fragments = allow_fragments

def __call__(self, value):
super().__call__(value)
value = force_str(value)
scheme, netloc, path, query, fragment = urlsplit(value)
if query and not self.allow_query:
raise ValidationError("{} URIs must not contain query".format(self.name))
if fragment and not self.allow_fragments:
raise ValidationError("{} URIs must not contain fragments".format(self.name))
if path and not self.allow_path:
raise ValidationError("{} URIs must not contain path".format(self.name))


##
# WildcardSet is a special set that contains everything.
# This is required in order to move validation of the scheme from
Expand Down
12 changes: 12 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,18 @@ def public_application():
)


@pytest.fixture
def cors_application():
return Application.objects.create(
name="Test CORS Application",
client_type=Application.CLIENT_CONFIDENTIAL,
authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE,
algorithm=Application.RS256_ALGORITHM,
client_secret=CLEARTEXT_SECRET,
allowed_origins="https://example.com http://example.com",
)


@pytest.fixture
def logged_in_client(test_user):
from django.test.client import Client
Expand Down
8 changes: 8 additions & 0 deletions tests/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,11 @@
"READ_SCOPE": "read",
"WRITE_SCOPE": "write",
}

ALLOWED_SCHEMES_DEFAULT = {
"ALLOWED_SCHEMES": ["https"],
}

ALLOWED_SCHEMES_HTTP = {
"ALLOWED_SCHEMES": ["https", "http"],
}
16 changes: 16 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,3 +594,19 @@ def test_application_clean(oauth2_settings, application):
assert "Enter a valid URL" in str(exc.value)
application.allowed_origins = "https://example.com"
application.clean()


@pytest.mark.django_db
@pytest.mark.oauth2_settings(presets.ALLOWED_SCHEMES_DEFAULT)
def test_application_origin_allowed_default_https(oauth2_settings, cors_application):
"""Test that http schemes are not allowed because ALLOWED_SCHEMES allows only https"""
assert cors_application.origin_allowed("https://example.com")
assert not cors_application.origin_allowed("http://example.com")


@pytest.mark.django_db
@pytest.mark.oauth2_settings(presets.ALLOWED_SCHEMES_HTTP)
def test_application_origin_allowed_http(oauth2_settings, cors_application):
"""Test that http schemes are allowed because http was added to ALLOWED_SCHEMES"""
assert cors_application.origin_allowed("https://example.com")
assert cors_application.origin_allowed("http://example.com")
71 changes: 70 additions & 1 deletion tests/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from django.core.validators import ValidationError
from django.test import TestCase

from oauth2_provider.validators import RedirectURIValidator
from oauth2_provider.validators import AllowedURIValidator, RedirectURIValidator


@pytest.mark.usefixtures("oauth2_settings")
Expand Down Expand Up @@ -36,6 +36,11 @@ def test_validate_custom_uri_scheme(self):
# Check ValidationError not thrown
validator(uri)

validator = AllowedURIValidator(["my-scheme", "https", "git+ssh"], "Origin")
for uri in good_uris:
# Check ValidationError not thrown
validator(uri)

def test_validate_bad_uris(self):
validator = RedirectURIValidator(allowed_schemes=["https"])
self.oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["https", "good"]
Expand All @@ -61,3 +66,67 @@ def test_validate_bad_uris(self):
for uri in bad_uris:
with self.assertRaises(ValidationError):
validator(uri)

def test_validate_good_origin_uris(self):
"""
Test AllowedURIValidator validates origin URIs if they match requirements
"""
validator = AllowedURIValidator(
["https"],
"Origin",
allow_path=False,
allow_query=False,
allow_fragments=False,
)
good_uris = [
"https://example.com",
"https://example.com:8080",
"https://example",
"https://localhost",
"https://1.1.1.1",
"https://127.0.0.1",
"https://255.255.255.255",
]
for uri in good_uris:
# Check ValidationError not thrown
validator(uri)

def test_validate_bad_origin_uris(self):
"""
Test AllowedURIValidator rejects origin URIs if they do not match requirements
"""
validator = AllowedURIValidator(
["https"],
"Origin",
allow_path=False,
allow_query=False,
allow_fragments=False,
)
bad_uris = [
"http:/example.com",
"HTTP://localhost",
"HTTP://example.com",
"HTTP://example.com.",
"http://example.com/#fragment",
"123://example.com",
"http://fe80::1",
"git+ssh://example.com",
"my-scheme://example.com",
"uri-without-a-scheme",
"https://example.com/#fragment",
"good://example.com/#fragment",
" ",
"",
# Bad IPv6 URL, urlparse behaves differently for these
'https://["><script>alert()</script>',
# Origin uri should not contain path, query of fragment parts
# https://www.rfc-editor.org/rfc/rfc6454#section-7.1
"https://example.com/",
"https://example.com/test",
"https://example.com/?q=test",
"https://example.com/#test",
]

for uri in bad_uris:
with self.assertRaises(ValidationError):
validator(uri)