Skip to content

Commit

Permalink
refactor(middleware): Refactor internals of CSPMiddleware so that it'…
Browse files Browse the repository at this point in the history
…s easier to extend existing logic without copy/pasting it into subclass
  • Loading branch information
crbunney committed Jul 15, 2024
1 parent f860f6c commit 20b8683
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 57 deletions.
58 changes: 18 additions & 40 deletions csp/contrib/rate_limiting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@

from django.conf import settings

from csp.middleware import CSPMiddleware
from csp.utils import build_policy
from csp.middleware import CSPMiddleware, PolicyParts

if TYPE_CHECKING:
from django.http import HttpRequest, HttpResponseBase
Expand All @@ -16,48 +15,27 @@ class RateLimitedCSPMiddleware(CSPMiddleware):
"""A CSP middleware that rate-limits the number of violation reports sent
to report-uri by excluding it from some requests."""

def build_policy(self, request: HttpRequest, response: HttpResponseBase) -> str:
config = getattr(response, "_csp_config", None)
update = getattr(response, "_csp_update", None)
replace = getattr(response, "_csp_replace", {})
nonce = getattr(request, "_csp_nonce", None)

policy = getattr(settings, "CONTENT_SECURITY_POLICY", None)

if policy is None:
return ""

report_percentage = policy.get("REPORT_PERCENTAGE", 100)
remove_report = random.randint(0, 99) >= report_percentage
if remove_report:
replace.update(
{
"report-uri": None,
"report-to": None,
}
)

return build_policy(config=config, update=update, replace=replace, nonce=nonce)

def build_policy_ro(self, request: HttpRequest, response: HttpResponseBase) -> str:
config = getattr(response, "_csp_config_ro", None)
update = getattr(response, "_csp_update_ro", None)
replace = getattr(response, "_csp_replace_ro", {})
nonce = getattr(request, "_csp_nonce", None)

policy = getattr(settings, "CONTENT_SECURITY_POLICY_REPORT_ONLY", None)
def get_policy_parts(self, request: HttpRequest, response: HttpResponseBase, report_only: bool = False) -> PolicyParts:
policy_parts = super().get_policy_parts(request, response, report_only)

csp_setting_name = "CONTENT_SECURITY_POLICY_REPORT_ONLY" if report_only else "CONTENT_SECURITY_POLICY"
policy = getattr(settings, csp_setting_name, None)
if policy is None:
return ""
return policy_parts

report_percentage = policy.get("REPORT_PERCENTAGE", 100)
remove_report = random.randint(0, 99) >= report_percentage
remove_report = random.randint(0, 99) >= policy.get("REPORT_PERCENTAGE", 100)
if remove_report:
replace.update(
{
if policy_parts.replace is None:
policy_parts.replace = {
"report-uri": None,
"report-to": None,
}
)

return build_policy(config=config, update=update, replace=replace, nonce=nonce, report_only=True)
else:
policy_parts.replace.update(
{
"report-uri": None,
"report-to": None,
}
)

return policy_parts
48 changes: 36 additions & 12 deletions csp/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import base64
import http.client as http_client
import os
import warnings
from dataclasses import asdict, dataclass
from functools import partial
from typing import TYPE_CHECKING

Expand All @@ -11,12 +13,21 @@
from django.utils.functional import SimpleLazyObject

from csp.constants import HEADER, HEADER_REPORT_ONLY
from csp.utils import build_policy
from csp.utils import DIRECTIVES_T, build_policy

if TYPE_CHECKING:
from django.http import HttpRequest, HttpResponseBase


@dataclass
class PolicyParts:
# A dataclass is used rather than a namedtuple so that the attributes are mutable
config: DIRECTIVES_T | None = None
update: DIRECTIVES_T | None = None
replace: DIRECTIVES_T | None = None
nonce: str | None = None


class CSPMiddleware(MiddlewareMixin):
"""
Implements the Content-Security-Policy response header, which
Expand All @@ -25,6 +36,7 @@ class CSPMiddleware(MiddlewareMixin):
See http://www.w3.org/TR/CSP/
Can be customised by subclassing and extending the get_policy_parts method.
"""

def _make_nonce(self, request: HttpRequest) -> str:
Expand All @@ -49,7 +61,8 @@ def process_response(self, request: HttpRequest, response: HttpResponseBase) ->
if response.status_code in exempted_debug_codes and settings.DEBUG:
return response

csp = self.build_policy(request, response)
policy_parts = self.get_policy_parts(request=request, response=response)
csp = build_policy(**asdict(policy_parts))
if csp:
# Only set header if not already set and not an excluded prefix and not exempted.
is_not_exempt = getattr(response, "_csp_exempt", False) is False
Expand All @@ -60,7 +73,8 @@ def process_response(self, request: HttpRequest, response: HttpResponseBase) ->
if no_header and is_not_exempt and is_not_excluded:
response[HEADER] = csp

csp_ro = self.build_policy_ro(request, response)
policy_parts_ro = self.get_policy_parts(request=request, response=response, report_only=True)
csp_ro = build_policy(**asdict(policy_parts_ro), report_only=True)
if csp_ro:
# Only set header if not already set and not an excluded prefix and not exempted.
is_not_exempt = getattr(response, "_csp_exempt_ro", False) is False
Expand All @@ -74,15 +88,25 @@ def process_response(self, request: HttpRequest, response: HttpResponseBase) ->
return response

def build_policy(self, request: HttpRequest, response: HttpResponseBase) -> str:
config = getattr(response, "_csp_config", None)
update = getattr(response, "_csp_update", None)
replace = getattr(response, "_csp_replace", None)
nonce = getattr(request, "_csp_nonce", None)
return build_policy(config=config, update=update, replace=replace, nonce=nonce)
warnings.warn("deprecated in favor of get_policy_parts", DeprecationWarning)
policy_parts = self.get_policy_parts(request=request, response=response, report_only=False)
return build_policy(**asdict(policy_parts))

def build_policy_ro(self, request: HttpRequest, response: HttpResponseBase) -> str:
config = getattr(response, "_csp_config_ro", None)
update = getattr(response, "_csp_update_ro", None)
replace = getattr(response, "_csp_replace_ro", None)
warnings.warn("deprecated in favor of get_policy_parts", DeprecationWarning)
policy_parts_ro = self.get_policy_parts(request=request, response=response, report_only=True)
return build_policy(**asdict(policy_parts_ro), report_only=True)

def get_policy_parts(self, request: HttpRequest, response: HttpResponseBase, report_only: bool = False) -> PolicyParts:
if report_only:
config = getattr(response, "_csp_config_ro", None)
update = getattr(response, "_csp_update_ro", None)
replace = getattr(response, "_csp_replace_ro", None)
else:
config = getattr(response, "_csp_config", None)
update = getattr(response, "_csp_update", None)
replace = getattr(response, "_csp_replace", None)

nonce = getattr(request, "_csp_nonce", None)
return build_policy(config=config, update=update, replace=replace, nonce=nonce, report_only=True)

return PolicyParts(config, update, replace, nonce)
10 changes: 5 additions & 5 deletions csp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@
"block-all-mixed-content": None, # Deprecated.
}

