Skip to content

Commit

Permalink
Merge pull request #27 from variable/cors_model
Browse files Browse the repository at this point in the history
added cors model for dynamic cors settings
  • Loading branch information
ottoyiu committed Jun 14, 2014
2 parents 715bebf + 1dece86 commit 484af82
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 0 deletions.
2 changes: 2 additions & 0 deletions corsheaders/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,5 @@
CORS_EXPOSE_HEADERS = getattr(settings, 'CORS_EXPOSE_HEADERS', ())

CORS_URLS_REGEX = getattr(settings, 'CORS_URLS_REGEX', '^.*$')

CORS_MODEL = getattr(settings, 'CORS_MODEL', None)
6 changes: 6 additions & 0 deletions corsheaders/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from urllib.parse import urlparse

from corsheaders import defaults as settings
from django.db.models.loading import get_model


ACCESS_CONTROL_ALLOW_ORIGIN = 'Access-Control-Allow-Origin'
Expand Down Expand Up @@ -41,6 +42,11 @@ def process_response(self, request, response):
# todo: check hostname from db instead
url = urlparse(origin)

if settings.CORS_MODEL is not None:
model = get_model(*settings.CORS_MODEL.split('.'))
if model.objects.filter(cors=url.netloc).count():
response[ACCESS_CONTROL_ALLOW_ORIGIN] = origin

if not settings.CORS_ORIGIN_ALLOW_ALL and self.origin_not_found_in_white_lists(origin, url):
return response

Expand Down
4 changes: 4 additions & 0 deletions corsheaders/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from django.db import models

class CorsModel(models.Model):
cors = models.CharField(max_length=255)
27 changes: 27 additions & 0 deletions corsheaders/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,15 @@ def assertAccessControlAllowOriginEquals(self, response, header):
self.assertEqual(response[ACCESS_CONTROL_ALLOW_ORIGIN], header)

def test_process_response_no_origin(self, settings):
settings.CORS_MODEL = None
settings.CORS_URLS_REGEX = '^.*$'
response = HttpResponse()
request = Mock(path='/', META={})
processed = self.middleware.process_response(request, response)
self.assertNotIn(ACCESS_CONTROL_ALLOW_ORIGIN, processed)

def test_process_response_not_in_whitelist(self, settings):
settings.CORS_MODEL = None
settings.CORS_ORIGIN_ALLOW_ALL = False
settings.CORS_ORIGIN_WHITELIST = ['example.com']
settings.CORS_URLS_REGEX = '^.*$'
Expand All @@ -88,6 +90,7 @@ def test_process_response_not_in_whitelist(self, settings):
self.assertNotIn(ACCESS_CONTROL_ALLOW_ORIGIN, processed)

def test_process_response_in_whitelist(self, settings):
settings.CORS_MODEL = None
settings.CORS_ORIGIN_ALLOW_ALL = False
settings.CORS_ORIGIN_WHITELIST = ['example.com', 'foobar.it']
settings.CORS_URLS_REGEX = '^.*$'
Expand All @@ -97,6 +100,7 @@ def test_process_response_in_whitelist(self, settings):
self.assertAccessControlAllowOriginEquals(processed, 'http://foobar.it')

def test_process_response_expose_headers(self, settings):
settings.CORS_MODEL = None
settings.CORS_ORIGIN_ALLOW_ALL = True
settings.CORS_EXPOSE_HEADERS = ['accept', 'origin', 'content-type']
settings.CORS_URLS_REGEX = '^.*$'
Expand All @@ -107,6 +111,7 @@ def test_process_response_expose_headers(self, settings):
'accept, origin, content-type')

def test_process_response_dont_expose_headers(self, settings):
settings.CORS_MODEL = None
settings.CORS_ORIGIN_ALLOW_ALL = True
settings.CORS_EXPOSE_HEADERS = []
settings.CORS_URLS_REGEX = '^.*$'
Expand All @@ -116,6 +121,7 @@ def test_process_response_dont_expose_headers(self, settings):
self.assertNotIn(ACCESS_CONTROL_EXPOSE_HEADERS, processed)

def test_process_response_allow_credentials(self, settings):
settings.CORS_MODEL = None
settings.CORS_ORIGIN_ALLOW_ALL = True
settings.CORS_ALLOW_CREDENTIALS = True
settings.CORS_URLS_REGEX = '^.*$'
Expand All @@ -125,6 +131,7 @@ def test_process_response_allow_credentials(self, settings):
self.assertEqual(processed[ACCESS_CONTROL_ALLOW_CREDENTIALS], 'true')