_DIRECTIVES = Dict[str, Any]
DIRECTIVES_T = Dict[str, Any]


def default_config(csp: _DIRECTIVES | None) -> _DIRECTIVES | None:
def default_config(csp: DIRECTIVES_T | None) -> DIRECTIVES_T | None:
if csp is None:
return None
# Make a copy of the passed in config to avoid mutating it, and also to drop any unknown keys.
Expand All @@ -66,9 +66,9 @@ def default_config(csp: _DIRECTIVES | None) -> _DIRECTIVES | None:


def build_policy(
config: _DIRECTIVES | None = None,
update: _DIRECTIVES | None = None,
replace: _DIRECTIVES | None = None,
config: DIRECTIVES_T | None = None,
update: DIRECTIVES_T | None = None,
replace: DIRECTIVES_T | None = None,
nonce: str | None = None,
report_only: bool = False,
) -> str:
Expand Down
65 changes: 65 additions & 0 deletions docs/migration-guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,71 @@ decorator now requires parentheses when used with and without arguments. For exa
Look for uses of the following decorators in your code: ``@csp``, ``@csp_update``, ``@csp_replace``,
and ``@csp_exempt``.

Migrating Custom Middleware
===========================
The `CSPMiddleware` has changed in order to support easier extension via subclassing.

The `CSPMiddleware.build_policy` and `CSPMiddleware.build_policy_ro` methods have been deprecated
in 4.0 and replaced with a new method `CSPMiddleware.build_policy_parts`.

.. note::
The deprecated methods will be removed in 4.1.

Unlike the old methods, which returned the built CSP policy header string, `build_policy_parts`
returns a dataclass that can be modified and updated before the policy is built. This allows
custom middleware to modify the policy whilst inheriting behaviour from the base classes.

An existing custom middleware, such as this:

.. code-block:: python
from django.http import HttpRequest, HttpResponseBase
from csp.middleware import CSPMiddleware, PolicyParts
class ACustomMiddleware(CSPMiddleware):
def build_policy(self, request: HttpRequest, response: HttpResponseBase) -> str:
config = getattr(response, "_csp_config", None)
update = getattr(response, "_csp_update", None)
replace = getattr(response, "_csp_replace", {})
nonce = getattr(request, "_csp_nonce", None)
# ... do custom CSP policy logic ...
return build_policy(config=config, update=update, replace=replace, nonce=nonce)
def build_policy_ro(self, request: HttpRequest, response: HttpResponseBase) -> str:
config = getattr(response, "_csp_config_ro", None)
update = getattr(response, "_csp_update_ro", None)
replace = getattr(response, "_csp_replace_ro", {})
nonce = getattr(request, "_csp_nonce", None)
# ... do custom CSP report only policy logic ...
return build_policy(config=config, update=update, replace=replace, nonce=nonce)
can be replaced with this:

.. code-block:: python
from django.http import HttpRequest, HttpResponseBase
from csp.middleware import CSPMiddleware, PolicyParts
class ACustomMiddleware(CSPMiddleware):
def get_policy_parts(self, request: HttpRequest, response: HttpResponseBase, report_only: bool = False) -> PolicyParts:
policy_parts = super().get_policy_parts(request, response, report_only)
if report_only:
# ... do custom CSP report only policy logic ...
else:
# ... do custom CSP policy logic ...
return policy_parts
Conclusion
==========

Expand Down

0 comments on commit 20b8683

Please sign in to comment.