def test_process_response_dont_allow_credentials(self, settings):
settings.CORS_MODEL = None
settings.CORS_ORIGIN_ALLOW_ALL = True
settings.CORS_ALLOW_CREDENTIALS = False
settings.CORS_URLS_REGEX = '^.*$'
Expand All @@ -134,6 +141,7 @@ def test_process_response_dont_allow_credentials(self, settings):
self.assertNotIn(ACCESS_CONTROL_ALLOW_CREDENTIALS, processed)

def test_process_response_options_method(self, settings):
settings.CORS_MODEL = None
settings.CORS_ORIGIN_ALLOW_ALL = True
settings.CORS_ALLOW_HEADERS = ['content-type', 'origin']
settings.CORS_ALLOW_METHODS = ['GET', 'OPTIONS']
Expand All @@ -149,6 +157,7 @@ def test_process_response_options_method(self, settings):
self.assertEqual(processed[ACCESS_CONTROL_MAX_AGE], '1002')

def test_process_response_options_method_no_max_age(self, settings):
settings.CORS_MODEL = None
settings.CORS_ORIGIN_ALLOW_ALL = True
settings.CORS_ALLOW_HEADERS = ['content-type', 'origin']
settings.CORS_ALLOW_METHODS = ['GET', 'OPTIONS']
Expand All @@ -164,6 +173,7 @@ def test_process_response_options_method_no_max_age(self, settings):
self.assertNotIn(ACCESS_CONTROL_MAX_AGE, processed)

def test_process_response_whitelist_with_port(self, settings):
settings.CORS_MODEL = None
settings.CORS_ORIGIN_ALLOW_ALL = False
settings.CORS_ALLOW_METHODS = ['OPTIONS']
settings.CORS_ORIGIN_WHITELIST = ('localhost:9000',)
Expand All @@ -175,6 +185,7 @@ def test_process_response_whitelist_with_port(self, settings):
self.assertEqual(processed.get(ACCESS_CONTROL_ALLOW_CREDENTIALS), 'true')

def test_process_response_adds_origin_when_domain_found_in_origin_regex_whitelist(self, settings):
settings.CORS_MODEL = None
settings.CORS_ORIGIN_REGEX_WHITELIST = ('^http?://(\w+\.)?google\.com$', )
settings.CORS_ALLOW_CREDENTIALS = True
settings.CORS_ORIGIN_ALLOW_ALL = False
Expand All @@ -187,6 +198,7 @@ def test_process_response_adds_origin_when_domain_found_in_origin_regex_whitelis
self.assertEqual(processed.get(ACCESS_CONTROL_ALLOW_ORIGIN), 'http://foo.google.com')

def test_process_response_will_not_add_origin_when_domain_not_found_in_origin_regex_whitelist(self, settings):
settings.CORS_MODEL = None
settings.CORS_ORIGIN_REGEX_WHITELIST = ('^http?://(\w+\.)?yahoo\.com$', )
settings.CORS_ALLOW_CREDENTIALS = True
settings.CORS_ORIGIN_ALLOW_ALL = False
Expand All @@ -197,3 +209,18 @@ def test_process_response_will_not_add_origin_when_domain_not_found_in_origin_re
request = Mock(path='/', META=request_headers, method='OPTIONS')
processed = self.middleware.process_response(request, response)
self.assertEqual(processed.get(ACCESS_CONTROL_ALLOW_ORIGIN), None)

def test_process_response_when_custom_model_enabled(self, settings):
from corsheaders.models import CorsModel
c = CorsModel.objects.create(cors='foo.google.com')
settings.CORS_ORIGIN_REGEX_WHITELIST = ()
settings.CORS_ALLOW_CREDENTIALS = False
settings.CORS_ORIGIN_ALLOW_ALL = False
settings.CORS_ALLOW_METHODS = settings.default_methods
settings.CORS_URLS_REGEX = '^.*$'
settings.INSTALLED_APP + ('corsheaders',)
settings.CORS_MODEL = 'corsheaders.CorsModel'
response = HttpResponse()
request = Mock(path='/', META={'HTTP_ORIGIN': 'http://foo.google.com'})
processed = self.middleware.process_response(request, response)
self.assertEqual(processed.get(ACCESS_CONTROL_ALLOW_ORIGIN), 'http://foo.google.com')

0 comments on commit 484af82

Please sign in to comment.